import socketserver import pickle import simplecrypto from uuid import uuid4 import datetime, io import threading import time from struct import Struct try: import zlib except: zlib = None peers = {} PORT = 16386 RETRANSMIT_RETRIES = 3 DATAGRAM_MAX_SIZE = 9000 RAW_DATA_MAX_SIZE = 8000 PACKET_NUM_SEQ_TTL = 300 SOCK_SEND_TIMEOUT = 60 PACKET_TYPE_HELLO = 0x00 PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01 PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02 PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03 PACKET_TYPE_PACKET = 0xa0 PACKET_TYPE_CONFIRM_RECV = 0xa1 PACKET_TYPE_GOODBUY = 0xff class InvalidPacket(Exception): pass class OldPacket(Exception): pass def pickle_data(data): return pickle.dumps(data, protocol=4) ################ class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): # Only allow datetime if module == "datetime" and name == 'datetime': return getattr(datetime, name) # Forbid everything else. raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_pickle_loads(s): """Helper function analogous to pickle.loads().""" return RestrictedUnpickler(io.BytesIO(s)).load() ################ # From rpyc.lib class Timeout: def __init__(self, timeout): if isinstance(timeout, Timeout): self.finite = timeout.finite self.tmax = timeout.tmax else: self.finite = timeout is not None and timeout >= 0 self.tmax = time.time()+timeout if self.finite else None def expired(self): return self.finite and time.time() >= self.tmax def timeleft(self): return max((0, self.tmax - time.time())) if self.finite else None def sleep(self, interval): time.sleep(min(interval, self.timeleft()) if self.finite else interval) class Packet: def __init__(self, packet_payload): try: d = restricted_pickle_loads(packet_payload) self.sid = d['sid'] self.type = d['type'] self.reset_timestamp = d['reset_timestamp'] self.num = d['num'] self.data = d['data'] except: raise InvalidPacket class Peer: def __init__(self, sock, endpoint): self.sid = None self.sock = sock self.endpoint = endpoint self.my_key = None self.peer_pub_key = None self.buf = [] self.confirm_wait_packet = None self.last_packet = None self.request_lock = threading.Lock() self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL) self.last_sent_packet_num = -1 self.last_sent_packet_num_reset_time = datetime.datetime.utcnow() self.last_received_packet_num = -1 self.last_received_packet_num_reset_time = None self.retransmit_count = 0 def next_packet_num(self): new_time = datetime.datetime.utcnow() if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl: self.last_sent_packet_num = -1 self.last_sent_packet_num += 1 return self.last_sent_packet_num def poll(self): return bool(len(self.buf)) def get_next_block(self): if not len(self.buf): return None return self.buf.pop() def put_block(self, data): self.buf.insert(0, data) def send(self, d, encrypted=False, confirm=False): if 'sid' not in d: d['sid'] = self.sid if 'num' not in d: d['num'] = None if 'reset_timestamp' not in d: d['reset_timestamp'] = None if 'data' not in d: d['data'] = b'' if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data']) data = pickle_data(d) if confirm: self.last_packet = data self.confirm_wait_packet = (d['reset_timestamp'], d['num']) self.sock.sendto(data, self.endpoint) def mark_packet(self, d): d['num'] = self.next_packet_num() d['reset_timestamp'] = self.last_sent_packet_num_reset_time return d def retransmit(self): self.retransmit_count += 1 if self.retransmit_count > RETRANSMIT_RETRIES: raise EOFError('retransmit limit reached') self.sock.sendto(self.last_packet, self.endpoint) def reply_my_pub_key(self, packet): try: self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data) except: raise EOFError('invalid pubkey data') self.my_key = simplecrypto.RsaKeypair() d = { 'type': PACKET_TYPE_PEER_PUB_KEY_REPLY, 'data': self.my_key.publickey.serialize() } self.send(d, encrypted=True) def request_peer_bub_key(self, packet): self.sid = packet.sid self.my_key = simplecrypto.RsaKeypair() d = { 'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST, 'data': self.my_key.publickey.serialize() } self.send(d) def confirm_packet_recv(self, packet): self.confirm_wait_packet = None self.last_packet = None d = { 'type': PACKET_TYPE_CONFIRM_RECV, 'num': packet.num, 'reset_timestamp': packet.reset_timestamp } self.send(d) def check_received_packet(self, packet): if self.last_received_packet_num_reset_time: if self.last_received_packet_num_reset_time > packet.reset_timestamp: raise OldPacket('packet from past') elif self.last_received_packet_num_reset_time < packet.reset_timestamp: self.last_received_packet_num_reset_time = packet.reset_timestamp if (self.last_received_packet_num + 1) != packet.num: raise EOFError('packet sequence corrupt') else: self.last_received_packet_num_reset_time = packet.reset_timestamp self.last_received_packet_num = packet.num def send_recv_confirmation(self, packet): pass def hello(self): self.sid = uuid4().hex d = { 'type': PACKET_TYPE_HELLO, } self.sock.sendto(pickle_data(d)) def recv_packet(self, packet_payload): with self.request_lock: try: packet = Packet(packet_payload) if packet.type == PACKET_TYPE_GOODBUY: raise EOFError('connection closed') except: raise EOFError('invalid packet') if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid): self.hello() return ############################################ if not self.peer_pub_key: if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY: try: self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data)) return except: raise EOFError('create pubkey failed') elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST: self.reply_my_pub_key(packet) return elif packet.type == PACKET_TYPE_HELLO: self.request_peer_bub_key(packet) return ############################################ if self.confirm_wait_packet: if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV: self.confirm_packet_recv(packet) return else: self.retransmit() return ############################################ else: if packet.type == PACKET_TYPE_PACKET: try: self.check_received_packet(packet) except OldPacket: return try: raw = self.my_key.decrypt_raw(packet.data) except: raise EOFError('decrypt packet error') self.put_block(raw) else: raise EOFError('connection lost') def send_packet(self, raw): if self.confirm_wait_packet: timeout = Timeout(SOCK_SEND_TIMEOUT) while timeout.timeleft(): if not self.confirm_wait_packet: break if self.confirm_wait_packet: raise EOFError('connection lost') d = { 'type': PACKET_TYPE_PACKET, 'data': raw } self.send(self.mark_packet(d), encrypted=True, confirm=True) class UDPRequestHandler(socketserver.DatagramRequestHandler): def finish(self): '''Don't send anything''' pass def handle(self): datagram = self.rfile.read(DATAGRAM_MAX_SIZE) peer_addr = self.client_address if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr) try: peers[peer_addr].recv_packet(datagram) except EOFError: del peers[peer_addr] class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer): pass udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler) udpserver_thread = threading.Thread(target=udpserver.serve_forever) udpserver_thread.start() class EncryptedUDPStream: def __init__(self, sock, peer_addr): self.peer_addr = peer_addr self.sock = sock @classmethod def _connect(cls, host, port): peers[(host, port)] = Peer(udpserver.socket, (host, port)) peers[(host, port)].hello() return udpserver.socket @classmethod def connect(cls, host, port, **kwargs): return cls(cls._connect(host, port), (host, port)) def poll(self, timeout): timeout = Timeout(timeout) while timeout.timeleft(): try: rl = peers[self.peer_addr].poll() if rl: break except: raise EOFError return rl def close(self): if self.peer_addr in peers: del peers[self.peer_addr] @property def closed(self): return self.peer_addr not in peers def fileno(self): try: return self.sock.fileno() except: self.close() raise EOFError def read(self): try: buf = peers[self.peer_addr].get_next_block() except: raise EOFError return buf def write(self, data): try: peers[self.peer_addr].send_packet(data) except: raise EOFError class Channel(object): MAX_IO_CHUNK = 8000 COMPRESSION_THRESHOLD = 3000 COMPRESSION_LEVEL = 1 FRAME_HEADER = Struct("!LB") FLUSHER = b'\n' __slots__ = ["stream", "compress"] def __init__(self, stream, compress = True): self.stream = stream if not zlib: compress = False self.compress = compress def close(self): self.stream.close() @property def closed(self): return self.stream.closed def fileno(self): return self.stream.fileno() def poll(self, timeout): return self.stream.poll(timeout) def recv(self): header = self.stream.read() if len(header) != self.FRAME_HEADER.size: raise EOFError('CHANNEL: Not a header received') length, compressed = self.FRAME_HEADER.unpack(header) block_len = length + len(self.FLUSHER) full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK))) if len(full_block) != block_len: raise EOFError('CHANNEL: Received block with wrong size') data = full_block[:-len(self.FLUSHER)] if compressed: data = zlib.decompress(data) return data def send(self, data): if self.compress and len(data) > self.COMPRESSION_THRESHOLD: compressed = 1 data = zlib.compress(data, self.COMPRESSION_LEVEL) else: compressed = 0 header = self.FRAME_HEADER.pack(len(data), compressed) self.stream.write(header) buf = data + self.FLUSHER for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK): self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK]) import rpyc.utils.server import rpyc.utils.factory import rpyc.Service