black-mamba/mods/utcp.py

542 lines
20 KiB
Python
Raw Normal View History

# based on https://github.com/ethay012/TCP-over-UDP
import random
import socket
import pickle
import threading
import io
import hashlib
import simplecrypto
from struct import Struct
import uuid
2019-04-08 18:10:20 +03:00
import bisect
DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info
DATA_LENGTH = DATA_DIVIDE_LENGTH
SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger
LAST_CONNECTION = -1
FIRST = 0
PACKET_END = b'___+++^^^END^^^+++___'
2019-04-08 08:29:24 +03:00
class Connection:
SMALLEST_STARTING_SEQ = 0
HIGHEST_STARTING_SEQ = 4294967295
2019-04-08 08:29:24 +03:00
def __init__(self, remote, encrypted=False):
self.peer_addr = remote
self.seq = Connection.gen_starting_seq_num()
2019-04-08 22:15:04 +03:00
self.recv_seq = 0
2019-04-08 09:22:52 +03:00
self.my_key = None
if encrypted:
2019-04-08 14:06:13 +03:00
self.my_key = simplecrypto.RsaKeypair()
2019-04-08 09:22:52 +03:00
self.pubkey = self.my_key.publickey.serialize()
self.peer_pub = None
self.recv_lock = threading.Lock()
self.send_lock = threading.Lock()
2019-04-08 17:48:13 +03:00
self.packet_buffer = {
'ACK': [],
'SYN-ACK': [],
'DATA': [],
'FIN-ACK': []
}
2019-04-08 08:29:24 +03:00
@staticmethod
def gen_starting_seq_num():
return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ)
def seq_inc(self, inc=1):
self.seq += inc
return self.seq
2019-04-08 09:34:47 +03:00
def set_ack(self, ack):
self.ack = ack
return ack
2019-04-08 18:10:20 +03:00
class UTCPPacket:
def __cmp__(self, other):
return (self.seq > other.seq) - (self.seq < other.seq)
2019-04-08 22:15:04 +03:00
def __lt__(self, other):
return self.seq < other.seq
def __gt__(self, other):
return self.seq > other.seq
def __eq__(self, other):
return self.seq == other.seq
2019-04-08 18:10:20 +03:00
class Ack(UTCPPacket):
2019-04-08 17:48:13 +03:00
type = 'ACK'
2019-04-08 22:15:04 +03:00
def __init__(self, id_):
self.id = id_
2019-04-08 22:15:04 +03:00
self.seq = 0
2019-04-08 18:10:20 +03:00
class Fin(UTCPPacket):
2019-04-08 17:48:13 +03:00
type = 'FIN'
2019-04-08 22:15:04 +03:00
def __init__(self):
self.id = uuid.uuid4().bytes
2019-04-08 22:15:04 +03:00
self.seq = 0
2019-04-08 18:10:20 +03:00
class FinAck(UTCPPacket):
2019-04-08 17:48:13 +03:00
type = 'FIN-ACK'
2019-04-08 22:15:04 +03:00
def __init__(self, id_):
self.id = id_
2019-04-08 22:15:04 +03:00
self.seq = 0
2019-04-08 18:10:20 +03:00
class Syn(UTCPPacket):
2019-04-08 17:48:13 +03:00
type = 'SYN'
checksum = None
2019-04-08 22:15:04 +03:00
def __init__(self):
self.id = uuid.uuid4().bytes
2019-04-08 22:15:04 +03:00
self.seq = 0
def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey)
self.pubkey = pubkey
2019-04-08 18:10:20 +03:00
class SynAck(UTCPPacket):
2019-04-08 17:48:13 +03:00
type = 'SYN-ACK'
checksum = None
2019-04-08 22:15:04 +03:00
def __init__(self, id_):
self.id = id_
2019-04-08 22:15:04 +03:00
self.seq = 0
def set_pub(self, pubkey):
self.checksum = TCP.checksum(pubkey)
self.pubkey = pubkey
2019-04-08 18:10:20 +03:00
class Packet(UTCPPacket):
2019-04-08 22:15:04 +03:00
type = 'DATA'
def __init__(self):
self.id = uuid.uuid4().bytes
2019-04-08 08:29:24 +03:00
self.checksum = 0
self.data = b''
2019-04-08 22:15:04 +03:00
self.seq = 0
def __repr__(self):
2019-04-08 22:15:04 +03:00
return f'Packet(seq={self.seq})'
def __str__(self):
return 'UUID: %d, DATA:%s' % (uuid.UUID(bytes=self.id), self.data)
2019-04-08 09:22:52 +03:00
def set_data(self, data):
self.checksum = TCP.checksum(data)
self.data = data
2019-04-08 17:48:13 +03:00
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module != 'builtins':
if name == 'Packet': return Packet
if name == 'Ack': return Ack
if name == 'Fin': return Fin
if name == 'FinAck': return FinAck
if name == 'Syn': return Syn
if name == 'SynAck': return SynAck
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_pickle_loads(s):
return RestrictedUnpickler(io.BytesIO(s)).load()
class UTCPChannel:
HEADER = Struct('!I')
TERMINATOR = b'\n'
POLL_TIMEOUT = 0.1
def __init__(self, sock):
self.sock = sock
def poll(self):
return self.sock.poll(self.POLL_TIMEOUT)
def recv(self):
header = self.sock.recv(self.HEADER.size)
data_len = self.HEADER.unpack(header)[0]
terminated_data = self.sock.recv(data_len + len(self.TERMINATOR))
return terminated_data[:-len(self.TERMINATOR)]
def send(self, data):
header = self.HEADER.pack(len(data))
self.sock.send(header + data + self.TERMINATOR)
class ConnectedSOCK(object):
def __init__(self, low_sock, client_addr):
self.client_addr = client_addr
self.low_sock = low_sock
self.channel = UTCPChannel(self)
def __getattribute__(self, att):
2019-04-08 08:10:23 +03:00
try:
return object.__getattribute__(self, att)
except AttributeError:
return getattr(self.low_sock, att)
def getpeername(self):
return self.client_addr
def send(self, data):
if self.closed:
raise EOFError
2019-04-08 08:10:23 +03:00
return self.low_sock.send(data, self.client_addr)
def recv(self, size):
if self.closed:
raise EOFError
2019-04-08 08:10:23 +03:00
return self.low_sock.recv(size, self.client_addr)
@property
def closed(self):
return self.low_sock.own_socket._closed or self.client_addr not in self.low_sock.connections
def close(self):
2019-04-08 08:10:23 +03:00
if self.client_addr in self.low_sock.connections:
self.low_sock.close(self.client_addr)
def shutdown(self, *a, **kw):
self.close()
def poll(self, timeout):
if not self.closed:
conn = self.low_sock.connections[self.client_addr]
with conn.recv_lock:
2019-04-08 18:10:20 +03:00
has_data = len(conn.packet_buffer['DATA'])
if has_data:
return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
2019-04-08 18:10:20 +03:00
return len(conn.packet_buffer['DATA'])
2019-04-08 08:10:23 +03:00
return False
class TCP(object):
host = None
port = None
client = False
def __init__(self, af_type=None, sock_type=None, encrypted=False):
self.encrypted = encrypted
self.incoming_packet_event = threading.Event()
2019-04-08 08:10:23 +03:00
self.new_conn_event = threading.Event()
2019-04-08 08:29:24 +03:00
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.settimeout()
self.connection_lock = threading.Lock()
self.queue_lock = threading.Lock()
self.channel = None
2019-04-08 08:29:24 +03:00
self.connections = {}
self.connection_queue = []
2019-04-08 18:10:20 +03:00
self.syn_received = {}
2019-04-08 08:10:23 +03:00
def poll(self, timeout):
if len(self.connections):
connection = list(self.connections.keys())[0]
conn = self.connections[connection]
with conn.recv_lock:
2019-04-08 18:10:20 +03:00
has_data = len(conn.packet_buffer['DATA'])
if has_data:
2019-04-08 08:10:23 +03:00
return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
2019-04-08 18:10:20 +03:00
return len(conn.packet_buffer['DATA'])
2019-04-08 08:10:23 +03:00
return False
def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
def setsockopt(self, *a, **kw):
pass
def settimeout(self, timeout=5):
self.own_socket.settimeout(timeout)
2019-04-08 08:10:23 +03:00
def setblocking(self, mode):
self.own_socket.setblocking(mode)
def __repr__(self):
return 'TCP()'
def __str__(self):
return 'Connections: %s' \
% str(self.connections)
def getsockname(self):
2019-04-08 08:10:23 +03:00
return self.own_socket.getsockname()
def getpeername(self):
if len(self.connections):
return list(self.connections.keys())[0]
else:
raise EOFError('Not connected')
def bind(self, addr):
self.host = addr[0]
self.port = addr[1]
self.own_socket.bind(addr)
def send(self, data, connection=None):
2019-04-08 23:38:35 +03:00
if self.closed:
raise EOFError
try:
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
2019-04-08 14:06:13 +03:00
conn = self.connections[connection]
data_parts = TCP.data_divider(data)
for data_part in data_parts:
2019-04-08 14:06:13 +03:00
data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part)
packet = Packet()
2019-04-08 14:06:13 +03:00
packet.set_data(data_chunk)
2019-04-08 22:15:04 +03:00
answer = self.__send_packet(connection, packet, retransmit=True)
2019-04-08 08:10:23 +03:00
return len(data)
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
2019-04-08 08:10:23 +03:00
def recv(self, size, connection=None):
2019-04-08 23:38:35 +03:00
if self.closed:
raise EOFError
try:
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
2019-04-08 08:10:23 +03:00
raise EOFError('Connection not in connected devices')
2019-04-08 18:10:20 +03:00
data = self.find_correct_packet('DATA', connection, size)
2019-04-08 08:10:23 +03:00
if not self.status:
raise EOFError('Disconnecting')
return data
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
2019-04-08 22:15:04 +03:00
def __send_packet(self, peer_addr, packet, retransmit=False, wait_cond='ACK'):
conn = self.connections[peer_addr]
packet.seq = conn.seq_inc()
packet_to_send = pickle.dumps(packet)
if not retransmit:
self.own_socket.sendto(packet_to_send, peer_addr)
else:
data_not_received = True
retransmit_count = 0
while data_not_received and retransmit_count < 3:
data_not_received = False
try:
self.own_socket.sendto(packet_to_send, peer_addr)
answer = self.find_correct_packet(wait_cond, peer_addr, want_id=packet.id)
if not answer:
data_not_received = True
retransmit_count += 1
except socket.timeout:
data_not_received = True
if not answer:
self.drop_connection(peer_addr)
raise EOFError('Connection lost')
return answer
def listen_handler(self, max_connections):
try:
while True and self.status:
try:
answer, address = self.find_correct_packet('SYN')
with self.queue_lock:
if len(self.connection_queue) < max_connections:
2019-04-08 09:22:52 +03:00
conn = Connection(address, self.encrypted)
if self.encrypted:
try:
2019-04-08 17:48:13 +03:00
conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey)
except:
raise socket.error('Init peer public key error')
2019-04-08 09:22:52 +03:00
self.connection_queue.append((answer, conn))
2019-04-08 08:10:23 +03:00
self.blink_new_conn_event()
else:
2019-04-08 22:15:04 +03:00
if answer.id in map(lambda x: x[0].id, self.connection_queue):
continue
self.own_socket.sendto(b'Connections full', address)
except KeyError:
continue
except socket.error as error:
raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error))
def listen(self, max_connections=1):
self.status = 1
self.central_receive()
try:
t = threading.Thread(target=self.listen_handler, args=(max_connections,))
t.daemon = True
t.start()
except Exception as error:
raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error))
2019-04-08 23:38:35 +03:00
def stop(self):
self.own_socket.close()
self.status = 0
def accept(self):
2019-04-08 23:38:35 +03:00
while self.status:
try:
2019-04-08 08:10:23 +03:00
self.new_conn_event.wait(0.1)
if self.connection_queue:
with self.queue_lock:
2019-04-08 09:22:52 +03:00
answer, conn = self.connection_queue.pop()
self.connections[conn.peer_addr] = conn
syn_ack = SynAck(answer.id)
if self.encrypted:
syn_ack.set_pub(conn.peer_pub.encrypt_raw(conn.pubkey))
2019-04-08 22:15:04 +03:00
answer = self.__send_packet(conn.peer_addr, syn_ack, retransmit=True)
2019-04-08 09:22:52 +03:00
return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr
2019-04-08 23:38:35 +03:00
except EOFError:
if conn.peer_addr in self.connections:
self.close(conn.peer_addr)
continue
except Exception as error:
if conn.peer_addr in self.connections:
self.close(conn.peer_addr)
raise EOFError('Something went wrong in accept func: ' + str(error))
def connect(self, server_address=('127.0.0.1', 10000)):
try:
self.bind(('', self.get_free_port()))
self.status = 1
self.client = True
self.central_receive()
2019-04-08 09:22:52 +03:00
conn = Connection(server_address, self.encrypted)
2019-04-08 14:06:13 +03:00
self.connections[server_address] = conn
syn = Syn()
if self.encrypted:
syn.set_pub(conn.pubkey)
try:
2019-04-08 22:15:04 +03:00
answer = self.__send_packet(server_address, syn, retransmit=True, wait_cond='SYN-ACK')
except EOFError:
raise EOFError('Remote peer unreachable')
if type(answer) == str: # == 'Connections full':
raise socket.error('Server cant receive any connections right now.')
if self.encrypted:
try:
2019-04-08 17:48:13 +03:00
peer_pub = conn.my_key.decrypt_raw(answer.pubkey)
2019-04-08 09:22:52 +03:00
conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub)
except:
raise socket.error('Decrypt peer public key error')
ack = Ack(answer.id)
2019-04-08 22:15:04 +03:00
self.__send_packet(server_address, ack)
self.channel = UTCPChannel(self)
except socket.error as error:
self.own_socket.close()
self.connections = {}
self.status = 0
raise EOFError('The socket was closed. Error:' + str(error))
def fileno(self):
return self.own_socket.fileno()
@property
def closed(self):
2019-04-08 23:38:35 +03:00
return not bool(len(self.connections))
2019-04-08 08:10:23 +03:00
def drop_connection(self, connection):
with self.connection_lock:
if len(self.connections):
self.connections.pop(connection)
def close(self, connection=None):
try:
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
fin = Fin()
2019-04-08 22:15:04 +03:00
answer = self.__send_packet(connection, fin, retransmit=True)
answer = self.find_correct_packet('FIN-ACK', connection, want_id=fin.id)
if not answer:
raise Exception('The receiver didn\'t send the fin packet')
else:
2019-04-08 08:10:23 +03:00
self.drop_connection(connection)
if len(self.connections) == 0 and self.client:
2019-04-08 23:38:35 +03:00
self.stop()
except Exception as error:
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
def disconnect(self, connection, fin_id):
try:
ack = Ack(fin_id)
2019-04-08 22:15:04 +03:00
self.__send_packet(connection, ack)
fin_ack = FinAck(fin_id)
2019-04-08 09:58:19 +03:00
try:
2019-04-08 23:38:35 +03:00
answer = self.__send_packet(connection, fin_ack)
2019-04-08 09:58:19 +03:00
except:
pass
with self.connection_lock:
self.connections.pop(connection)
except Exception as error:
raise EOFError('Something went wrong in disconnect func:%s ' % error)
@staticmethod
def data_divider(data):
'''Divides the data into a list where each element's length is 1024'''
data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)]
return data
@staticmethod
def checksum(source_bytes):
return hashlib.sha1(source_bytes).digest()
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH, want_id=None):
not_found = True
tries = 0
while not_found and tries < 2 and self.status:
try:
not_found = False
if address[0] == 'Any':
2019-04-08 18:10:20 +03:00
order = self.syn_received.popitem() # to reverse the tuple received
return order[1], order[0]
2019-04-08 14:06:13 +03:00
conn = self.connections[address]
if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']:
tries += 1
2019-04-08 18:10:20 +03:00
if condition == 'DATA':
2019-04-08 22:15:04 +03:00
if len(conn.packet_buffer[condition]):
data = b''
while size and conn.packet_buffer[condition]:
with conn.recv_lock:
packet = conn.packet_buffer[condition][0]
chunk = packet.data[:size]
chunk_len = len(chunk)
data += chunk
packet.data = packet.data[size:]
size -= chunk_len
if not len(packet.data):
try:
conn.packet_buffer[condition].pop(0)
except IndexError:
size = 0
return data
else:
raise KeyError
else:
2019-04-08 17:48:13 +03:00
with conn.recv_lock:
2019-04-08 22:15:04 +03:00
if len(conn.packet_buffer[condition]):
packet = conn.packet_buffer[condition].pop()
else:
raise KeyError
if want_id and packet.id != want_id:
raise KeyError
return packet
except KeyError:
not_found = True
2019-04-08 23:38:35 +03:00
self.incoming_packet_event.wait(0.5)
2019-04-08 08:10:23 +03:00
def blink_incoming_packet_event(self):
self.incoming_packet_event.set()
self.incoming_packet_event.clear()
2019-04-08 08:10:23 +03:00
def blink_new_conn_event(self):
self.new_conn_event.set()
self.new_conn_event.clear()
def sort_answers(self, packet, address):
if address not in self.connections and not isinstance(packet, Syn):
2019-04-08 08:10:23 +03:00
return
2019-04-08 22:15:04 +03:00
if not isinstance(packet, Syn):
conn = self.connections[address]
if conn.recv_seq == packet.seq:
ack = Ack(packet.id)
self.__send_packet(address, ack)
return
elif conn.recv_seq > packet.seq:
return
else:
conn.recv_seq = packet.seq
if isinstance(packet, Packet):
if packet.checksum == TCP.checksum(packet.data):
if self.encrypted:
packet.data = conn.my_key.decrypt_raw(packet.data)
else:
return
if isinstance(packet, Fin):
self.disconnect(address, packet.id)
2019-04-08 22:15:04 +03:00
elif isinstance(packet, Syn):
if packet.id not in map(lambda x: x.id, self.syn_received.values()):
self.syn_received[address] = packet
else:
2019-04-08 22:15:04 +03:00
with conn.recv_lock:
if packet.id not in map(lambda x: x.id, conn.packet_buffer[packet.type]):
bisect.insort(conn.packet_buffer[packet.type], packet)
if isinstance(packet, Packet):
ack = Ack(packet.id)
self.__send_packet(address, ack)
self.blink_incoming_packet_event()
def central_receive_handler(self):
while True and self.status:
try:
packet, address = self.own_socket.recvfrom(SENT_SIZE)
packet = restricted_pickle_loads(packet)
self.sort_answers(packet, address)
except pickle.UnpicklingError:
continue
except socket.timeout:
continue
except socket.error as error:
self.own_socket.close()
self.status = 0
def central_receive(self):
t = threading.Thread(target=self.central_receive_handler)
t.daemon = True
t.start()