remove garbage. RPyC over UTCP full functional version

master
Роман Бородин 2019-04-12 13:48:08 +03:00
parent 328692bd20
commit ed57cbcb57
5 changed files with 97 additions and 726 deletions

View File

@ -1,237 +0,0 @@
# Based on rpyc (4.0.2)
# !!!!!!! MULTIPLE CONNECTIONS DON'T WORK !!!!!!!!!
import rpyc
from rpyc.utils.server import ThreadedServer, spawn
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
import socket
from mbedtls import tls
import datetime as dt
from mbedtls import hash as hashlib
from mbedtls import pk
from mbedtls import x509
from uuid import uuid4
from contextlib import suppress
import sys
from rpyc.lib import safe_import
zlib = safe_import("zlib")
import errno
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
def block(callback, *args, **kwargs):
while True:
with suppress(tls.WantReadError, tls.WantWriteError):
return callback(*args, **kwargs)
class DTLSCerts:
def __init__(self):
now = dt.datetime.utcnow()
self.ca0_key = pk.RSA()
_ = self.ca0_key.generate()
ca0_csr = x509.CSR.new(self.ca0_key, 'CN=Black mamba thrusted CA', hashlib.sha256())
self.ca0_crt = x509.CRT.selfsign(
ca0_csr, self.ca0_key,
not_before=now, not_after=now + dt.timedelta(days=3650),
serial_number=0x123456,
basic_constraints=x509.BasicConstraints(True, 1))
self.ca1_key = pk.ECC()
_ = self.ca1_key.generate()
ca1_csr = x509.CSR.new(self.ca1_key, 'CN=Black mamba intermediate CA', hashlib.sha256())
self.ca1_crt = self.ca0_crt.sign(
ca1_csr, self.ca0_key, now, now + dt.timedelta(days=3650), 0x123456,
basic_constraints=x509.BasicConstraints(ca=True, max_path_length=3))
self.srv_crt, self.srv_key = self.server_cert()
def server_cert(self):
now = dt.datetime.utcnow()
ee0_key = pk.ECC()
_ = ee0_key.generate()
ee0_csr = x509.CSR.new(ee0_key, f'CN=Black mamba peer [{uuid4().hex}]', hashlib.sha256())
ee0_crt = self.ca1_crt.sign(
ee0_csr, self.ca1_key, now, now + dt.timedelta(days=3650), 0x987654)
return ee0_crt, ee0_key
dtls_certs = DTLSCerts()
trust_store = tls.TrustStore()
trust_store.add(dtls_certs.ca0_crt)
srv_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([dtls_certs.srv_crt, dtls_certs.ca1_crt], dtls_certs.srv_key),
validate_certificates=False,
)
cli_ctx_conf = tls.DTLSConfiguration(
trust_store=trust_store,
validate_certificates=False,
)
MAX_IO_CHUNK = 8192
class DTLSSocketStream(SocketStream):
MAX_IO_CHUNK = MAX_IO_CHUNK
@classmethod
def dtls_connect(cls, host, port, ssl_kwargs, timeout=3, **kwargs):
if kwargs.pop('ipv6', False):
family = socket.AF_INET6
else:
family = socket.AF_INET
#tls._enable_debug_output(cli_ctx_conf)
#tls._set_debug_level(10)
dtls_cli_ctx = tls.ClientContext(cli_ctx_conf)
dtls_cli = dtls_cli_ctx.wrap_socket(
socket.socket(family, socket.SOCK_DGRAM),
server_hostname=None,
)
dtls_cli.settimeout(timeout)
dtls_cli.connect((host, port))
block(dtls_cli.do_handshake)
return cls(dtls_cli)
def read(self, count):
while True:
try:
buf = block(self.sock.recv, self.MAX_IO_CHUNK)
except socket.timeout:
continue
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in retry_errnos:
# windows just has to be a bitch
# inpos: I agree
continue
self.close()
raise EOFError(ex)
else:
break
if not buf:
self.close()
raise EOFError("connection closed by peer")
return buf
def write(self, data):
try:
_ = block(self.sock.send, data[:self.MAX_IO_CHUNK])
except socket.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
class DTLSChannel(Channel):
MAX_IO_CHUNK = MAX_IO_CHUNK
def recv(self):
header = self.stream.read(self.MAX_IO_CHUNK)
length, compressed = self.FRAME_HEADER.unpack(header)
length += len(self.FLUSHER)
data = b''
while length:
dat = self.stream.read(self.MAX_IO_CHUNK)
data += dat
length -= len(dat)
data = data[:-len(self.FLUSHER)]
if compressed:
data = zlib.decompress(data)
return data
def send(self, data):
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
compressed = 1
data = zlib.compress(data, self.COMPRESSION_LEVEL)
else:
compressed = 0
header = self.FRAME_HEADER.pack(len(data), compressed)
self.stream.write(header)
data = data + self.FLUSHER
data = [data[i:i + self.MAX_IO_CHUNK] for i in range(0, len(data), self.MAX_IO_CHUNK)]
for chunk in data:
self.stream.write(chunk)
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(DTLSChannel(stream), service=service, config=config)
def dtls_connect(host, port, keyfile=None, certfile=None, ca_certs=None,
cert_reqs=None, ssl_version=None, ciphers=None,
service=rpyc.VoidService, config={}, ipv6=False, keepalive=False):
ssl_kwargs = {'server_side' : False}
if ciphers is not None:
ssl_kwargs['ciphers'] = ciphers
s = DTLSSocketStream.dtls_connect(host, port, ssl_kwargs, ipv6=ipv6, keepalive=keepalive)
return connect_stream(s, service, config)
class DTLSThreadedServer(ThreadedServer):
def __init__(self, service, hostname = "", ipv6 = False, port = 0,
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
ThreadedServer.__init__(self, service, hostname=hostname, ipv6=ipv6, port=port,
backlog=backlog, reuse_addr=reuse_addr, authenticator=authenticator, registrar=registrar,
auto_register=auto_register, protocol_config=protocol_config, logger=logger, listener_timeout=listener_timeout,
socket_path=socket_path)
self.listener.close()
self.host = hostname
self.port = port
#tls._enable_debug_output(srv_ctx_conf)
#tls._set_debug_level(10)
dtls_srv_ctx = tls.ServerContext(srv_ctx_conf)
dtls_srv = dtls_srv_ctx.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
if reuse_addr and sys.platform != 'win32':
dtls_srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dtls_srv.bind((hostname, port))
# dtls_srv.settimeout(listener_timeout)
self.listener = dtls_srv
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
if credentials:
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
conn = self.service._connect(DTLSChannel(DTLSSocketStream(sock)), config)
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)
def _listen(self):
if self.active:
return
#self.listener.listen(self.backlog) #####################
if not self.port:
self.port = self.listener.getsockname()[1]
self.logger.info('server started on [%s]:%s', self.host, self.port)
self.active = True
def accept(self):
while self.active:
try:
print('Accepting new connections!')
sock, addrinfo = self.listener.accept()
except socket.timeout:
pass
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in (errno.EINTR, errno.EAGAIN):
pass
else:
raise EOFError()
raise
else:
break
if not self.active:
return
sock.setblocking(True)
addr = sock.getpeername()
sock.setcookieparam(addr[0].encode())
with suppress(tls.HelloVerifyRequest):
block(sock.do_handshake)
sock2, addr = sock.accept()
sock.close()
sock2.setblocking(True)
sock2.setcookieparam(addr[0].encode())
block(sock2.do_handshake)
self.logger.info("accepted %s with fd %s", addrinfo, sock2.fileno())
print("accepted %s with fd %s" % (addrinfo, sock2.fileno()))
self.clients.add(sock2)
self._accept_method(sock2)

View File

@ -1,15 +1,10 @@
import rpyc
from rpyc.utils.server import ThreadedServer
from rpyc.core import SocketStream, Channel
from rpyc.core.stream import retry_errnos
from rpyc.utils.factory import connect_channel
from rpyc.utils.factory import connect_stream
from rpyc.lib import Timeout
from rpyc.lib.compat import select_error
from rpyc.lib.compat import get_exc_errno, BYTES_LITERAL
from . import utcp
import sys
import errno
import socket
class UTCPSocketStream(SocketStream):
MAX_IO_CHUNK = utcp.DATA_LENGTH
@ -20,24 +15,22 @@ class UTCPSocketStream(SocketStream):
return cls(sock)
def poll(self, timeout):
timeout = Timeout(timeout)
return self.sock.poll(timeout.timeleft())
def read(self, count):
try:
while True:
try:
rl = self.sock.poll(timeout.timeleft())
except select_error:
ex = sys.exc_info()[1]
if ex.args[0] == errno.EINTR:
continue
else:
raise
else:
break
except ValueError:
return self.sock.recv(count)
except EOFError:
self.close()
raise EOFError
def write(self, data):
try:
self.sock.send(data)
except EOFError:
ex = sys.exc_info()[1]
raise select_error(str(ex))
return rl
def connect_stream(stream, service=rpyc.VoidService, config={}):
return connect_channel(Channel(stream), service=service, config=config)
self.close()
raise EOFError(ex)
def utcp_connect(host, port, service=rpyc.VoidService, config={}, **kw):
s = UTCPSocketStream.utcp_connect(host, port, **kw)
return connect_stream(s, service, config)
@ -71,4 +64,4 @@ class UTCPThreadedServer(ThreadedServer):
self._handle_connection(conn)
finally:
self.logger.info("goodbye %s", addrinfo)

View File

@ -1,356 +0,0 @@
import socketserver
import pickle
import simplecrypto
from uuid import uuid4
import datetime, io
import threading
import time
from struct import Struct
try:
import zlib
except:
zlib = None
peers = {}
PORT = 16386
RETRANSMIT_RETRIES = 3
DATAGRAM_MAX_SIZE = 9000
RAW_DATA_MAX_SIZE = 8000
PACKET_NUM_SEQ_TTL = 300
SOCK_SEND_TIMEOUT = 60
PACKET_TYPE_HELLO = 0x00
PACKET_TYPE_PEER_PUB_KEY_REQUEST = 0x01
PACKET_TYPE_PEER_PUB_KEY_REPLY = 0x02
PACKET_TYPE_PEER_NEW_PUB_KEY = 0x03
PACKET_TYPE_PACKET = 0xa0
PACKET_TYPE_CONFIRM_RECV = 0xa1
PACKET_TYPE_GOODBUY = 0xff
class InvalidPacket(Exception): pass
class OldPacket(Exception): pass
def pickle_data(data):
return pickle.dumps(data, protocol=4)
################
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow datetime
if module == "datetime" and name == 'datetime':
return getattr(datetime, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_pickle_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
################
# From rpyc.lib
class Timeout:
def __init__(self, timeout):
if isinstance(timeout, Timeout):
self.finite = timeout.finite
self.tmax = timeout.tmax
else:
self.finite = timeout is not None and timeout >= 0
self.tmax = time.time()+timeout if self.finite else None
def expired(self):
return self.finite and time.time() >= self.tmax
def timeleft(self):
return max((0, self.tmax - time.time())) if self.finite else None
def sleep(self, interval):
time.sleep(min(interval, self.timeleft()) if self.finite else interval)
class Packet:
def __init__(self, packet_payload):
try:
d = restricted_pickle_loads(packet_payload)
self.sid = d['sid']
self.type = d['type']
self.reset_timestamp = d['reset_timestamp']
self.num = d['num']
self.data = d['data']
except:
raise InvalidPacket
class Peer:
def __init__(self, sock, endpoint):
self.sid = None
self.sock = sock
self.endpoint = endpoint
self.my_key = None
self.peer_pub_key = None
self.buf = []
self.confirm_wait_packet = None
self.last_packet = None
self.request_lock = threading.Lock()
self.num_seq_ttl = datetime.timedelta(seconds=PACKET_NUM_SEQ_TTL)
self.last_sent_packet_num = -1
self.last_sent_packet_num_reset_time = datetime.datetime.utcnow()
self.last_received_packet_num = -1
self.last_received_packet_num_reset_time = None
self.retransmit_count = 0
def next_packet_num(self):
new_time = datetime.datetime.utcnow()
if (new_time - self.last_sent_packet_num_reset_time) >= self.num_seq_ttl:
self.last_sent_packet_num = -1
self.last_sent_packet_num += 1
return self.last_sent_packet_num
def poll(self):
return bool(len(self.buf))
def get_next_block(self):
if not len(self.buf):
return None
return self.buf.pop()
def put_block(self, data):
self.buf.insert(0, data)
def send(self, d, encrypted=False, confirm=False):
if 'sid' not in d: d['sid'] = self.sid
if 'num' not in d: d['num'] = None
if 'reset_timestamp' not in d: d['reset_timestamp'] = None
if 'data' not in d: d['data'] = b''
if encrypted: d['data'] = self.peer_pub_key.encrypt_raw(d['data'])
data = pickle_data(d)
if confirm:
self.last_packet = data
self.confirm_wait_packet = (d['reset_timestamp'], d['num'])
self.sock.sendto(data, self.endpoint)
def mark_packet(self, d):
d['num'] = self.next_packet_num()
d['reset_timestamp'] = self.last_sent_packet_num_reset_time
return d
def retransmit(self):
self.retransmit_count += 1
if self.retransmit_count > RETRANSMIT_RETRIES:
raise EOFError('retransmit limit reached')
self.sock.sendto(self.last_packet, self.endpoint)
def reply_my_pub_key(self, packet):
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(packet.data)
except:
raise EOFError('invalid pubkey data')
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REPLY,
'data': self.my_key.publickey.serialize()
}
self.send(d, encrypted=True)
def request_peer_bub_key(self, packet):
self.sid = packet.sid
self.my_key = simplecrypto.RsaKeypair()
d = {
'type': PACKET_TYPE_PEER_PUB_KEY_REQUEST,
'data': self.my_key.publickey.serialize()
}
self.send(d)
def confirm_packet_recv(self, packet):
self.confirm_wait_packet = None
self.last_packet = None
d = {
'type': PACKET_TYPE_CONFIRM_RECV,
'num': packet.num,
'reset_timestamp': packet.reset_timestamp
}
self.send(d)
def check_received_packet(self, packet):
if self.last_received_packet_num_reset_time:
if self.last_received_packet_num_reset_time > packet.reset_timestamp:
raise OldPacket('packet from past')
elif self.last_received_packet_num_reset_time < packet.reset_timestamp:
self.last_received_packet_num_reset_time = packet.reset_timestamp
if (self.last_received_packet_num + 1) != packet.num:
raise EOFError('packet sequence corrupt')
else:
self.last_received_packet_num_reset_time = packet.reset_timestamp
self.last_received_packet_num = packet.num
def send_recv_confirmation(self, packet):
pass
def hello(self):
self.sid = uuid4().hex
d = {
'type': PACKET_TYPE_HELLO,
}
self.sock.sendto(pickle_data(d))
def recv_packet(self, packet_payload):
with self.request_lock:
try:
packet = Packet(packet_payload)
if packet.type == PACKET_TYPE_GOODBUY:
raise EOFError('connection closed')
except:
raise EOFError('invalid packet')
if packet.type != PACKET_TYPE_HELLO and (not self.sid or self.sid != packet.sid):
self.hello()
return
############################################
if not self.peer_pub_key:
if packet.type == PACKET_TYPE_PEER_PUB_KEY_REPLY:
try:
self.peer_pub_key = simplecrypto.RsaPublicKey(self.my_key.decrypt_raw(packet.data))
return
except:
raise EOFError('create pubkey failed')
elif packet.type == PACKET_TYPE_PEER_PUB_KEY_REQUEST:
self.reply_my_pub_key(packet)
return
elif packet.type == PACKET_TYPE_HELLO:
self.request_peer_bub_key(packet)
return
############################################
if self.confirm_wait_packet:
if (packet.reset_timestamp, packet.num) == self.confirm_wait_packet and packet.type == PACKET_TYPE_CONFIRM_RECV:
self.confirm_packet_recv(packet)
return
else:
self.retransmit()
return
############################################
else:
if packet.type == PACKET_TYPE_PACKET:
try:
self.check_received_packet(packet)
except OldPacket:
return
try:
raw = self.my_key.decrypt_raw(packet.data)
except:
raise EOFError('decrypt packet error')
self.put_block(raw)
else:
raise EOFError('connection lost')
def send_packet(self, raw):
if self.confirm_wait_packet:
timeout = Timeout(SOCK_SEND_TIMEOUT)
while timeout.timeleft():
if not self.confirm_wait_packet: break
if self.confirm_wait_packet:
raise EOFError('connection lost')
d = {
'type': PACKET_TYPE_PACKET,
'data': raw
}
self.send(self.mark_packet(d), encrypted=True, confirm=True)
class UDPRequestHandler(socketserver.DatagramRequestHandler):
def finish(self):
'''Don't send anything'''
pass
def handle(self):
datagram = self.rfile.read(DATAGRAM_MAX_SIZE)
peer_addr = self.client_address
if peer_addr not in peers: peers[peer_addr] = Peer(self.socket, peer_addr)
try:
peers[peer_addr].recv_packet(datagram)
except EOFError:
del peers[peer_addr]
class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer):
pass
udpserver = ThreadingUDPServer(('0.0.0.0', PORT), UDPRequestHandler)
udpserver_thread = threading.Thread(target=udpserver.serve_forever)
udpserver_thread.start()
class EncryptedUDPStream:
def __init__(self, sock, peer_addr):
self.peer_addr = peer_addr
self.sock = sock
@classmethod
def _connect(cls, host, port):
peers[(host, port)] = Peer(udpserver.socket, (host, port))
peers[(host, port)].hello()
return udpserver.socket
@classmethod
def connect(cls, host, port, **kwargs):
return cls(cls._connect(host, port), (host, port))
def poll(self, timeout):
timeout = Timeout(timeout)
while timeout.timeleft():
try:
rl = peers[self.peer_addr].poll()
if rl: break
except:
raise EOFError
return rl
def close(self):
if self.peer_addr in peers: del peers[self.peer_addr]
@property
def closed(self):
return self.peer_addr not in peers
def fileno(self):
try:
return self.sock.fileno()
except:
self.close()
raise EOFError
def read(self):
try:
buf = peers[self.peer_addr].get_next_block()
except:
raise EOFError
return buf
def write(self, data):
try:
peers[self.peer_addr].send_packet(data)
except:
raise EOFError
class Channel(object):
MAX_IO_CHUNK = 8000
COMPRESSION_THRESHOLD = 3000
COMPRESSION_LEVEL = 1
FRAME_HEADER = Struct("!LB")
FLUSHER = b'\n'
__slots__ = ["stream", "compress"]
def __init__(self, stream, compress = True):
self.stream = stream
if not zlib:
compress = False
self.compress = compress
def close(self):
self.stream.close()
@property
def closed(self):
return self.stream.closed
def fileno(self):
return self.stream.fileno()
def poll(self, timeout):
return self.stream.poll(timeout)
def recv(self):
header = self.stream.read()
if len(header) != self.FRAME_HEADER.size:
raise EOFError('CHANNEL: Not a header received')
length, compressed = self.FRAME_HEADER.unpack(header)
block_len = length + len(self.FLUSHER)
full_block = b''.join((self.stream.read() for x in range(0, block_len, self.MAX_IO_CHUNK)))
if len(full_block) != block_len:
raise EOFError('CHANNEL: Received block with wrong size')
data = full_block[:-len(self.FLUSHER)]
if compressed:
data = zlib.decompress(data)
return data
def send(self, data):
if self.compress and len(data) > self.COMPRESSION_THRESHOLD:
compressed = 1
data = zlib.compress(data, self.COMPRESSION_LEVEL)
else:
compressed = 0
header = self.FRAME_HEADER.pack(len(data), compressed)
self.stream.write(header)
buf = data + self.FLUSHER
for chunk_start in range(0, len(buf), self.MAX_IO_CHUNK):
self.stream.write(buf[chunk_start:self.MAX_IO_CHUNK])
import rpyc.utils.server
import rpyc.utils.factory
import rpyc.Service

View File

@ -1,78 +0,0 @@
import socketserver
import pickle
import simplecrypto
from uuid import uuid4
BUFSIZE = 8192
BLOCKSIZE = 4096
PACKET_TYPE_RECV_RESULT = 'recv_result'
PACKET_TYPE_DATA_FRAGMENT = 'data_fragment'
PACKET_TYPE_SVC_MESSAGE = 'svc_message'
PACKET_TYPE_NEW_CONNECTION = 'new_connection'
DATA_TYPE_FILE_CHUNK = 'file_chunk'
DATA_TYPE_CMD = 'cmd'
DATA_TYPE_SEARCH_QUERY = 'search_query'
DATA_TYPE_SEARCH_RESULT = 'search_result'
DATA_TYPE_PEER_PUBKEY = 'peer_pubkey'
SVC_MESSAGE_BAD_PACKET = 'bad_packet'
SVC_MESSAGE_DECRYPT_ERROR = 'decrypt_error'
SVC_MESSAGE_YOU_ARE_STRANGER = 'you_are_stranger'
RECV_OK_CODE = 0
RECV_ERROR_CODE = 1
HEADER_SVC_YOU_ARE_STRANGER = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': SVC_MESSAGE_YOU_ARE_STRANGER}
HEADER_RECV_OK = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_OK_CODE}
HEADER_RECV_ERROR = {'type': PACKET_TYPE_RECV_RESULT, 'code': RECV_ERROR_CODE}
key = None
peers = {}
class Peer:
def __init__(self, peer_addr, pubkey):
self.addr = peer_addr
self.pubkey = pubkey
def get_key():
global key
if not key:
key = simplecrypto.RsaKeypair()
return key
def pickle_data(data):
return pickle.dumps(data, protocol=4)
def write_svc_msg(fobj, svc_msg, extra_data={}):
MSG = {'type': PACKET_TYPE_SVC_MESSAGE, 'msg': svc_msg}
if extra_data:
MSG.update(extra_data)
fobj.write(pickle_data(MSG))
class UDPRequestHandler(socketserver.DatagramRequestHandler):
def handle(self):
datagram = self.rfile.read(BUFSIZE)
if self.client_address not in peers:
write_svc_msg(self.wfile, SVC_MESSAGE_YOU_ARE_STRANGER)
return
try:
unpickled_datagram = pickle.dumps(datagram)
except pickle.UnpicklingError:
write_svc_msg(self.wfile, SVC_MESSAGE_BAD_PACKET)
return
pk = get_key()
try:
data = pickle.loads(pk.decrypt_raw(unpickled_datagram['packet']))
except:
write_svc_msg(self.wfile, SVC_MESSAGE_DECRYPT_ERROR, {'packet_id': unpickled_datagram['packet_id']})
return
block_id = data['block_id']
fragment_num = data['num']
cmd = data['']

View File

@ -10,6 +10,8 @@ from struct import Struct
import uuid
import bisect
from datetime import datetime
DATA_DIVIDE_LENGTH = 8000
PACKET_HEADER_SIZE = 512 # Pickle service info
DATA_LENGTH = DATA_DIVIDE_LENGTH
@ -19,6 +21,7 @@ class Connection:
SMALLEST_STARTING_SEQ = 0
HIGHEST_STARTING_SEQ = 4294967295
def __init__(self, remote, encrypted=False):
self.fileno = 0
self.peer_addr = remote
self.seq = Connection.gen_starting_seq_num()
self.recv_seq = -1
@ -153,6 +156,10 @@ class ConnectedSOCK(object):
if self.closed:
raise EOFError
return self.low_sock.send(data, self.client_addr)
def sendall(self, data):
if self.closed:
raise EOFError
self.low_sock.sendall(data, self.client_addr)
def recv(self, size):
if self.closed:
raise EOFError
@ -166,17 +173,13 @@ class ConnectedSOCK(object):
def shutdown(self, *a, **kw):
self.close()
def poll(self, timeout):
if not self.closed:
conn = self.low_sock.connections[self.client_addr]
with conn.recv_lock:
has_data = len(conn.packet_buffer['DATA'])
if has_data:
return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
return len(conn.packet_buffer['DATA'])
return False
return self.low_sock.poll(timeout, self.client_addr)
def packets_arrived(self, packet_type):
return self.low_sock.packets_arrived(packet_type, self.client_addr)
def fileno(self):
if self.closed:
raise EOFError
return self.low_sock.fileno(self.client_addr)
class UTCP(object):
host = None
@ -195,18 +198,40 @@ class UTCP(object):
self.connections = {}
self.connection_queue = []
self.syn_received = {}
def poll(self, timeout):
if len(self.connections):
connection = list(self.connections.keys())[0]
self.fileno_seq = 40000000
def next_fileno(self):
self.fileno_seq += 1
return self.fileno_seq
def packets_arrived(self, packet_type, connection=None):
try:
conn = self.connections[connection]
with conn.recv_lock:
has_data = bool(len(conn.packet_buffer['DATA']))
except:
raise EOFError
with conn.recv_lock:
return bool(len(conn.packet_buffer[packet_type]))
def poll(self, timeout, connection=None):
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
if not self.closed:
has_data = self.packets_arrived('DATA', connection)
if has_data:
return True
else:
self.incoming_packet_event.wait(timeout)
with conn.recv_lock:
return bool(len(conn.packet_buffer['DATA']))
if not timeout:
timeout = 0.5
while True and not self.closed:
self.incoming_packet_event.wait(timeout)
has_data = self.packets_arrived('DATA', connection)
if not has_data:
continue
else:
return has_data
else:
self.incoming_packet_event.wait(timeout)
return self.packets_arrived('DATA', connection)
return False
def get_free_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -256,6 +281,8 @@ class UTCP(object):
return len(data)
except socket.error as error:
raise EOFError('Socket was closed before executing command. Error is: %s.' % error)
def sendall(self, data, connection=None):
_ = self.send(data, connection)
def recv(self, size, connection=None):
if self.closed:
raise EOFError
@ -302,6 +329,7 @@ class UTCP(object):
with self.queue_lock:
if len(self.connection_queue) < max_connections:
conn = Connection(address, self.encrypted)
conn.fileno = self.next_fileno()
if self.encrypted:
try:
conn.peer_pub = simplecrypto.RsaPublicKey(answer.pubkey)
@ -315,6 +343,8 @@ class UTCP(object):
self.own_socket.sendto(b'Connections full', address)
except KeyError:
continue
except TypeError:
continue
except socket.error as error:
raise EOFError('Something went wrong in listen_handler func! Error is: %s.' + str(error))
@ -330,6 +360,11 @@ class UTCP(object):
def stop(self):
self.own_socket.close()
self.status = 0
def shutdown(self, *a, **kw):
self.close()
self.status = 0
self.connections = {}
self.stop()
def accept(self):
while self.status:
try:
@ -354,6 +389,8 @@ class UTCP(object):
raise EOFError('Something went wrong in accept func: ' + str(error))
def connect(self, server_address=('127.0.0.1', 10000)):
if server_address in self.connections:
raise EOFError('Already connected to peer')
try:
self.bind(('', self.get_free_port()))
self.status = 1
@ -379,13 +416,19 @@ class UTCP(object):
ack = Ack(answer.id)
self.__send_packet(server_address, ack)
self.channel = UTCPChannel(self)
conn.fileno = self.next_fileno()
except socket.error as error:
self.own_socket.close()
self.connections = {}
self.status = 0
raise EOFError('The socket was closed. Error:' + str(error))
def fileno(self):
return self.own_socket.fileno()
def fileno(self, connection=None):
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
else:
raise EOFError('Connection not in connected devices')
return self.connections[connection].fileno
@property
def closed(self):
return not bool(len(self.connections))
@ -397,7 +440,10 @@ class UTCP(object):
try:
if connection not in list(self.connections.keys()):
if connection is None:
connection = list(self.connections.keys())[0]
if len(self.connections):
connection = list(self.connections.keys())[0]
else:
return
else:
raise EOFError('Connection not in connected devices')
fin = Fin()
@ -421,8 +467,7 @@ class UTCP(object):
self.__send_packet(connection, fin_ack)
except:
pass
with self.connection_lock:
self.connections.pop(connection)
self.drop_connection(connection)
except Exception as error:
raise EOFError('Something went wrong in disconnect func:%s ' % error)
@ -445,14 +490,17 @@ class UTCP(object):
if address[0] == 'Any':
order = self.syn_received.popitem() # to reverse the tuple received
return order[1], order[0]
conn = self.connections[address]
try:
conn = self.connections[address]
except:
break
if condition in ['ACK', 'SYN-ACK', 'FIN-ACK']:
tries += 1
if condition == 'DATA':
if len(conn.packet_buffer[condition]):
if self.poll(0.1, address):
data = b''
while size:
if not self.poll(0.5):
if not self.poll(0.1, address):
continue
with conn.recv_lock:
packet = conn.packet_buffer[condition][0]
@ -470,11 +518,10 @@ class UTCP(object):
else:
raise KeyError
else:
with conn.recv_lock:
if len(conn.packet_buffer[condition]):
packet = conn.packet_buffer[condition].pop()
else:
raise KeyError
if self.packets_arrived(condition, address):
packet = conn.packet_buffer[condition].pop()
else:
raise KeyError
if want_id and packet.id != want_id:
raise KeyError
return packet
@ -513,6 +560,8 @@ class UTCP(object):
if isinstance(packet, Fin):
self.disconnect(address, packet.id)
elif isinstance(packet, Syn):
if address in self.connections:
return
if packet.id not in map(lambda x: x.id, self.syn_received.values()):
self.syn_received[address] = packet
else: