diff --git a/mods/utcp.py b/mods/utcp.py index 68dd347..6d63d30 100644 --- a/mods/utcp.py +++ b/mods/utcp.py @@ -25,43 +25,30 @@ class KeyPair: def __init__(self, sec): self.my_key = sec -class TCPPacket(object): - ''' - Add Documentation here - ''' +class Connection: SMALLEST_STARTING_SEQ = 0 HIGHEST_STARTING_SEQ = 4294967295 + def __init__(self, remote, encrypted=False): + self.peer_addr = remote + self.seq = Connection.gen_starting_seq_num() + self.my_key + @staticmethod + def gen_starting_seq_num(): + return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ) + def seq_inc(self, inc=1): + self.seq += inc + return self.seq - def __init__(self): - # self.src_port = src_port # 16bit - # self.dst_port = dst_port # 16bit - self.seq = TCPPacket.gen_starting_seq_num() # 32bit - self.ack = 0 # 32bit - self.data_offset = 0 # 4 bits - self.reserved_field = 0 # 3bits saved for future use must be zero assert self.reserved_field = 0 - #FLAGS - self.flag_ns = 0 # 1bit - self.flag_cwr = 0 # 1bit - self.flag_ece = 0 # 1bit - self.flag_urg = 0 # 1bit - self.flag_ack = 0 # 1bit - self.flag_psh = 0 # 1bit - self.flag_rst = 0 # 1bit - self.flag_syn = 0 # 1bit - self.flag_fin = 0 # 1bit - #window size - self.window_size = 0 # 16bit - #checksum - self.checksum = 0 # 16bit - #urgent pointer - self.urgent_pointer = 0 # 16bit - #options - self.options = 0 # 0-320bits, divisible by 32 - #padding - TCP packet must be on a 32bit boundary this ensures that it is the padding is filled with 0's - self.padding = 0 # as much as needed +class TCPPacket(object): + def __init__(self, seq): + self.seq = seq + self.ack = 0 + self.flag_ack = 0 + self.flag_syn = 0 + self.checksum = 0 self.data = b'' def __repr__(self): - return 'TCPpacket()' + return f'TCPpacket(type={self.packet_type()})' def __str__(self): return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \ @@ -100,20 +87,15 @@ class TCPPacket(object): else: self.flag_fin = 0 - @staticmethod - def gen_starting_seq_num(): - return random.randint(TCPPacket.SMALLEST_STARTING_SEQ, TCPPacket.HIGHEST_STARTING_SEQ) + class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): - # Only allow TCPPacket if name == 'TCPPacket': return TCPPacket - # Forbid everything else. raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_pickle_loads(s): - """Helper function analogous to pickle.loads().""" return RestrictedUnpickler(io.BytesIO(s)).load() class ConnectedSOCK(object): @@ -155,24 +137,18 @@ class TCP(object): host = None port = None client = False - peer_keypair = {} - connections = {} - connection_queue = [] - packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def __init__(self, af_type=None, sock_type=None, encrypted=False): self.encrypted = encrypted self.incoming_packet_event = threading.Event() self.new_conn_event = threading.Event() - #seq will have the last packet send and ack will have the next packet waiting to receive - self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication. + self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.settimeout() - #self.peer_keypair = {} - #self.connections = {} - #self.connection_queue = [] self.connection_lock = threading.Lock() self.queue_lock = threading.Lock() - # each condition will have a dictionary of an address and it's corresponding packet. - #self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} + self.peer_keypair = {} + self.connections = {} + self.connection_queue = [] + self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} def poll(self, timeout): if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']: return True @@ -312,6 +288,7 @@ class TCP(object): if self.connection_queue: with self.queue_lock: answer, address = self.connection_queue.pop() + self.connections[address] = TCPPacket() self.connections[address].ack = answer.seq + 1 self.connections[address].seq += 1