black-mamba/mods/stream.py

357 lines
12 KiB
Python
Raw Permalink Normal View History

2019-04-04 20:38:48 +03:00
import socketserver
import pickle
import simplecrypto
from uuid import uuid4
2019-04-04 22:25:49 +03:00
import datetime, io
2019-04-04 20:38:48 +03:00
import threading
import time
2019-04-05 09:50:06 +03:00
from struct import Struct
try:
import zlib
except:
zlib = None
2019-04-04 20:38:48 +03:00
peers = {}
PORT = 16386
2019-04-04 20:38:48 +03:00
RETRANSMIT_RETRIES = 3
DATAGRAM_MAX_SIZE = 9000
2019-04-04 22:25:49 +03:00
RAW_DATA_MAX_SIZE = 8000
2019-04-04 20:38:48 +03:00
PACKET_NUM_SEQ_TTL = 300
SOCK_SEND_TIMEOUT = 60
PACKET_TYPE_HELLO = 0x00
PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01
PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02
PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03
PACKET_TYPE_PACKET = 0xa0
PACKET_TYPE_CONFIRM_RECV = 0xa1
PACKET_TYPE_GOODBUY = 0xff
class InvalidPacket(Exception): pass
2019-04-05 09:50:06 +03:00
class OldPacket(Exception): pass
2019-04-04 20:38:48 +03:00
def pickle_data(data):
return pickle.dumps(data, protocol=4)
2019-04-04 22:25:49 +03:00
################
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow datetime
if module == "datetime" and name == 'datetime':
return getattr(datetime, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_pickle_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
################
2019-04-04 20:38:48 +03:00
# From rpyc.lib
class Timeout:
def __init__(self, timeout):
if isinstance(timeout, Timeout):
self.finite = timeout.finite
self.tmax = timeout.tmax
else:
self.finite = timeout is not None and timeout >= 0
self.tmax = time.time()+timeout if self.finite else None
def expired(self):
return self.finite and time.time() >= self.tmax
def timeleft(self):
return max((0, self.tmax - time.time())) if self.finite else None
def sleep(self, interval):
time.sleep(min(interval, self.timeleft()) if self.finite else interval)
class Packet:
def __init__(self, packet_payload):
try:
2019-04-04 22:25:49 +03:00
d = restricted_pickle_loads(packet_payload)
2019-04-04 20:38:48 +03:00
self.sid = d['sid']
self.type = d['type']
self.reset_timestamp = d['reset_timestamp']
self.num = d['num']
self.data = d['data']
except:
raise InvalidPacket
class Peer:
def __init__(self, sock, endpoint):
self.sid = None
self.sock = sock
self.endpoint = endpoint
self.my_key = None
self.peer_pub_key = None
self.buf = []
self.confirm_wait_packet = None
self.last_packet = None
self.request_lock = threading.Lock()
2019-04-04 22:25:49 +03:00
self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL)
2019-04-04 20:38:48 +03:00
self.last_sent_packet_num = -1
2019-04-04 22:25:49 +03:00
self.last_sent_packet_num_reset_time = datetime.datetime.utcnow()
2019-04-04 20:38:48 +03:00
self.last_received_packet_num = -1
self.last_received_packet_num_reset_time = None
self.retransmit_count = 0
def next_packet_num(self):
2019-04-04 22:25:49 +03:00
new_time = datetime.datetime.utcnow()
2019-04-04 20:38:48 +03:00
if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl:
self.last_sent_packet_num = -1
self.last_sent_packet_num += 1
return self.last_sent_packet_num
def poll(self):
return bool(len(self.buf))
def get_next_block(self):
if not len(self.buf):
return None
return self.buf.pop()
def put_block(self, data):
self.buf.insert(0, data)
def send(self, d, encrypted=False, confirm=False):
if 'sid' not in d: d['sid'] = self.sid
if 'num' not in d: d['num'] = None
if 'reset_timestamp' not in d: d['reset_timestamp'] = None
if 'data' not in d: d['data'] = b''
if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data'])
data = pickle_data(d)
if confirm:
self.last_packet = data
self.confirm_wait_packet = (d['reset_timestamp'], d['num'])
2019-04-04 22:25:49 +03:00
self.sock.sendto(data, self.endpoint)
2019-04-04 20:38:48 +03:00
def mark_packet(self, d):
d['num'] = self.next_packet_num()
d['reset_timestamp'] = self.last_sent_packet_num_reset_time
return d
def retransmit(self):
2019-04-04 22:25:49 +03:00
self.retransmit_count += 1
if self.retransmit_count > RETRANSMIT_RETRIES:
raise EOFError('retransmit limit reached')
self.sock.sendto(self.last_packet, self.endpoint)
2019-04-04 20:38:48 +03:00
def reply_my_pub_key(self, packet):
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data)
except:
2019-04-04 22:25:49 +03:00
raise EOFError('invalid pubkey data')
2019-04-04 20:38:48 +03:00
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REPLY,
'data': self.my_key.publickey.serialize()
}
self.send(d, encrypted=True)
def request_peer_bub_key(self, packet):
self.sid = packet.sid
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST,
'data': self.my_key.publickey.serialize()
}
self.send(d)
def confirm_packet_recv(self, packet):
self.confirm_wait_packet = None
self.last_packet = None
d = {
'type': PACKET_TYPE_CONFIRM_RECV,
'num': packet.num,
'reset_timestamp': packet.reset_timestamp
}
self.send(d)
2019-04-04 22:25:49 +03:00
def check_received_packet(self, packet):
if self.last_received_packet_num_reset_time:
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
2019-04-05 09:50:06 +03:00
raise OldPacket('packet from past')
2019-04-04 22:25:49 +03:00
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
self.last_received_packet_num_reset_time = packet.reset_timestamp
if (self.last_received_packet_num + 1) != packet.num:
raise EOFError('packet sequence corrupt')
else:
self.last_received_packet_num_reset_time = packet.reset_timestamp
self.last_received_packet_num = packet.num
def send_recv_confirmation(self, packet):
pass
2019-04-04 20:38:48 +03:00
def hello(self):
self.sid = uuid4().hex
d = {
'type': PACKET_TYPE_HELLO,
}
self.sock.sendto(pickle_data(d))
def recv_packet(self, packet_payload):
with self.request_lock:
try:
packet = Packet(packet_payload)
if packet.type == PACKET_TYPE_GOODBUY:
2019-04-04 22:25:49 +03:00
raise EOFError('connection closed')
2019-04-04 20:38:48 +03:00
except:
2019-04-04 22:25:49 +03:00
raise EOFError('invalid packet')
2019-04-04 20:38:48 +03:00
if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid):
self.hello()
return
############################################
if not self.peer_pub_key:
if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY:
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data))
return
except:
2019-04-04 22:25:49 +03:00
raise EOFError('create pubkey failed')
2019-04-04 20:38:48 +03:00
elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST:
self.reply_my_pub_key(packet)
return
elif packet.type == PACKET_TYPE_HELLO:
self.request_peer_bub_key(packet)
return
############################################
if self.confirm_wait_packet:
2019-04-04 22:25:49 +03:00
if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV:
2019-04-04 20:38:48 +03:00
self.confirm_packet_recv(packet)
2019-04-04 22:25:49 +03:00
return
else:
self.retransmit()
return
############################################
2019-04-04 20:38:48 +03:00
else:
2019-04-04 22:25:49 +03:00
if packet.type == PACKET_TYPE_PACKET:
2019-04-05 09:50:06 +03:00
try:
self.check_received_packet(packet)
except OldPacket:
return
2019-04-04 22:25:49 +03:00
try:
raw = self.my_key.decrypt_raw(packet.data)
except:
raise EOFError('decrypt packet error')
self.put_block(raw)
else:
raise EOFError('connection lost')
def send_packet(self, raw):
if self.confirm_wait_packet:
timeout = Timeout(SOCK_SEND_TIMEOUT)
while timeout.timeleft():
if not self.confirm_wait_packet: break
if self.confirm_wait_packet:
raise EOFError('connection lost')
d = {
'type': PACKET_TYPE_PACKET,
'data': raw
}
self.send(self.mark_packet(d), encrypted=True, confirm=True)
2019-04-04 20:38:48 +03:00
class UDPRequestHandler(socketserver.DatagramRequestHandler):
2019-04-04 22:25:49 +03:00
def finish(self):
'''Don't send anything'''
pass
2019-04-04 20:38:48 +03:00
def handle(self):
2019-04-04 22:25:49 +03:00
datagram = self.rfile.read(DATAGRAM_MAX_SIZE)
peer_addr = self.client_address
if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr)
try:
peers[peer_addr].recv_packet(datagram)
except EOFError:
del peers[peer_addr]
class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer):
pass
udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler)
udpserver_thread = threading.Thread(target=udpserver.serve_forever)
udpserver_thread.start()
2019-04-04 22:25:49 +03:00
class EncryptedUDPStream:
def __init__(self, sock, peer_addr):
self.peer_addr = peer_addr
self.sock = sock
@classmethod
def _connect(cls, host, port):
peers[(host, port)] = Peer(udpserver.socket, (host, port))
peers[(host, port)].hello()
return udpserver.socket
@classmethod
def connect(cls, host, port, **kwargs):
return cls(cls._connect(host, port), (host, port))
def poll(self, timeout):
timeout = Timeout(timeout)
while timeout.timeleft():
try:
rl = peers[self.peer_addr].poll()
if rl: break
except:
raise EOFError
return rl
def close(self):
if self.peer_addr in peers: del peers[self.peer_addr]
@property
def closed(self):
return self.peer_addr not in peers
def fileno(self):
try:
return self.sock.fileno()
except:
self.close()
raise EOFError
2019-04-05 09:50:06 +03:00
def read(self):
try:
buf = peers[self.peer_addr].get_next_block()
except:
raise EOFError
return buf
def write(self, data):
try:
peers[self.peer_addr].send_packet(data)
except:
raise EOFError
2019-04-05 09:50:06 +03:00
class Channel(object):
MAX_IO_CHUNK = 8000
COMPRESSION_THRESHOLD = 3000
COMPRESSION_LEVEL = 1
FRAME_HEADER = Struct("!LB")
FLUSHER = b'\n'
__slots__ = ["stream", "compress"]
def __init__(self, stream, compress = True):
self.stream = stream
if not zlib:
compress = False
self.compress = compress
def close(self):
self.stream.close()
@property
def closed(self):
return self.stream.closed
def fileno(self):
return self.stream.fileno()
2019-04-05 09:50:06 +03:00
def poll(self, timeout):
return self.stream.poll(timeout)
def recv(self):
header = self.stream.read()
if len(header) != self.FRAME_HEADER.size:
raise EOFError('CHANNEL: Not a header received')
length, compressed = self.FRAME_HEADER.unpack(header)
block_len = length + len(self.FLUSHER)
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
if len(full_block) != block_len:
raise EOFError('CHANNEL: Received block with wrong size')
data = full_block[:-len(self.FLUSHER)]
if compressed:
data = zlib.decompress(data)
return data
def send(self, data):
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
compressed = 1
data = zlib.compress(data, self.COMPRESSION_LEVEL)
else:
compressed = 0
header = self.FRAME_HEADER.pack(len(data), compressed)
self.stream.write(header)
buf = data + self.FLUSHER
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
import rpyc.utils.server
import rpyc.utils.factory
import rpyc.Service