encrypted "tcp" over udp
parent
20e70a5ed3
commit
f7a4da82bd
|
@ -1,6 +1,12 @@
|
|||
# Based on rpyc (4.0.2)
|
||||
|
||||
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
|
||||
|
||||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer, spawn
|
||||
from rpyc.core.stream import SocketStream
|
||||
from rpyc.core import SocketStream, Channel
|
||||
from rpyc.core.stream import retry_errnos
|
||||
from rpyc.utils.factory import connect_channel
|
||||
import socket
|
||||
from mbedtls import tls
|
||||
import datetime as dt
|
||||
|
@ -10,6 +16,10 @@ from mbedtls import x509
|
|||
from uuid import uuid4
|
||||
from contextlib import suppress
|
||||
import sys
|
||||
from rpyc.lib import safe_import
|
||||
zlib = safe_import("zlib")
|
||||
import errno
|
||||
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
||||
|
||||
def block(callback, *args, **kwargs):
|
||||
while True:
|
||||
|
@ -33,6 +43,7 @@ class DTLSCerts:
|
|||
self.ca1_crt = self.ca0_crt.sign(
|
||||
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
|
||||
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
|
||||
self.srv_crt, self.srv_key = self.server_cert()
|
||||
def server_cert(self):
|
||||
now = dt.datetime.utcnow()
|
||||
ee0_key = pk.ECC()
|
||||
|
@ -46,17 +57,28 @@ dtls_certs = DTLSCerts()
|
|||
trust_store = tls.TrustStore()
|
||||
trust_store.add(dtls_certs.ca0_crt)
|
||||
|
||||
srv_ctx_conf = tls.DTLSConfiguration(
|
||||
trust_store=trust_store,
|
||||
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
|
||||
validate_certificates=False,
|
||||
)
|
||||
cli_ctx_conf = tls.DTLSConfiguration(
|
||||
trust_store=trust_store,
|
||||
validate_certificates=False,
|
||||
)
|
||||
MAX_IO_CHUNK = 20971520
|
||||
|
||||
class DTLSSocketStream(SocketStream):
|
||||
MAX_IO_CHUNK = MAX_IO_CHUNK
|
||||
@classmethod
|
||||
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
|
||||
if kwargs.pop('ipv6', False):
|
||||
family = socket.AF_INET6
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration(
|
||||
trust_store=trust_store,
|
||||
validate_certificates=False,
|
||||
))
|
||||
#tls._enable_debug_output(cli_ctx_conf)
|
||||
#tls._set_debug_level(10)
|
||||
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
|
||||
dtls_cli = dtls_cli_ctx.wrap_socket(
|
||||
socket.socket(family, socket.SOCK_DGRAM),
|
||||
server_hostname=None,
|
||||
|
@ -66,33 +88,128 @@ class DTLSSocketStream(SocketStream):
|
|||
block(dtls_cli.do_handshake)
|
||||
return cls(dtls_cli)
|
||||
def read(self, count):
|
||||
return block(SocketStream.read(self, count))
|
||||
while True:
|
||||
try:
|
||||
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
|
||||
except socket.timeout:
|
||||
continue
|
||||
except socket.error:
|
||||
ex = sys.exc_info()[1]
|
||||
if get_exc_errno(ex) in retry_errnos:
|
||||
# windows just has to be a bitch
|
||||
# inpos: I agree
|
||||
continue
|
||||
self.close()
|
||||
raise EOFError(ex)
|
||||
else:
|
||||
break
|
||||
if not buf:
|
||||
self.close()
|
||||
raise EOFError("connection closed by peer")
|
||||
return buf
|
||||
def write(self, data):
|
||||
block(SocketStream.write(self, data))
|
||||
try:
|
||||
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
|
||||
except socket.error:
|
||||
ex = sys.exc_info()[1]
|
||||
self.close()
|
||||
raise EOFError(ex)
|
||||
|
||||
class DTLSChannel(Channel):
|
||||
MAX_IO_CHUNK = MAX_IO_CHUNK
|
||||
def recv(self):
|
||||
raw_data = self.stream.read(self.MAX_IO_CHUNK)
|
||||
header = raw_data[:self.FRAME_HEADER.size]
|
||||
raw_data = raw_data[self.FRAME_HEADER.size:]
|
||||
length, compressed = self.FRAME_HEADER.unpack(header)
|
||||
data = raw_data[:length]
|
||||
if compressed:
|
||||
data = zlib.decompress(data)
|
||||
return data
|
||||
|
||||
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
||||
return connect_channel(DTLSChannel(stream), service=service, config=config)
|
||||
|
||||
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
|
||||
cert_reqs=None, ssl_version=None, ciphers=None,
|
||||
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
|
||||
ssl_kwargs = {'server_side' : False}
|
||||
if ciphers is not None:
|
||||
ssl_kwargs['ciphers'] = ciphers
|
||||
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
|
||||
return connect_stream(s, service, config)
|
||||
|
||||
class DTLSThreadedServer(ThreadedServer):
|
||||
def dtls(self, listener_timeout = 0.5, reuse_addr = True):
|
||||
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
|
||||
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
|
||||
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
||||
socket_path = None):
|
||||
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
||||
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
||||
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
||||
socket_path=socket_path)
|
||||
self.listener.close()
|
||||
srv_crt, srv_key = dtls_certs.server_cert()
|
||||
dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration(
|
||||
trust_store=trust_store,
|
||||
certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key),
|
||||
validate_certificates=False,
|
||||
))
|
||||
|
||||
self.host = hostname
|
||||
self.port = port
|
||||
#tls._enable_debug_output(srv_ctx_conf)
|
||||
#tls._set_debug_level(10)
|
||||
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
|
||||
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
|
||||
if reuse_addr and sys.platform != 'win32':
|
||||
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
dtls_srv.bind((self.host, self.port))
|
||||
dtls_srv.settimeout(listener_timeout)
|
||||
dtls_srv.bind((hostname, port))
|
||||
# dtls_srv.settimeout(listener_timeout)
|
||||
self.listener = dtls_srv
|
||||
sockname = self.listener.getsockname()
|
||||
self.host, self.port = sockname[0], sockname[1]
|
||||
def _serve_client(self, sock, credentials):
|
||||
addrinfo = sock.getpeername()
|
||||
if credentials:
|
||||
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
||||
else:
|
||||
self.logger.info("welcome %s", addrinfo)
|
||||
try:
|
||||
config = dict(self.protocol_config, credentials = credentials,
|
||||
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
||||
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
|
||||
self._handle_connection(conn)
|
||||
finally:
|
||||
self.logger.info("goodbye %s", addrinfo)
|
||||
def _listen(self):
|
||||
if self.active:
|
||||
return
|
||||
#self.listener.listen(self.backlog) #####################
|
||||
if not self.port:
|
||||
self.port = self.listener.getsockname()[1]
|
||||
self.logger.info('server started on [%s]:%s', self.host, self.port)
|
||||
self.active = True
|
||||
def _authenticate_and_serve_client(self, sock):
|
||||
def accept(self):
|
||||
while self.active:
|
||||
try:
|
||||
print('Accepting new connections!')
|
||||
sock, addrinfo = self.listener.accept()
|
||||
except socket.timeout:
|
||||
pass
|
||||
except socket.error:
|
||||
ex = sys.exc_info()[1]
|
||||
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
|
||||
pass
|
||||
else:
|
||||
raise EOFError()
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
if not self.active:
|
||||
return
|
||||
|
||||
sock.setblocking(True)
|
||||
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
|
||||
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
|
||||
self.clients.add(sock)
|
||||
self._accept_method(sock)
|
||||
def _accept_method(self, sock):
|
||||
addr = sock.getpeername()
|
||||
sock.setcookieparam(addr[0].encode())
|
||||
with suppress(tls.HelloVerifyRequest):
|
||||
|
@ -100,6 +217,6 @@ class DTLSThreadedServer(ThreadedServer):
|
|||
sock, addr = sock.accept()
|
||||
sock.setcookieparam(addr[0].encode())
|
||||
block(sock.do_handshake)
|
||||
ThreadedServer._authenticate_and_serve_client(self, sock)
|
||||
spawn(self._authenticate_and_serve_client, sock)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from rpyc.core import SocketStream, Channel
|
||||
from rpyc.core.stream import retry_errnos
|
||||
from rpyc.utils.factory import connect_channel
|
||||
from rpyc.lib import Timeout
|
||||
from rpyc.lib.compat import select_error
|
||||
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
||||
from . import utcp
|
||||
import sys
|
||||
import errno
|
||||
import socket
|
||||
|
||||
class UTCPSocketStream(SocketStream):
|
||||
MAX_IO_CHUNK = utcp.DATA_LENGTH
|
||||
@classmethod
|
||||
def utcp_connect(cls, host, port, *a, **kw):
|
||||
sock = utcp.TCP(encrypted=True)
|
||||
sock.connect((host, port))
|
||||
return cls(sock)
|
||||
def poll(self, timeout):
|
||||
timeout = Timeout(timeout)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
rl = self.sock.poll(timeout.timeleft())
|
||||
except select_error:
|
||||
ex = sys.exc_info()[1]
|
||||
if ex.args[0] == errno.EINTR:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
except ValueError:
|
||||
ex = sys.exc_info()[1]
|
||||
raise select_error(str(ex))
|
||||
return rl
|
||||
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
||||
return connect_channel(Channel(stream), service=service, config=config)
|
||||
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
|
||||
s = UTCPSocketStream.utcp_connect(host, port, **kw)
|
||||
return connect_stream(s, service, config)
|
||||
class UTCPThreadedServer(ThreadedServer):
|
||||
def __init__(self, service, hostname = '', ipv6 = False, port = 0,
|
||||
backlog = 1, reuse_addr = True, authenticator = None, registrar = None,
|
||||
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
||||
socket_path = None):
|
||||
backlog = 1
|
||||
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
||||
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
||||
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
||||
socket_path=socket_path)
|
||||
self.listener.close()
|
||||
self.listener = None
|
||||
##########
|
||||
self.listener = utcp.TCP(encrypted=True)
|
||||
self.listener.bind((hostname, port))
|
||||
sockname = self.listener.getsockname()
|
||||
self.host, self.port = sockname[0], sockname[1]
|
||||
def _serve_client(self, sock, credentials):
|
||||
addrinfo = sock.getpeername()
|
||||
if credentials:
|
||||
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
||||
else:
|
||||
self.logger.info("welcome %s", addrinfo)
|
||||
try:
|
||||
config = dict(self.protocol_config, credentials = credentials,
|
||||
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
||||
conn = self.service._connect(Channel(UTCPSocketStream(sock)), config)
|
||||
self._handle_connection(conn)
|
||||
finally:
|
||||
self.logger.info("goodbye %s", addrinfo)
|
||||
|
173
mods/utcp.py
173
mods/utcp.py
|
@ -6,8 +6,6 @@ import threading
|
|||
import io
|
||||
import hashlib
|
||||
import simplecrypto
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
DATA_DIVIDE_LENGTH = 8000
|
||||
PACKET_HEADER_SIZE = 512 # Pickle service info
|
||||
|
@ -109,7 +107,7 @@ class TCPPacket(object):
|
|||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
# Only allow TCPPacket
|
||||
if module == "utcp" and name == 'TCPPacket':
|
||||
if name == 'TCPPacket':
|
||||
return TCPPacket
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
|
@ -123,44 +121,67 @@ class ConnectedSOCK(object):
|
|||
self.client_addr = client_addr
|
||||
self.low_sock = low_sock
|
||||
def __getattribute__(self, att):
|
||||
if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']:
|
||||
if att in self.low_sock.__dict__:
|
||||
return getattr(self.low_sock, att)
|
||||
try:
|
||||
return object.__getattribute__(self, att)
|
||||
except AttributeError:
|
||||
return getattr(self.low_sock, att)
|
||||
def getpeername(self):
|
||||
return self.client_addr
|
||||
def send(self, data):
|
||||
if self.closed:
|
||||
raise EOFError
|
||||
self.low_sock.send(data, self.client_addr)
|
||||
def recv(self, size=None):
|
||||
return self.low_sock.send(data, self.client_addr)
|
||||
def recv(self, size):
|
||||
if self.closed:
|
||||
raise EOFError
|
||||
if size:
|
||||
return self.low_sock.recv(self.client_addr)[:size]
|
||||
else:
|
||||
return self.low_sock.recv(self.client_addr)
|
||||
return self.low_sock.recv(size, self.client_addr)
|
||||
@property
|
||||
def closed(self):
|
||||
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
|
||||
def close(self):
|
||||
if self.client_addr in self.low_sock.connections:
|
||||
self.low_sock.close(self.client_addr)
|
||||
def shutdown(self, *a, **kw):
|
||||
self.close()
|
||||
def poll(self, timeout):
|
||||
if self.client_addr in self.packets_received['DATA or FIN']:
|
||||
return True
|
||||
else:
|
||||
self.incoming_packet_event.wait(timeout)
|
||||
return self.client_addr in self.packets_received['DATA or FIN']
|
||||
return False
|
||||
|
||||
class TCP(object):
|
||||
host = None
|
||||
port = None
|
||||
client = False
|
||||
peer_keypair = {}
|
||||
connections = {}
|
||||
connection_queue = []
|
||||
packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||||
def __init__(self, af_type=None, sock_type=None, encrypted=False):
|
||||
self.encrypted = encrypted
|
||||
self.incoming_packet_event = threading.Event()
|
||||
self.new_conn_event = threading.Event()
|
||||
#seq will have the last packet send and ack will have the next packet waiting to receive
|
||||
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
|
||||
self.settimeout()
|
||||
self.peer_keypair = {}
|
||||
self.connections = {}
|
||||
self.connection_queue = []
|
||||
#self.peer_keypair = {}
|
||||
#self.connections = {}
|
||||
#self.connection_queue = []
|
||||
self.connection_lock = threading.Lock()
|
||||
self.queue_lock = threading.Lock()
|
||||
# each condition will have a dictionary of an address and it's corresponding packet.
|
||||
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||||
#self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||||
def poll(self, timeout):
|
||||
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
||||
return True
|
||||
else:
|
||||
self.incoming_packet_event.wait(timeout)
|
||||
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_free_port(self):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.bind(('', 0))
|
||||
|
@ -171,6 +192,8 @@ class TCP(object):
|
|||
pass
|
||||
def settimeout(self, timeout=5):
|
||||
self.own_socket.settimeout(timeout)
|
||||
def setblocking(self, mode):
|
||||
self.own_socket.setblocking(mode)
|
||||
def __repr__(self):
|
||||
return 'TCP()'
|
||||
|
||||
|
@ -178,7 +201,12 @@ class TCP(object):
|
|||
return 'Connections: %s' \
|
||||
% str(self.connections)
|
||||
def getsockname(self):
|
||||
return (self.host, self.port)
|
||||
return self.own_socket.getsockname()
|
||||
def getpeername(self):
|
||||
if len(self.connections):
|
||||
return list(self.connections.keys())[0]
|
||||
else:
|
||||
raise EOFError('Not connected')
|
||||
def bind(self, addr):
|
||||
self.host = addr[0]
|
||||
self.port = addr[1]
|
||||
|
@ -201,60 +229,48 @@ class TCP(object):
|
|||
packet_to_send = pickle.dumps(self.connections[connection])
|
||||
self.connections[connection].checksum = 0
|
||||
self.connections[connection].data = b''
|
||||
while data_not_received:
|
||||
retransmit_count = 0
|
||||
while data_not_received and retransmit_count < 3:
|
||||
data_not_received = False
|
||||
try:
|
||||
self.own_socket.sendto(packet_to_send, connection)
|
||||
answer = self.find_correct_packet('ACK', connection)
|
||||
if not answer:
|
||||
data_not_received = True
|
||||
retransmit_count += 1
|
||||
except socket.timeout:
|
||||
#print('timeout')
|
||||
data_not_received = True
|
||||
if not answer:
|
||||
self.drop_connection(connection)
|
||||
raise EOFError('Connection lost')
|
||||
self.connections[connection].seq += len(data_part)
|
||||
return len(data)
|
||||
except socket.error as error:
|
||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||||
|
||||
def recv(self, connection=None):
|
||||
def recv(self, size, connection=None):
|
||||
try:
|
||||
data = b''
|
||||
if connection not in list(self.connections.keys()):
|
||||
if connection is None:
|
||||
connection = list(self.connections.keys())[0]
|
||||
else:
|
||||
return 'Connection not in connected devices'
|
||||
raise EOFError('Connection not in connected devices')
|
||||
|
||||
while True and self.status:
|
||||
data_part = self.find_correct_packet('DATA or FIN', connection)
|
||||
data = self.find_correct_packet('DATA or FIN', connection, size)
|
||||
if not self.status:
|
||||
# print('I am disconnectiong cause sock is dead')
|
||||
raise EOFError('Disconnected')
|
||||
if data_part.packet_type() == 'FIN':
|
||||
self.disconnect(connection)
|
||||
raise EOFError('Disconnected')
|
||||
checksum_value = TCP.checksum(data_part.data)
|
||||
|
||||
while checksum_value != data_part.checksum:
|
||||
data_part = self.find_correct_packet('DATA or FIN', connection)
|
||||
checksum_value = TCP.checksum(data_part.data)
|
||||
|
||||
data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data)
|
||||
if data_chunk != PACKET_END:
|
||||
data += data_chunk
|
||||
self.connections[connection].ack = data_part.seq + len(data_part.data)
|
||||
self.connections[connection].seq += 1 # syn flag is 1 byte
|
||||
raise EOFError('Disconnecting')
|
||||
return data
|
||||
except socket.error as error:
|
||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||||
def send_ack(self, connection, ack):
|
||||
self.connections[connection].ack = ack
|
||||
self.connections[connection].seq += 1
|
||||
self.connections[connection].set_flags(ack=True)
|
||||
self.connections[connection].data = b''
|
||||
packet_to_send = pickle.dumps(self.connections[connection])
|
||||
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
|
||||
self.connections[connection].set_flags()
|
||||
|
||||
if data_chunk == PACKET_END:
|
||||
break
|
||||
|
||||
return data
|
||||
except socket.error as error:
|
||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||||
|
||||
# conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA']
|
||||
def listen_handler(self, max_connections):
|
||||
try:
|
||||
while True and self.status:
|
||||
|
@ -265,12 +281,13 @@ class TCP(object):
|
|||
if self.encrypted:
|
||||
try:
|
||||
peer_pub = answer.data
|
||||
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key
|
||||
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec)
|
||||
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
||||
except:
|
||||
self.peer_keypair.pop(address)
|
||||
raise socket.error('Init peer public key error')
|
||||
self.connection_queue.append((answer, address))
|
||||
self.blink_new_conn_event()
|
||||
else:
|
||||
self.own_socket.sendto('Connections full', address)
|
||||
except KeyError:
|
||||
|
@ -291,6 +308,7 @@ class TCP(object):
|
|||
def accept(self):
|
||||
try:
|
||||
while True:
|
||||
self.new_conn_event.wait(0.1)
|
||||
if self.connection_queue:
|
||||
with self.queue_lock:
|
||||
answer, address = self.connection_queue.pop()
|
||||
|
@ -367,14 +385,17 @@ class TCP(object):
|
|||
self.peer_keypair = {}
|
||||
self.status = 0
|
||||
raise EOFError('The socket was closed. Error:' + str(error))
|
||||
def shutdown(self, *a, **kw):
|
||||
self.own_socket.close()
|
||||
self.status = 0
|
||||
def fileno(self):
|
||||
return self.own_socket.fileno()
|
||||
@property
|
||||
def closed(self):
|
||||
return self.own_socket._closed
|
||||
def drop_connection(self, connection):
|
||||
with self.connection_lock:
|
||||
if len(self.connections):
|
||||
self.connections.pop(connection)
|
||||
if len(self.peer_keypair):
|
||||
self.peer_keypair.pop(connection)
|
||||
def close(self, connection=None):
|
||||
try:
|
||||
if connection not in list(self.connections.keys()):
|
||||
|
@ -392,19 +413,11 @@ class TCP(object):
|
|||
if answer.flag_fin != 1:
|
||||
raise Exception('The receiver didn\'t send the fin packet')
|
||||
else:
|
||||
self.connections[connection].ack += 1
|
||||
self.connections[connection].seq += 1
|
||||
self.connections[connection].set_flags(ack=True)
|
||||
packet_to_send = pickle.dumps(self.connections[connection])
|
||||
self.own_socket.sendto(packet_to_send, connection)
|
||||
with self.connection_lock:
|
||||
if len(self.connections):
|
||||
self.connections.pop(connection)
|
||||
if len(self.peer_keypair):
|
||||
self.peer_keypair.pop(connection)
|
||||
#if len(self.connections) == 0 and self.client:
|
||||
# self.own_socket.close()
|
||||
# self.status = 0
|
||||
self.send_ack(connection, self.connections[connection].ack + 1)
|
||||
self.drop_connection(connection)
|
||||
if len(self.connections) == 0 and self.client:
|
||||
self.own_socket.close()
|
||||
self.status = 0
|
||||
except Exception as error:
|
||||
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
|
||||
|
||||
|
@ -436,7 +449,7 @@ class TCP(object):
|
|||
def checksum(source_bytes):
|
||||
return hashlib.sha1(source_bytes).digest()
|
||||
|
||||
def find_correct_packet(self, condition, address=('Any',)):
|
||||
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH):
|
||||
not_found = True
|
||||
tries = 0
|
||||
while not_found and tries < 2 and self.status:
|
||||
|
@ -448,7 +461,9 @@ class TCP(object):
|
|||
if condition == 'ACK':
|
||||
tries += 1
|
||||
if condition == 'DATA or FIN':
|
||||
packet = self.packets_received[condition][address].pop()
|
||||
with self.connection_lock:
|
||||
packet = self.packets_received[condition][address][:size]
|
||||
self.packets_received[condition][address] = self.packets_received[condition][address][size:]
|
||||
if not len(self.packets_received[condition][address]):
|
||||
del self.packets_received[condition][address]
|
||||
else:
|
||||
|
@ -457,21 +472,33 @@ class TCP(object):
|
|||
except KeyError:
|
||||
not_found = True
|
||||
self.incoming_packet_event.wait(0.1)
|
||||
def blink_event(self):
|
||||
def blink_incoming_packet_event(self):
|
||||
self.incoming_packet_event.set()
|
||||
self.incoming_packet_event.clear()
|
||||
def blink_new_conn_event(self):
|
||||
self.new_conn_event.set()
|
||||
self.new_conn_event.clear()
|
||||
def sort_answers(self, packet, address):
|
||||
if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN':
|
||||
if address not in self.connections and packet.packet_type() != 'SYN':
|
||||
return
|
||||
if packet.packet_type() == 'FIN':
|
||||
self.disconnect(address)
|
||||
elif packet.packet_type() == 'DATA':
|
||||
if packet.checksum == TCP.checksum(packet.data):
|
||||
data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data)
|
||||
if data_chunk != PACKET_END:
|
||||
with self.connection_lock:
|
||||
if address not in self.packets_received['DATA or FIN']:
|
||||
self.packets_received['DATA or FIN'][address] = []
|
||||
self.packets_received['DATA or FIN'][address].insert(0, packet)
|
||||
self.blink_event()
|
||||
self.packets_received['DATA or FIN'][address] = b''
|
||||
self.packets_received['DATA or FIN'][address] += data_chunk
|
||||
self.send_ack(address, packet.seq + len(packet.data))
|
||||
self.blink_incoming_packet_event()
|
||||
elif packet.packet_type() == '':
|
||||
#print('redundant packet found', packet)
|
||||
pass
|
||||
else:
|
||||
self.packets_received[packet.packet_type()][address] = packet
|
||||
self.blink_event()
|
||||
self.blink_incoming_packet_event()
|
||||
|
||||
def central_receive_handler(self):
|
||||
while True and self.status:
|
||||
|
|
Loading…
Reference in New Issue