encrypted "tcp" over udp

utcp_objects
Роман Бородин 2019-04-08 08:10:23 +03:00
parent 20e70a5ed3
commit f7a4da82bd
4 changed files with 315 additions and 97 deletions

0
mods/__init__.py 100644
View File

View File

@ -1,6 +1,12 @@
# Based on rpyc (4.0.2)
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
import rpyc
from rpyc.utils.server import ThreadedServer, spawn
from rpyc.core.stream import SocketStream
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
import socket
from mbedtls import tls
import datetime as dt
@ -10,6 +16,10 @@ from mbedtls import x509
from uuid import uuid4
from contextlib import suppress
import sys
from rpyc.lib import safe_import
zlib = safe_import("zlib")
import errno
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
def block(callback, *args, **kwargs):
while True:
@ -33,6 +43,7 @@ class DTLSCerts:
self.ca1_crt = self.ca0_crt.sign(
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
self.srv_crt, self.srv_key = self.server_cert()
def server_cert(self):
now = dt.datetime.utcnow()
ee0_key = pk.ECC()
@ -46,17 +57,28 @@ dtls_certs = DTLSCerts()
trust_store = tls.TrustStore()
trust_store.add(dtls_certs.ca0_crt)
srv_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
validate_certificates=False,
)
cli_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
)
MAX_IO_CHUNK = 20971520
class DTLSSocketStream(SocketStream):
MAX_IO_CHUNK = MAX_IO_CHUNK
@classmethod
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
if kwargs.pop('ipv6', False):
family = socket.AF_INET6
else:
family = socket.AF_INET
dtls_cli_ctx = tls.ClientContext(tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
))
#tls._enable_debug_output(cli_ctx_conf)
#tls._set_debug_level(10)
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
dtls_cli = dtls_cli_ctx.wrap_socket(
socket.socket(family, socket.SOCK_DGRAM),
server_hostname=None,
@ -66,33 +88,128 @@ class DTLSSocketStream(SocketStream):
block(dtls_cli.do_handshake)
return cls(dtls_cli)
def read(self, count):
return block(SocketStream.read(self, count))
while True:
try:
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
except socket.timeout:
continue
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in retry_errnos:
# windows just has to be a bitch
# inpos: I agree
continue
self.close()
raise EOFError(ex)
else:
break
if not buf:
self.close()
raise EOFError("connection closed by peer")
return buf
def write(self, data):
block(SocketStream.write(self, data))
try:
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
except socket.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
class DTLSChannel(Channel):
MAX_IO_CHUNK = MAX_IO_CHUNK
def recv(self):
raw_data = self.stream.read(self.MAX_IO_CHUNK)
header = raw_data[:self.FRAME_HEADER.size]
raw_data = raw_data[self.FRAME_HEADER.size:]
length, compressed = self.FRAME_HEADER.unpack(header)
data = raw_data[:length]
if compressed:
data = zlib.decompress(data)
return data
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(DTLSChannel(stream), service=service, config=config)
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
cert_reqs=None, ssl_version=None, ciphers=None,
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
ssl_kwargs = {'server_side' : False}
if ciphers is not None:
ssl_kwargs['ciphers'] = ciphers
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
return connect_stream(s, service, config)
class DTLSThreadedServer(ThreadedServer):
def dtls(self, listener_timeout = 0.5, reuse_addr = True):
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close()
srv_crt, srv_key = dtls_certs.server_cert()
dtls_srv_ctx = tls.ServerContext(tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([srv_crt, dtls_certs.ca1_crt], srv_key),
validate_certificates=False,
))
self.host = hostname
self.port = port
#tls._enable_debug_output(srv_ctx_conf)
#tls._set_debug_level(10)
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
if reuse_addr and sys.platform != 'win32':
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dtls_srv.bind((self.host, self.port))
dtls_srv.settimeout(listener_timeout)
dtls_srv.bind((hostname, port))
# dtls_srv.settimeout(listener_timeout)
self.listener = dtls_srv
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)
def _listen(self):
if self.active:
return
#self.listener.listen(self.backlog) #####################
if not self.port:
self.port = self.listener.getsockname()[1]
self.logger.info('server started on [%s]:%s', self.host, self.port)
self.active = True
def _authenticate_and_serve_client(self, sock):
def accept(self):
while self.active:
try:
print('Accepting new connections!')
sock, addrinfo = self.listener.accept()
except socket.timeout:
pass
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
pass
else:
raise EOFError()
raise
else:
break
if not self.active:
return
sock.setblocking(True)
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
self.clients.add(sock)
self._accept_method(sock)
def _accept_method(self, sock):
addr = sock.getpeername()
sock.setcookieparam(addr[0].encode())
with suppress(tls.HelloVerifyRequest):
@ -100,6 +217,6 @@ class DTLSThreadedServer(ThreadedServer):
sock, addr = sock.accept()
sock.setcookieparam(addr[0].encode())
block(sock.do_handshake)
ThreadedServer._authenticate_and_serve_client(self, sock)
spawn(self._authenticate_and_serve_client, sock)

74
mods/rpyc_utcp.py 100644
View File

@ -0,0 +1,74 @@
import rpyc
from rpyc.utils.server import ThreadedServer
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
from rpyc.lib import Timeout
from rpyc.lib.compat import select_error
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
from . import utcp
import sys
import errno
import socket
class UTCPSocketStream(SocketStream):
MAX_IO_CHUNK = utcp.DATA_LENGTH
@classmethod
def utcp_connect(cls, host, port, *a, **kw):
sock = utcp.TCP(encrypted=True)
sock.connect((host, port))
return cls(sock)
def poll(self, timeout):
timeout = Timeout(timeout)
try:
while True:
try:
rl = self.sock.poll(timeout.timeleft())
except select_error:
ex = sys.exc_info()[1]
if ex.args[0] == errno.EINTR:
continue
else:
raise
else:
break
except ValueError:
ex = sys.exc_info()[1]
raise select_error(str(ex))
return rl
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(Channel(stream), service=service, config=config)
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
s = UTCPSocketStream.utcp_connect(host, port, **kw)
return connect_stream(s, service, config)
class UTCPThreadedServer(ThreadedServer):
def __init__(self, service, hostname = '', ipv6 = False, port = 0,
backlog = 1, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
backlog = 1
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close()
self.listener = None
##########
self.listener = utcp.TCP(encrypted=True)
self.listener.bind((hostname, port))
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(Channel(UTCPSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)

View File

@ -6,8 +6,6 @@ import threading
import io
import hashlib
import simplecrypto
from datetime import datetime
DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info
@ -109,7 +107,7 @@ class TCPPacket(object):
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow TCPPacket
if module == "utcp" and name == 'TCPPacket':
if name == 'TCPPacket':
return TCPPacket
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
@ -123,44 +121,67 @@ class ConnectedSOCK(object):
self.client_addr = client_addr
self.low_sock = low_sock
def __getattribute__(self, att):
if not att.startswith('_') and not att in ['client_addr', 'low_sock', 'send', 'recv', 'close', 'closed']:
if att in self.low_sock.__dict__:
return getattr(self.low_sock, att)
try:
return object.__getattribute__(self, att)
except AttributeError:
return getattr(self.low_sock, att)
def getpeername(self):
return self.client_addr
def send(self, data):
if self.closed:
raise EOFError
self.low_sock.send(data, self.client_addr)
def recv(self, size=None):
return self.low_sock.send(data, self.client_addr)
def recv(self, size):
if self.closed:
raise EOFError
if size:
return self.low_sock.recv(self.client_addr)[:size]
else:
return self.low_sock.recv(self.client_addr)
return self.low_sock.recv(size, self.client_addr)
@property
def closed(self):
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
def close(self):
if self.client_addr in self.low_sock.connections:
self.low_sock.close(self.client_addr)
def shutdown(self, *a, **kw):
self.close()
def poll(self, timeout):
if self.client_addr in self.packets_received['DATA or FIN']:
return True
else:
self.incoming_packet_event.wait(timeout)
return self.client_addr in self.packets_received['DATA or FIN']
return False
class TCP(object):
host = None
port = None
client = False
peer_keypair = {}
connections = {}
connection_queue = []
packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
def __init__(self, af_type=None, sock_type=None, encrypted=False):
self.encrypted = encrypted
self.incoming_packet_event = threading.Event()
self.new_conn_event = threading.Event()
#seq will have the last packet send and ack will have the next packet waiting to receive
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP socket used for communication.
self.settimeout()
self.peer_keypair = {}
self.connections = {}
self.connection_queue = []
#self.peer_keypair = {}
#self.connections = {}
#self.connection_queue = []
self.connection_lock = threading.Lock()
self.queue_lock = threading.Lock()
# each condition will have a dictionary of an address and it's corresponding packet.
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
#self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
def poll(self, timeout):
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
return True
else:
self.incoming_packet_event.wait(timeout)
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
return True
return False
def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('', 0))
@ -171,6 +192,8 @@ class TCP(object):
pass
def settimeout(self, timeout=5):
self.own_socket.settimeout(timeout)
def setblocking(self, mode):
self.own_socket.setblocking(mode)
def __repr__(self):
return 'TCP()'
@ -178,7 +201,12 @@ class TCP(object):
return 'Connections: %s' \
% str(self.connections)
def getsockname(self):
return (self.host, self.port)
return self.own_socket.getsockname()
def getpeername(self):
if len(self.connections):
return list(self.connections.keys())[0]
else:
raise EOFError('Not connected')
def bind(self, addr):
self.host = addr[0]
self.port = addr[1]
@ -201,60 +229,48 @@ class TCP(object):
packet_to_send = pickle.dumps(self.connections[connection])
self.connections[connection].checksum = 0
self.connections[connection].data = b''
while data_not_received:
retransmit_count = 0
while data_not_received and retransmit_count < 3:
data_not_received = False
try:
self.own_socket.sendto(packet_to_send, connection)
answer = self.find_correct_packet('ACK', connection)
if not answer:
data_not_received = True
retransmit_count += 1
except socket.timeout:
#print('timeout')
data_not_received = True
if not answer:
self.drop_connection(connection)
raise EOFError('Connection lost')
self.connections[connection].seq += len(data_part)
return len(data)
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def recv(self, connection=None):
def recv(self, size, connection=None):
try:
data = b''
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
return 'Connection not in connected devices'
raise EOFError('Connection not in connected devices')
while True and self.status:
data_part = self.find_correct_packet('DATA or FIN', connection)
data = self.find_correct_packet('DATA or FIN', connection, size)
if not self.status:
# print('I am disconnectiong cause sock is dead')
raise EOFError('Disconnected')
if data_part.packet_type() == 'FIN':
self.disconnect(connection)
raise EOFError('Disconnected')
checksum_value = TCP.checksum(data_part.data)
while checksum_value != data_part.checksum:
data_part = self.find_correct_packet('DATA or FIN', connection)
checksum_value = TCP.checksum(data_part.data)
data_chunk = data_part.data if not self.encrypted else self.peer_keypair[connection].my_key.decrypt_raw(data_part.data)
if data_chunk != PACKET_END:
data += data_chunk
self.connections[connection].ack = data_part.seq + len(data_part.data)
self.connections[connection].seq += 1 # syn flag is 1 byte
raise EOFError('Disconnecting')
return data
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def send_ack(self, connection, ack):
self.connections[connection].ack = ack
self.connections[connection].seq += 1
self.connections[connection].set_flags(ack=True)
self.connections[connection].data = b''
packet_to_send = pickle.dumps(self.connections[connection])
self.own_socket.sendto(packet_to_send, connection) # after receiving correct info sends ack
self.connections[connection].set_flags()
if data_chunk == PACKET_END:
break
return data
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
# conditions = ['SYN', 'SYN-ACK', 'ACK', 'FIN', 'FIN-ACK', 'DATA']
def listen_handler(self, max_connections):
try:
while True and self.status:
@ -265,12 +281,13 @@ class TCP(object):
if self.encrypted:
try:
peer_pub = answer.data
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key
self.peer_keypair[address] = KeyPair(simplecrypto.RsaKeypair()) # FIXME: for some reason slowly creates a key (~5 sec)
self.peer_keypair[address].peer_pub = simplecrypto.RsaPublicKey(peer_pub)
except:
self.peer_keypair.pop(address)
raise socket.error('Init peer public key error')
self.connection_queue.append((answer, address))
self.blink_new_conn_event()
else:
self.own_socket.sendto('Connections full', address)
except KeyError:
@ -291,6 +308,7 @@ class TCP(object):
def accept(self):
try:
while True:
self.new_conn_event.wait(0.1)
if self.connection_queue:
with self.queue_lock:
answer, address = self.connection_queue.pop()
@ -367,14 +385,17 @@ class TCP(object):
self.peer_keypair = {}
self.status = 0
raise EOFError('The socket was closed. Error:' + str(error))
def shutdown(self, *a, **kw):
self.own_socket.close()
self.status = 0
def fileno(self):
return self.own_socket.fileno()
@property
def closed(self):
return self.own_socket._closed
def drop_connection(self, connection):
with self.connection_lock:
if len(self.connections):
self.connections.pop(connection)
if len(self.peer_keypair):
self.peer_keypair.pop(connection)
def close(self, connection=None):
try:
if connection not in list(self.connections.keys()):
@ -392,19 +413,11 @@ class TCP(object):
if answer.flag_fin != 1:
raise Exception('The receiver didn\'t send the fin packet')
else:
self.connections[connection].ack += 1
self.connections[connection].seq += 1
self.connections[connection].set_flags(ack=True)
packet_to_send = pickle.dumps(self.connections[connection])
self.own_socket.sendto(packet_to_send, connection)
with self.connection_lock:
if len(self.connections):
self.connections.pop(connection)
if len(self.peer_keypair):
self.peer_keypair.pop(connection)
#if len(self.connections) == 0 and self.client:
# self.own_socket.close()
# self.status = 0
self.send_ack(connection, self.connections[connection].ack + 1)
self.drop_connection(connection)
if len(self.connections) == 0 and self.client:
self.own_socket.close()
self.status = 0
except Exception as error:
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
@ -436,7 +449,7 @@ class TCP(object):
def checksum(source_bytes):
return hashlib.sha1(source_bytes).digest()
def find_correct_packet(self, condition, address=('Any',)):
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH):
not_found = True
tries = 0
while not_found and tries < 2 and self.status:
@ -448,7 +461,9 @@ class TCP(object):
if condition == 'ACK':
tries += 1
if condition == 'DATA or FIN':
packet = self.packets_received[condition][address].pop()
with self.connection_lock:
packet = self.packets_received[condition][address][:size]
self.packets_received[condition][address] = self.packets_received[condition][address][size:]
if not len(self.packets_received[condition][address]):
del self.packets_received[condition][address]
else:
@ -457,21 +472,33 @@ class TCP(object):
except KeyError:
not_found = True
self.incoming_packet_event.wait(0.1)
def blink_event(self):
def blink_incoming_packet_event(self):
self.incoming_packet_event.set()
self.incoming_packet_event.clear()
def blink_new_conn_event(self):
self.new_conn_event.set()
self.new_conn_event.clear()
def sort_answers(self, packet, address):
if packet.packet_type() == 'DATA' or packet.packet_type() == 'FIN':
if address not in self.connections and packet.packet_type() != 'SYN':
return
if packet.packet_type() == 'FIN':
self.disconnect(address)
elif packet.packet_type() == 'DATA':
if packet.checksum == TCP.checksum(packet.data):
data_chunk = packet.data if not self.encrypted else self.peer_keypair[address].my_key.decrypt_raw(packet.data)
if data_chunk != PACKET_END:
with self.connection_lock:
if address not in self.packets_received['DATA or FIN']:
self.packets_received['DATA or FIN'][address] = []
self.packets_received['DATA or FIN'][address].insert(0, packet)
self.blink_event()
self.packets_received['DATA or FIN'][address] = b''
self.packets_received['DATA or FIN'][address] += data_chunk
self.send_ack(address, packet.seq + len(packet.data))
self.blink_incoming_packet_event()
elif packet.packet_type() == '':
#print('redundant packet found', packet)
pass
else:
self.packets_received[packet.packet_type()][address] = packet
self.blink_event()
self.blink_incoming_packet_event()
def central_receive_handler(self):
while True and self.status: