493 lines
21 KiB
Python
493 lines
21 KiB
Python
|
# based on https://github.com/ethay012/TCP-over-UDP
|
||
|
import random
|
||
|
import socket
|
||
|
import pickle
|
||
|
import threading
|
||
|
import io
|
||
|
import hashlib
|
||
|
import simplecrypto
|
||
|
from datetime import datetime
|
||
|
|
||
|
|
||
|
DATA_DIVIDE_LENGTH = 8000
|
||
|
PACKET_HEADER_SIZE = 512 # Pickle service info
|
||
|
DATA_LENGTH = DATA_DIVIDE_LENGTH
|
||
|
SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger
|
||
|
LAST_CONNECTION = -1
|
||
|
FIRST = 0
|
||
|
PACKET_END = b'___+++^^^END^^^+++___'
|
||
|
|
||
|
# need for emulate
|
||
|
AF_INET = None
|
||
|
SOCK_STREAM = None
|
||
|
|
||
|
class KeyPair:
|
||
|
my_key = None
|
||
|
peer_pub = None
|
||
|
def __init__(self, sec):
|
||
|
self.my_key = sec
|
||
|
|
||
|
class TCPPacket(object):
|
||
|
'''
|
||
|
Add Documentation here
|
||
|
'''
|
||
|
SMALLEST_STARTING_SEQ = 0
|
||
|
HIGHEST_STARTING_SEQ = 4294967295
|
||
|
|
||
|
def __init__(self):
|
||
|
# self.src_port = src_port # 16bit
|
||
|
# self.dst_port = dst_port # 16bit
|
||
|
self.seq = TCPPacket.gen_starting_seq_num() # 32bit
|
||
|
self.ack = 0 # 32bit
|
||
|
self.data_offset = 0 # 4 bits
|
||
|
self.reserved_field = 0 # 3bits saved for future use must be zero assert self.reserved_field = 0
|
||
|
#FLAGS
|
||
|
self.flag_ns = 0 # 1bit
|
||
|
self.flag_cwr = 0 # 1bit
|
||
|
self.flag_ece = 0 # 1bit
|
||
|
self.flag_urg = 0 # 1bit
|
||
|
self.flag_ack = 0 # 1bit
|
||
|
self.flag_psh = 0 # 1bit
|
||
|
self.flag_rst = 0 # 1bit
|
||
|
self.flag_syn = 0 # 1bit
|
||
|
self.flag_fin = 0 # 1bit
|
||
|
#window size
|
||
|
self.window_size = 0 # 16bit
|
||
|
#checksum
|
||
|
self.checksum = 0 # 16bit
|
||
|
#urgent pointer
|
||
|
self.urgent_pointer = 0 # 16bit
|
||
|
#options
|
||
|
self.options = 0 # 0-320bits, divisible by 32
|
||
|
#padding - TCP packet must be on a 32bit boundary this ensures that it is the padding is filled with 0's
|
||
|
self.padding = 0 # as much as needed
|
||
|
self.data = b''
|
||
|
def __repr__(self):
|
||
|
return 'TCPpacket()'
|
||
|
|
||
|
def __str__(self):
|
||
|
return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \
|
||
|
% (self.seq, self.ack, self.flag_ack, self.flag_syn, self.flag_fin, self.packet_type(), self.data)
|
||
|
|
||
|
def __cmp__(self, other):
|
||
|
return (self.seq > other.seq) - (self.seq < other.seq)
|
||
|
|
||
|
def packet_type(self):
|
||
|
packet_type = ''
|
||
|
if self.flag_syn == 1 and self.flag_ack == 1:
|
||
|
packet_type = 'SYN-ACK'
|
||
|
elif self.flag_ack == 1 and self.flag_fin == 1:
|
||
|
packet_type = 'FIN-ACK'
|
||
|
elif self.flag_syn == 1:
|
||
|
packet_type = 'SYN'
|
||
|
elif self.flag_ack == 1:
|
||
|
packet_type = 'ACK'
|
||
|
elif self.flag_fin == 1:
|
||
|
packet_type = 'FIN'
|
||
|
elif self.data != b'':
|
||
|
packet_type = 'DATA'
|
||
|
return packet_type
|
||
|
|
||
|
def set_flags(self, ack=False, syn=False, fin=False):
|
||
|
if ack:
|
||
|
self.flag_ack = 1
|
||
|
else:
|
||
|
self.flag_ack = 0
|
||
|
if syn:
|
||
|
self.flag_syn = 1
|
||
|
else:
|
||
|
self.flag_syn = 0
|
||
|
if fin:
|
||
|
self.flag_fin = 1
|
||
|
else:
|
||
|
self.flag_fin = 0
|
||
|
|
||
|
@staticmethod
|
||
|
def gen_starting_seq_num():
|
||
|
return random.randint(TCPPacket.SMALLEST_STARTING_SEQ, TCPPacket.HIGHEST_STARTING_SEQ)
|
||
|
|
||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||
|
def find_class(self, module, name):
|
||
|
# Only allow TCPPacket
|
||
|
if module == "utcp" and name == 'TCPPacket':
|
||
|
return TCPPacket
|
||
|
# Forbid everything else.
|
||
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||
|
(module, name))
|
||
|
def restricted_pickle_loads(s):
|
||
|
"""Helper function analogous to pickle.loads()."""
|
||
|
return RestrictedUnpickler(io.BytesIO(s)).load()
|
||
|
|
||
|
class ConnectedSOCK(object):
|
||
|
def __init__(self, low_sock, client_addr):
|
||
|
self.client_addr = client_addr
|
||
|
self.low_sock = low_sock
|
||
|
def __getattribute__(self, att):
|
||
|
if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']:
|
||
|
if att in self.low_sock.__dict__:
|
||
|
return getattr(self.low_sock, att)
|
||
|
return object.__getattribute__(self, att)
|
||
|
def send(self, data):
|
||
|
if self.closed:
|
||
|
raise EOFError
|
||
|
self.low_sock.send(data, self.client_addr)
|
||
|
def recv(self, size=None):
|
||
|
if self.closed:
|
||
|
raise EOFError
|
||
|
if size:
|
||
|
return self.low_sock.recv(self.client_addr)[:size]
|
||
|
else:
|
||
|
return self.low_sock.recv(self.client_addr)
|
||
|
@property
|
||
|
def closed(self):
|
||
|
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
|
||
|
def close(self):
|
||
|
self.low_sock.close(self.client_addr)
|
||
|
|
||
|
class TCP(object):
|
||
|
host = None
|
||
|
port = None
|
||
|
client = False
|
||
|
def __init__(self, af_type=None, sock_type=None, encrypted=False):
|
||
|
self.encrypted = encrypted
|
||
|
self.incoming_packet_event = threading.Event()
|
||
|
#seq will have the last packet send and ack will have the next packet waiting to receive
|
||
|
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
|
||
|
self.settimeout()
|
||
|
self.peer_keypair = {}
|
||
|
self.connections = {}
|
||
|
self.connection_queue = []
|
||
|
self.connection_lock = threading.Lock()
|
||
|
self.queue_lock = threading.Lock()
|
||
|
# each condition will have a dictionary of an address and it's corresponding packet.
|
||
|
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||
|
def get_free_port(self):
|
||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||
|
s.bind(('', 0))
|
||
|
port = s.getsockname()[1]
|
||
|
s.close()
|
||
|
return port
|
||
|
def setsockopt(self, *a, **kw):
|
||
|
pass
|
||
|
def settimeout(self, timeout=5):
|
||
|
self.own_socket.settimeout(timeout)
|
||
|
def __repr__(self):
|
||
|
return 'TCP()'
|
||
|
|
||
|
def __str__(self):
|
||
|
return 'Connections: %s' \
|
||
|
% str(self.connections)
|
||
|
def getsockname(self):
|
||
|
return (self.host, self.port)
|
||
|
def bind(self, addr):
|
||
|
self.host = addr[0]
|
||
|
self.port = addr[1]
|
||
|
self.own_socket.bind(addr)
|
||
|
def send(self, data, connection=None):
|
||
|
try:
|
||
|
if connection not in list(self.connections.keys()):
|
||
|
if connection is None:
|
||
|
connection = list(self.connections.keys())[0]
|
||
|
else:
|
||
|
raise EOFError('Connection not in connected devices')
|
||
|
data_parts = TCP.data_divider(data)
|
||
|
for data_part in data_parts:
|
||
|
data_not_received = True
|
||
|
data_chunk = data_part if not self.encrypted else self.peer_keypair[connection].peer_pub.encrypt_raw(data_part)
|
||
|
checksum_of_data = TCP.checksum(data_chunk)
|
||
|
self.connections[connection].checksum = checksum_of_data
|
||
|
self.connections[connection].data = data_chunk
|
||
|
self.connections[connection].set_flags()
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.connections[connection].checksum = 0
|
||
|
self.connections[connection].data = b''
|
||
|
while data_not_received:
|
||
|
data_not_received = False
|
||
|
try:
|
||
|
self.own_socket.sendto(packet_to_send, connection)
|
||
|
answer = self.find_correct_packet('ACK', connection)
|
||
|
except socket.timeout:
|
||
|
#print('timeout')
|
||
|
data_not_received = True
|
||
|
self.connections[connection].seq += len(data_part)
|
||
|
except socket.error as error:
|
||
|
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||
|
|
||
|
def recv(self, connection=None):
|
||
|
try:
|
||
|
data = b''
|
||
|
if connection not in list(self.connections.keys()):
|
||
|
if connection is None:
|
||
|
connection = list(self.connections.keys())[0]
|
||
|
else:
|
||
|
return 'Connection not in connected devices'
|
||
|
|
||
|
while True and self.status:
|
||
|
data_part = self.find_correct_packet('DATA or FIN', connection)
|
||
|
if not self.status:
|
||
|
# print('I am disconnectiong cause sock is dead')
|
||
|
raise EOFError('Disconnected')
|
||
|
if data_part.packet_type() == 'FIN':
|
||
|
self.disconnect(connection)
|
||
|
raise EOFError('Disconnected')
|
||
|
checksum_value = TCP.checksum(data_part.data)
|
||
|
|
||
|
while checksum_value != data_part.checksum:
|
||
|
data_part = self.find_correct_packet('DATA or FIN', connection)
|
||
|
checksum_value = TCP.checksum(data_part.data)
|
||
|
|
||
|
data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data)
|
||
|
if data_chunk != PACKET_END:
|
||
|
data += data_chunk
|
||
|
self.connections[connection].ack = data_part.seq + len(data_part.data)
|
||
|
self.connections[connection].seq += 1 # syn flag is 1 byte
|
||
|
self.connections[connection].set_flags(ack=True)
|
||
|
self.connections[connection].data = b''
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
|
||
|
self.connections[connection].set_flags()
|
||
|
|
||
|
if data_chunk == PACKET_END:
|
||
|
break
|
||
|
|
||
|
return data
|
||
|
except socket.error as error:
|
||
|
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||
|
|
||
|
# conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA']
|
||
|
def listen_handler(self, max_connections):
|
||
|
try:
|
||
|
while True and self.status:
|
||
|
try:
|
||
|
answer, address = self.find_correct_packet('SYN')
|
||
|
with self.queue_lock:
|
||
|
if len(self.connection_queue) < max_connections:
|
||
|
if self.encrypted:
|
||
|
try:
|
||
|
peer_pub = answer.data
|
||
|
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key
|
||
|
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
||
|
except:
|
||
|
self.peer_keypair.pop(address)
|
||
|
raise socket.error('Init peer public key error')
|
||
|
self.connection_queue.append((answer, address))
|
||
|
else:
|
||
|
self.own_socket.sendto('Connections full', address)
|
||
|
except KeyError:
|
||
|
continue
|
||
|
except socket.error as error:
|
||
|
raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error))
|
||
|
|
||
|
def listen(self, max_connections=1):
|
||
|
self.status = 1
|
||
|
self.central_receive()
|
||
|
try:
|
||
|
t = threading.Thread(target=self.listen_handler, args=(max_connections,))
|
||
|
t.daemon = True
|
||
|
t.start()
|
||
|
except Exception as error:
|
||
|
raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error))
|
||
|
|
||
|
def accept(self):
|
||
|
try:
|
||
|
while True:
|
||
|
if self.connection_queue:
|
||
|
with self.queue_lock:
|
||
|
answer, address = self.connection_queue.pop()
|
||
|
self.connections[address] = TCPPacket()
|
||
|
self.connections[address].ack = answer.seq + 1
|
||
|
self.connections[address].seq += 1
|
||
|
self.connections[address].set_flags(ack=True, syn=True)
|
||
|
if self.encrypted:
|
||
|
pubkey = self.peer_keypair[address].my_key.publickey.serialize()
|
||
|
self.connections[address].data = self.peer_keypair[address].peer_pub.encrypt_raw(pubkey)
|
||
|
self.connections[address].checksum = TCP.checksum(self.connections[address].data)
|
||
|
packet_to_send = pickle.dumps(self.connections[address])
|
||
|
if self.encrypted:
|
||
|
self.connections[address].data = b''
|
||
|
self.connections[address].checksum = 0
|
||
|
#lock address, connections dictionary?
|
||
|
packet_not_sent_correctly = True
|
||
|
while packet_not_sent_correctly or answer is None:
|
||
|
try:
|
||
|
packet_not_sent_correctly = False
|
||
|
self.own_socket.sendto(packet_to_send, address)
|
||
|
answer = self.find_correct_packet('ACK', address)
|
||
|
except socket.timeout:
|
||
|
packet_not_sent_correctly = True
|
||
|
self.connections[address].set_flags()
|
||
|
self.connections[address].ack = answer.seq + 1
|
||
|
return ConnectedSOCK(self, address), address
|
||
|
except Exception as error:
|
||
|
self.close(address)
|
||
|
raise EOFError('Something went wrong in accept func: ' + str(error))
|
||
|
|
||
|
def connect(self, server_address=('127.0.0.1', 10000)):
|
||
|
try:
|
||
|
self.bind(('', self.get_free_port()))
|
||
|
self.status = 1
|
||
|
self.client = True
|
||
|
self.central_receive()
|
||
|
self.connections[server_address] = TCPPacket()
|
||
|
self.connections[server_address].set_flags(syn=True)
|
||
|
if self.encrypted:
|
||
|
self.peer_keypair[server_address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: but here it creates the key quickly
|
||
|
pubkey = self.peer_keypair[server_address].my_key.publickey.serialize()
|
||
|
pub_checksum = TCP.checksum(pubkey)
|
||
|
self.connections[server_address].checksum = pub_checksum
|
||
|
self.connections[server_address].data = pubkey
|
||
|
first_packet_to_send = pickle.dumps(self.connections[server_address])
|
||
|
self.connections[server_address].data = b''
|
||
|
self.connections[server_address].checksum = 0
|
||
|
|
||
|
self.own_socket.sendto(first_packet_to_send, list(self.connections.keys())[FIRST])
|
||
|
self.connections[server_address].set_flags()
|
||
|
answer = self.find_correct_packet('SYN-ACK', server_address)
|
||
|
if type(answer) == str: # == 'Connections full':
|
||
|
raise socket.error('Server cant receive any connections right now.')
|
||
|
if self.encrypted:
|
||
|
try:
|
||
|
peer_pub = self.peer_keypair[server_address].my_key.decrypt_raw(answer.data)
|
||
|
self.peer_keypair[server_address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
||
|
except:
|
||
|
raise socket.error('Decrypt peer public key error')
|
||
|
if not peer_pub or answer.checksum != TCP.checksum(answer.data):
|
||
|
raise socket.error('Invalid peer public key')
|
||
|
|
||
|
self.connections[server_address].ack = answer.seq + 1
|
||
|
self.connections[server_address].seq += 1
|
||
|
self.connections[server_address].set_flags(ack=True)
|
||
|
second_packet_to_send = pickle.dumps(self.connections[server_address])
|
||
|
self.own_socket.sendto(second_packet_to_send, list(self.connections.keys())[FIRST])
|
||
|
self.connections[server_address].set_flags()
|
||
|
|
||
|
except socket.error as error:
|
||
|
self.own_socket.close()
|
||
|
self.connections = {}
|
||
|
self.peer_keypair = {}
|
||
|
self.status = 0
|
||
|
raise EOFError('The socket was closed. Error:' + str(error))
|
||
|
def shutdown(self, *a, **kw):
|
||
|
self.own_socket.close()
|
||
|
self.status = 0
|
||
|
def fileno(self):
|
||
|
return self.own_socket.fileno()
|
||
|
@property
|
||
|
def closed(self):
|
||
|
return self.own_socket._closed
|
||
|
def close(self, connection=None):
|
||
|
try:
|
||
|
if connection not in list(self.connections.keys()):
|
||
|
if connection is None:
|
||
|
connection = list(self.connections.keys())[0]
|
||
|
else:
|
||
|
raise EOFError('Connection not in connected devices')
|
||
|
self.connections[connection].set_flags(fin=True)
|
||
|
self.connections[connection].seq += 1
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.own_socket.sendto(packet_to_send, connection)
|
||
|
answer = self.find_correct_packet('ACK', connection) # change cause may get a None value
|
||
|
self.connections[connection].ack += 1
|
||
|
answer = self.find_correct_packet('FIN-ACK', connection)
|
||
|
if answer.flag_fin != 1:
|
||
|
raise Exception('The receiver didn\'t send the fin packet')
|
||
|
else:
|
||
|
self.connections[connection].ack += 1
|
||
|
self.connections[connection].seq += 1
|
||
|
self.connections[connection].set_flags(ack=True)
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.own_socket.sendto(packet_to_send, connection)
|
||
|
with self.connection_lock:
|
||
|
if len(self.connections):
|
||
|
self.connections.pop(connection)
|
||
|
if len(self.peer_keypair):
|
||
|
self.peer_keypair.pop(connection)
|
||
|
#if len(self.connections) == 0 and self.client:
|
||
|
# self.own_socket.close()
|
||
|
# self.status = 0
|
||
|
except Exception as error:
|
||
|
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
|
||
|
|
||
|
def disconnect(self, connection):
|
||
|
try:
|
||
|
self.connections[connection].ack += 1
|
||
|
self.connections[connection].seq += 1
|
||
|
self.connections[connection].set_flags(ack=True)
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.own_socket.sendto(packet_to_send, connection)
|
||
|
self.connections[connection].set_flags(fin=True, ack=True)
|
||
|
self.connections[connection].seq += 1
|
||
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||
|
self.own_socket.sendto(packet_to_send, connection)
|
||
|
answer = self.find_correct_packet('ACK', connection)
|
||
|
with self.connection_lock:
|
||
|
self.connections.pop(connection)
|
||
|
except Exception as error:
|
||
|
raise EOFError('Something went wrong in disconnect func:%s ' % error)
|
||
|
|
||
|
@staticmethod
|
||
|
def data_divider(data):
|
||
|
'''Divides the data into a list where each element's length is 1024'''
|
||
|
data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)]
|
||
|
data.append(PACKET_END)
|
||
|
return data
|
||
|
|
||
|
@staticmethod
|
||
|
def checksum(source_bytes):
|
||
|
return hashlib.sha1(source_bytes).digest()
|
||
|
|
||
|
def find_correct_packet(self, condition, address=('Any',)):
|
||
|
not_found = True
|
||
|
tries = 0
|
||
|
while not_found and tries < 2 and self.status:
|
||
|
try:
|
||
|
not_found = False
|
||
|
if address[0] == 'Any':
|
||
|
order = self.packets_received[condition].popitem() # to reverse the tuple received
|
||
|
return order[1], order[0]
|
||
|
if condition == 'ACK':
|
||
|
tries += 1
|
||
|
if condition == 'DATA or FIN':
|
||
|
packet = self.packets_received[condition][address].pop()
|
||
|
if not len(self.packets_received[condition][address]):
|
||
|
del self.packets_received[condition][address]
|
||
|
else:
|
||
|
packet = self.packets_received[condition].pop(address)
|
||
|
return packet
|
||
|
except KeyError:
|
||
|
not_found = True
|
||
|
self.incoming_packet_event.wait(0.1)
|
||
|
def blink_event(self):
|
||
|
self.incoming_packet_event.set()
|
||
|
self.incoming_packet_event.clear()
|
||
|
def sort_answers(self, packet, address):
|
||
|
if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN':
|
||
|
if address not in self.packets_received['DATA or FIN']:
|
||
|
self.packets_received['DATA or FIN'][address] = []
|
||
|
self.packets_received['DATA or FIN'][address].insert(0, packet)
|
||
|
self.blink_event()
|
||
|
elif packet.packet_type() == '':
|
||
|
#print('redundant packet found', packet)
|
||
|
pass
|
||
|
else:
|
||
|
self.packets_received[packet.packet_type()][address] = packet
|
||
|
self.blink_event()
|
||
|
|
||
|
def central_receive_handler(self):
|
||
|
while True and self.status:
|
||
|
try:
|
||
|
packet, address = self.own_socket.recvfrom(SENT_SIZE)
|
||
|
packet = restricted_pickle_loads(packet)
|
||
|
self.sort_answers(packet, address)
|
||
|
except socket.timeout:
|
||
|
continue
|
||
|
except socket.error as error:
|
||
|
self.own_socket.close()
|
||
|
self.status = 0
|
||
|
# print('An error has occured: Socket error %s' % error)
|
||
|
|
||
|
def central_receive(self):
|
||
|
t = threading.Thread(target=self.central_receive_handler)
|
||
|
t.daemon = True
|
||
|
t.start()
|