encrypted "tcp" over udp

utcp_objects
Роман Бородин 2019-04-08 08:10:23 +03:00
parent 20e70a5ed3
commit f7a4da82bd
4 changed files with 315 additions and 97 deletions

0
mods/__init__.py 100644
View File

View File

@ -1,6 +1,12 @@
# Based on rpyc (4.0.2)
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
import rpyc import rpyc
from rpyc.utils.server import ThreadedServer, spawn from rpyc.utils.server import ThreadedServer, spawn
from rpyc.core.stream import SocketStream from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
import socket import socket
from mbedtls import tls from mbedtls import tls
import datetime as dt import datetime as dt
@ -10,6 +16,10 @@ from mbedtls import x509
from uuid import uuid4 from uuid import uuid4
from contextlib import suppress from contextlib import suppress
import sys import sys
from rpyc.lib import safe_import
zlib = safe_import("zlib")
import errno
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
def block(callback, *args, **kwargs): def block(callback, *args, **kwargs):
while True: while True:
@ -33,6 +43,7 @@ class DTLSCerts:
self.ca1_crt = self.ca0_crt.sign( self.ca1_crt = self.ca0_crt.sign(
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456, ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3)) basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
self.srv_crt, self.srv_key = self.server_cert()
def server_cert(self): def server_cert(self):
now = dt.datetime.utcnow() now = dt.datetime.utcnow()
ee0_key = pk.ECC() ee0_key = pk.ECC()
@ -46,17 +57,28 @@ dtls_certs = DTLSCerts()
trust_store = tls.TrustStore() trust_store = tls.TrustStore()
trust_store.add(dtls_certs.ca0_crt) trust_store.add(dtls_certs.ca0_crt)
srv_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
validate_certificates=False,
)
cli_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
)
MAX_IO_CHUNK = 20971520
class DTLSSocketStream(SocketStream): class DTLSSocketStream(SocketStream):
MAX_IO_CHUNK = MAX_IO_CHUNK
@classmethod @classmethod
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs): def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
if kwargs.pop('ipv6', False): if kwargs.pop('ipv6', False):
family = socket.AF_INET6 family = socket.AF_INET6
else: else:
family = socket.AF_INET family = socket.AF_INET
dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration( #tls._enable_debug_output(cli_ctx_conf)
trust_store=trust_store, #tls._set_debug_level(10)
validate_certificates=False, dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
))
dtls_cli = dtls_cli_ctx.wrap_socket( dtls_cli = dtls_cli_ctx.wrap_socket(
socket.socket(family, socket.SOCK_DGRAM), socket.socket(family, socket.SOCK_DGRAM),
server_hostname=None, server_hostname=None,
@ -66,33 +88,128 @@ class DTLSSocketStream(SocketStream):
block(dtls_cli.do_handshake) block(dtls_cli.do_handshake)
return cls(dtls_cli) return cls(dtls_cli)
def read(self, count): def read(self, count):
return block(SocketStream.read(self, count)) while True:
try:
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
except socket.timeout:
continue
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in retry_errnos:
# windows just has to be a bitch
# inpos: I agree
continue
self.close()
raise EOFError(ex)
else:
break
if not buf:
self.close()
raise EOFError("connection closed by peer")
return buf
def write(self, data): def write(self, data):
block(SocketStream.write(self, data)) try:
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
except socket.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
class DTLSChannel(Channel):
MAX_IO_CHUNK = MAX_IO_CHUNK
def recv(self):
raw_data = self.stream.read(self.MAX_IO_CHUNK)
header = raw_data[:self.FRAME_HEADER.size]
raw_data = raw_data[self.FRAME_HEADER.size:]
length, compressed = self.FRAME_HEADER.unpack(header)
data = raw_data[:length]
if compressed:
data = zlib.decompress(data)
return data
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(DTLSChannel(stream), service=service, config=config)
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
cert_reqs=None, ssl_version=None, ciphers=None,
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
ssl_kwargs = {'server_side' : False}
if ciphers is not None:
ssl_kwargs['ciphers'] = ciphers
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
return connect_stream(s, service, config)
class DTLSThreadedServer(ThreadedServer): class DTLSThreadedServer(ThreadedServer):
def dtls(self, listener_timeout = 0.5, reuse_addr = True): def __init__(self, service, hostname = "", ipv6 = False, port = 0,
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close() self.listener.close()
srv_crt, srv_key = dtls_certs.server_cert()
dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration( self.host = hostname
trust_store=trust_store, self.port = port
certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key), #tls._enable_debug_output(srv_ctx_conf)
validate_certificates=False, #tls._set_debug_level(10)
)) dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
if reuse_addr and sys.platform != 'win32': if reuse_addr and sys.platform != 'win32':
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dtls_srv.bind((self.host, self.port)) dtls_srv.bind((hostname, port))
dtls_srv.settimeout(listener_timeout) # dtls_srv.settimeout(listener_timeout)
self.listener = dtls_srv self.listener = dtls_srv
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)
def _listen(self): def _listen(self):
if self.active: if self.active:
return return
#self.listener.listen(self.backlog) #####################
if not self.port: if not self.port:
self.port = self.listener.getsockname()[1] self.port = self.listener.getsockname()[1]
self.logger.info('server started on [%s]:%s', self.host, self.port) self.logger.info('server started on [%s]:%s', self.host, self.port)
self.active = True self.active = True
def _authenticate_and_serve_client(self, sock): def accept(self):
while self.active:
try:
print('Accepting new connections!')
sock, addrinfo = self.listener.accept()
except socket.timeout:
pass
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
pass
else:
raise EOFError()
raise
else:
break
if not self.active:
return
sock.setblocking(True)
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
self.clients.add(sock)
self._accept_method(sock)
def _accept_method(self, sock):
addr = sock.getpeername() addr = sock.getpeername()
sock.setcookieparam(addr[0].encode()) sock.setcookieparam(addr[0].encode())
with suppress(tls.HelloVerifyRequest): with suppress(tls.HelloVerifyRequest):
@ -100,6 +217,6 @@ class DTLSThreadedServer(ThreadedServer):
sock, addr = sock.accept() sock, addr = sock.accept()
sock.setcookieparam(addr[0].encode()) sock.setcookieparam(addr[0].encode())
block(sock.do_handshake) block(sock.do_handshake)
ThreadedServer._authenticate_and_serve_client(self, sock) spawn(self._authenticate_and_serve_client, sock)

