diff --git a/mods/rpyc_dtls.py b/mods/rpyc_dtls.py deleted file mode 100644 index 9dbf606..0000000 --- a/mods/rpyc_dtls.py +++ /dev/null @@ -1,237 +0,0 @@ -# Based on rpyc (4.0.2) - -# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!! - -import rpyc -from rpyc.utils.server import ThreadedServer, spawn -from rpyc.core import SocketStream, Channel -from rpyc.core.stream import retry_errnos -from rpyc.utils.factory import connect_channel -import socket -from mbedtls import tls -import datetime as dt -from mbedtls import hash as hashlib -from mbedtls import pk -from mbedtls import x509 -from uuid import uuid4 -from contextlib import suppress -import sys -from rpyc.lib import safe_import -zlib = safe_import("zlib") -import errno -from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL - -def block(callback, *args, **kwargs): - while True: - with suppress(tls.WantReadError, tls.WantWriteError): - return callback(*args, **kwargs) - -class DTLSCerts: - def __init__(self): - now = dt.datetime.utcnow() - self.ca0_key = pk.RSA() - _ = self.ca0_key.generate() - ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256()) - self.ca0_crt = x509.CRT.selfsign( - ca0_csr, self.ca0_key, - not_before=now, not_after=now + dt.timedelta(days=3650), - serial_number=0x123456, - basic_constraints=x509.BasicConstraints(True, 1)) - self.ca1_key = pk.ECC() - _ = self.ca1_key.generate() - ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256()) - self.ca1_crt = self.ca0_crt.sign( - ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, - basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) - self.srv_crt, self.srv_key = self.server_cert() - def server_cert(self): - now = dt.datetime.utcnow() - ee0_key = pk.ECC() - _ = ee0_key.generate() - ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256()) - ee0_crt = self.ca1_crt.sign( - ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654) - return ee0_crt, ee0_key - -dtls_certs = DTLSCerts() -trust_store = tls.TrustStore() -trust_store.add(dtls_certs.ca0_crt) - -srv_ctx_conf = tls.DTLSConfiguration( - trust_store=trust_store, - certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key), - validate_certificates=False, - ) -cli_ctx_conf = tls.DTLSConfiguration( - trust_store=trust_store, - validate_certificates=False, - ) -MAX_IO_CHUNK = 8192 - -class DTLSSocketStream(SocketStream): - MAX_IO_CHUNK = MAX_IO_CHUNK - @classmethod - def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): - if kwargs.pop('ipv6', False): - family = socket.AF_INET6 - else: - family = socket.AF_INET - #tls._enable_debug_output(cli_ctx_conf) - #tls._set_debug_level(10) - dtls_cli_ctx = tls.ClientContext(cli_ctx_conf) - dtls_cli = dtls_cli_ctx.wrap_socket( - socket.socket(family, socket.SOCK_DGRAM), - server_hostname=None, - ) - dtls_cli.settimeout(timeout) - dtls_cli.connect((host, port)) - block(dtls_cli.do_handshake) - return cls(dtls_cli) - def read(self, count): - while True: - try: - buf = block(self.sock.recv, self.MAX_IO_CHUNK) - except socket.timeout: - continue - except socket.error: - ex = sys.exc_info()[1] - if get_exc_errno(ex) in retry_errnos: - # windows just has to be a bitch - # inpos: I agree - continue - self.close() - raise EOFError(ex) - else: - break - if not buf: - self.close() - raise EOFError("connection closed by peer") - return buf - def write(self, data): - try: - _ = block(self.sock.send, data[:self.MAX_IO_CHUNK]) - except socket.error: - ex = sys.exc_info()[1] - self.close() - raise EOFError(ex) - -class DTLSChannel(Channel): - MAX_IO_CHUNK = MAX_IO_CHUNK - def recv(self): - header = self.stream.read(self.MAX_IO_CHUNK) - length, compressed = self.FRAME_HEADER.unpack(header) - length += len(self.FLUSHER) - data = b'' - while length: - dat = self.stream.read(self.MAX_IO_CHUNK) - data += dat - length -= len(dat) - data = data[:-len(self.FLUSHER)] - if compressed: - data = zlib.decompress(data) - return data - def send(self, data): - if self.compress and len(data) > self.COMPRESSION_THRESHOLD: - compressed = 1 - data = zlib.compress(data, self.COMPRESSION_LEVEL) - else: - compressed = 0 - header = self.FRAME_HEADER.pack(len(data), compressed) - self.stream.write(header) - data = data + self.FLUSHER - data = [data[i:i + self.MAX_IO_CHUNK] for i in range(0, len(data), self.MAX_IO_CHUNK)] - for chunk in data: - self.stream.write(chunk) - -def connect_stream(stream, service=rpyc.VoidService, config={}): - return connect_channel(DTLSChannel(stream), service=service, config=config) - -def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None, - cert_reqs=None, ssl_version=None, ciphers=None, - service=rpyc.VoidService, config={}, ipv6=False, keepalive=False): - ssl_kwargs = {'server_side' : False} - if ciphers is not None: - ssl_kwargs['ciphers'] = ciphers - s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive) - return connect_stream(s, service, config) - -class DTLSThreadedServer(ThreadedServer): - def __init__(self, service, hostname = "", ipv6 = False, port = 0, - backlog = 10, reuse_addr = True, authenticator = None, registrar = None, - auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5, - socket_path = None): - ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port, - backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar, - auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout, - socket_path=socket_path) - self.listener.close() - - self.host = hostname - self.port = port - #tls._enable_debug_output(srv_ctx_conf) - #tls._set_debug_level(10) - dtls_srv_ctx = tls.ServerContext(srv_ctx_conf) - dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) - if reuse_addr and sys.platform != 'win32': - dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - dtls_srv.bind((hostname, port)) - # dtls_srv.settimeout(listener_timeout) - self.listener = dtls_srv - sockname = self.listener.getsockname() - self.host, self.port = sockname[0], sockname[1] - def _serve_client(self, sock, credentials): - addrinfo = sock.getpeername() - if credentials: - self.logger.info("welcome %s (%r)", addrinfo, credentials) - else: - self.logger.info("welcome %s", addrinfo) - try: - config = dict(self.protocol_config, credentials = credentials, - endpoints = (sock.getsockname(), addrinfo), logger = self.logger) - conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config) - self._handle_connection(conn) - finally: - self.logger.info("goodbye %s", addrinfo) - def _listen(self): - if self.active: - return - #self.listener.listen(self.backlog) ##################### - if not self.port: - self.port = self.listener.getsockname()[1] - self.logger.info('server started on [%s]:%s', self.host, self.port) - self.active = True - def accept(self): - while self.active: - try: - print('Accepting new connections!') - sock, addrinfo = self.listener.accept() - except socket.timeout: - pass - except socket.error: - ex = sys.exc_info()[1] - if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN): - pass - else: - raise EOFError() - raise - else: - break - - if not self.active: - return - - sock.setblocking(True) - addr = sock.getpeername() - sock.setcookieparam(addr[0].encode()) - with suppress(tls.HelloVerifyRequest): - block(sock.do_handshake) - sock2, addr = sock.accept() - sock.close() - sock2.setblocking(True) - sock2.setcookieparam(addr[0].encode()) - block(sock2.do_handshake) - self.logger.info("accepted %s with fd %s", addrinfo, sock2.fileno()) - print("accepted %s with fd %s" % (addrinfo, sock2.fileno())) - self.clients.add(sock2) - self._accept_method(sock2) - diff --git a/mods/rpyc_utcp.py b/mods/rpyc_utcp.py index daf9725..6f2a524 100644 --- a/mods/rpyc_utcp.py +++ b/mods/rpyc_utcp.py @@ -1,15 +1,10 @@ import rpyc from rpyc.utils.server import ThreadedServer from rpyc.core import SocketStream, Channel -from rpyc.core.stream import retry_errnos -from rpyc.utils.factory import connect_channel +from rpyc.utils.factory import connect_stream from rpyc.lib import Timeout -from rpyc.lib.compat import select_error -from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL from . import utcp import sys -import errno -import socket class UTCPSocketStream(SocketStream): MAX_IO_CHUNK = utcp.DATA_LENGTH @@ -20,24 +15,22 @@ class UTCPSocketStream(SocketStream): return cls(sock) def poll(self, timeout): timeout = Timeout(timeout) + return self.sock.poll(timeout.timeleft()) + + def read(self, count): try: - while True: - try: - rl = self.sock.poll(timeout.timeleft()) - except select_error: - ex = sys.exc_info()[1] - if ex.args[0] == errno.EINTR: - continue - else: - raise - else: - break - except ValueError: + return self.sock.recv(count) + except EOFError: + self.close() + raise EOFError + def write(self, data): + try: + self.sock.send(data) + except EOFError: + ex = sys.exc_info()[1] - raise select_error(str(ex)) - return rl -def connect_stream(stream, service=rpyc.VoidService, config={}): - return connect_channel(Channel(stream), service=service, config=config) + self.close() + raise EOFError(ex) def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw): s = UTCPSocketStream.utcp_connect(host, port, **kw) return connect_stream(s, service, config) @@ -71,4 +64,4 @@ class UTCPThreadedServer(ThreadedServer): self._handle_connection(conn) finally: self.logger.info("goodbye %s", addrinfo) - \ No newline at end of file + diff --git a/mods/stream.py b/mods/stream.py deleted file mode 100644 index cafffc2..0000000 --- a/mods/stream.py +++ /dev/null @@ -1,356 +0,0 @@ -import socketserver -import pickle -import simplecrypto -from uuid import uuid4 -import datetime, io -import threading -import time -from struct import Struct -try: - import zlib -except: - zlib = None - -peers = {} - -PORT = 16386 - -RETRANSMIT_RETRIES = 3 -DATAGRAM_MAX_SIZE = 9000 -RAW_DATA_MAX_SIZE = 8000 -PACKET_NUM_SEQ_TTL = 300 - -SOCK_SEND_TIMEOUT = 60 - -PACKET_TYPE_HELLO = 0x00 -PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01 -PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02 -PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03 -PACKET_TYPE_PACKET = 0xa0 -PACKET_TYPE_CONFIRM_RECV = 0xa1 -PACKET_TYPE_GOODBUY = 0xff - -class InvalidPacket(Exception): pass -class OldPacket(Exception): pass - -def pickle_data(data): - return pickle.dumps(data, protocol=4) -################ -class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): - # Only allow datetime - if module == "datetime" and name == 'datetime': - return getattr(datetime, name) - # Forbid everything else. - raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) -def restricted_pickle_loads(s): - """Helper function analogous to pickle.loads().""" - return RestrictedUnpickler(io.BytesIO(s)).load() -################ - - -# From rpyc.lib -class Timeout: - def __init__(self, timeout): - if isinstance(timeout, Timeout): - self.finite = timeout.finite - self.tmax = timeout.tmax - else: - self.finite = timeout is not None and timeout >= 0 - self.tmax = time.time()+timeout if self.finite else None - def expired(self): - return self.finite and time.time() >= self.tmax - def timeleft(self): - return max((0, self.tmax - time.time())) if self.finite else None - def sleep(self, interval): - time.sleep(min(interval, self.timeleft()) if self.finite else interval) - -class Packet: - def __init__(self, packet_payload): - try: - d = restricted_pickle_loads(packet_payload) - self.sid = d['sid'] - self.type = d['type'] - self.reset_timestamp = d['reset_timestamp'] - self.num = d['num'] - self.data = d['data'] - except: - raise InvalidPacket - -class Peer: - def __init__(self, sock, endpoint): - self.sid = None - self.sock = sock - self.endpoint = endpoint - self.my_key = None - self.peer_pub_key = None - self.buf = [] - self.confirm_wait_packet = None - self.last_packet = None - self.request_lock = threading.Lock() - self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL) - self.last_sent_packet_num = -1 - self.last_sent_packet_num_reset_time = datetime.datetime.utcnow() - self.last_received_packet_num = -1 - self.last_received_packet_num_reset_time = None - self.retransmit_count = 0 - def next_packet_num(self): - new_time = datetime.datetime.utcnow() - if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl: - self.last_sent_packet_num = -1 - self.last_sent_packet_num += 1 - return self.last_sent_packet_num - def poll(self): - return bool(len(self.buf)) - def get_next_block(self): - if not len(self.buf): - return None - return self.buf.pop() - def put_block(self, data): - self.buf.insert(0, data) - def send(self, d, encrypted=False, confirm=False): - if 'sid' not in d: d['sid'] = self.sid - if 'num' not in d: d['num'] = None - if 'reset_timestamp' not in d: d['reset_timestamp'] = None - if 'data' not in d: d['data'] = b'' - if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data']) - data = pickle_data(d) - if confirm: - self.last_packet = data - self.confirm_wait_packet = (d['reset_timestamp'], d['num']) - self.sock.sendto(data, self.endpoint) - def mark_packet(self, d): - d['num'] = self.next_packet_num() - d['reset_timestamp'] = self.last_sent_packet_num_reset_time - return d - def retransmit(self): - self.retransmit_count += 1 - if self.retransmit_count > RETRANSMIT_RETRIES: - raise EOFError('retransmit limit reached') - self.sock.sendto(self.last_packet, self.endpoint) - def reply_my_pub_key(self, packet): - try: - self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data) - except: - raise EOFError('invalid pubkey data') - self.my_key = simplecrypto.RsaKeypair() - d = { - 'type': PACKET_TYPE_PEER_PUB_KEY_REPLY, - 'data': self.my_key.publickey.serialize() - } - self.send(d, encrypted=True) - def request_peer_bub_key(self, packet): - self.sid = packet.sid - self.my_key = simplecrypto.RsaKeypair() - d = { - 'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST, - 'data': self.my_key.publickey.serialize() - } - self.send(d) - def confirm_packet_recv(self, packet): - self.confirm_wait_packet = None - self.last_packet = None - d = { - 'type': PACKET_TYPE_CONFIRM_RECV, - 'num': packet.num, - 'reset_timestamp': packet.reset_timestamp - } - self.send(d) - def check_received_packet(self, packet): - if self.last_received_packet_num_reset_time: - if self.last_received_packet_num_reset_time > packet.reset_timestamp: - raise OldPacket('packet from past') - elif self.last_received_packet_num_reset_time < packet.reset_timestamp: - self.last_received_packet_num_reset_time = packet.reset_timestamp - if (self.last_received_packet_num + 1) != packet.num: - raise EOFError('packet sequence corrupt') - else: - self.last_received_packet_num_reset_time = packet.reset_timestamp - self.last_received_packet_num = packet.num - def send_recv_confirmation(self, packet): - pass - def hello(self): - self.sid = uuid4().hex - d = { - 'type': PACKET_TYPE_HELLO, - } - self.sock.sendto(pickle_data(d)) - def recv_packet(self, packet_payload): - with self.request_lock: - try: - packet = Packet(packet_payload) - if packet.type == PACKET_TYPE_GOODBUY: - raise EOFError('connection closed') - except: - raise EOFError('invalid packet') - if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid): - self.hello() - return - ############################################ - if not self.peer_pub_key: - if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY: - try: - self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data)) - return - except: - raise EOFError('create pubkey failed') - elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST: - self.reply_my_pub_key(packet) - return - elif packet.type == PACKET_TYPE_HELLO: - self.request_peer_bub_key(packet) - return - ############################################ - if self.confirm_wait_packet: - if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV: - self.confirm_packet_recv(packet) - return - else: - self.retransmit() - return - ############################################ - else: - if packet.type == PACKET_TYPE_PACKET: - try: - self.check_received_packet(packet) - except OldPacket: - return - try: - raw = self.my_key.decrypt_raw(packet.data) - except: - raise EOFError('decrypt packet error') - self.put_block(raw) - else: - raise EOFError('connection lost') - def send_packet(self, raw): - if self.confirm_wait_packet: - timeout = Timeout(SOCK_SEND_TIMEOUT) - while timeout.timeleft(): - if not self.confirm_wait_packet: break - if self.confirm_wait_packet: - raise EOFError('connection lost') - d = { - 'type': PACKET_TYPE_PACKET, - 'data': raw - } - self.send(self.mark_packet(d), encrypted=True, confirm=True) - -class UDPRequestHandler(socketserver.DatagramRequestHandler): - def finish(self): - '''Don't send anything''' - pass - def handle(self): - datagram = self.rfile.read(DATAGRAM_MAX_SIZE) - peer_addr = self.client_address - if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr) - try: - peers[peer_addr].recv_packet(datagram) - except EOFError: - del peers[peer_addr] - -class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer): - pass - -udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler) -udpserver_thread = threading.Thread(target=udpserver.serve_forever) -udpserver_thread.start() - -class EncryptedUDPStream: - def __init__(self, sock, peer_addr): - self.peer_addr = peer_addr - self.sock = sock - @classmethod - def _connect(cls, host, port): - peers[(host, port)] = Peer(udpserver.socket, (host, port)) - peers[(host, port)].hello() - return udpserver.socket - @classmethod - def connect(cls, host, port, **kwargs): - return cls(cls._connect(host, port), (host, port)) - def poll(self, timeout): - timeout = Timeout(timeout) - while timeout.timeleft(): - try: - rl = peers[self.peer_addr].poll() - if rl: break - except: - raise EOFError - return rl - def close(self): - if self.peer_addr in peers: del peers[self.peer_addr] - @property - def closed(self): - return self.peer_addr not in peers - def fileno(self): - try: - return self.sock.fileno() - except: - self.close() - raise EOFError - def read(self): - try: - buf = peers[self.peer_addr].get_next_block() - except: - raise EOFError - return buf - def write(self, data): - try: - peers[self.peer_addr].send_packet(data) - except: - raise EOFError - -class Channel(object): - MAX_IO_CHUNK = 8000 - COMPRESSION_THRESHOLD = 3000 - COMPRESSION_LEVEL = 1 - FRAME_HEADER = Struct("!LB") - FLUSHER = b'\n' - __slots__ = ["stream", "compress"] - - def __init__(self, stream, compress = True): - self.stream = stream - if not zlib: - compress = False - self.compress = compress - def close(self): - self.stream.close() - - @property - def closed(self): - return self.stream.closed - def fileno(self): - return self.stream.fileno() - - def poll(self, timeout): - return self.stream.poll(timeout) - - def recv(self): - header = self.stream.read() - if len(header) != self.FRAME_HEADER.size: - raise EOFError('CHANNEL: Not a header received') - length, compressed = self.FRAME_HEADER.unpack(header) - block_len = length + len(self.FLUSHER) - full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK))) - if len(full_block) != block_len: - raise EOFError('CHANNEL: Received block with wrong size') - data = full_block[:-len(self.FLUSHER)] - if compressed: - data = zlib.decompress(data) - return data - - def send(self, data): - if self.compress and len(data) > self.COMPRESSION_THRESHOLD: - compressed = 1 - data = zlib.compress(data, self.COMPRESSION_LEVEL) - else: - compressed = 0 - header = self.FRAME_HEADER.pack(len(data), compressed) - self.stream.write(header) - buf = data + self.FLUSHER - for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK): - self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK]) - -import rpyc.utils.server -import rpyc.utils.factory -import rpyc.Service diff --git a/mods/udp_srv.py b/mods/udp_srv.py deleted file mode 100644 index edf7db4..0000000 --- a/mods/udp_srv.py +++ /dev/null @@ -1,78 +0,0 @@ -import socketserver -import pickle -import simplecrypto -from uuid import uuid4 - -BUFSIZE = 8192 -BLOCKSIZE = 4096 - -PACKET_TYPE_RECV_RESULT = 'recv_result' -PACKET_TYPE_DATA_FRAGMENT = 'data_fragment' -PACKET_TYPE_SVC_MESSAGE = 'svc_message' -PACKET_TYPE_NEW_CONNECTION = 'new_connection' - -DATA_TYPE_FILE_CHUNK = 'file_chunk' -DATA_TYPE_CMD = 'cmd' -DATA_TYPE_SEARCH_QUERY = 'search_query' -DATA_TYPE_SEARCH_RESULT = 'search_result' -DATA_TYPE_PEER_PUBKEY = 'peer_pubkey' - -SVC_MESSAGE_BAD_PACKET = 'bad_packet' -SVC_MESSAGE_DECRYPT_ERROR = 'decrypt_error' -SVC_MESSAGE_YOU_ARE_STRANGER = 'you_are_stranger' - -RECV_OK_CODE = 0 -RECV_ERROR_CODE = 1 - -HEADER_SVC_YOU_ARE_STRANGER = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': SVC_MESSAGE_YOU_ARE_STRANGER} - -HEADER_RECV_OK = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_OK_CODE} -HEADER_RECV_ERROR = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_ERROR_CODE} - -key = None - -peers = {} - -class Peer: - def __init__(self, peer_addr, pubkey): - self.addr = peer_addr - self.pubkey = pubkey - -def get_key(): - global key - if not key: - key = simplecrypto.RsaKeypair() - return key - -def pickle_data(data): - return pickle.dumps(data, protocol=4) - -def write_svc_msg(fobj, svc_msg, extra_data={}): - MSG = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': svc_msg} - if extra_data: - MSG.update(extra_data) - fobj.write(pickle_data(MSG)) - -class UDPRequestHandler(socketserver.DatagramRequestHandler): - def handle(self): - datagram = self.rfile.read(BUFSIZE) - if self.client_address not in peers: - write_svc_msg(self.wfile, SVC_MESSAGE_YOU_ARE_STRANGER) - return - try: - unpickled_datagram = pickle.dumps(datagram) - except pickle.UnpicklingError: - write_svc_msg(self.wfile, SVC_MESSAGE_BAD_PACKET) - return - pk = get_key() - try: - data = pickle.loads(pk.decrypt_raw(unpickled_datagram['packet'])) - except: - write_svc_msg(self.wfile, SVC_MESSAGE_DECRYPT_ERROR, {'packet_id': unpickled_datagram['packet_id']}) - return - block_id = data['block_id'] - fragment_num = data['num'] - cmd = data[''] - - - diff --git a/mods/utcp.py b/mods/utcp.py index 6b36adf..5e82da9 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -10,6 +10,8 @@ from struct import Struct import uuid import bisect +from datetime import datetime + DATA_DIVIDE_LENGTH = 8000 PACKET_HEADER_SIZE = 512 # Pickle service info DATA_LENGTH = DATA_DIVIDE_LENGTH @@ -19,6 +21,7 @@ class Connection: SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 def __init__(self, remote, encrypted=False): + self.fileno = 0 self.peer_addr = remote self.seq = Connection.gen_starting_seq_num() self.recv_seq = -1 @@ -153,6 +156,10 @@ class ConnectedSOCK(object): if self.closed: raise EOFError return self.low_sock.send(data, self.client_addr) + def sendall(self, data): + if self.closed: + raise EOFError + self.low_sock.sendall(data, self.client_addr) def recv(self, size): if self.closed: raise EOFError @@ -166,17 +173,13 @@ class ConnectedSOCK(object): def shutdown(self, *a, **kw): self.close() def poll(self, timeout): - if not self.closed: - conn = self.low_sock.connections[self.client_addr] - with conn.recv_lock: - has_data = len(conn.packet_buffer['DATA']) - if has_data: - return True - else: - self.incoming_packet_event.wait(timeout) - with conn.recv_lock: - return len(conn.packet_buffer['DATA']) - return False + return self.low_sock.poll(timeout, self.client_addr) + def packets_arrived(self, packet_type): + return self.low_sock.packets_arrived(packet_type, self.client_addr) + def fileno(self): + if self.closed: + raise EOFError + return self.low_sock.fileno(self.client_addr) class UTCP(object): host = None @@ -195,18 +198,40 @@ class UTCP(object): self.connections = {} self.connection_queue = [] self.syn_received = {} - def poll(self, timeout): - if len(self.connections): - connection = list(self.connections.keys())[0] + self.fileno_seq = 40000000 + def next_fileno(self): + self.fileno_seq += 1 + return self.fileno_seq + def packets_arrived(self, packet_type, connection=None): + try: conn = self.connections[connection] - with conn.recv_lock: - has_data = bool(len(conn.packet_buffer['DATA'])) + except: + raise EOFError + with conn.recv_lock: + return bool(len(conn.packet_buffer[packet_type])) + def poll(self, timeout, connection=None): + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + if not self.closed: + has_data = self.packets_arrived('DATA', connection) if has_data: return True else: - self.incoming_packet_event.wait(timeout) - with conn.recv_lock: - return bool(len(conn.packet_buffer['DATA'])) + if not timeout: + timeout = 0.5 + while True and not self.closed: + self.incoming_packet_event.wait(timeout) + has_data = self.packets_arrived('DATA', connection) + if not has_data: + continue + else: + return has_data + else: + self.incoming_packet_event.wait(timeout) + return self.packets_arrived('DATA', connection) return False def get_free_port(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -256,6 +281,8 @@ class UTCP(object): return len(data) except socket.error as error: raise EOFError('Socket was closed before executing command. Error is: %s.' % error) + def sendall(self, data, connection=None): + _ = self.send(data, connection) def recv(self, size, connection=None): if self.closed: raise EOFError @@ -302,6 +329,7 @@ class UTCP(object): with self.queue_lock: if len(self.connection_queue) < max_connections: conn = Connection(address, self.encrypted) + conn.fileno = self.next_fileno() if self.encrypted: try: conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey) @@ -315,6 +343,8 @@ class UTCP(object): self.own_socket.sendto(b'Connections full', address) except KeyError: continue + except TypeError: + continue except socket.error as error: raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error)) @@ -330,6 +360,11 @@ class UTCP(object): def stop(self): self.own_socket.close() self.status = 0 + def shutdown(self, *a, **kw): + self.close() + self.status = 0 + self.connections = {} + self.stop() def accept(self): while self.status: try: @@ -354,6 +389,8 @@ class UTCP(object): raise EOFError('Something went wrong in accept func: ' + str(error)) def connect(self, server_address=('127.0.0.1', 10000)): + if server_address in self.connections: + raise EOFError('Already connected to peer') try: self.bind(('', self.get_free_port())) self.status = 1 @@ -379,13 +416,19 @@ class UTCP(object): ack = Ack(answer.id) self.__send_packet(server_address, ack) self.channel = UTCPChannel(self) + conn.fileno = self.next_fileno() except socket.error as error: self.own_socket.close() self.connections = {} self.status = 0 raise EOFError('The socket was closed. Error:' + str(error)) - def fileno(self): - return self.own_socket.fileno() + def fileno(self, connection=None): + if connection not in list(self.connections.keys()): + if connection is None: + connection = list(self.connections.keys())[0] + else: + raise EOFError('Connection not in connected devices') + return self.connections[connection].fileno @property def closed(self): return not bool(len(self.connections)) @@ -397,7 +440,10 @@ class UTCP(object): try: if connection not in list(self.connections.keys()): if connection is None: - connection = list(self.connections.keys())[0] + if len(self.connections): + connection = list(self.connections.keys())[0] + else: + return else: raise EOFError('Connection not in connected devices') fin = Fin() @@ -421,8 +467,7 @@ class UTCP(object): self.__send_packet(connection, fin_ack) except: pass - with self.connection_lock: - self.connections.pop(connection) + self.drop_connection(connection) except Exception as error: raise EOFError('Something went wrong in disconnect func:%s ' % error) @@ -445,14 +490,17 @@ class UTCP(object): if address[0] == 'Any': order = self.syn_received.popitem() # to reverse the tuple received return order[1], order[0] - conn = self.connections[address] + try: + conn = self.connections[address] + except: + break if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']: tries += 1 if condition == 'DATA': - if len(conn.packet_buffer[condition]): + if self.poll(0.1, address): data = b'' while size: - if not self.poll(0.5): + if not self.poll(0.1, address): continue with conn.recv_lock: packet = conn.packet_buffer[condition][0] @@ -470,11 +518,10 @@ class UTCP(object): else: raise KeyError else: - with conn.recv_lock: - if len(conn.packet_buffer[condition]): - packet = conn.packet_buffer[condition].pop() - else: - raise KeyError + if self.packets_arrived(condition, address): + packet = conn.packet_buffer[condition].pop() + else: + raise KeyError if want_id and packet.id != want_id: raise KeyError return packet @@ -513,6 +560,8 @@ class UTCP(object): if isinstance(packet, Fin): self.disconnect(address, packet.id) elif isinstance(packet, Syn): + if address in self.connections: + return if packet.id not in map(lambda x: x.id, self.syn_received.values()): self.syn_received[address] = packet else: