diff --git a/mods/stream.py b/mods/stream.py index 5818bb8..052af0e 100644 --- a/mods/stream.py +++ b/mods/stream.py @@ -2,16 +2,15 @@ import socketserver import pickle import simplecrypto from uuid import uuid4 -from datetime import datetime, timedelta +import datetime, io import threading import time -from rpyc.lib import safe_import, Timeout - peers = {} RETRANSMIT_RETRIES = 3 DATAGRAM_MAX_SIZE = 9000 +RAW_DATA_MAX_SIZE = 8000 PACKET_NUM_SEQ_TTL = 300 SOCK_SEND_TIMEOUT = 60 @@ -28,6 +27,20 @@ class InvalidPacket(Exception): pass def pickle_data(data): return pickle.dumps(data, protocol=4) +################ +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + # Only allow datetime + if module == "datetime" and name == 'datetime': + return getattr(datetime, name) + # Forbid everything else. + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) +def restricted_pickle_loads(s): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(s)).load() +################ + # From rpyc.lib class Timeout: @@ -48,7 +61,7 @@ class Timeout: class Packet: def __init__(self, packet_payload): try: - d = pickle.loads(packet_payload) + d = restricted_pickle_loads(packet_payload) self.sid = d['sid'] self.type = d['type'] self.reset_timestamp = d['reset_timestamp'] @@ -68,14 +81,14 @@ class Peer: self.confirm_wait_packet = None self.last_packet = None self.request_lock = threading.Lock() - self.num_seq_ttl = timedelta(seconds=PACKET_NUM_SEQ_TTL) + self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL) self.last_sent_packet_num = -1 - self.last_sent_packet_num_reset_time = datetime.utcnow() + self.last_sent_packet_num_reset_time = datetime.datetime.utcnow() self.last_received_packet_num = -1 self.last_received_packet_num_reset_time = None self.retransmit_count = 0 def next_packet_num(self): - new_time = datetime.utcnow() + new_time = datetime.datetime.utcnow() if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl: self.last_sent_packet_num = -1 self.last_sent_packet_num += 1 @@ -98,18 +111,21 @@ class Peer: if confirm: self.last_packet = data self.confirm_wait_packet = (d['reset_timestamp'], d['num']) - self.sock.sendto(pickle_data(d), self.endpoint) + self.sock.sendto(data, self.endpoint) def mark_packet(self, d): d['num'] = self.next_packet_num() d['reset_timestamp'] = self.last_sent_packet_num_reset_time return d def retransmit(self): - pass + self.retransmit_count += 1 + if self.retransmit_count > RETRANSMIT_RETRIES: + raise EOFError('retransmit limit reached') + self.sock.sendto(self.last_packet, self.endpoint) def reply_my_pub_key(self, packet): try: self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data) except: - raise EOFError + raise EOFError('invalid pubkey data') self.my_key = simplecrypto.RsaKeypair() d = { 'type': PACKET_TYPE_PEER_PUB_KEY_REPLY, @@ -133,12 +149,23 @@ class Peer: 'reset_timestamp': packet.reset_timestamp } self.send(d) + def check_received_packet(self, packet): + if self.last_received_packet_num_reset_time: + if self.last_received_packet_num_reset_time > packet.reset_timestamp: + raise EOFError('packet from past') + elif self.last_received_packet_num_reset_time < packet.reset_timestamp: + self.last_received_packet_num_reset_time = packet.reset_timestamp + if (self.last_received_packet_num + 1) != packet.num: + raise EOFError('packet sequence corrupt') + else: + self.last_received_packet_num_reset_time = packet.reset_timestamp + self.last_received_packet_num = packet.num + def send_recv_confirmation(self, packet): + pass def hello(self): self.sid = uuid4().hex d = { 'type': PACKET_TYPE_HELLO, - 'reset_timestamp': None, - 'num': None } self.sock.sendto(pickle_data(d)) def recv_packet(self, packet_payload): @@ -146,9 +173,9 @@ class Peer: try: packet = Packet(packet_payload) if packet.type == PACKET_TYPE_GOODBUY: - raise EOFError + raise EOFError('connection closed') except: - raise EOFError + raise EOFError('invalid packet') if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid): self.hello() return @@ -159,7 +186,7 @@ class Peer: self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data)) return except: - raise EOFError + raise EOFError('create pubkey failed') elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST: self.reply_my_pub_key(packet) return @@ -168,14 +195,48 @@ class Peer: return ############################################ if self.confirm_wait_packet: - if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet: + if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV: self.confirm_packet_recv(packet) + return + else: + self.retransmit() + return + ############################################ else: - pass - - + if packet.type == PACKET_TYPE_PACKET: + try: + raw = self.my_key.decrypt_raw(packet.data) + except: + raise EOFError('decrypt packet error') + self.check_received_packet(packet) + self.put_block(raw) + else: + raise EOFError('connection lost') + def send_packet(self, raw): + if self.confirm_wait_packet: + timeout = Timeout(SOCK_SEND_TIMEOUT) + while timeout.timeleft(): + if not self.confirm_wait_packet: break + if self.confirm_wait_packet: + raise EOFError('connection lost') + d = { + 'type': PACKET_TYPE_PACKET, + 'data': raw + } + self.send(self.mark_packet(d), encrypted=True, confirm=True) class UDPRequestHandler(socketserver.DatagramRequestHandler): + def finish(self): + '''Don't send anything''' + pass def handle(self): - datagram = self.rfile.read(BUFSIZE) - if self.client_address not in peers: + datagram = self.rfile.read(DATAGRAM_MAX_SIZE) + peer_addr = self.client_address + if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr) + try: + peers[peer_addr].recv_packet(datagram) + except EOFError: + del peers[peer_addr] + +class EncryptedUDPStream: +