74
mods/rpyc_utcp.py 100644
View File

@ -0,0 +1,74 @@
import rpyc
from rpyc.utils.server import ThreadedServer
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
from rpyc.lib import Timeout
from rpyc.lib.compat import select_error
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
from . import utcp
import sys
import errno
import socket
class UTCPSocketStream(SocketStream):
MAX_IO_CHUNK = utcp.DATA_LENGTH
@classmethod
def utcp_connect(cls, host, port, *a, **kw):
sock = utcp.TCP(encrypted=True)
sock.connect((host, port))
return cls(sock)
def poll(self, timeout):
timeout = Timeout(timeout)
try:
while True:
try:
rl = self.sock.poll(timeout.timeleft())
except select_error:
ex = sys.exc_info()[1]
if ex.args[0] == errno.EINTR:
continue
else:
raise
else:
break
except ValueError:
ex = sys.exc_info()[1]
raise select_error(str(ex))
return rl
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(Channel(stream), service=service, config=config)
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
s = UTCPSocketStream.utcp_connect(host, port, **kw)
return connect_stream(s, service, config)
class UTCPThreadedServer(ThreadedServer):
def __init__(self, service, hostname = '', ipv6 = False, port = 0,
backlog = 1, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
backlog = 1
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close()
self.listener = None
##########
self.listener = utcp.TCP(encrypted=True)
self.listener.bind((hostname, port))
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(Channel(UTCPSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)

View File

@ -6,8 +6,6 @@ import threading
import io import io
import hashlib import hashlib
import simplecrypto import simplecrypto
from datetime import datetime
DATA_DIVIDE_LENGTH = 8000 DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info PACKET_HEADER_SIZE = 512 # Pickle service info
@ -109,7 +107,7 @@ class TCPPacket(object):
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name): def find_class(self, module, name):
# Only allow TCPPacket # Only allow TCPPacket
if module == "utcp" and name == 'TCPPacket': if name == 'TCPPacket':
return TCPPacket return TCPPacket
# Forbid everything else. # Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" % raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
@ -123,44 +121,67 @@ class ConnectedSOCK(object):
self.client_addr = client_addr self.client_addr = client_addr
self.low_sock = low_sock self.low_sock = low_sock
def __getattribute__(self, att): def __getattribute__(self, att):
if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']: try:
if att in self.low_sock.__dict__:
return getattr(self.low_sock, att)
return object.__getattribute__(self, att) return object.__getattribute__(self, att)
except AttributeError:
return getattr(self.low_sock, att)
def getpeername(self):
return self.client_addr
def send(self, data): def send(self, data):
if self.closed: if self.closed:
raise EOFError raise EOFError
self.low_sock.send(data, self.client_addr) return self.low_sock.send(data, self.client_addr)
def recv(self, size=None): def recv(self, size):
if self.closed: if self.closed:
raise EOFError raise EOFError
if size: return self.low_sock.recv(size, self.client_addr)
return self.low_sock.recv(self.client_addr)[:size]
else:
return self.low_sock.recv(self.client_addr)
@property @property
def closed(self): def closed(self):
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin) return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
def close(self): def close(self):
if self.client_addr in self.low_sock.connections:
self.low_sock.close(self.client_addr) self.low_sock.close(self.client_addr)
def shutdown(self, *a, **kw):
self.close()
def poll(self, timeout):
if self.client_addr in self.packets_received['DATA or FIN']:
return True
else:
self.incoming_packet_event.wait(timeout)
return self.client_addr in self.packets_received['DATA or FIN']
return False
class TCP(object): class TCP(object):
host = None host = None
port = None port = None
client = False client = False
peer_keypair = {}
connections = {}
connection_queue = []
packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
def __init__(self, af_type=None, sock_type=None, encrypted=False): def __init__(self, af_type=None, sock_type=None, encrypted=False):
self.encrypted = encrypted self.encrypted = encrypted
self.incoming_packet_event = threading.Event() self.incoming_packet_event = threading.Event()
self.new_conn_event = threading.Event()
#seq will have the last packet send and ack will have the next packet waiting to receive #seq will have the last packet send and ack will have the next packet waiting to receive
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication. self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
self.settimeout() self.settimeout()
self.peer_keypair = {} #self.peer_keypair = {}
self.connections = {} #self.connections = {}
self.connection_queue = [] #self.connection_queue = []
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
self.queue_lock = threading.Lock() self.queue_lock = threading.Lock()
# each condition will have a dictionary of an address and it's corresponding packet. # each condition will have a dictionary of an address and it's corresponding packet.
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}} #self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
def poll(self, timeout):
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
return True
else:
self.incoming_packet_event.wait(timeout)
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
return True
return False
def get_free_port(self): def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('', 0)) s.bind(('', 0))
@ -171,6 +192,8 @@ class TCP(object):
pass pass
def settimeout(self, timeout=5): def settimeout(self, timeout=5):
self.own_socket.settimeout(timeout) self.own_socket.settimeout(timeout)
def setblocking(self, mode):
self.own_socket.setblocking(mode)
def __repr__(self): def __repr__(self):
return 'TCP()' return 'TCP()'
@ -178,7 +201,12 @@ class TCP(object):
return 'Connections: %s' \ return 'Connections: %s' \
% str(self.connections) % str(self.connections)
def getsockname(self): def getsockname(self):
return (self.host, self.port) return self.own_socket.getsockname()
def getpeername(self):
if len(self.connections):
return list(self.connections.keys())[0]
else:
raise EOFError('Not connected')
def bind(self, addr): def bind(self, addr):
self.host = addr[0] self.host = addr[0]
self.port = addr[1] self.port = addr[1]
@ -201,60 +229,48 @@ class TCP(object):
packet_to_send = pickle.dumps(self.connections[connection]) packet_to_send = pickle.dumps(self.connections[connection])
self.connections[connection].checksum = 0 self.connections[connection].checksum = 0
self.connections[connection].data = b'' self.connections[connection].data = b''
while data_not_received: retransmit_count = 0
while data_not_received and retransmit_count < 3:
data_not_received = False data_not_received = False
try: try:
self.own_socket.sendto(packet_to_send, connection) self.own_socket.sendto(packet_to_send, connection)
answer = self.find_correct_packet('ACK', connection) answer = self.find_correct_packet('ACK', connection)
if not answer:
data_not_received = True
retransmit_count += 1
except socket.timeout: except socket.timeout:
#print('timeout') #print('timeout')
data_not_received = True data_not_received = True
if not answer:
self.drop_connection(connection)
raise EOFError('Connection lost')
self.connections[connection].seq += len(data_part) self.connections[connection].seq += len(data_part)
return len(data)
except socket.error as error: except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def recv(self, connection=None): def recv(self, size, connection=None):
try: try:
data = b''
if connection not in list(self.connections.keys()): if connection not in list(self.connections.keys()):
if connection is None: if connection is None:
connection = list(self.connections.keys())[0] connection = list(self.connections.keys())[0]
else: else:
return 'Connection not in connected devices' raise EOFError('Connection not in connected devices')
while True and self.status: data = self.find_correct_packet('DATA or FIN', connection, size)
data_part = self.find_correct_packet('DATA or FIN', connection)
if not self.status: if not self.status:
# print('I am disconnectiong cause sock is dead') raise EOFError('Disconnecting')
raise EOFError('Disconnected') return data
if data_part.packet_type() == 'FIN': except socket.error as error:
self.disconnect(connection) raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
raise EOFError('Disconnected') def send_ack(self, connection, ack):
checksum_value = TCP.checksum(data_part.data) self.connections[connection].ack = ack
self.connections[connection].seq += 1
while checksum_value != data_part.checksum:
data_part = self.find_correct_packet('DATA or FIN', connection)
checksum_value = TCP.checksum(data_part.data)
data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data)
if data_chunk != PACKET_END:
data += data_chunk
self.connections[connection].ack = data_part.seq + len(data_part.data)
self.connections[connection].seq += 1 # syn flag is 1 byte
self.connections[connection].set_flags(ack=True) self.connections[connection].set_flags(ack=True)
self.connections[connection].data = b'' self.connections[connection].data = b''
packet_to_send = pickle.dumps(self.connections[connection]) packet_to_send = pickle.dumps(self.connections[connection])
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
self.connections[connection].set_flags() self.connections[connection].set_flags()
if data_chunk == PACKET_END:
break
return data
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
# conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA']
def listen_handler(self, max_connections): def listen_handler(self, max_connections):
try: try:
while True and self.status: while True and self.status:
@ -265,12 +281,13 @@ class TCP(object):
if self.encrypted: if self.encrypted:
try: try:
peer_pub = answer.data peer_pub = answer.data
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec)
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub) self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
except: except:
self.peer_keypair.pop(address) self.peer_keypair.pop(address)
raise socket.error('Init peer public key error') raise socket.error('Init peer public key error')
self.connection_queue.append((answer, address)) self.connection_queue.append((answer, address))
self.blink_new_conn_event()
else: else:
self.own_socket.sendto('Connections full', address) self.own_socket.sendto('Connections full', address)
except KeyError: except KeyError:
@ -291,6 +308,7 @@ class TCP(object):
def accept(self): def accept(self):
try: try:
while True: while True:
self.new_conn_event.wait(0.1)
if self.connection_queue: if self.connection_queue:
with self.queue_lock: with self.queue_lock:
answer, address = self.connection_queue.pop() answer, address = self.connection_queue.pop()
@ -367,14 +385,17 @@ class TCP(object):
self.peer_keypair = {} self.peer_keypair = {}
self.status = 0 self.status = 0
raise EOFError('The socket was closed. Error:' + str(error)) raise EOFError('The socket was closed. Error:' + str(error))
def shutdown(self, *a, **kw):
self.own_socket.close()
self.status = 0
def fileno(self): def fileno(self):
return self.own_socket.fileno() return self.own_socket.fileno()
@property @property
def closed(self): def closed(self):
return self.own_socket._closed return self.own_socket._closed
def drop_connection(self, connection):
with self.connection_lock:
if len(self.connections):
self.connections.pop(connection)
if len(self.peer_keypair):
self.peer_keypair.pop(connection)
def close(self, connection=None): def close(self, connection=None):
try: try:
if connection not in list(self.connections.keys()): if connection not in list(self.connections.keys()):
@ -392,19 +413,11 @@ class TCP(object):
if answer.flag_fin != 1: if answer.flag_fin != 1:
raise Exception('The receiver didn\'t send the fin packet') raise Exception('The receiver didn\'t send the fin packet')
else: else:
self.connections[connection].ack += 1 self.send_ack(connection, self.connections[connection].ack + 1)
self.connections[connection].seq += 1 self.drop_connection(connection)
self.connections[connection].set_flags(ack=True) if len(self.connections) == 0 and self.client:
packet_to_send = pickle.dumps(self.connections[connection]) self.own_socket.close()
self.own_socket.sendto(packet_to_send, connection) self.status = 0
with self.connection_lock:
if len(self.connections):
self.connections.pop(connection)
if len(self.peer_keypair):
self.peer_keypair.pop(connection)
#if len(self.connections) == 0 and self.client:
# self.own_socket.close()
# self.status = 0
except Exception as error: except Exception as error:
raise EOFError('Something went wrong in the close func! Error is: %s.' % error) raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
@ -436,7 +449,7 @@ class TCP(object):
def checksum(source_bytes): def checksum(source_bytes):
return hashlib.sha1(source_bytes).digest() return hashlib.sha1(source_bytes).digest()
def find_correct_packet(self, condition, address=('Any',)): def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH):
not_found = True not_found = True
tries = 0 tries = 0
while not_found and tries < 2 and self.status: while not_found and tries < 2 and self.status:
@ -448,7 +461,9 @@ class TCP(object):
if condition == 'ACK': if condition == 'ACK':
tries += 1 tries += 1
if condition == 'DATA or FIN': if condition == 'DATA or FIN':
packet = self.packets_received[condition][address].pop() with self.connection_lock:
packet = self.packets_received[condition][address][:size]
self.packets_received[condition][address] = self.packets_received[condition][address][size:]
if not len(self.packets_received[condition][address]): if not len(self.packets_received[condition][address]):
del self.packets_received[condition][address] del self.packets_received[condition][address]
else: else:
@ -457,21 +472,33 @@ class TCP(object):
except KeyError: except KeyError:
not_found = True not_found = True
self.incoming_packet_event.wait(0.1) self.incoming_packet_event.wait(0.1)
def blink_event(self): def blink_incoming_packet_event(self):
self.incoming_packet_event.set() self.incoming_packet_event.set()
self.incoming_packet_event.clear() self.incoming_packet_event.clear()
def blink_new_conn_event(self):
self.new_conn_event.set()
self.new_conn_event.clear()
def sort_answers(self, packet, address): def sort_answers(self, packet, address):
if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN': if address not in self.connections and packet.packet_type() != 'SYN':
return
if packet.packet_type() == 'FIN':
self.disconnect(address)
elif packet.packet_type() == 'DATA':
if packet.checksum == TCP.checksum(packet.data):
data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data)
if data_chunk != PACKET_END:
with self.connection_lock:
if address not in self.packets_received['DATA or FIN']: if address not in self.packets_received['DATA or FIN']:
self.packets_received['DATA or FIN'][address] = [] self.packets_received['DATA or FIN'][address] = b''
self.packets_received['DATA or FIN'][address].insert(0, packet) self.packets_received['DATA or FIN'][address] += data_chunk
self.blink_event() self.send_ack(address, packet.seq + len(packet.data))
self.blink_incoming_packet_event()
elif packet.packet_type() == '': elif packet.packet_type() == '':
#print('redundant packet found', packet) #print('redundant packet found', packet)
pass pass
else: else:
self.packets_received[packet.packet_type()][address] = packet self.packets_received[packet.packet_type()][address] = packet
self.blink_event() self.blink_incoming_packet_event()
def central_receive_handler(self): def central_receive_handler(self):
while True and self.status: while True and self.status: