remove garbage. RPyC over UTCP full functional version

master
Роман Бородин 2019-04-12 13:48:08 +03:00
parent 328692bd20
commit ed57cbcb57
5 changed files with 97 additions and 726 deletions

View File

@ -1,237 +0,0 @@
# Based on rpyc (4.0.2)
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
import rpyc
from rpyc.utils.server import ThreadedServer, spawn
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
import socket
from mbedtls import tls
import datetime as dt
from mbedtls import hash as hashlib
from mbedtls import pk
from mbedtls import x509
from uuid import uuid4
from contextlib import suppress
import sys
from rpyc.lib import safe_import
zlib = safe_import("zlib")
import errno
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
def block(callback, *args, **kwargs):
while True:
with suppress(tls.WantReadError, tls.WantWriteError):
return callback(*args, **kwargs)
class DTLSCerts:
def __init__(self):
now = dt.datetime.utcnow()
self.ca0_key = pk.RSA()
_ = self.ca0_key.generate()
ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256())
self.ca0_crt = x509.CRT.selfsign(
ca0_csr, self.ca0_key,
not_before=now, not_after=now + dt.timedelta(days=3650),
serial_number=0x123456,
basic_constraints=x509.BasicConstraints(True, 1))
self.ca1_key = pk.ECC()
_ = self.ca1_key.generate()
ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256())
self.ca1_crt = self.ca0_crt.sign(
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
self.srv_crt, self.srv_key = self.server_cert()
def server_cert(self):
now = dt.datetime.utcnow()
ee0_key = pk.ECC()
_ = ee0_key.generate()
ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256())
ee0_crt = self.ca1_crt.sign(
ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654)
return ee0_crt, ee0_key
dtls_certs = DTLSCerts()
trust_store = tls.TrustStore()
trust_store.add(dtls_certs.ca0_crt)
srv_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
validate_certificates=False,
)
cli_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
)
MAX_IO_CHUNK = 8192
class DTLSSocketStream(SocketStream):
MAX_IO_CHUNK = MAX_IO_CHUNK
@classmethod
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
if kwargs.pop('ipv6', False):
family = socket.AF_INET6
else:
family = socket.AF_INET
#tls._enable_debug_output(cli_ctx_conf)
#tls._set_debug_level(10)
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
dtls_cli = dtls_cli_ctx.wrap_socket(
socket.socket(family, socket.SOCK_DGRAM),
server_hostname=None,
)
dtls_cli.settimeout(timeout)
dtls_cli.connect((host, port))
block(dtls_cli.do_handshake)
return cls(dtls_cli)
def read(self, count):
while True:
try:
buf = block(self.sock.recv, self.MAX_IO_CHUNK)
except socket.timeout:
continue
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in retry_errnos:
# windows just has to be a bitch
# inpos: I agree
continue
self.close()
raise EOFError(ex)
else:
break
if not buf:
self.close()
raise EOFError("connection closed by peer")
return buf
def write(self, data):
try:
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
except socket.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
class DTLSChannel(Channel):
MAX_IO_CHUNK = MAX_IO_CHUNK
def recv(self):
header = self.stream.read(self.MAX_IO_CHUNK)
length, compressed = self.FRAME_HEADER.unpack(header)
length += len(self.FLUSHER)
data = b''
while length:
dat = self.stream.read(self.MAX_IO_CHUNK)
data += dat
length -= len(dat)
data = data[:-len(self.FLUSHER)]
if compressed:
data = zlib.decompress(data)
return data
def send(self, data):
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
compressed = 1
data = zlib.compress(data, self.COMPRESSION_LEVEL)
else:
compressed = 0
header = self.FRAME_HEADER.pack(len(data), compressed)
self.stream.write(header)
data = data + self.FLUSHER
data = [data[i:i + self.MAX_IO_CHUNK] for i in range(0, len(data), self.MAX_IO_CHUNK)]
for chunk in data:
self.stream.write(chunk)
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(DTLSChannel(stream), service=service, config=config)
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
cert_reqs=None, ssl_version=None, ciphers=None,
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
ssl_kwargs = {'server_side' : False}
if ciphers is not None:
ssl_kwargs['ciphers'] = ciphers
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
return connect_stream(s, service, config)
class DTLSThreadedServer(ThreadedServer):
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close()
self.host = hostname
self.port = port
#tls._enable_debug_output(srv_ctx_conf)
#tls._set_debug_level(10)
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
if reuse_addr and sys.platform != 'win32':
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dtls_srv.bind((hostname, port))
# dtls_srv.settimeout(listener_timeout)
self.listener = dtls_srv
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)
def _listen(self):
if self.active:
return
#self.listener.listen(self.backlog) #####################
if not self.port:
self.port = self.listener.getsockname()[1]
self.logger.info('server started on [%s]:%s', self.host, self.port)
self.active = True
def accept(self):
while self.active:
try:
print('Accepting new connections!')
sock, addrinfo = self.listener.accept()
except socket.timeout:
pass
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
pass
else:
raise EOFError()
raise
else:
break
if not self.active:
return
sock.setblocking(True)
addr = sock.getpeername()
sock.setcookieparam(addr[0].encode())
with suppress(tls.HelloVerifyRequest):
block(sock.do_handshake)
sock2, addr = sock.accept()
sock.close()
sock2.setblocking(True)
sock2.setcookieparam(addr[0].encode())
block(sock2.do_handshake)
self.logger.info("accepted %s with fd %s", addrinfo, sock2.fileno())
print("accepted %s with fd %s" % (addrinfo, sock2.fileno()))
self.clients.add(sock2)
self._accept_method(sock2)

View File

@ -1,15 +1,10 @@
import rpyc import rpyc
from rpyc.utils.server import ThreadedServer from rpyc.utils.server import ThreadedServer
from rpyc.core import SocketStream, Channel from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos from rpyc.utils.factory import connect_stream
from rpyc.utils.factory import connect_channel
from rpyc.lib import Timeout from rpyc.lib import Timeout
from rpyc.lib.compat import select_error
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
from . import utcp from . import utcp
import sys import sys
import errno
import socket
class UTCPSocketStream(SocketStream): class UTCPSocketStream(SocketStream):
MAX_IO_CHUNK = utcp.DATA_LENGTH MAX_IO_CHUNK = utcp.DATA_LENGTH
@ -20,24 +15,22 @@ class UTCPSocketStream(SocketStream):
return cls(sock) return cls(sock)
def poll(self, timeout): def poll(self, timeout):
timeout = Timeout(timeout) timeout = Timeout(timeout)
return self.sock.poll(timeout.timeleft())
def read(self, count):
try: try:
while True: return self.sock.recv(count)
try: except EOFError:
rl = self.sock.poll(timeout.timeleft()) self.close()
except select_error: raise EOFError
ex = sys.exc_info()[1] def write(self, data):
if ex.args[0] == errno.EINTR: try:
continue self.sock.send(data)
else: except EOFError:
raise
else:
break
except ValueError:
ex = sys.exc_info()[1] ex = sys.exc_info()[1]
raise select_error(str(ex)) self.close()
return rl raise EOFError(ex)
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(Channel(stream), service=service, config=config)
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw): def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
s = UTCPSocketStream.utcp_connect(host, port, **kw) s = UTCPSocketStream.utcp_connect(host, port, **kw)
return connect_stream(s, service, config) return connect_stream(s, service, config)
@ -71,4 +64,4 @@ class UTCPThreadedServer(ThreadedServer):
self._handle_connection(conn) self._handle_connection(conn)
finally: finally:
self.logger.info("goodbye %s", addrinfo) self.logger.info("goodbye %s", addrinfo)

View File

@ -1,356 +0,0 @@
import socketserver
import pickle
import simplecrypto
from uuid import uuid4
import datetime, io
import threading
import time
from struct import Struct
try:
import zlib
except:
zlib = None
peers = {}
PORT = 16386
RETRANSMIT_RETRIES = 3
DATAGRAM_MAX_SIZE = 9000
RAW_DATA_MAX_SIZE = 8000
PACKET_NUM_SEQ_TTL = 300
SOCK_SEND_TIMEOUT = 60
PACKET_TYPE_HELLO = 0x00
PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01
PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02
PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03
PACKET_TYPE_PACKET = 0xa0
PACKET_TYPE_CONFIRM_RECV = 0xa1
PACKET_TYPE_GOODBUY = 0xff
class InvalidPacket(Exception): pass
class OldPacket(Exception): pass
def pickle_data(data):
return pickle.dumps(data, protocol=4)
################
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow datetime
if module == "datetime" and name == 'datetime':
return getattr(datetime, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_pickle_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
################
# From rpyc.lib
class Timeout:
def __init__(self, timeout):
if isinstance(timeout, Timeout):
self.finite = timeout.finite
self.tmax = timeout.tmax
else:
self.finite = timeout is not None and timeout >= 0
self.tmax = time.time()+timeout if self.finite else None
def expired(self):
return self.finite and time.time() >= self.tmax
def timeleft(self):
return max((0, self.tmax - time.time())) if self.finite else None
def sleep(self, interval):
time.sleep(min(interval, self.timeleft()) if self.finite else interval)
class Packet:
def __init__(self, packet_payload):
try:
d = restricted_pickle_loads(packet_payload)
self.sid = d['sid']
self.type = d['type']
self.reset_timestamp = d['reset_timestamp']
self.num = d['num']
self.data = d['data']
except:
raise InvalidPacket
class Peer:
def __init__(self, sock, endpoint):
self.sid = None
self.sock = sock
self.endpoint = endpoint
self.my_key = None
self.peer_pub_key = None
self.buf = []
self.confirm_wait_packet = None
self.last_packet = None
self.request_lock = threading.Lock()
self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL)
self.last_sent_packet_num = -1
self.last_sent_packet_num_reset_time = datetime.datetime.utcnow()
self.last_received_packet_num = -1
self.last_received_packet_num_reset_time = None
self.retransmit_count = 0
def next_packet_num(self):
new_time = datetime.datetime.utcnow()
if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl:
self.last_sent_packet_num = -1
self.last_sent_packet_num += 1
return self.last_sent_packet_num
def poll(self):
return bool(len(self.buf))
def get_next_block(self):
if not len(self.buf):
return None
return self.buf.pop()
def put_block(self, data):
self.buf.insert(0, data)
def send(self, d, encrypted=False, confirm=False):
if 'sid' not in d: d['sid'] = self.sid
if 'num' not in d: d['num'] = None
if 'reset_timestamp' not in d: d['reset_timestamp'] = None
if 'data' not in d: d['data'] = b''
if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data'])
data = pickle_data(d)
if confirm:
self.last_packet = data
self.confirm_wait_packet = (d['reset_timestamp'], d['num'])
self.sock.sendto(data, self.endpoint)
def mark_packet(self, d):
d['num'] = self.next_packet_num()
d['reset_timestamp'] = self.last_sent_packet_num_reset_time
return d
def retransmit(self):
self.retransmit_count += 1
if self.retransmit_count > RETRANSMIT_RETRIES:
raise EOFError('retransmit limit reached')
self.sock.sendto(self.last_packet, self.endpoint)
def reply_my_pub_key(self, packet):
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data)
except:
raise EOFError('invalid pubkey data')
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REPLY,
'data': self.my_key.publickey.serialize()
}
self.send(d, encrypted=True)
def request_peer_bub_key(self, packet):
self.sid = packet.sid
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST,
'data': self.my_key.publickey.serialize()
}
self.send(d)
def confirm_packet_recv(self, packet):
self.confirm_wait_packet = None
self.last_packet = None
d = {
'type': PACKET_TYPE_CONFIRM_RECV,
'num': packet.num,
'reset_timestamp': packet.reset_timestamp
}
self.send(d)
def check_received_packet(self, packet):
if self.last_received_packet_num_reset_time:
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
raise OldPacket('packet from past')
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
self.last_received_packet_num_reset_time = packet.reset_timestamp
if (self.last_received_packet_num + 1) != packet.num:
raise EOFError('packet sequence corrupt')
else:
self.last_received_packet_num_reset_time = packet.reset_timestamp
self.last_received_packet_num = packet.num
def send_recv_confirmation(self, packet):
pass
def hello(self):
self.sid = uuid4().hex
d = {
'type': PACKET_TYPE_HELLO,
}
self.sock.sendto(pickle_data(d))
def recv_packet(self, packet_payload):
with self.request_lock:
try:
packet = Packet(packet_payload)
if packet.type == PACKET_TYPE_GOODBUY:
raise EOFError('connection closed')
except:
raise EOFError('invalid packet')
if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid):
self.hello()
return
############################################
if not self.peer_pub_key:
if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY:
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data))
return
except:
raise EOFError('create pubkey failed')
elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST:
self.reply_my_pub_key(packet)
return
elif packet.type == PACKET_TYPE_HELLO:
self.request_peer_bub_key(packet)
return
############################################
if self.confirm_wait_packet:
if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV:
self.confirm_packet_recv(packet)
return
else:
self.retransmit()
return
############################################
else:
if packet.type == PACKET_TYPE_PACKET:
try:
self.check_received_packet(packet)
except OldPacket:
return
try:
raw = self.my_key.decrypt_raw(packet.data)
except:
raise EOFError('decrypt packet error')
self.put_block(raw)
else:
raise EOFError('connection lost')
def send_packet(self, raw):
if self.confirm_wait_packet:
timeout = Timeout(SOCK_SEND_TIMEOUT)
while timeout.timeleft():
if not self.confirm_wait_packet: break
if self.confirm_wait_packet:
raise EOFError('connection lost')
d = {
'type': PACKET_TYPE_PACKET,
'data': raw
}
self.send(self.mark_packet(d), encrypted=True, confirm=True)
class UDPRequestHandler(socketserver.DatagramRequestHandler):
def finish(self):
'''Don't send anything'''
pass
def handle(self):
datagram = self.rfile.read(DATAGRAM_MAX_SIZE)
peer_addr = self.client_address
if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr)
try:
peers[peer_addr].recv_packet(datagram)
except EOFError:
del peers[peer_addr]
class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer):
pass
udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler)
udpserver_thread = threading.Thread(target=udpserver.serve_forever)
udpserver_thread.start()
class EncryptedUDPStream:
def __init__(self, sock, peer_addr):
self.peer_addr = peer_addr
self.sock = sock
@classmethod
def _connect(cls, host, port):
peers[(host, port)] = Peer(udpserver.socket, (host, port))
peers[(host, port)].hello()
return udpserver.socket
@classmethod
def connect(cls, host, port, **kwargs):
return cls(cls._connect(host, port), (host, port))
def poll(self, timeout):
timeout = Timeout(timeout)
while timeout.timeleft():
try:
rl = peers[self.peer_addr].poll()
if rl: break
except:
raise EOFError
return rl
def close(self):
if self.peer_addr in peers: del peers[self.peer_addr]
@property
def closed(self):
return self.peer_addr not in peers
def fileno(self):
try:
return self.sock.fileno()
except:
self.close()
raise EOFError
def read(self):
try:
buf = peers[self.peer_addr].get_next_block()
except:
raise EOFError
return buf
def write(self, data):
try:
peers[self.peer_addr].send_packet(data)
except:
raise EOFError
class Channel(object):
MAX_IO_CHUNK = 8000
COMPRESSION_THRESHOLD = 3000
COMPRESSION_LEVEL = 1
FRAME_HEADER = Struct("!LB")
FLUSHER = b'\n'
__slots__ = ["stream", "compress"]
def __init__(self, stream, compress = True):
self.stream = stream
if not zlib:
compress = False
self.compress = compress
def close(self):
self.stream.close()
@property
def closed(self):
return self.stream.closed
def fileno(self):
return self.stream.fileno()
def poll(self, timeout):
return self.stream.poll(timeout)
def recv(self):
header = self.stream.read()
if len(header) != self.FRAME_HEADER.size:
raise EOFError('CHANNEL: Not a header received')
length, compressed = self.FRAME_HEADER.unpack(header)
block_len = length + len(self.FLUSHER)
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
if len(full_block) != block_len:
raise EOFError('CHANNEL: Received block with wrong size')
data = full_block[:-len(self.FLUSHER)]
if compressed:
data = zlib.decompress(data)
return data
def send(self, data):
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
compressed = 1
data = zlib.compress(data, self.COMPRESSION_LEVEL)
else:
compressed = 0
header = self.FRAME_HEADER.pack(len(data), compressed)
self.stream.write(header)
buf = data + self.FLUSHER
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
import rpyc.utils.server
import rpyc.utils.factory
import rpyc.Service

View File

@ -1,78 +0,0 @@
import socketserver
import pickle
import simplecrypto
from uuid import uuid4
BUFSIZE = 8192
BLOCKSIZE = 4096
PACKET_TYPE_RECV_RESULT = 'recv_result'
PACKET_TYPE_DATA_FRAGMENT = 'data_fragment'
PACKET_TYPE_SVC_MESSAGE = 'svc_message'
PACKET_TYPE_NEW_CONNECTION = 'new_connection'
DATA_TYPE_FILE_CHUNK = 'file_chunk'
DATA_TYPE_CMD = 'cmd'
DATA_TYPE_SEARCH_QUERY = 'search_query'
DATA_TYPE_SEARCH_RESULT = 'search_result'
DATA_TYPE_PEER_PUBKEY = 'peer_pubkey'
SVC_MESSAGE_BAD_PACKET = 'bad_packet'
SVC_MESSAGE_DECRYPT_ERROR = 'decrypt_error'
SVC_MESSAGE_YOU_ARE_STRANGER = 'you_are_stranger'
RECV_OK_CODE = 0
RECV_ERROR_CODE = 1
HEADER_SVC_YOU_ARE_STRANGER = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': SVC_MESSAGE_YOU_ARE_STRANGER}
HEADER_RECV_OK = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_OK_CODE}
HEADER_RECV_ERROR = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_ERROR_CODE}
key = None
peers = {}
class Peer:
def __init__(self, peer_addr, pubkey):
self.addr = peer_addr
self.pubkey = pubkey
def get_key():
global key
if not key:
key = simplecrypto.RsaKeypair()
return key
def pickle_data(data):
return pickle.dumps(data, protocol=4)
def write_svc_msg(fobj, svc_msg, extra_data={}):
MSG = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': svc_msg}
if extra_data:
MSG.update(extra_data)
fobj.write(pickle_data(MSG))
class UDPRequestHandler(socketserver.DatagramRequestHandler):
def handle(self):
datagram = self.rfile.read(BUFSIZE)
if self.client_address not in peers:
write_svc_msg(self.wfile, SVC_MESSAGE_YOU_ARE_STRANGER)
return
try:
unpickled_datagram = pickle.dumps(datagram)
except pickle.UnpicklingError:
write_svc_msg(self.wfile, SVC_MESSAGE_BAD_PACKET)
return
pk = get_key()
try:
data = pickle.loads(pk.decrypt_raw(unpickled_datagram['packet']))
except:
write_svc_msg(self.wfile, SVC_MESSAGE_DECRYPT_ERROR, {'packet_id': unpickled_datagram['packet_id']})
return
block_id = data['block_id']
fragment_num = data['num']
cmd = data['']

View File

@ -10,6 +10,8 @@ from struct import Struct
import uuid import uuid
import bisect import bisect
from datetime import datetime
DATA_DIVIDE_LENGTH = 8000 DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info PACKET_HEADER_SIZE = 512 # Pickle service info
DATA_LENGTH = DATA_DIVIDE_LENGTH DATA_LENGTH = DATA_DIVIDE_LENGTH
@ -19,6 +21,7 @@ class Connection:
SMALLEST_STARTING_SEQ = 0 SMALLEST_STARTING_SEQ = 0
HIGHEST_STARTING_SEQ = 4294967295 HIGHEST_STARTING_SEQ = 4294967295
def __init__(self, remote, encrypted=False): def __init__(self, remote, encrypted=False):
self.fileno = 0
self.peer_addr = remote self.peer_addr = remote
self.seq = Connection.gen_starting_seq_num() self.seq = Connection.gen_starting_seq_num()
self.recv_seq = -1 self.recv_seq = -1
@ -153,6 +156,10 @@ class ConnectedSOCK(object):
if self.closed: if self.closed:
raise EOFError raise EOFError
return self.low_sock.send(data, self.client_addr) return self.low_sock.send(data, self.client_addr)
def sendall(self, data):
if self.closed:
raise EOFError
self.low_sock.sendall(data, self.client_addr)
def recv(self, size): def recv(self, size):
if self.closed: if self.closed:
raise EOFError raise EOFError
@ -166,17 +173,13 @@ class ConnectedSOCK(object):
def shutdown(self, *a, **kw): def shutdown(self, *a, **kw):
self.close() self.close()
def poll(self, timeout): def poll(self, timeout):
if not self.closed: return self.low_sock.poll(timeout, self.client_addr)
conn = self.low_sock.connections[self.client_addr] def packets_arrived(self, packet_type):
with conn.recv_lock: return self.low_sock.packets_arrived(packet_type, self.client_addr)
has_data = len(conn.packet_buffer['DATA']) def fileno(self):
if has_data: if self.closed:
return True raise EOFError
else: return self.low_sock.fileno(self.client_addr)
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
return len(conn.packet_buffer['DATA'])
return False
class UTCP(object): class UTCP(object):
host = None host = None
@ -195,18 +198,40 @@ class UTCP(object):
self.connections = {} self.connections = {}
self.connection_queue = [] self.connection_queue = []
self.syn_received = {} self.syn_received = {}
def poll(self, timeout): self.fileno_seq = 40000000
if len(self.connections): def next_fileno(self):
connection = list(self.connections.keys())[0] self.fileno_seq += 1
return self.fileno_seq
def packets_arrived(self, packet_type, connection=None):
try:
conn = self.connections[connection] conn = self.connections[connection]
with conn.recv_lock: except:
has_data = bool(len(conn.packet_buffer['DATA'])) raise EOFError
with conn.recv_lock:
return bool(len(conn.packet_buffer[packet_type]))
def poll(self, timeout, connection=None):
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
if not self.closed:
has_data = self.packets_arrived('DATA', connection)
if has_data: if has_data:
return True return True
else: else:
self.incoming_packet_event.wait(timeout) if not timeout:
with conn.recv_lock: timeout = 0.5
return bool(len(conn.packet_buffer['DATA'])) while True and not self.closed:
self.incoming_packet_event.wait(timeout)
has_data = self.packets_arrived('DATA', connection)
if not has_data:
continue
else:
return has_data
else:
self.incoming_packet_event.wait(timeout)
return self.packets_arrived('DATA', connection)
return False return False
def get_free_port(self): def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -256,6 +281,8 @@ class UTCP(object):
return len(data) return len(data)
except socket.error as error: except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def sendall(self, data, connection=None):
_ = self.send(data, connection)
def recv(self, size, connection=None): def recv(self, size, connection=None):
if self.closed: if self.closed:
raise EOFError raise EOFError
@ -302,6 +329,7 @@ class UTCP(object):
with self.queue_lock: with self.queue_lock:
if len(self.connection_queue) < max_connections: if len(self.connection_queue) < max_connections:
conn = Connection(address, self.encrypted) conn = Connection(address, self.encrypted)
conn.fileno = self.next_fileno()
if self.encrypted: if self.encrypted:
try: try:
conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey) conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey)
@ -315,6 +343,8 @@ class UTCP(object):
self.own_socket.sendto(b'Connections full', address) self.own_socket.sendto(b'Connections full', address)
except KeyError: except KeyError:
continue continue
except TypeError:
continue
except socket.error as error: except socket.error as error:
raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error)) raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error))
@ -330,6 +360,11 @@ class UTCP(object):
def stop(self): def stop(self):
self.own_socket.close() self.own_socket.close()
self.status = 0 self.status = 0
def shutdown(self, *a, **kw):
self.close()
self.status = 0
self.connections = {}
self.stop()
def accept(self): def accept(self):
while self.status: while self.status:
try: try:
@ -354,6 +389,8 @@ class UTCP(object):
raise EOFError('Something went wrong in accept func: ' + str(error)) raise EOFError('Something went wrong in accept func: ' + str(error))
def connect(self, server_address=('127.0.0.1', 10000)): def connect(self, server_address=('127.0.0.1', 10000)):
if server_address in self.connections:
raise EOFError('Already connected to peer')
try: try:
self.bind(('', self.get_free_port())) self.bind(('', self.get_free_port()))
self.status = 1 self.status = 1
@ -379,13 +416,19 @@ class UTCP(object):
ack = Ack(answer.id) ack = Ack(answer.id)
self.__send_packet(server_address, ack) self.__send_packet(server_address, ack)
self.channel = UTCPChannel(self) self.channel = UTCPChannel(self)
conn.fileno = self.next_fileno()
except socket.error as error: except socket.error as error:
self.own_socket.close() self.own_socket.close()
self.connections = {} self.connections = {}
self.status = 0 self.status = 0
raise EOFError('The socket was closed. Error:' + str(error)) raise EOFError('The socket was closed. Error:' + str(error))
def fileno(self): def fileno(self, connection=None):
return self.own_socket.fileno() if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
return self.connections[connection].fileno
@property @property
def closed(self): def closed(self):
return not bool(len(self.connections)) return not bool(len(self.connections))
@ -397,7 +440,10 @@ class UTCP(object):
try: try:
if connection not in list(self.connections.keys()): if connection not in list(self.connections.keys()):
if connection is None: if connection is None:
connection = list(self.connections.keys())[0] if len(self.connections):
connection = list(self.connections.keys())[0]
else:
return
else: else:
raise EOFError('Connection not in connected devices') raise EOFError('Connection not in connected devices')
fin = Fin() fin = Fin()
@ -421,8 +467,7 @@ class UTCP(object):
self.__send_packet(connection, fin_ack) self.__send_packet(connection, fin_ack)
except: except:
pass pass
with self.connection_lock: self.drop_connection(connection)
self.connections.pop(connection)
except Exception as error: except Exception as error:
raise EOFError('Something went wrong in disconnect func:%s ' % error) raise EOFError('Something went wrong in disconnect func:%s ' % error)
@ -445,14 +490,17 @@ class UTCP(object):
if address[0] == 'Any': if address[0] == 'Any':
order = self.syn_received.popitem() # to reverse the tuple received order = self.syn_received.popitem() # to reverse the tuple received
return order[1], order[0] return order[1], order[0]
conn = self.connections[address] try:
conn = self.connections[address]
except:
break
if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']:
tries += 1 tries += 1
if condition == 'DATA': if condition == 'DATA':
if len(conn.packet_buffer[condition]): if self.poll(0.1, address):
data = b'' data = b''
while size: while size:
if not self.poll(0.5): if not self.poll(0.1, address):
continue continue
with conn.recv_lock: with conn.recv_lock:
packet = conn.packet_buffer[condition][0] packet = conn.packet_buffer[condition][0]
@ -470,11 +518,10 @@ class UTCP(object):
else: else:
raise KeyError raise KeyError
else: else:
with conn.recv_lock: if self.packets_arrived(condition, address):
if len(conn.packet_buffer[condition]): packet = conn.packet_buffer[condition].pop()
packet = conn.packet_buffer[condition].pop() else:
else: raise KeyError
raise KeyError
if want_id and packet.id != want_id: if want_id and packet.id != want_id:
raise KeyError raise KeyError
return packet return packet
@ -513,6 +560,8 @@ class UTCP(object):
if isinstance(packet, Fin): if isinstance(packet, Fin):
self.disconnect(address, packet.id) self.disconnect(address, packet.id)
elif isinstance(packet, Syn): elif isinstance(packet, Syn):
if address in self.connections:
return
if packet.id not in map(lambda x: x.id, self.syn_received.values()): if packet.id not in map(lambda x: x.id, self.syn_received.values()):
self.syn_received[address] = packet self.syn_received[address] = packet
else: else: