encrypted "tcp" over udp
parent
20e70a5ed3
commit
f7a4da82bd
|
@ -1,6 +1,12 @@
|
||||||
|
# Based on rpyc (4.0.2)
|
||||||
|
|
||||||
|
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
from rpyc.utils.server import ThreadedServer, spawn
|
from rpyc.utils.server import ThreadedServer, spawn
|
||||||
from rpyc.core.stream import SocketStream
|
from rpyc.core import SocketStream, Channel
|
||||||
|
from rpyc.core.stream import retry_errnos
|
||||||
|
from rpyc.utils.factory import connect_channel
|
||||||
import socket
|
import socket
|
||||||
from mbedtls import tls
|
from mbedtls import tls
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
@ -10,6 +16,10 @@ from mbedtls import x509
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import sys
|
import sys
|
||||||
|
from rpyc.lib import safe_import
|
||||||
|
zlib = safe_import("zlib")
|
||||||
|
import errno
|
||||||
|
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
||||||
|
|
||||||
def block(callback, *args, **kwargs):
|
def block(callback, *args, **kwargs):
|
||||||
while True:
|
while True:
|
||||||
|
@ -33,6 +43,7 @@ class DTLSCerts:
|
||||||
self.ca1_crt = self.ca0_crt.sign(
|
self.ca1_crt = self.ca0_crt.sign(
|
||||||
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
|
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
|
||||||
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
|
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
|
||||||
|
self.srv_crt, self.srv_key = self.server_cert()
|
||||||
def server_cert(self):
|
def server_cert(self):
|
||||||
now = dt.datetime.utcnow()
|
now = dt.datetime.utcnow()
|
||||||
ee0_key = pk.ECC()
|
ee0_key = pk.ECC()
|
||||||
|
@ -46,17 +57,28 @@ dtls_certs = DTLSCerts()
|
||||||
trust_store = tls.TrustStore()
|
trust_store = tls.TrustStore()
|
||||||
trust_store.add(dtls_certs.ca0_crt)
|
trust_store.add(dtls_certs.ca0_crt)
|
||||||
|
|
||||||
|
srv_ctx_conf = tls.DTLSConfiguration(
|
||||||
|
trust_store=trust_store,
|
||||||
|
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
|
||||||
|
validate_certificates=False,
|
||||||
|
)
|
||||||
|
cli_ctx_conf = tls.DTLSConfiguration(
|
||||||
|
trust_store=trust_store,
|
||||||
|
validate_certificates=False,
|
||||||
|
)
|
||||||
|
MAX_IO_CHUNK = 20971520
|
||||||
|
|
||||||
class DTLSSocketStream(SocketStream):
|
class DTLSSocketStream(SocketStream):
|
||||||
|
MAX_IO_CHUNK = MAX_IO_CHUNK
|
||||||
@classmethod
|
@classmethod
|
||||||
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
|
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
|
||||||
if kwargs.pop('ipv6', False):
|
if kwargs.pop('ipv6', False):
|
||||||
family = socket.AF_INET6
|
family = socket.AF_INET6
|
||||||
else:
|
else:
|
||||||
family = socket.AF_INET
|
family = socket.AF_INET
|
||||||
dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration(
|
#tls._enable_debug_output(cli_ctx_conf)
|
||||||
trust_store=trust_store,
|
#tls._set_debug_level(10)
|
||||||
validate_certificates=False,
|
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
|
||||||
))
|
|
||||||
dtls_cli = dtls_cli_ctx.wrap_socket(
|
dtls_cli = dtls_cli_ctx.wrap_socket(
|
||||||
socket.socket(family, socket.SOCK_DGRAM),
|
socket.socket(family, socket.SOCK_DGRAM),
|
||||||
server_hostname=None,
|
server_hostname=None,
|
||||||
|
@ -66,33 +88,128 @@ class DTLSSocketStream(SocketStream):
|
||||||
block(dtls_cli.do_handshake)
|
block(dtls_cli.do_handshake)
|
||||||
return cls(dtls_cli)
|
return cls(dtls_cli)
|
||||||
def read(self, count):
|
def read(self, count):
|
||||||
return block(SocketStream.read(self, count))
|
while True:
|
||||||
|
try:
|
||||||
|
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
|
||||||
|
except socket.timeout:
|
||||||
|
continue
|
||||||
|
except socket.error:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
if get_exc_errno(ex) in retry_errnos:
|
||||||
|
# windows just has to be a bitch
|
||||||
|
# inpos: I agree
|
||||||
|
continue
|
||||||
|
self.close()
|
||||||
|
raise EOFError(ex)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if not buf:
|
||||||
|
self.close()
|
||||||
|
raise EOFError("connection closed by peer")
|
||||||
|
return buf
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
block(SocketStream.write(self, data))
|
try:
|
||||||
|
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
|
||||||
|
except socket.error:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
self.close()
|
||||||
|
raise EOFError(ex)
|
||||||
|
|
||||||
|
class DTLSChannel(Channel):
|
||||||
|
MAX_IO_CHUNK = MAX_IO_CHUNK
|
||||||
|
def recv(self):
|
||||||
|
raw_data = self.stream.read(self.MAX_IO_CHUNK)
|
||||||
|
header = raw_data[:self.FRAME_HEADER.size]
|
||||||
|
raw_data = raw_data[self.FRAME_HEADER.size:]
|
||||||
|
length, compressed = self.FRAME_HEADER.unpack(header)
|
||||||
|
data = raw_data[:length]
|
||||||
|
if compressed:
|
||||||
|
data = zlib.decompress(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
||||||
|
return connect_channel(DTLSChannel(stream), service=service, config=config)
|
||||||
|
|
||||||
|
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
|
||||||
|
cert_reqs=None, ssl_version=None, ciphers=None,
|
||||||
|
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
|
||||||
|
ssl_kwargs = {'server_side' : False}
|
||||||
|
if ciphers is not None:
|
||||||
|
ssl_kwargs['ciphers'] = ciphers
|
||||||
|
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
|
||||||
|
return connect_stream(s, service, config)
|
||||||
|
|
||||||
class DTLSThreadedServer(ThreadedServer):
|
class DTLSThreadedServer(ThreadedServer):
|
||||||
def dtls(self, listener_timeout = 0.5, reuse_addr = True):
|
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
|
||||||
|
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
|
||||||
|
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
||||||
|
socket_path = None):
|
||||||
|
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
||||||
|
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
||||||
|
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
||||||
|
socket_path=socket_path)
|
||||||
self.listener.close()
|
self.listener.close()
|
||||||
srv_crt, srv_key = dtls_certs.server_cert()
|
|
||||||
dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration(
|
self.host = hostname
|
||||||
trust_store=trust_store,
|
self.port = port
|
||||||
certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key),
|
#tls._enable_debug_output(srv_ctx_conf)
|
||||||
validate_certificates=False,
|
#tls._set_debug_level(10)
|
||||||
))
|
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
|
||||||
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
|
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
|
||||||
if reuse_addr and sys.platform != 'win32':
|
if reuse_addr and sys.platform != 'win32':
|
||||||
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
dtls_srv.bind((self.host, self.port))
|
dtls_srv.bind((hostname, port))
|
||||||
dtls_srv.settimeout(listener_timeout)
|
# dtls_srv.settimeout(listener_timeout)
|
||||||
self.listener = dtls_srv
|
self.listener = dtls_srv
|
||||||
|
sockname = self.listener.getsockname()
|
||||||
|
self.host, self.port = sockname[0], sockname[1]
|
||||||
|
def _serve_client(self, sock, credentials):
|
||||||
|
addrinfo = sock.getpeername()
|
||||||
|
if credentials:
|
||||||
|
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
||||||
|
else:
|
||||||
|
self.logger.info("welcome %s", addrinfo)
|
||||||
|
try:
|
||||||
|
config = dict(self.protocol_config, credentials = credentials,
|
||||||
|
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
||||||
|
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
|
||||||
|
self._handle_connection(conn)
|
||||||
|
finally:
|
||||||
|
self.logger.info("goodbye %s", addrinfo)
|
||||||
def _listen(self):
|
def _listen(self):
|
||||||
if self.active:
|
if self.active:
|
||||||
return
|
return
|
||||||
|
#self.listener.listen(self.backlog) #####################
|
||||||
if not self.port:
|
if not self.port:
|
||||||
self.port = self.listener.getsockname()[1]
|
self.port = self.listener.getsockname()[1]
|
||||||
self.logger.info('server started on [%s]:%s', self.host, self.port)
|
self.logger.info('server started on [%s]:%s', self.host, self.port)
|
||||||
self.active = True
|
self.active = True
|
||||||
def _authenticate_and_serve_client(self, sock):
|
def accept(self):
|
||||||
|
while self.active:
|
||||||
|
try:
|
||||||
|
print('Accepting new connections!')
|
||||||
|
sock, addrinfo = self.listener.accept()
|
||||||
|
except socket.timeout:
|
||||||
|
pass
|
||||||
|
except socket.error:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise EOFError()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self.active:
|
||||||
|
return
|
||||||
|
|
||||||
|
sock.setblocking(True)
|
||||||
|
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
|
||||||
|
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
|
||||||
|
self.clients.add(sock)
|
||||||
|
self._accept_method(sock)
|
||||||
|
def _accept_method(self, sock):
|
||||||
addr = sock.getpeername()
|
addr = sock.getpeername()
|
||||||
sock.setcookieparam(addr[0].encode())
|
sock.setcookieparam(addr[0].encode())
|
||||||
with suppress(tls.HelloVerifyRequest):
|
with suppress(tls.HelloVerifyRequest):
|
||||||
|
@ -100,6 +217,6 @@ class DTLSThreadedServer(ThreadedServer):
|
||||||
sock, addr = sock.accept()
|
sock, addr = sock.accept()
|
||||||
sock.setcookieparam(addr[0].encode())
|
sock.setcookieparam(addr[0].encode())
|
||||||
block(sock.do_handshake)
|
block(sock.do_handshake)
|
||||||
ThreadedServer._authenticate_and_serve_client(self, sock)
|
spawn(self._authenticate_and_serve_client, sock)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
import rpyc
|
||||||
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
from rpyc.core import SocketStream, Channel
|
||||||
|
from rpyc.core.stream import retry_errnos
|
||||||
|
from rpyc.utils.factory import connect_channel
|
||||||
|
from rpyc.lib import Timeout
|
||||||
|
from rpyc.lib.compat import select_error
|
||||||
|
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
||||||
|
from . import utcp
|
||||||
|
import sys
|
||||||
|
import errno
|
||||||
|
import socket
|
||||||
|
|
||||||
|
class UTCPSocketStream(SocketStream):
|
||||||
|
MAX_IO_CHUNK = utcp.DATA_LENGTH
|
||||||
|
@classmethod
|
||||||
|
def utcp_connect(cls, host, port, *a, **kw):
|
||||||
|
sock = utcp.TCP(encrypted=True)
|
||||||
|
sock.connect((host, port))
|
||||||
|
return cls(sock)
|
||||||
|
def poll(self, timeout):
|
||||||
|
timeout = Timeout(timeout)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
rl = self.sock.poll(timeout.timeleft())
|
||||||
|
except select_error:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
if ex.args[0] == errno.EINTR:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
ex = sys.exc_info()[1]
|
||||||
|
raise select_error(str(ex))
|
||||||
|
return rl
|
||||||
|
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
||||||
|
return connect_channel(Channel(stream), service=service, config=config)
|
||||||
|
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
|
||||||
|
s = UTCPSocketStream.utcp_connect(host, port, **kw)
|
||||||
|
return connect_stream(s, service, config)
|
||||||
|
class UTCPThreadedServer(ThreadedServer):
|
||||||
|
def __init__(self, service, hostname = '', ipv6 = False, port = 0,
|
||||||
|
backlog = 1, reuse_addr = True, authenticator = None, registrar = None,
|
||||||
|
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
||||||
|
socket_path = None):
|
||||||
|
backlog = 1
|
||||||
|
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
||||||
|
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
||||||
|
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
||||||
|
socket_path=socket_path)
|
||||||
|
self.listener.close()
|
||||||
|
self.listener = None
|
||||||
|
##########
|
||||||
|
self.listener = utcp.TCP(encrypted=True)
|
||||||
|
self.listener.bind((hostname, port))
|
||||||
|
sockname = self.listener.getsockname()
|
||||||
|
self.host, self.port = sockname[0], sockname[1]
|
||||||
|
def _serve_client(self, sock, credentials):
|
||||||
|
addrinfo = sock.getpeername()
|
||||||
|
if credentials:
|
||||||
|
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
||||||
|
else:
|
||||||
|
self.logger.info("welcome %s", addrinfo)
|
||||||
|
try:
|
||||||
|
config = dict(self.protocol_config, credentials = credentials,
|
||||||
|
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
||||||
|
conn = self.service._connect(Channel(UTCPSocketStream(sock)), config)
|
||||||
|
self._handle_connection(conn)
|
||||||
|
finally:
|
||||||
|
self.logger.info("goodbye %s", addrinfo)
|
||||||
|
|
173
mods/utcp.py
173
mods/utcp.py
|
@ -6,8 +6,6 @@ import threading
|
||||||
import io
|
import io
|
||||||
import hashlib
|
import hashlib
|
||||||
import simplecrypto
|
import simplecrypto
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
DATA_DIVIDE_LENGTH = 8000
|
DATA_DIVIDE_LENGTH = 8000
|
||||||
PACKET_HEADER_SIZE = 512 # Pickle service info
|
PACKET_HEADER_SIZE = 512 # Pickle service info
|
||||||
|
@ -109,7 +107,7 @@ class TCPPacket(object):
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
# Only allow TCPPacket
|
# Only allow TCPPacket
|
||||||
if module == "utcp" and name == 'TCPPacket':
|
if name == 'TCPPacket':
|
||||||
return TCPPacket
|
return TCPPacket
|
||||||
# Forbid everything else.
|
# Forbid everything else.
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||||
|
@ -123,44 +121,67 @@ class ConnectedSOCK(object):
|
||||||
self.client_addr = client_addr
|
self.client_addr = client_addr
|
||||||
self.low_sock = low_sock
|
self.low_sock = low_sock
|
||||||
def __getattribute__(self, att):
|
def __getattribute__(self, att):
|
||||||
if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']:
|
try:
|
||||||
if att in self.low_sock.__dict__:
|
|
||||||
return getattr(self.low_sock, att)
|
|
||||||
return object.__getattribute__(self, att)
|
return object.__getattribute__(self, att)
|
||||||
|
except AttributeError:
|
||||||
|
return getattr(self.low_sock, att)
|
||||||
|
def getpeername(self):
|
||||||
|
return self.client_addr
|
||||||
def send(self, data):
|
def send(self, data):
|
||||||
if self.closed:
|
if self.closed:
|
||||||
raise EOFError
|
raise EOFError
|
||||||
self.low_sock.send(data, self.client_addr)
|
return self.low_sock.send(data, self.client_addr)
|
||||||
def recv(self, size=None):
|
def recv(self, size):
|
||||||
if self.closed:
|
if self.closed:
|
||||||
raise EOFError
|
raise EOFError
|
||||||
if size:
|
return self.low_sock.recv(size, self.client_addr)
|
||||||
return self.low_sock.recv(self.client_addr)[:size]
|
|
||||||
else:
|
|
||||||
return self.low_sock.recv(self.client_addr)
|
|
||||||
@property
|
@property
|
||||||
def closed(self):
|
def closed(self):
|
||||||
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
|
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
|
||||||
def close(self):
|
def close(self):
|
||||||
|
if self.client_addr in self.low_sock.connections:
|
||||||
self.low_sock.close(self.client_addr)
|
self.low_sock.close(self.client_addr)
|
||||||
|
def shutdown(self, *a, **kw):
|
||||||
|
self.close()
|
||||||
|
def poll(self, timeout):
|
||||||
|
if self.client_addr in self.packets_received['DATA or FIN']:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.incoming_packet_event.wait(timeout)
|
||||||
|
return self.client_addr in self.packets_received['DATA or FIN']
|
||||||
|
return False
|
||||||
|
|
||||||
class TCP(object):
|
class TCP(object):
|
||||||
host = None
|
host = None
|
||||||
port = None
|
port = None
|
||||||
client = False
|
client = False
|
||||||
|
peer_keypair = {}
|
||||||
|
connections = {}
|
||||||
|
connection_queue = []
|
||||||
|
packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||||||
def __init__(self, af_type=None, sock_type=None, encrypted=False):
|
def __init__(self, af_type=None, sock_type=None, encrypted=False):
|
||||||
self.encrypted = encrypted
|
self.encrypted = encrypted
|
||||||
self.incoming_packet_event = threading.Event()
|
self.incoming_packet_event = threading.Event()
|
||||||
|
self.new_conn_event = threading.Event()
|
||||||
#seq will have the last packet send and ack will have the next packet waiting to receive
|
#seq will have the last packet send and ack will have the next packet waiting to receive
|
||||||
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
|
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
|
||||||
self.settimeout()
|
self.settimeout()
|
||||||
self.peer_keypair = {}
|
#self.peer_keypair = {}
|
||||||
self.connections = {}
|
#self.connections = {}
|
||||||
self.connection_queue = []
|
#self.connection_queue = []
|
||||||
self.connection_lock = threading.Lock()
|
self.connection_lock = threading.Lock()
|
||||||
self.queue_lock = threading.Lock()
|
self.queue_lock = threading.Lock()
|
||||||
# each condition will have a dictionary of an address and it's corresponding packet.
|
# each condition will have a dictionary of an address and it's corresponding packet.
|
||||||
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
#self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
||||||
|
def poll(self, timeout):
|
||||||
|
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.incoming_packet_event.wait(timeout)
|
||||||
|
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_free_port(self):
|
def get_free_port(self):
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
s.bind(('', 0))
|
s.bind(('', 0))
|
||||||
|
@ -171,6 +192,8 @@ class TCP(object):
|
||||||
pass
|
pass
|
||||||
def settimeout(self, timeout=5):
|
def settimeout(self, timeout=5):
|
||||||
self.own_socket.settimeout(timeout)
|
self.own_socket.settimeout(timeout)
|
||||||
|
def setblocking(self, mode):
|
||||||
|
self.own_socket.setblocking(mode)
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'TCP()'
|
return 'TCP()'
|
||||||
|
|
||||||
|
@ -178,7 +201,12 @@ class TCP(object):
|
||||||
return 'Connections: %s' \
|
return 'Connections: %s' \
|
||||||
% str(self.connections)
|
% str(self.connections)
|
||||||
def getsockname(self):
|
def getsockname(self):
|
||||||
return (self.host, self.port)
|
return self.own_socket.getsockname()
|
||||||
|
def getpeername(self):
|
||||||
|
if len(self.connections):
|
||||||
|
return list(self.connections.keys())[0]
|
||||||
|
else:
|
||||||
|
raise EOFError('Not connected')
|
||||||
def bind(self, addr):
|
def bind(self, addr):
|
||||||
self.host = addr[0]
|
self.host = addr[0]
|
||||||
self.port = addr[1]
|
self.port = addr[1]
|
||||||
|
@ -201,60 +229,48 @@ class TCP(object):
|
||||||
packet_to_send = pickle.dumps(self.connections[connection])
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||||||
self.connections[connection].checksum = 0
|
self.connections[connection].checksum = 0
|
||||||
self.connections[connection].data = b''
|
self.connections[connection].data = b''
|
||||||
while data_not_received:
|
retransmit_count = 0
|
||||||
|
while data_not_received and retransmit_count < 3:
|
||||||
data_not_received = False
|
data_not_received = False
|
||||||
try:
|
try:
|
||||||
self.own_socket.sendto(packet_to_send, connection)
|
self.own_socket.sendto(packet_to_send, connection)
|
||||||
answer = self.find_correct_packet('ACK', connection)
|
answer = self.find_correct_packet('ACK', connection)
|
||||||
|
if not answer:
|
||||||
|
data_not_received = True
|
||||||
|
retransmit_count += 1
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
#print('timeout')
|
#print('timeout')
|
||||||
data_not_received = True
|
data_not_received = True
|
||||||
|
if not answer:
|
||||||
|
self.drop_connection(connection)
|
||||||
|
raise EOFError('Connection lost')
|
||||||
self.connections[connection].seq += len(data_part)
|
self.connections[connection].seq += len(data_part)
|
||||||
|
return len(data)
|
||||||
except socket.error as error:
|
except socket.error as error:
|
||||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||||||
|
|
||||||
def recv(self, connection=None):
|
def recv(self, size, connection=None):
|
||||||
try:
|
try:
|
||||||
data = b''
|
|
||||||
if connection not in list(self.connections.keys()):
|
if connection not in list(self.connections.keys()):
|
||||||
if connection is None:
|
if connection is None:
|
||||||
connection = list(self.connections.keys())[0]
|
connection = list(self.connections.keys())[0]
|
||||||
else:
|
else:
|
||||||
return 'Connection not in connected devices'
|
raise EOFError('Connection not in connected devices')
|
||||||
|
|
||||||
while True and self.status:
|
data = self.find_correct_packet('DATA or FIN', connection, size)
|
||||||
data_part = self.find_correct_packet('DATA or FIN', connection)
|
|
||||||
if not self.status:
|
if not self.status:
|
||||||
# print('I am disconnectiong cause sock is dead')
|
raise EOFError('Disconnecting')
|
||||||
raise EOFError('Disconnected')
|
return data
|
||||||
if data_part.packet_type() == 'FIN':
|
except socket.error as error:
|
||||||
self.disconnect(connection)
|
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
||||||
raise EOFError('Disconnected')
|
def send_ack(self, connection, ack):
|
||||||
checksum_value = TCP.checksum(data_part.data)
|
self.connections[connection].ack = ack
|
||||||
|
self.connections[connection].seq += 1
|
||||||
while checksum_value != data_part.checksum:
|
|
||||||
data_part = self.find_correct_packet('DATA or FIN', connection)
|
|
||||||
checksum_value = TCP.checksum(data_part.data)
|
|
||||||
|
|
||||||
data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data)
|
|
||||||
if data_chunk != PACKET_END:
|
|
||||||
data += data_chunk
|
|
||||||
self.connections[connection].ack = data_part.seq + len(data_part.data)
|
|
||||||
self.connections[connection].seq += 1 # syn flag is 1 byte
|
|
||||||
self.connections[connection].set_flags(ack=True)
|
self.connections[connection].set_flags(ack=True)
|
||||||
self.connections[connection].data = b''
|
self.connections[connection].data = b''
|
||||||
packet_to_send = pickle.dumps(self.connections[connection])
|
packet_to_send = pickle.dumps(self.connections[connection])
|
||||||
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
|
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
|
||||||
self.connections[connection].set_flags()
|
self.connections[connection].set_flags()
|
||||||
|
|
||||||
if data_chunk == PACKET_END:
|
|
||||||
break
|
|
||||||
|
|
||||||
return data
|
|
||||||
except socket.error as error:
|
|
||||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
|
||||||
|
|
||||||
# conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA']
|
|
||||||
def listen_handler(self, max_connections):
|
def listen_handler(self, max_connections):
|
||||||
try:
|
try:
|
||||||
while True and self.status:
|
while True and self.status:
|
||||||
|
@ -265,12 +281,13 @@ class TCP(object):
|
||||||
if self.encrypted:
|
if self.encrypted:
|
||||||
try:
|
try:
|
||||||
peer_pub = answer.data
|
peer_pub = answer.data
|
||||||
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key
|
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec)
|
||||||
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
||||||
except:
|
except:
|
||||||
self.peer_keypair.pop(address)
|
self.peer_keypair.pop(address)
|
||||||
raise socket.error('Init peer public key error')
|
raise socket.error('Init peer public key error')
|
||||||
self.connection_queue.append((answer, address))
|
self.connection_queue.append((answer, address))
|
||||||
|
self.blink_new_conn_event()
|
||||||
else:
|
else:
|
||||||
self.own_socket.sendto('Connections full', address)
|
self.own_socket.sendto('Connections full', address)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -291,6 +308,7 @@ class TCP(object):
|
||||||
def accept(self):
|
def accept(self):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
self.new_conn_event.wait(0.1)
|
||||||
if self.connection_queue:
|
if self.connection_queue:
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
answer, address = self.connection_queue.pop()
|
answer, address = self.connection_queue.pop()
|
||||||
|
@ -367,14 +385,17 @@ class TCP(object):
|
||||||
self.peer_keypair = {}
|
self.peer_keypair = {}
|
||||||
self.status = 0
|
self.status = 0
|
||||||
raise EOFError('The socket was closed. Error:' + str(error))
|
raise EOFError('The socket was closed. Error:' + str(error))
|
||||||
def shutdown(self, *a, **kw):
|
|
||||||
self.own_socket.close()
|
|
||||||
self.status = 0
|
|
||||||
def fileno(self):
|
def fileno(self):
|
||||||
return self.own_socket.fileno()
|
return self.own_socket.fileno()
|
||||||
@property
|
@property
|
||||||
def closed(self):
|
def closed(self):
|
||||||
return self.own_socket._closed
|
return self.own_socket._closed
|
||||||
|
def drop_connection(self, connection):
|
||||||
|
with self.connection_lock:
|
||||||
|
if len(self.connections):
|
||||||
|
self.connections.pop(connection)
|
||||||
|
if len(self.peer_keypair):
|
||||||
|
self.peer_keypair.pop(connection)
|
||||||
def close(self, connection=None):
|
def close(self, connection=None):
|
||||||
try:
|
try:
|
||||||
if connection not in list(self.connections.keys()):
|
if connection not in list(self.connections.keys()):
|
||||||
|
@ -392,19 +413,11 @@ class TCP(object):
|
||||||
if answer.flag_fin != 1:
|
if answer.flag_fin != 1:
|
||||||
raise Exception('The receiver didn\'t send the fin packet')
|
raise Exception('The receiver didn\'t send the fin packet')
|
||||||
else:
|
else:
|
||||||
self.connections[connection].ack += 1
|
self.send_ack(connection, self.connections[connection].ack + 1)
|
||||||
self.connections[connection].seq += 1
|
self.drop_connection(connection)
|
||||||
self.connections[connection].set_flags(ack=True)
|
if len(self.connections) == 0 and self.client:
|
||||||
packet_to_send = pickle.dumps(self.connections[connection])
|
self.own_socket.close()
|
||||||
self.own_socket.sendto(packet_to_send, connection)
|
self.status = 0
|
||||||
with self.connection_lock:
|
|
||||||
if len(self.connections):
|
|
||||||
self.connections.pop(connection)
|
|
||||||
if len(self.peer_keypair):
|
|
||||||
self.peer_keypair.pop(connection)
|
|
||||||
#if len(self.connections) == 0 and self.client:
|
|
||||||
# self.own_socket.close()
|
|
||||||
# self.status = 0
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
|
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
|
||||||
|
|
||||||
|
@ -436,7 +449,7 @@ class TCP(object):
|
||||||
def checksum(source_bytes):
|
def checksum(source_bytes):
|
||||||
return hashlib.sha1(source_bytes).digest()
|
return hashlib.sha1(source_bytes).digest()
|
||||||
|
|
||||||
def find_correct_packet(self, condition, address=('Any',)):
|
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH):
|
||||||
not_found = True
|
not_found = True
|
||||||
tries = 0
|
tries = 0
|
||||||
while not_found and tries < 2 and self.status:
|
while not_found and tries < 2 and self.status:
|
||||||
|
@ -448,7 +461,9 @@ class TCP(object):
|
||||||
if condition == 'ACK':
|
if condition == 'ACK':
|
||||||
tries += 1
|
tries += 1
|
||||||
if condition == 'DATA or FIN':
|
if condition == 'DATA or FIN':
|
||||||
packet = self.packets_received[condition][address].pop()
|
with self.connection_lock:
|
||||||
|
packet = self.packets_received[condition][address][:size]
|
||||||
|
self.packets_received[condition][address] = self.packets_received[condition][address][size:]
|
||||||
if not len(self.packets_received[condition][address]):
|
if not len(self.packets_received[condition][address]):
|
||||||
del self.packets_received[condition][address]
|
del self.packets_received[condition][address]
|
||||||
else:
|
else:
|
||||||
|
@ -457,21 +472,33 @@ class TCP(object):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
not_found = True
|
not_found = True
|
||||||
self.incoming_packet_event.wait(0.1)
|
self.incoming_packet_event.wait(0.1)
|
||||||
def blink_event(self):
|
def blink_incoming_packet_event(self):
|
||||||
self.incoming_packet_event.set()
|
self.incoming_packet_event.set()
|
||||||
self.incoming_packet_event.clear()
|
self.incoming_packet_event.clear()
|
||||||
|
def blink_new_conn_event(self):
|
||||||
|
self.new_conn_event.set()
|
||||||
|
self.new_conn_event.clear()
|
||||||
def sort_answers(self, packet, address):
|
def sort_answers(self, packet, address):
|
||||||
if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN':
|
if address not in self.connections and packet.packet_type() != 'SYN':
|
||||||
|
return
|
||||||
|
if packet.packet_type() == 'FIN':
|
||||||
|
self.disconnect(address)
|
||||||
|
elif packet.packet_type() == 'DATA':
|
||||||
|
if packet.checksum == TCP.checksum(packet.data):
|
||||||
|
data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data)
|
||||||
|
if data_chunk != PACKET_END:
|
||||||
|
with self.connection_lock:
|
||||||
if address not in self.packets_received['DATA or FIN']:
|
if address not in self.packets_received['DATA or FIN']:
|
||||||
self.packets_received['DATA or FIN'][address] = []
|
self.packets_received['DATA or FIN'][address] = b''
|
||||||
self.packets_received['DATA or FIN'][address].insert(0, packet)
|
self.packets_received['DATA or FIN'][address] += data_chunk
|
||||||
self.blink_event()
|
self.send_ack(address, packet.seq + len(packet.data))
|
||||||
|
self.blink_incoming_packet_event()
|
||||||
elif packet.packet_type() == '':
|
elif packet.packet_type() == '':
|
||||||
#print('redundant packet found', packet)
|
#print('redundant packet found', packet)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.packets_received[packet.packet_type()][address] = packet
|
self.packets_received[packet.packet_type()][address] = packet
|
||||||
self.blink_event()
|
self.blink_incoming_packet_event()
|
||||||
|
|
||||||
def central_receive_handler(self):
|
def central_receive_handler(self):
|
||||||
while True and self.status:
|
while True and self.status:
|
||||||
|
|
Loading…
Reference in New Issue