Compare commits
9 Commits
utcp_objec
...
master
Author | SHA1 | Date |
---|---|---|
Роман Бородин | 9cdc27c1f3 | |
Роман Бородин | ed57cbcb57 | |
Роман Бородин | 328692bd20 | |
Роман Бородин | 4fff6e42fb | |
Роман Бородин | 8ef5153e2e | |
Роман Бородин | 11b7ec1ee2 | |
Роман Бородин | 9932525bd2 | |
Роман Бородин | 4c65194dbc | |
Роман Бородин | 544a98edd3 |
|
@ -0,0 +1,5 @@
|
||||||
|
import os
|
||||||
|
app_dir = os.path.join(os.environ['HOME'], '.black-mamba')
|
||||||
|
signalpeers_db = os.path.join(app_dir, 'signalpeers.sqlite')
|
||||||
|
downloads_db = os.path.join(app_dir, 'downloads.sqlite')
|
||||||
|
sharedfiles_db = os.path.join(app_dir, 'sharedfiles.sqlite')
|
|
@ -0,0 +1,15 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def store_signalpeer(conn, signalpeerkvstore):
|
||||||
|
peer_addr = (conn._channel.stream.sock.getpeername()[0], conn.root.peer_port)
|
||||||
|
signalpeer_key = f'{peer_addr[0]:{peer_addr[1]}}'
|
||||||
|
if signalpeer_key not in signalpeerkvstore:
|
||||||
|
signalpeerkvstore[signalpeer_key] = {'addr': peer_addr, 'seen': datetime.utcnow(), 'cname': conn.root.peer_cname}
|
||||||
|
else:
|
||||||
|
signalpeerkvstore[signalpeer_key]['seen'] = datetime.utcnow()
|
||||||
|
if signalpeerkvstore[signalpeer_key]['cname'] != conn.root.peer_cname:
|
||||||
|
signalpeerkvstore[signalpeer_key]['cname'] = conn.root.peer_cname
|
||||||
|
def merge_signalpeers(signalpeers, signalpeerkvstore):
|
||||||
|
new_keys = list(set(signalpeers.keys()) - set(signalpeerkvstore.keys()))
|
||||||
|
for k in new_keys:
|
||||||
|
signalpeerkvstore[k] = signalpeers[k]
|
|
@ -0,0 +1 @@
|
||||||
|
import rpyc
|
|
@ -1,222 +0,0 @@
|
||||||
# Based on rpyc (4.0.2)
|
|
||||||
|
|
||||||
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
|
|
||||||
|
|
||||||
import rpyc
|
|
||||||
from rpyc.utils.server import ThreadedServer, spawn
|
|
||||||
from rpyc.core import SocketStream, Channel
|
|
||||||
from rpyc.core.stream import retry_errnos
|
|
||||||
from rpyc.utils.factory import connect_channel
|
|
||||||
import socket
|
|
||||||
from mbedtls import tls
|
|
||||||
import datetime as dt
|
|
||||||
from mbedtls import hash as hashlib
|
|
||||||
from mbedtls import pk
|
|
||||||
from mbedtls import x509
|
|
||||||
from uuid import uuid4
|
|
||||||
from contextlib import suppress
|
|
||||||
import sys
|
|
||||||
from rpyc.lib import safe_import
|
|
||||||
zlib = safe_import("zlib")
|
|
||||||
import errno
|
|
||||||
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
|
||||||
|
|
||||||
def block(callback, *args, **kwargs):
|
|
||||||
while True:
|
|
||||||
with suppress(tls.WantReadError, tls.WantWriteError):
|
|
||||||
return callback(*args, **kwargs)
|
|
||||||
|
|
||||||
class DTLSCerts:
|
|
||||||
def __init__(self):
|
|
||||||
now = dt.datetime.utcnow()
|
|
||||||
self.ca0_key = pk.RSA()
|
|
||||||
_ = self.ca0_key.generate()
|
|
||||||
ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256())
|
|
||||||
self.ca0_crt = x509.CRT.selfsign(
|
|
||||||
ca0_csr, self.ca0_key,
|
|
||||||
not_before=now, not_after=now + dt.timedelta(days=3650),
|
|
||||||
serial_number=0x123456,
|
|
||||||
basic_constraints=x509.BasicConstraints(True, 1))
|
|
||||||
self.ca1_key = pk.ECC()
|
|
||||||
_ = self.ca1_key.generate()
|
|
||||||
ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256())
|
|
||||||
self.ca1_crt = self.ca0_crt.sign(
|
|
||||||
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
|
|
||||||
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
|
|
||||||
self.srv_crt, self.srv_key = self.server_cert()
|
|
||||||
def server_cert(self):
|
|
||||||
now = dt.datetime.utcnow()
|
|
||||||
ee0_key = pk.ECC()
|
|
||||||
_ = ee0_key.generate()
|
|
||||||
ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256())
|
|
||||||
ee0_crt = self.ca1_crt.sign(
|
|
||||||
ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654)
|
|
||||||
return ee0_crt, ee0_key
|
|
||||||
|
|
||||||
dtls_certs = DTLSCerts()
|
|
||||||
trust_store = tls.TrustStore()
|
|
||||||
trust_store.add(dtls_certs.ca0_crt)
|
|
||||||
|
|
||||||
srv_ctx_conf = tls.DTLSConfiguration(
|
|
||||||
trust_store=trust_store,
|
|
||||||
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
|
|
||||||
validate_certificates=False,
|
|
||||||
)
|
|
||||||
cli_ctx_conf = tls.DTLSConfiguration(
|
|
||||||
trust_store=trust_store,
|
|
||||||
validate_certificates=False,
|
|
||||||
)
|
|
||||||
MAX_IO_CHUNK = 20971520
|
|
||||||
|
|
||||||
class DTLSSocketStream(SocketStream):
|
|
||||||
MAX_IO_CHUNK = MAX_IO_CHUNK
|
|
||||||
@classmethod
|
|
||||||
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
|
|
||||||
if kwargs.pop('ipv6', False):
|
|
||||||
family = socket.AF_INET6
|
|
||||||
else:
|
|
||||||
family = socket.AF_INET
|
|
||||||
#tls._enable_debug_output(cli_ctx_conf)
|
|
||||||
#tls._set_debug_level(10)
|
|
||||||
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
|
|
||||||
dtls_cli = dtls_cli_ctx.wrap_socket(
|
|
||||||
socket.socket(family, socket.SOCK_DGRAM),
|
|
||||||
server_hostname=None,
|
|
||||||
)
|
|
||||||
dtls_cli.settimeout(timeout)
|
|
||||||
dtls_cli.connect((host, port))
|
|
||||||
block(dtls_cli.do_handshake)
|
|
||||||
return cls(dtls_cli)
|
|
||||||
def read(self, count):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
buf = block(self.sock.recv, min(self.MAX_IO_CHUNK, count))
|
|
||||||
except socket.timeout:
|
|
||||||
continue
|
|
||||||
except socket.error:
|
|
||||||
ex = sys.exc_info()[1]
|
|
||||||
if get_exc_errno(ex) in retry_errnos:
|
|
||||||
# windows just has to be a bitch
|
|
||||||
# inpos: I agree
|
|
||||||
continue
|
|
||||||
self.close()
|
|
||||||
raise EOFError(ex)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
if not buf:
|
|
||||||
self.close()
|
|
||||||
raise EOFError("connection closed by peer")
|
|
||||||
return buf
|
|
||||||
def write(self, data):
|
|
||||||
try:
|
|
||||||
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
|
|
||||||
except socket.error:
|
|
||||||
ex = sys.exc_info()[1]
|
|
||||||
self.close()
|
|
||||||
raise EOFError(ex)
|
|
||||||
|
|
||||||
class DTLSChannel(Channel):
|
|
||||||
MAX_IO_CHUNK = MAX_IO_CHUNK
|
|
||||||
def recv(self):
|
|
||||||
raw_data = self.stream.read(self.MAX_IO_CHUNK)
|
|
||||||
header = raw_data[:self.FRAME_HEADER.size]
|
|
||||||
raw_data = raw_data[self.FRAME_HEADER.size:]
|
|
||||||
length, compressed = self.FRAME_HEADER.unpack(header)
|
|
||||||
data = raw_data[:length]
|
|
||||||
if compressed:
|
|
||||||
data = zlib.decompress(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
|
||||||
return connect_channel(DTLSChannel(stream), service=service, config=config)
|
|
||||||
|
|
||||||
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
|
|
||||||
cert_reqs=None, ssl_version=None, ciphers=None,
|
|
||||||
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
|
|
||||||
ssl_kwargs = {'server_side' : False}
|
|
||||||
if ciphers is not None:
|
|
||||||
ssl_kwargs['ciphers'] = ciphers
|
|
||||||
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
|
|
||||||
return connect_stream(s, service, config)
|
|
||||||
|
|
||||||
class DTLSThreadedServer(ThreadedServer):
|
|
||||||
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
|
|
||||||
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
|
|
||||||
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
|
||||||
socket_path = None):
|
|
||||||
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
|
||||||
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
|
||||||
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
|
||||||
socket_path=socket_path)
|
|
||||||
self.listener.close()
|
|
||||||
|
|
||||||
self.host = hostname
|
|
||||||
self.port = port
|
|
||||||
#tls._enable_debug_output(srv_ctx_conf)
|
|
||||||
#tls._set_debug_level(10)
|
|
||||||
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
|
|
||||||
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
|
|
||||||
if reuse_addr and sys.platform != 'win32':
|
|
||||||
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
dtls_srv.bind((hostname, port))
|
|
||||||
# dtls_srv.settimeout(listener_timeout)
|
|
||||||
self.listener = dtls_srv
|
|
||||||
sockname = self.listener.getsockname()
|
|
||||||
self.host, self.port = sockname[0], sockname[1]
|
|
||||||
def _serve_client(self, sock, credentials):
|
|
||||||
addrinfo = sock.getpeername()
|
|
||||||
if credentials:
|
|
||||||
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
|
||||||
else:
|
|
||||||
self.logger.info("welcome %s", addrinfo)
|
|
||||||
try:
|
|
||||||
config = dict(self.protocol_config, credentials = credentials,
|
|
||||||
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
|
||||||
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
|
|
||||||
self._handle_connection(conn)
|
|
||||||
finally:
|
|
||||||
self.logger.info("goodbye %s", addrinfo)
|
|
||||||
def _listen(self):
|
|
||||||
if self.active:
|
|
||||||
return
|
|
||||||
#self.listener.listen(self.backlog) #####################
|
|
||||||
if not self.port:
|
|
||||||
self.port = self.listener.getsockname()[1]
|
|
||||||
self.logger.info('server started on [%s]:%s', self.host, self.port)
|
|
||||||
self.active = True
|
|
||||||
def accept(self):
|
|
||||||
while self.active:
|
|
||||||
try:
|
|
||||||
print('Accepting new connections!')
|
|
||||||
sock, addrinfo = self.listener.accept()
|
|
||||||
except socket.timeout:
|
|
||||||
pass
|
|
||||||
except socket.error:
|
|
||||||
ex = sys.exc_info()[1]
|
|
||||||
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise EOFError()
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not self.active:
|
|
||||||
return
|
|
||||||
|
|
||||||
sock.setblocking(True)
|
|
||||||
self.logger.info("accepted %s with fd %s", addrinfo, sock.fileno())
|
|
||||||
print("accepted %s with fd %s" % (addrinfo, sock.fileno()))
|
|
||||||
self.clients.add(sock)
|
|
||||||
self._accept_method(sock)
|
|
||||||
def _accept_method(self, sock):
|
|
||||||
addr = sock.getpeername()
|
|
||||||
sock.setcookieparam(addr[0].encode())
|
|
||||||
with suppress(tls.HelloVerifyRequest):
|
|
||||||
block(sock.do_handshake)
|
|
||||||
sock, addr = sock.accept()
|
|
||||||
sock.setcookieparam(addr[0].encode())
|
|
||||||
block(sock.do_handshake)
|
|
||||||
spawn(self._authenticate_and_serve_client, sock)
|
|
||||||
|
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
import rpyc
|
|
||||||
from rpyc.utils.server import ThreadedServer
|
|
||||||
from rpyc.core import SocketStream, Channel
|
|
||||||
from rpyc.core.stream import retry_errnos
|
|
||||||
from rpyc.utils.factory import connect_channel
|
|
||||||
from rpyc.lib import Timeout
|
|
||||||
from rpyc.lib.compat import select_error
|
|
||||||
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
|
|
||||||
from . import utcp
|
|
||||||
import sys
|
|
||||||
import errno
|
|
||||||
import socket
|
|
||||||
|
|
||||||
class UTCPSocketStream(SocketStream):
|
|
||||||
MAX_IO_CHUNK = utcp.DATA_LENGTH
|
|
||||||
@classmethod
|
|
||||||
def utcp_connect(cls, host, port, *a, **kw):
|
|
||||||
sock = utcp.TCP(encrypted=True)
|
|
||||||
sock.connect((host, port))
|
|
||||||
return cls(sock)
|
|
||||||
def poll(self, timeout):
|
|
||||||
timeout = Timeout(timeout)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
rl = self.sock.poll(timeout.timeleft())
|
|
||||||
except select_error:
|
|
||||||
ex = sys.exc_info()[1]
|
|
||||||
if ex.args[0] == errno.EINTR:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
except ValueError:
|
|
||||||
ex = sys.exc_info()[1]
|
|
||||||
raise select_error(str(ex))
|
|
||||||
return rl
|
|
||||||
def connect_stream(stream, service=rpyc.VoidService, config={}):
|
|
||||||
return connect_channel(Channel(stream), service=service, config=config)
|
|
||||||
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
|
|
||||||
s = UTCPSocketStream.utcp_connect(host, port, **kw)
|
|
||||||
return connect_stream(s, service, config)
|
|
||||||
class UTCPThreadedServer(ThreadedServer):
|
|
||||||
def __init__(self, service, hostname = '', ipv6 = False, port = 0,
|
|
||||||
backlog = 1, reuse_addr = True, authenticator = None, registrar = None,
|
|
||||||
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
|
|
||||||
socket_path = None):
|
|
||||||
backlog = 1
|
|
||||||
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
|
|
||||||
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
|
|
||||||
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
|
|
||||||
socket_path=socket_path)
|
|
||||||
self.listener.close()
|
|
||||||
self.listener = None
|
|
||||||
##########
|
|
||||||
self.listener = utcp.TCP(encrypted=True)
|
|
||||||
self.listener.bind((hostname, port))
|
|
||||||
sockname = self.listener.getsockname()
|
|
||||||
self.host, self.port = sockname[0], sockname[1]
|
|
||||||
def _serve_client(self, sock, credentials):
|
|
||||||
addrinfo = sock.getpeername()
|
|
||||||
if credentials:
|
|
||||||
self.logger.info("welcome %s (%r)", addrinfo, credentials)
|
|
||||||
else:
|
|
||||||
self.logger.info("welcome %s", addrinfo)
|
|
||||||
try:
|
|
||||||
config = dict(self.protocol_config, credentials = credentials,
|
|
||||||
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
|
|
||||||
conn = self.service._connect(Channel(UTCPSocketStream(sock)), config)
|
|
||||||
self._handle_connection(conn)
|
|
||||||
finally:
|
|
||||||
self.logger.info("goodbye %s", addrinfo)
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
def __init__(self):
|
||||||
|
self.listen_port = 8765
|
||||||
|
self.listen_host = ''
|
||||||
|
self.cname = 'Very cool peer'
|
||||||
|
|
||||||
|
settings = Settings()
|
|
@ -0,0 +1,41 @@
|
||||||
|
import rpyc
|
||||||
|
from . import const
|
||||||
|
from . import typedefs
|
||||||
|
from .settings import settings
|
||||||
|
from uuid import uuid4
|
||||||
|
from sqlitedict import SqliteDict
|
||||||
|
from . import helpers
|
||||||
|
|
||||||
|
peer = {}
|
||||||
|
my_id = uuid4().hex
|
||||||
|
|
||||||
|
class SignalpeerService(rpyc.Service):
|
||||||
|
signalpeerkvstore = None
|
||||||
|
peer_id = None
|
||||||
|
exposed_peer_type = typedefs.SIGNALPEER
|
||||||
|
def on_connect(self, conn):
|
||||||
|
self.peer_id = uuid4().hex
|
||||||
|
if conn.root.pid == my_id:
|
||||||
|
conn.close()
|
||||||
|
return
|
||||||
|
self.signalpeerkvstore = SqliteDict(const.signalpeers_db, autocommit=True)
|
||||||
|
peer[self.peer_id] = conn
|
||||||
|
peer_type = conn.root.peer_type
|
||||||
|
if peer_type == typedefs.SIGNALPEER:
|
||||||
|
helpers.store_signalpeer(conn, self.signalpeerkvstore)
|
||||||
|
helpers.merge_signalpeers(conn.root.signal_peers, self.signalpeerkvstore)
|
||||||
|
def on_disconnect(self, conn):
|
||||||
|
if self.peer_id in peer:
|
||||||
|
peer.pop(self.peer_id)
|
||||||
|
@property
|
||||||
|
def exposed_peer_port(self):
|
||||||
|
return settings.listen_port
|
||||||
|
@property
|
||||||
|
def exposed_peer_cname(self):
|
||||||
|
return settings.cname
|
||||||
|
@property
|
||||||
|
def exposed_pid(self):
|
||||||
|
return my_id
|
||||||
|
@property
|
||||||
|
def exposed_signal_peers(self):
|
||||||
|
return dict(self.signalpeerkvstore.items())
|
356
mods/stream.py
356
mods/stream.py
|
@ -1,356 +0,0 @@
|
||||||
import socketserver
|
|
||||||
import pickle
|
|
||||||
import simplecrypto
|
|
||||||
from uuid import uuid4
|
|
||||||
import datetime, io
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from struct import Struct
|
|
||||||
try:
|
|
||||||
import zlib
|
|
||||||
except:
|
|
||||||
zlib = None
|
|
||||||
|
|
||||||
peers = {}
|
|
||||||
|
|
||||||
PORT = 16386
|
|
||||||
|
|
||||||
RETRANSMIT_RETRIES = 3
|
|
||||||
DATAGRAM_MAX_SIZE = 9000
|
|
||||||
RAW_DATA_MAX_SIZE = 8000
|
|
||||||
PACKET_NUM_SEQ_TTL = 300
|
|
||||||
|
|
||||||
SOCK_SEND_TIMEOUT = 60
|
|
||||||
|
|
||||||
PACKET_TYPE_HELLO = 0x00
|
|
||||||
PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01
|
|
||||||
PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02
|
|
||||||
PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03
|
|
||||||
PACKET_TYPE_PACKET = 0xa0
|
|
||||||
PACKET_TYPE_CONFIRM_RECV = 0xa1
|
|
||||||
PACKET_TYPE_GOODBUY = 0xff
|
|
||||||
|
|
||||||
class InvalidPacket(Exception): pass
|
|
||||||
class OldPacket(Exception): pass
|
|
||||||
|
|
||||||
def pickle_data(data):
|
|
||||||
return pickle.dumps(data, protocol=4)
|
|
||||||
################
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
# Only allow datetime
|
|
||||||
if module == "datetime" and name == 'datetime':
|
|
||||||
return getattr(datetime, name)
|
|
||||||
# Forbid everything else.
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
def restricted_pickle_loads(s):
|
|
||||||
"""Helper function analogous to pickle.loads()."""
|
|
||||||
return RestrictedUnpickler(io.BytesIO(s)).load()
|
|
||||||
################
|
|
||||||
|
|
||||||
|
|
||||||
# From rpyc.lib
|
|
||||||
class Timeout:
|
|
||||||
def __init__(self, timeout):
|
|
||||||
if isinstance(timeout, Timeout):
|
|
||||||
self.finite = timeout.finite
|
|
||||||
self.tmax = timeout.tmax
|
|
||||||
else:
|
|
||||||
self.finite = timeout is not None and timeout >= 0
|
|
||||||
self.tmax = time.time()+timeout if self.finite else None
|
|
||||||
def expired(self):
|
|
||||||
return self.finite and time.time() >= self.tmax
|
|
||||||
def timeleft(self):
|
|
||||||
return max((0, self.tmax - time.time())) if self.finite else None
|
|
||||||
def sleep(self, interval):
|
|
||||||
time.sleep(min(interval, self.timeleft()) if self.finite else interval)
|
|
||||||
|
|
||||||
class Packet:
|
|
||||||
def __init__(self, packet_payload):
|
|
||||||
try:
|
|
||||||
d = restricted_pickle_loads(packet_payload)
|
|
||||||
self.sid = d['sid']
|
|
||||||
self.type = d['type']
|
|
||||||
self.reset_timestamp = d['reset_timestamp']
|
|
||||||
self.num = d['num']
|
|
||||||
self.data = d['data']
|
|
||||||
except:
|
|
||||||
raise InvalidPacket
|
|
||||||
|
|
||||||
class Peer:
|
|
||||||
def __init__(self, sock, endpoint):
|
|
||||||
self.sid = None
|
|
||||||
self.sock = sock
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.my_key = None
|
|
||||||
self.peer_pub_key = None
|
|
||||||
self.buf = []
|
|
||||||
self.confirm_wait_packet = None
|
|
||||||
self.last_packet = None
|
|
||||||
self.request_lock = threading.Lock()
|
|
||||||
self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL)
|
|
||||||
self.last_sent_packet_num = -1
|
|
||||||
self.last_sent_packet_num_reset_time = datetime.datetime.utcnow()
|
|
||||||
self.last_received_packet_num = -1
|
|
||||||
self.last_received_packet_num_reset_time = None
|
|
||||||
self.retransmit_count = 0
|
|
||||||
def next_packet_num(self):
|
|
||||||
new_time = datetime.datetime.utcnow()
|
|
||||||
if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl:
|
|
||||||
self.last_sent_packet_num = -1
|
|
||||||
self.last_sent_packet_num += 1
|
|
||||||
return self.last_sent_packet_num
|
|
||||||
def poll(self):
|
|
||||||
return bool(len(self.buf))
|
|
||||||
def get_next_block(self):
|
|
||||||
if not len(self.buf):
|
|
||||||
return None
|
|
||||||
return self.buf.pop()
|
|
||||||
def put_block(self, data):
|
|
||||||
self.buf.insert(0, data)
|
|
||||||
def send(self, d, encrypted=False, confirm=False):
|
|
||||||
if 'sid' not in d: d['sid'] = self.sid
|
|
||||||
if 'num' not in d: d['num'] = None
|
|
||||||
if 'reset_timestamp' not in d: d['reset_timestamp'] = None
|
|
||||||
if 'data' not in d: d['data'] = b''
|
|
||||||
if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data'])
|
|
||||||
data = pickle_data(d)
|
|
||||||
if confirm:
|
|
||||||
self.last_packet = data
|
|
||||||
self.confirm_wait_packet = (d['reset_timestamp'], d['num'])
|
|
||||||
self.sock.sendto(data, self.endpoint)
|
|
||||||
def mark_packet(self, d):
|
|
||||||
d['num'] = self.next_packet_num()
|
|
||||||
d['reset_timestamp'] = self.last_sent_packet_num_reset_time
|
|
||||||
return d
|
|
||||||
def retransmit(self):
|
|
||||||
self.retransmit_count += 1
|
|
||||||
if self.retransmit_count > RETRANSMIT_RETRIES:
|
|
||||||
raise EOFError('retransmit limit reached')
|
|
||||||
self.sock.sendto(self.last_packet, self.endpoint)
|
|
||||||
def reply_my_pub_key(self, packet):
|
|
||||||
try:
|
|
||||||
self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data)
|
|
||||||
except:
|
|
||||||
raise EOFError('invalid pubkey data')
|
|
||||||
self.my_key = simplecrypto.RsaKeypair()
|
|
||||||
d = {
|
|
||||||
'type': PACKET_TYPE_PEER_PUB_KEY_REPLY,
|
|
||||||
'data': self.my_key.publickey.serialize()
|
|
||||||
}
|
|
||||||
self.send(d, encrypted=True)
|
|
||||||
def request_peer_bub_key(self, packet):
|
|
||||||
self.sid = packet.sid
|
|
||||||
self.my_key = simplecrypto.RsaKeypair()
|
|
||||||
d = {
|
|
||||||
'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST,
|
|
||||||
'data': self.my_key.publickey.serialize()
|
|
||||||
}
|
|
||||||
self.send(d)
|
|
||||||
def confirm_packet_recv(self, packet):
|
|
||||||
self.confirm_wait_packet = None
|
|
||||||
self.last_packet = None
|
|
||||||
d = {
|
|
||||||
'type': PACKET_TYPE_CONFIRM_RECV,
|
|
||||||
'num': packet.num,
|
|
||||||
'reset_timestamp': packet.reset_timestamp
|
|
||||||
}
|
|
||||||
self.send(d)
|
|
||||||
def check_received_packet(self, packet):
|
|
||||||
if self.last_received_packet_num_reset_time:
|
|
||||||
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
|
|
||||||
raise OldPacket('packet from past')
|
|
||||||
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
|
|
||||||
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
|
||||||
if (self.last_received_packet_num + 1) != packet.num:
|
|
||||||
raise EOFError('packet sequence corrupt')
|
|
||||||
else:
|
|
||||||
self.last_received_packet_num_reset_time = packet.reset_timestamp
|
|
||||||
self.last_received_packet_num = packet.num
|
|
||||||
def send_recv_confirmation(self, packet):
|
|
||||||
pass
|
|
||||||
def hello(self):
|
|
||||||
self.sid = uuid4().hex
|
|
||||||
d = {
|
|
||||||
'type': PACKET_TYPE_HELLO,
|
|
||||||
}
|
|
||||||
self.sock.sendto(pickle_data(d))
|
|
||||||
def recv_packet(self, packet_payload):
|
|
||||||
with self.request_lock:
|
|
||||||
try:
|
|
||||||
packet = Packet(packet_payload)
|
|
||||||
if packet.type == PACKET_TYPE_GOODBUY:
|
|
||||||
raise EOFError('connection closed')
|
|
||||||
except:
|
|
||||||
raise EOFError('invalid packet')
|
|
||||||
if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid):
|
|
||||||
self.hello()
|
|
||||||
return
|
|
||||||
############################################
|
|
||||||
if not self.peer_pub_key:
|
|
||||||
if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY:
|
|
||||||
try:
|
|
||||||
self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data))
|
|
||||||
return
|
|
||||||
except:
|
|
||||||
raise EOFError('create pubkey failed')
|
|
||||||
elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST:
|
|
||||||
self.reply_my_pub_key(packet)
|
|
||||||
return
|
|
||||||
elif packet.type == PACKET_TYPE_HELLO:
|
|
||||||
self.request_peer_bub_key(packet)
|
|
||||||
return
|
|
||||||
############################################
|
|
||||||
if self.confirm_wait_packet:
|
|
||||||
if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV:
|
|
||||||
self.confirm_packet_recv(packet)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.retransmit()
|
|
||||||
return
|
|
||||||
############################################
|
|
||||||
else:
|
|
||||||
if packet.type == PACKET_TYPE_PACKET:
|
|
||||||
try:
|
|
||||||
self.check_received_packet(packet)
|
|
||||||
except OldPacket:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
raw = self.my_key.decrypt_raw(packet.data)
|
|
||||||
except:
|
|
||||||
raise EOFError('decrypt packet error')
|
|
||||||
self.put_block(raw)
|
|
||||||
else:
|
|
||||||
raise EOFError('connection lost')
|
|
||||||
def send_packet(self, raw):
|
|
||||||
if self.confirm_wait_packet:
|
|
||||||
timeout = Timeout(SOCK_SEND_TIMEOUT)
|
|
||||||
while timeout.timeleft():
|
|
||||||
if not self.confirm_wait_packet: break
|
|
||||||
if self.confirm_wait_packet:
|
|
||||||
raise EOFError('connection lost')
|
|
||||||
d = {
|
|
||||||
'type': PACKET_TYPE_PACKET,
|
|
||||||
'data': raw
|
|
||||||
}
|
|
||||||
self.send(self.mark_packet(d), encrypted=True, confirm=True)
|
|
||||||
|
|
||||||
class UDPRequestHandler(socketserver.DatagramRequestHandler):
|
|
||||||
def finish(self):
|
|
||||||
'''Don't send anything'''
|
|
||||||
pass
|
|
||||||
def handle(self):
|
|
||||||
datagram = self.rfile.read(DATAGRAM_MAX_SIZE)
|
|
||||||
peer_addr = self.client_address
|
|
||||||
if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr)
|
|
||||||
try:
|
|
||||||
peers[peer_addr].recv_packet(datagram)
|
|
||||||
except EOFError:
|
|
||||||
del peers[peer_addr]
|
|
||||||
|
|
||||||
class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler)
|
|
||||||
udpserver_thread = threading.Thread(target=udpserver.serve_forever)
|
|
||||||
udpserver_thread.start()
|
|
||||||
|
|
||||||
class EncryptedUDPStream:
|
|
||||||
def __init__(self, sock, peer_addr):
|
|
||||||
self.peer_addr = peer_addr
|
|
||||||
self.sock = sock
|
|
||||||
@classmethod
|
|
||||||
def _connect(cls, host, port):
|
|
||||||
peers[(host, port)] = Peer(udpserver.socket, (host, port))
|
|
||||||
peers[(host, port)].hello()
|
|
||||||
return udpserver.socket
|
|
||||||
@classmethod
|
|
||||||
def connect(cls, host, port, **kwargs):
|
|
||||||
return cls(cls._connect(host, port), (host, port))
|
|
||||||
def poll(self, timeout):
|
|
||||||
timeout = Timeout(timeout)
|
|
||||||
while timeout.timeleft():
|
|
||||||
try:
|
|
||||||
rl = peers[self.peer_addr].poll()
|
|
||||||
if rl: break
|
|
||||||
except:
|
|
||||||
raise EOFError
|
|
||||||
return rl
|
|
||||||
def close(self):
|
|
||||||
if self.peer_addr in peers: del peers[self.peer_addr]
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self.peer_addr not in peers
|
|
||||||
def fileno(self):
|
|
||||||
try:
|
|
||||||
return self.sock.fileno()
|
|
||||||
except:
|
|
||||||
self.close()
|
|
||||||
raise EOFError
|
|
||||||
def read(self):
|
|
||||||
try:
|
|
||||||
buf = peers[self.peer_addr].get_next_block()
|
|
||||||
except:
|
|
||||||
raise EOFError
|
|
||||||
return buf
|
|
||||||
def write(self, data):
|
|
||||||
try:
|
|
||||||
peers[self.peer_addr].send_packet(data)
|
|
||||||
except:
|
|
||||||
raise EOFError
|
|
||||||
|
|
||||||
class Channel(object):
|
|
||||||
MAX_IO_CHUNK = 8000
|
|
||||||
COMPRESSION_THRESHOLD = 3000
|
|
||||||
COMPRESSION_LEVEL = 1
|
|
||||||
FRAME_HEADER = Struct("!LB")
|
|
||||||
FLUSHER = b'\n'
|
|
||||||
__slots__ = ["stream", "compress"]
|
|
||||||
|
|
||||||
def __init__(self, stream, compress = True):
|
|
||||||
self.stream = stream
|
|
||||||
if not zlib:
|
|
||||||
compress = False
|
|
||||||
self.compress = compress
|
|
||||||
def close(self):
|
|
||||||
self.stream.close()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self.stream.closed
|
|
||||||
def fileno(self):
|
|
||||||
return self.stream.fileno()
|
|
||||||
|
|
||||||
def poll(self, timeout):
|
|
||||||
return self.stream.poll(timeout)
|
|
||||||
|
|
||||||
def recv(self):
|
|
||||||
header = self.stream.read()
|
|
||||||
if len(header) != self.FRAME_HEADER.size:
|
|
||||||
raise EOFError('CHANNEL: Not a header received')
|
|
||||||
length, compressed = self.FRAME_HEADER.unpack(header)
|
|
||||||
block_len = length + len(self.FLUSHER)
|
|
||||||
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
|
|
||||||
if len(full_block) != block_len:
|
|
||||||
raise EOFError('CHANNEL: Received block with wrong size')
|
|
||||||
data = full_block[:-len(self.FLUSHER)]
|
|
||||||
if compressed:
|
|
||||||
data = zlib.decompress(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def send(self, data):
|
|
||||||
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
|
|
||||||
compressed = 1
|
|
||||||
data = zlib.compress(data, self.COMPRESSION_LEVEL)
|
|
||||||
else:
|
|
||||||
compressed = 0
|
|
||||||
header = self.FRAME_HEADER.pack(len(data), compressed)
|
|
||||||
self.stream.write(header)
|
|
||||||
buf = data + self.FLUSHER
|
|
||||||
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
|
|
||||||
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
|
|
||||||
|
|
||||||
import rpyc.utils.server
|
|
||||||
import rpyc.utils.factory
|
|
||||||
import rpyc.Service
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
SIGNALPEER = 1
|
||||||
|
PEER = 2
|
|
@ -1,78 +0,0 @@
|
||||||
import socketserver
|
|
||||||
import pickle
|
|
||||||
import simplecrypto
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
BUFSIZE = 8192
|
|
||||||
BLOCKSIZE = 4096
|
|
||||||
|
|
||||||
PACKET_TYPE_RECV_RESULT = 'recv_result'
|
|
||||||
PACKET_TYPE_DATA_FRAGMENT = 'data_fragment'
|
|
||||||
PACKET_TYPE_SVC_MESSAGE = 'svc_message'
|
|
||||||
PACKET_TYPE_NEW_CONNECTION = 'new_connection'
|
|
||||||
|
|
||||||
DATA_TYPE_FILE_CHUNK = 'file_chunk'
|
|
||||||
DATA_TYPE_CMD = 'cmd'
|
|
||||||
DATA_TYPE_SEARCH_QUERY = 'search_query'
|
|
||||||
DATA_TYPE_SEARCH_RESULT = 'search_result'
|
|
||||||
DATA_TYPE_PEER_PUBKEY = 'peer_pubkey'
|
|
||||||
|
|
||||||
SVC_MESSAGE_BAD_PACKET = 'bad_packet'
|
|
||||||
SVC_MESSAGE_DECRYPT_ERROR = 'decrypt_error'
|
|
||||||
SVC_MESSAGE_YOU_ARE_STRANGER = 'you_are_stranger'
|
|
||||||
|
|
||||||
RECV_OK_CODE = 0
|
|
||||||
RECV_ERROR_CODE = 1
|
|
||||||
|
|
||||||
HEADER_SVC_YOU_ARE_STRANGER = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': SVC_MESSAGE_YOU_ARE_STRANGER}
|
|
||||||
|
|
||||||
HEADER_RECV_OK = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_OK_CODE}
|
|
||||||
HEADER_RECV_ERROR = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_ERROR_CODE}
|
|
||||||
|
|
||||||
key = None
|
|
||||||
|
|
||||||
peers = {}
|
|
||||||
|
|
||||||
class Peer:
|
|
||||||
def __init__(self, peer_addr, pubkey):
|
|
||||||
self.addr = peer_addr
|
|
||||||
self.pubkey = pubkey
|
|
||||||
|
|
||||||
def get_key():
|
|
||||||
global key
|
|
||||||
if not key:
|
|
||||||
key = simplecrypto.RsaKeypair()
|
|
||||||
return key
|
|
||||||
|
|
||||||
def pickle_data(data):
|
|
||||||
return pickle.dumps(data, protocol=4)
|
|
||||||
|
|
||||||
def write_svc_msg(fobj, svc_msg, extra_data={}):
|
|
||||||
MSG = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': svc_msg}
|
|
||||||
if extra_data:
|
|
||||||
MSG.update(extra_data)
|
|
||||||
fobj.write(pickle_data(MSG))
|
|
||||||
|
|
||||||
class UDPRequestHandler(socketserver.DatagramRequestHandler):
|
|
||||||
def handle(self):
|
|
||||||
datagram = self.rfile.read(BUFSIZE)
|
|
||||||
if self.client_address not in peers:
|
|
||||||
write_svc_msg(self.wfile, SVC_MESSAGE_YOU_ARE_STRANGER)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
unpickled_datagram = pickle.dumps(datagram)
|
|
||||||
except pickle.UnpicklingError:
|
|
||||||
write_svc_msg(self.wfile, SVC_MESSAGE_BAD_PACKET)
|
|
||||||
return
|
|
||||||
pk = get_key()
|
|
||||||
try:
|
|
||||||
data = pickle.loads(pk.decrypt_raw(unpickled_datagram['packet']))
|
|
||||||
except:
|
|
||||||
write_svc_msg(self.wfile, SVC_MESSAGE_DECRYPT_ERROR, {'packet_id': unpickled_datagram['packet_id']})
|
|
||||||
return
|
|
||||||
block_id = data['block_id']
|
|
||||||
fragment_num = data['num']
|
|
||||||
cmd = data['']
|
|
||||||
|
|
||||||
|
|
||||||
|
|
481
mods/utcp.py
481
mods/utcp.py
|
@ -1,481 +0,0 @@
|
||||||
# based on https://github.com/ethay012/TCP-over-UDP
|
|
||||||
import random
|
|
||||||
import socket
|
|
||||||
import pickle
|
|
||||||
import threading
|
|
||||||
import io
|
|
||||||
import hashlib
|
|
||||||
import simplecrypto
|
|
||||||
|
|
||||||
DATA_DIVIDE_LENGTH = 8000
|
|
||||||
PACKET_HEADER_SIZE = 512 # Pickle service info
|
|
||||||
DATA_LENGTH = DATA_DIVIDE_LENGTH
|
|
||||||
SENT_SIZE = PACKET_HEADER_SIZE + DATA_LENGTH + 272 # Encrypted data always 272 bytes bigger
|
|
||||||
LAST_CONNECTION = -1
|
|
||||||
FIRST = 0
|
|
||||||
PACKET_END = b'___+++^^^END^^^+++___'
|
|
||||||
|
|
||||||
# need for emulate
|
|
||||||
AF_INET = None
|
|
||||||
SOCK_STREAM = None
|
|
||||||
|
|
||||||
class Connection:
|
|
||||||
SMALLEST_STARTING_SEQ = 0
|
|
||||||
HIGHEST_STARTING_SEQ = 4294967295
|
|
||||||
def __init__(self, remote, encrypted=False):
|
|
||||||
self.peer_addr = remote
|
|
||||||
self.ack = 0
|
|
||||||
self.seq = Connection.gen_starting_seq_num()
|
|
||||||
self.my_key = None
|
|
||||||
if encrypted:
|
|
||||||
self.my_key = simplecrypto.RsaKeypair()
|
|
||||||
self.pubkey = self.my_key.publickey.serialize()
|
|
||||||
self.peer_pub = None
|
|
||||||
self.recv_lock = threading.Lock()
|
|
||||||
self.send_lock = threading.Lock()
|
|
||||||
@staticmethod
|
|
||||||
def gen_starting_seq_num():
|
|
||||||
return random.randint(Connection.SMALLEST_STARTING_SEQ, Connection.HIGHEST_STARTING_SEQ)
|
|
||||||
def seq_inc(self, inc=1):
|
|
||||||
self.seq += inc
|
|
||||||
return self.seq
|
|
||||||
def set_ack(self, ack):
|
|
||||||
self.ack = ack
|
|
||||||
return ack
|
|
||||||
|
|
||||||
class TCPPacket(object):
|
|
||||||
def __init__(self, seq):
|
|
||||||
self.seq = seq
|
|
||||||
self.ack = 0
|
|
||||||
self.flag_ack = 0
|
|
||||||
self.flag_syn = 0
|
|
||||||
self.flag_fin = 0
|
|
||||||
self.checksum = 0
|
|
||||||
self.data = b''
|
|
||||||
def __repr__(self):
|
|
||||||
return f'TCPpacket(type={self.packet_type()})'
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return 'SEQ Number: %d, ACK Number: %d, ACK:%d, SYN:%d, FIN:%d, TYPE:%s, DATA:%s' \
|
|
||||||
% (self.seq, self.ack, self.flag_ack, self.flag_syn, self.flag_fin, self.packet_type(), self.data)
|
|
||||||
|
|
||||||
def __cmp__(self, other):
|
|
||||||
return (self.seq > other.seq) - (self.seq < other.seq)
|
|
||||||
|
|
||||||
def packet_type(self):
|
|
||||||
packet_type = ''
|
|
||||||
if self.flag_syn == 1 and self.flag_ack == 1:
|
|
||||||
packet_type = 'SYN-ACK'
|
|
||||||
elif self.flag_ack == 1 and self.flag_fin == 1:
|
|
||||||
packet_type = 'FIN-ACK'
|
|
||||||
elif self.flag_syn == 1:
|
|
||||||
packet_type = 'SYN'
|
|
||||||
elif self.flag_ack == 1:
|
|
||||||
packet_type = 'ACK'
|
|
||||||
elif self.flag_fin == 1:
|
|
||||||
packet_type = 'FIN'
|
|
||||||
elif self.data != b'':
|
|
||||||
packet_type = 'DATA'
|
|
||||||
return packet_type
|
|
||||||
|
|
||||||
def set_flags(self, ack=False, syn=False, fin=False):
|
|
||||||
if ack:
|
|
||||||
self.flag_ack = 1
|
|
||||||
else:
|
|
||||||
self.flag_ack = 0
|
|
||||||
if syn:
|
|
||||||
self.flag_syn = 1
|
|
||||||
else:
|
|
||||||
self.flag_syn = 0
|
|
||||||
if fin:
|
|
||||||
self.flag_fin = 1
|
|
||||||
else:
|
|
||||||
self.flag_fin = 0
|
|
||||||
def set_data(self, data):
|
|
||||||
self.checksum = TCP.checksum(data)
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
if name == 'TCPPacket':
|
|
||||||
return TCPPacket
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
def restricted_pickle_loads(s):
|
|
||||||
return RestrictedUnpickler(io.BytesIO(s)).load()
|
|
||||||
|
|
||||||
class ConnectedSOCK(object):
|
|
||||||
def __init__(self, low_sock, client_addr):
|
|
||||||
self.client_addr = client_addr
|
|
||||||
self.low_sock = low_sock
|
|
||||||
def __getattribute__(self, att):
|
|
||||||
try:
|
|
||||||
return object.__getattribute__(self, att)
|
|
||||||
except AttributeError:
|
|
||||||
return getattr(self.low_sock, att)
|
|
||||||
def getpeername(self):
|
|
||||||
return self.client_addr
|
|
||||||
def send(self, data):
|
|
||||||
if self.closed:
|
|
||||||
raise EOFError
|
|
||||||
return self.low_sock.send(data, self.client_addr)
|
|
||||||
def recv(self, size):
|
|
||||||
if self.closed:
|
|
||||||
raise EOFError
|
|
||||||
return self.low_sock.recv(size, self.client_addr)
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self.low_sock.own_socket._closed or (self.client_addr not in self.low_sock.connections or self.low_sock.connections[self.client_addr].flag_fin)
|
|
||||||
def close(self):
|
|
||||||
if self.client_addr in self.low_sock.connections:
|
|
||||||
self.low_sock.close(self.client_addr)
|
|
||||||
def shutdown(self, *a, **kw):
|
|
||||||
self.close()
|
|
||||||
def poll(self, timeout):
|
|
||||||
if self.client_addr in self.packets_received['DATA or FIN']:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
self.incoming_packet_event.wait(timeout)
|
|
||||||
return self.client_addr in self.packets_received['DATA or FIN']
|
|
||||||
return False
|
|
||||||
|
|
||||||
class TCP(object):
|
|
||||||
host = None
|
|
||||||
port = None
|
|
||||||
client = False
|
|
||||||
def __init__(self, af_type=None, sock_type=None, encrypted=False):
|
|
||||||
self.encrypted = encrypted
|
|
||||||
self.incoming_packet_event = threading.Event()
|
|
||||||
self.new_conn_event = threading.Event()
|
|
||||||
self.own_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
self.settimeout()
|
|
||||||
self.connection_lock = threading.Lock()
|
|
||||||
self.queue_lock = threading.Lock()
|
|
||||||
self.connections = {}
|
|
||||||
self.connection_queue = []
|
|
||||||
self.packets_received = {'SYN': {}, 'ACK': {}, 'SYN-ACK': {}, 'DATA or FIN': {}, 'FIN-ACK': {}}
|
|
||||||
def poll(self, timeout):
|
|
||||||
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
self.incoming_packet_event.wait(timeout)
|
|
||||||
if len(self.connections) and list(self.connections.keys())[0] in self.packets_received['DATA or FIN']:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_free_port(self):
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
s.bind(('', 0))
|
|
||||||
port = s.getsockname()[1]
|
|
||||||
s.close()
|
|
||||||
return port
|
|
||||||
def setsockopt(self, *a, **kw):
|
|
||||||
pass
|
|
||||||
def settimeout(self, timeout=5):
|
|
||||||
self.own_socket.settimeout(timeout)
|
|
||||||
def setblocking(self, mode):
|
|
||||||
self.own_socket.setblocking(mode)
|
|
||||||
def __repr__(self):
|
|
||||||
return 'TCP()'
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return 'Connections: %s' \
|
|
||||||
% str(self.connections)
|
|
||||||
def getsockname(self):
|
|
||||||
return self.own_socket.getsockname()
|
|
||||||
def getpeername(self):
|
|
||||||
if len(self.connections):
|
|
||||||
return list(self.connections.keys())[0]
|
|
||||||
else:
|
|
||||||
raise EOFError('Not connected')
|
|
||||||
def bind(self, addr):
|
|
||||||
self.host = addr[0]
|
|
||||||
self.port = addr[1]
|
|
||||||
self.own_socket.bind(addr)
|
|
||||||
def send(self, data, connection=None):
|
|
||||||
try:
|
|
||||||
if connection not in list(self.connections.keys()):
|
|
||||||
if connection is None:
|
|
||||||
connection = list(self.connections.keys())[0]
|
|
||||||
else:
|
|
||||||
raise EOFError('Connection not in connected devices')
|
|
||||||
conn = self.connections[connection]
|
|
||||||
data_parts = TCP.data_divider(data)
|
|
||||||
for data_part in data_parts:
|
|
||||||
data_chunk = data_part if not self.encrypted else conn.peer_pub.encrypt_raw(data_part)
|
|
||||||
packet = TCPPacket(conn.seq)
|
|
||||||
packet.set_data(data_chunk)
|
|
||||||
packet_to_send = pickle.dumps(packet)
|
|
||||||
answer = self.retransmit(connection, packet_to_send)
|
|
||||||
conn.seq_inc(len(data_part))
|
|
||||||
return len(data)
|
|
||||||
except socket.error as error:
|
|
||||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
|
||||||
def retransmit(self, peer_addr, pickled_packet, condition='ACK'):
|
|
||||||
data_not_received = True
|
|
||||||
retransmit_count = 0
|
|
||||||
while data_not_received and retransmit_count < 3:
|
|
||||||
data_not_received = False
|
|
||||||
try:
|
|
||||||
self.own_socket.sendto(pickled_packet, peer_addr)
|
|
||||||
answer = self.find_correct_packet(condition, peer_addr)
|
|
||||||
if not answer:
|
|
||||||
data_not_received = True
|
|
||||||
retransmit_count += 1
|
|
||||||
except socket.timeout:
|
|
||||||
data_not_received = True
|
|
||||||
if not answer:
|
|
||||||
self.drop_connection(peer_addr)
|
|
||||||
raise EOFError('Connection lost')
|
|
||||||
return answer
|
|
||||||
def recv(self, size, connection=None):
|
|
||||||
try:
|
|
||||||
if connection not in list(self.connections.keys()):
|
|
||||||
if connection is None:
|
|
||||||
connection = list(self.connections.keys())[0]
|
|
||||||
else:
|
|
||||||
raise EOFError('Connection not in connected devices')
|
|
||||||
|
|
||||||
data = self.find_correct_packet('DATA or FIN', connection, size)
|
|
||||||
if not self.status:
|
|
||||||
raise EOFError('Disconnecting')
|
|
||||||
return data
|
|
||||||
except socket.error as error:
|
|
||||||
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
|
|
||||||
def send_ack(self, connection, ack):
|
|
||||||
conn = self.connections[connection]
|
|
||||||
ack_packet = TCPPacket(conn.seq_inc())
|
|
||||||
ack_packet.ack = conn.set_ack(ack)
|
|
||||||
ack_packet.set_flags(ack=True)
|
|
||||||
packet_to_send = pickle.dumps(ack_packet)
|
|
||||||
self.own_socket.sendto(packet_to_send, connection)
|
|
||||||
def listen_handler(self, max_connections):
|
|
||||||
try:
|
|
||||||
while True and self.status:
|
|
||||||
try:
|
|
||||||
answer, address = self.find_correct_packet('SYN')
|
|
||||||
with self.queue_lock:
|
|
||||||
if len(self.connection_queue) < max_connections:
|
|
||||||
conn = Connection(address, self.encrypted)
|
|
||||||
if self.encrypted:
|
|
||||||
try:
|
|
||||||
conn.peer_pub = simplecrypto.RsaPublicKey(answer.data)
|
|
||||||
except:
|
|
||||||
raise socket.error('Init peer public key error')
|
|
||||||
self.connection_queue.append((answer, conn))
|
|
||||||
self.blink_new_conn_event()
|
|
||||||
else:
|
|
||||||
self.own_socket.sendto('Connections full', address)
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
except socket.error as error:
|
|
||||||
raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error))
|
|
||||||
|
|
||||||
def listen(self, max_connections=1):
|
|
||||||
self.status = 1
|
|
||||||
self.central_receive()
|
|
||||||
try:
|
|
||||||
t = threading.Thread(target=self.listen_handler, args=(max_connections,))
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
except Exception as error:
|
|
||||||
raise EOFError('Something went wrong in listen func! Error is: %s.' % str(error))
|
|
||||||
|
|
||||||
def accept(self):
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
self.new_conn_event.wait(0.1)
|
|
||||||
if self.connection_queue:
|
|
||||||
with self.queue_lock:
|
|
||||||
answer, conn = self.connection_queue.pop()
|
|
||||||
self.connections[conn.peer_addr] = conn
|
|
||||||
packet = TCPPacket(conn.seq)
|
|
||||||
packet.ack = answer.seq + 1
|
|
||||||
packet.seq = conn.seq_inc()
|
|
||||||
packet.set_flags(ack=True, syn=True)
|
|
||||||
if self.encrypted:
|
|
||||||
packet.set_data(conn.peer_pub.encrypt_raw(conn.pubkey))
|
|
||||||
packet_to_send = pickle.dumps(packet)
|
|
||||||
#On packet lost retransmit
|
|
||||||
packet_not_sent_correctly = True
|
|
||||||
while packet_not_sent_correctly or answer is None:
|
|
||||||
try:
|
|
||||||
packet_not_sent_correctly = False
|
|
||||||
self.own_socket.sendto(packet_to_send, conn.peer_addr)
|
|
||||||
answer = self.find_correct_packet('ACK', conn.peer_addr)
|
|
||||||
except socket.timeout:
|
|
||||||
packet_not_sent_correctly = True
|
|
||||||
conn.ack = answer.seq + 1
|
|
||||||
return ConnectedSOCK(self, conn.peer_addr), conn.peer_addr
|
|
||||||
except Exception as error:
|
|
||||||
self.close(conn.peer_addr)
|
|
||||||
raise EOFError('Something went wrong in accept func: ' + str(error))
|
|
||||||
|
|
||||||
def connect(self, server_address=('127.0.0.1', 10000)):
|
|
||||||
try:
|
|
||||||
self.bind(('', self.get_free_port()))
|
|
||||||
self.status = 1
|
|
||||||
self.client = True
|
|
||||||
self.central_receive()
|
|
||||||
conn = Connection(server_address, self.encrypted)
|
|
||||||
self.connections[server_address] = conn
|
|
||||||
syn_packet = TCPPacket(conn.seq)
|
|
||||||
syn_packet.set_flags(syn=True)
|
|
||||||
if self.encrypted:
|
|
||||||
syn_packet.set_data(conn.pubkey)
|
|
||||||
first_packet_to_send = pickle.dumps(syn_packet)
|
|
||||||
self.own_socket.sendto(first_packet_to_send, server_address)
|
|
||||||
answer = self.find_correct_packet('SYN-ACK', server_address)
|
|
||||||
if type(answer) == str: # == 'Connections full':
|
|
||||||
raise socket.error('Server cant receive any connections right now.')
|
|
||||||
if self.encrypted:
|
|
||||||
try:
|
|
||||||
peer_pub = conn.my_key.decrypt_raw(answer.data)
|
|
||||||
conn.peer_pub = simplecrypto.RsaPublicKey(peer_pub)
|
|
||||||
except:
|
|
||||||
raise socket.error('Decrypt peer public key error')
|
|
||||||
ack_packet = TCPPacket(conn.seq_inc())
|
|
||||||
ack_packet.ack = conn.set_ack(answer.seq + 1)
|
|
||||||
ack_packet.set_flags(ack=True)
|
|
||||||
second_packet_to_send = pickle.dumps(ack_packet)
|
|
||||||
self.own_socket.sendto(second_packet_to_send, server_address)
|
|
||||||
|
|
||||||
except socket.error as error:
|
|
||||||
self.own_socket.close()
|
|
||||||
self.connections = {}
|
|
||||||
self.status = 0
|
|
||||||
raise EOFError('The socket was closed. Error:' + str(error))
|
|
||||||
def fileno(self):
|
|
||||||
return self.own_socket.fileno()
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self.own_socket._closed
|
|
||||||
def drop_connection(self, connection):
|
|
||||||
with self.connection_lock:
|
|
||||||
if len(self.connections):
|
|
||||||
self.connections.pop(connection)
|
|
||||||
def close(self, connection=None):
|
|
||||||
try:
|
|
||||||
if connection not in list(self.connections.keys()):
|
|
||||||
if connection is None:
|
|
||||||
connection = list(self.connections.keys())[0]
|
|
||||||
else:
|
|
||||||
raise EOFError('Connection not in connected devices')
|
|
||||||
conn = self.connections[connection]
|
|
||||||
fin_packet = TCPPacket(conn.seq_inc())
|
|
||||||
fin_packet.set_flags(fin=True)
|
|
||||||
packet_to_send = pickle.dumps(fin_packet)
|
|
||||||
self.own_socket.sendto(packet_to_send, connection)
|
|
||||||
answer = self.retransmit(connection, packet_to_send)
|
|
||||||
conn.ack += 1
|
|
||||||
answer = self.find_correct_packet('FIN-ACK', connection)
|
|
||||||
if answer.flag_fin != 1:
|
|
||||||
raise Exception('The receiver didn\'t send the fin packet')
|
|
||||||
else:
|
|
||||||
self.send_ack(connection, conn.ack + 1)
|
|
||||||
self.drop_connection(connection)
|
|
||||||
if len(self.connections) == 0 and self.client:
|
|
||||||
self.own_socket.close()
|
|
||||||
self.status = 0
|
|
||||||
except Exception as error:
|
|
||||||
raise EOFError('Something went wrong in the close func! Error is: %s.' % error)
|
|
||||||
|
|
||||||
def disconnect(self, connection):
|
|
||||||
try:
|
|
||||||
conn = self.connections[connection]
|
|
||||||
self.send_ack(connection, conn.set_ack(conn.ack + 1))
|
|
||||||
finack_packet = TCPPacket(conn.seq_inc())
|
|
||||||
finack_packet.set_flags(fin=True, ack=True)
|
|
||||||
packet_to_send = pickle.dumps(finack_packet)
|
|
||||||
try:
|
|
||||||
answer = self.retransmit(connection, packet_to_send)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
with self.connection_lock:
|
|
||||||
self.connections.pop(connection)
|
|
||||||
except Exception as error:
|
|
||||||
raise EOFError('Something went wrong in disconnect func:%s ' % error)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def data_divider(data):
|
|
||||||
'''Divides the data into a list where each element's length is 1024'''
|
|
||||||
data = [data[i:i + DATA_DIVIDE_LENGTH] for i in range(0, len(data), DATA_DIVIDE_LENGTH)]
|
|
||||||
data.append(PACKET_END)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def checksum(source_bytes):
|
|
||||||
return hashlib.sha1(source_bytes).digest()
|
|
||||||
|
|
||||||
def find_correct_packet(self, condition, address=('Any',), size=DATA_LENGTH):
|
|
||||||
not_found = True
|
|
||||||
tries = 0
|
|
||||||
while not_found and tries < 2 and self.status:
|
|
||||||
try:
|
|
||||||
not_found = False
|
|
||||||
if address[0] == 'Any':
|
|
||||||
order = self.packets_received[condition].popitem() # to reverse the tuple received
|
|
||||||
return order[1], order[0]
|
|
||||||
conn = self.connections[address]
|
|
||||||
if condition == 'ACK':
|
|
||||||
tries += 1
|
|
||||||
if condition == 'DATA or FIN':
|
|
||||||
with conn.recv_lock:
|
|
||||||
packet = self.packets_received[condition][address][:size]
|
|
||||||
self.packets_received[condition][address] = self.packets_received[condition][address][size:]
|
|
||||||
if not len(self.packets_received[condition][address]):
|
|
||||||
del self.packets_received[condition][address]
|
|
||||||
else:
|
|
||||||
packet = self.packets_received[condition].pop(address)
|
|
||||||
return packet
|
|
||||||
except KeyError:
|
|
||||||
not_found = True
|
|
||||||
self.incoming_packet_event.wait(0.1)
|
|
||||||
def blink_incoming_packet_event(self):
|
|
||||||
self.incoming_packet_event.set()
|
|
||||||
self.incoming_packet_event.clear()
|
|
||||||
def blink_new_conn_event(self):
|
|
||||||
self.new_conn_event.set()
|
|
||||||
self.new_conn_event.clear()
|
|
||||||
def sort_answers(self, packet, address):
|
|
||||||
if address not in self.connections and packet.packet_type() != 'SYN':
|
|
||||||
return
|
|
||||||
if packet.packet_type() == 'FIN':
|
|
||||||
self.disconnect(address)
|
|
||||||
elif packet.packet_type() == 'DATA':
|
|
||||||
if packet.checksum == TCP.checksum(packet.data):
|
|
||||||
conn = self.connections[address]
|
|
||||||
data_chunk = packet.data if not self.encrypted else conn.my_key.decrypt_raw(packet.data)
|
|
||||||
if data_chunk != PACKET_END:
|
|
||||||
with conn.recv_lock:
|
|
||||||
if address not in self.packets_received['DATA or FIN']:
|
|
||||||
self.packets_received['DATA or FIN'][address] = b''
|
|
||||||
self.packets_received['DATA or FIN'][address] += data_chunk
|
|
||||||
self.send_ack(address, packet.seq + len(packet.data))
|
|
||||||
self.blink_incoming_packet_event()
|
|
||||||
elif packet.packet_type() == '':
|
|
||||||
#print('redundant packet found', packet)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.packets_received[packet.packet_type()][address] = packet
|
|
||||||
self.blink_incoming_packet_event()
|
|
||||||
|
|
||||||
def central_receive_handler(self):
|
|
||||||
while True and self.status:
|
|
||||||
try:
|
|
||||||
packet, address = self.own_socket.recvfrom(SENT_SIZE)
|
|
||||||
packet = restricted_pickle_loads(packet)
|
|
||||||
self.sort_answers(packet, address)
|
|
||||||
except socket.timeout:
|
|
||||||
continue
|
|
||||||
except socket.error as error:
|
|
||||||
self.own_socket.close()
|
|
||||||
self.status = 0
|
|
||||||
# print('An error has occured: Socket error %s' % error)
|
|
||||||
|
|
||||||
def central_receive(self):
|
|
||||||
t = threading.Thread(target=self.central_receive_handler)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
Loading…
Reference in New Issue