357 lines
12 KiB
Python
357 lines
12 KiB
Python
import socketserver
|
|
import pickle
|
|
import simplecrypto
|
|
from uuid import uuid4
|
|
import datetime, io
|
|
import threading
|
|
import time
|
|
from struct import Struct
|
|
try:
|
|
import zlib
|
|
except:
|
|
zlib = None
|
|
|
|
peers = {}
|
|
|
|
PORT = 16386
|
|
|
|
RETRANSMIT_RETRIES = 3
|
|
DATAGRAM_MAX_SIZE = 9000
|
|
RAW_DATA_MAX_SIZE = 8000
|
|
PACKET_NUM_SEQ_TTL = 300
|
|
|
|
SOCK_SEND_TIMEOUT = 60
|
|
|
|
PACKET_TYPE_HELLO = 0x00
|
|
PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01
|
|
PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02
|
|
PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03
|
|
PACKET_TYPE_PACKET = 0xa0
|
|
PACKET_TYPE_CONFIRM_RECV = 0xa1
|
|
PACKET_TYPE_GOODBUY = 0xff
|
|
|
|
class InvalidPacket(Exception): pass
|
|
class OldPacket(Exception): pass
|
|
|
|
def pickle_data(data):
|
|
return pickle.dumps(data, protocol=4)
|
|
################
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
# Only allow datetime
|
|
if module == "datetime" and name == 'datetime':
|
|
return getattr(datetime, name)
|
|
# Forbid everything else.
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
(module, name))
|
|
def restricted_pickle_loads(s):
|
|
"""Helper function analogous to pickle.loads()."""
|
|
return RestrictedUnpickler(io.BytesIO(s)).load()
|
|
################
|
|
|
|
|
|
# From rpyc.lib
|
|
class Timeout:
|
|
def __init__(self, timeout):
|
|
if isinstance(timeout, Timeout):
|
|
self.finite = timeout.finite
|
|
self.tmax = timeout.tmax
|
|
else:
|
|
self.finite = timeout is not None and timeout >= 0
|
|
self.tmax = time.time()+timeout if self.finite else None
|
|
def expired(self):
|
|
return self.finite and time.time() >= self.tmax
|
|
def timeleft(self):
|
|
return max((0, self.tmax - time.time())) if self.finite else None
|
|
def sleep(self, interval):
|
|
time.sleep(min(interval, self.timeleft()) if self.finite else interval)
|
|
|
|
class Packet:
|
|
def __init__(self, packet_payload):
|
|
try:
|
|
d = restricted_pickle_loads(packet_payload)
|
|
self.sid = d['sid']
|
|
self.type = d['type']
|
|
self.reset_timestamp = d['reset_timestamp']
|
|
self.num = d['num']
|
|
self.data = d['data']
|
|
except:
|
|
raise InvalidPacket
|
|
|
|
class Peer:
|
|
def __init__(self, sock, endpoint):
|
|
self.sid = None
|
|
self.sock = sock
|
|
self.endpoint = endpoint
|
|
self.my_key = None
|
|
self.peer_pub_key = None
|
|
self.buf = []
|
|
self.confirm_wait_packet = None
|
|
self.last_packet = None
|
|
self.request_lock = threading.Lock()
|
|
self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL)
|
|
self.last_sent_packet_num = -1
|
|
self.last_sent_packet_num_reset_time = datetime.datetime.utcnow()
|
|
self.last_received_packet_num = -1
|
|
self.last_received_packet_num_reset_time = None
|
|
self.retransmit_count = 0
|
|
def next_packet_num(self):
|
|
new_time = datetime.datetime.utcnow()
|
|
if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl:
|
|
self.last_sent_packet_num = -1
|
|
self.last_sent_packet_num += 1
|
|
return self.last_sent_packet_num
|
|
def poll(self):
|
|
return bool(len(self.buf))
|
|
def get_next_block(self):
|
|
if not len(self.buf):
|
|
return None
|
|
return self.buf.pop()
|
|
def put_block(self, data):
|
|
self.buf.insert(0, data)
|
|
def send(self, d, encrypted=False, confirm=False):
|
|
if 'sid' not in d: d['sid'] = self.sid
|
|
if 'num' not in d: d['num'] = None
|
|
if 'reset_timestamp' not in d: d['reset_timestamp'] = None
|
|
if 'data' not in d: d['data'] = b''
|
|
if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data'])
|
|
data = pickle_data(d)
|
|
if confirm:
|
|
self.last_packet = data
|
|
self.confirm_wait_packet = (d['reset_timestamp'], d['num'])
|
|
self.sock.sendto(data, self.endpoint)
|
|
def mark_packet(self, d):
|
|
d['num'] = self.next_packet_num()
|
|
d['reset_timestamp'] = self.last_sent_packet_num_reset_time
|
|
return d
|
|
def retransmit(self):
|
|
self.retransmit_count += 1
|
|
if self.retransmit_count > RETRANSMIT_RETRIES:
|
|
raise EOFError('retransmit limit reached')
|
|
self.sock.sendto(self.last_packet, self.endpoint)
|
|
def reply_my_pub_key(self, packet):
|
|
try:
|
|
self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data)
|
|
except:
|
|
raise EOFError('invalid pubkey data')
|
|
self.my_key = simplecrypto.RsaKeypair()
|
|
d = {
|
|
'type': PACKET_TYPE_PEER_PUB_KEY_REPLY,
|
|
'data': self.my_key.publickey.serialize()
|
|
}
|
|
self.send(d, encrypted=True)
|
|
def request_peer_bub_key(self, packet):
|
|
self.sid = packet.sid
|
|
self.my_key = simplecrypto.RsaKeypair()
|
|
d = {
|
|
'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST,
|
|
'data': self.my_key.publickey.serialize()
|
|
}
|
|
self.send(d)
|
|
def confirm_packet_recv(self, packet):
|
|
self.confirm_wait_packet = None
|
|
self.last_packet = None
|
|
d = {
|
|
'type': PACKET_TYPE_CONFIRM_RECV,
|
|
'num': packet.num,
|
|
'reset_timestamp': packet.reset_timestamp
|
|
}
|
|
self.send(d)
|
|
def check_received_packet(self, packet):
|
|
if self.last_received_packet_num_reset_time:
|
|
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
|
|
raise OldPacket('packet from past')
|
|
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
|
|
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
|
if (self.last_received_packet_num + 1) != packet.num:
|
|
raise EOFError('packet sequence corrupt')
|
|
else:
|
|
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
|
self.last_received_packet_num = packet.num
|
|
def send_recv_confirmation(self, packet):
|
|
pass
|
|
def hello(self):
|
|
self.sid = uuid4().hex
|
|
d = {
|
|
'type': PACKET_TYPE_HELLO,
|
|
}
|
|
self.sock.sendto(pickle_data(d))
|
|
def recv_packet(self, packet_payload):
|
|
with self.request_lock:
|
|
try:
|
|
packet = Packet(packet_payload)
|
|
if packet.type == PACKET_TYPE_GOODBUY:
|
|
raise EOFError('connection closed')
|
|
except:
|
|
raise EOFError('invalid packet')
|
|
if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid):
|
|
self.hello()
|
|
return
|
|
############################################
|
|
if not self.peer_pub_key:
|
|
if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY:
|
|
try:
|
|
self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data))
|
|
return
|
|
except:
|
|
raise EOFError('create pubkey failed')
|
|
elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST:
|
|
self.reply_my_pub_key(packet)
|
|
return
|
|
elif packet.type == PACKET_TYPE_HELLO:
|
|
self.request_peer_bub_key(packet)
|
|
return
|
|
############################################
|
|
if self.confirm_wait_packet:
|
|
if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV:
|
|
self.confirm_packet_recv(packet)
|
|
return
|
|
else:
|
|
self.retransmit()
|
|
return
|
|
############################################
|
|
else:
|
|
if packet.type == PACKET_TYPE_PACKET:
|
|
try:
|
|
self.check_received_packet(packet)
|
|
except OldPacket:
|
|
return
|
|
try:
|
|
raw = self.my_key.decrypt_raw(packet.data)
|
|
except:
|
|
raise EOFError('decrypt packet error')
|
|
self.put_block(raw)
|
|
else:
|
|
raise EOFError('connection lost')
|
|
def send_packet(self, raw):
|
|
if self.confirm_wait_packet:
|
|
timeout = Timeout(SOCK_SEND_TIMEOUT)
|
|
while timeout.timeleft():
|
|
if not self.confirm_wait_packet: break
|
|
if self.confirm_wait_packet:
|
|
raise EOFError('connection lost')
|
|
d = {
|
|
'type': PACKET_TYPE_PACKET,
|
|
'data': raw
|
|
}
|
|
self.send(self.mark_packet(d), encrypted=True, confirm=True)
|
|
|
|
class UDPRequestHandler(socketserver.DatagramRequestHandler):
|
|
def finish(self):
|
|
'''Don't send anything'''
|
|
pass
|
|
def handle(self):
|
|
datagram = self.rfile.read(DATAGRAM_MAX_SIZE)
|
|
peer_addr = self.client_address
|
|
if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr)
|
|
try:
|
|
peers[peer_addr].recv_packet(datagram)
|
|
except EOFError:
|
|
del peers[peer_addr]
|
|
|
|
class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer):
|
|
pass
|
|
|
|
udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler)
|
|
udpserver_thread = threading.Thread(target=udpserver.serve_forever)
|
|
udpserver_thread.start()
|
|
|
|
class EncryptedUDPStream:
|
|
def __init__(self, sock, peer_addr):
|
|
self.peer_addr = peer_addr
|
|
self.sock = sock
|
|
@classmethod
|
|
def _connect(cls, host, port):
|
|
peers[(host, port)] = Peer(udpserver.socket, (host, port))
|
|
peers[(host, port)].hello()
|
|
return udpserver.socket
|
|
@classmethod
|
|
def connect(cls, host, port, **kwargs):
|
|
return cls(cls._connect(host, port), (host, port))
|
|
def poll(self, timeout):
|
|
timeout = Timeout(timeout)
|
|
while timeout.timeleft():
|
|
try:
|
|
rl = peers[self.peer_addr].poll()
|
|
if rl: break
|
|
except:
|
|
raise EOFError
|
|
return rl
|
|
def close(self):
|
|
if self.peer_addr in peers: del peers[self.peer_addr]
|
|
@property
|
|
def closed(self):
|
|
return self.peer_addr not in peers
|
|
def fileno(self):
|
|
try:
|
|
return self.sock.fileno()
|
|
except:
|
|
self.close()
|
|
raise EOFError
|
|
def read(self):
|
|
try:
|
|
buf = peers[self.peer_addr].get_next_block()
|
|
except:
|
|
raise EOFError
|
|
return buf
|
|
def write(self, data):
|
|
try:
|
|
peers[self.peer_addr].send_packet(data)
|
|
except:
|
|
raise EOFError
|
|
|
|
class Channel(object):
|
|
MAX_IO_CHUNK = 8000
|
|
COMPRESSION_THRESHOLD = 3000
|
|
COMPRESSION_LEVEL = 1
|
|
FRAME_HEADER = Struct("!LB")
|
|
FLUSHER = b'\n'
|
|
__slots__ = ["stream", "compress"]
|
|
|
|
def __init__(self, stream, compress = True):
|
|
self.stream = stream
|
|
if not zlib:
|
|
compress = False
|
|
self.compress = compress
|
|
def close(self):
|
|
self.stream.close()
|
|
|
|
@property
|
|
def closed(self):
|
|
return self.stream.closed
|
|
def fileno(self):
|
|
return self.stream.fileno()
|
|
|
|
def poll(self, timeout):
|
|
return self.stream.poll(timeout)
|
|
|
|
def recv(self):
|
|
header = self.stream.read()
|
|
if len(header) != self.FRAME_HEADER.size:
|
|
raise EOFError('CHANNEL: Not a header received')
|
|
length, compressed = self.FRAME_HEADER.unpack(header)
|
|
block_len = length + len(self.FLUSHER)
|
|
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
|
|
if len(full_block) != block_len:
|
|
raise EOFError('CHANNEL: Received block with wrong size')
|
|
data = full_block[:-len(self.FLUSHER)]
|
|
if compressed:
|
|
data = zlib.decompress(data)
|
|
return data
|
|
|
|
def send(self, data):
|
|
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
|
|
compressed = 1
|
|
data = zlib.compress(data, self.COMPRESSION_LEVEL)
|
|
else:
|
|
compressed = 0
|
|
header = self.FRAME_HEADER.pack(len(data), compressed)
|
|
self.stream.write(header)
|
|
buf = data + self.FLUSHER
|
|
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
|
|
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
|
|
|
|
import rpyc.utils.server
|
|
import rpyc.utils.factory
|
|
import rpyc.Service
|