pydtls/dtls/wrapper.py

287 lines
9.0 KiB
Python

# -*- encoding: utf-8 -*-
import datetime
import select
from logging import getLogger
import ssl
import socket
from dtls import do_patch
do_patch()
_logger = getLogger(__name__)
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
class DtlsSocket(object):
def __init__(self,
host,
port,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None):
self._ssl_logging = False
self._peer = (host, int(port))
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._sock = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
ciphers=self._ciphers,
cb_user_ssl_ctx_config=self.user_ssl_ctx_config,
cb_user_ssl_config=self.user_ssl_config)
if self._server_side:
self._clients = {}
self._timeout = None
self._sock.bind(self._peer)
self._sock.listen(0)
else:
self._sock.connect(self._peer)
def user_ssl_ctx_config(self, _ctx):
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain()
_ctx.set_ecdh_curve() # ("secp256k1")
def user_ssl_config(self, _ssl):
if self._user_mtu:
_ssl.set_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
else:
for conn in r: # type: ssl.SSLSocket
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
else:
if buf:
return buf, self._peer
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
# elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ]:
# # no ciphers available
# if e_ssl.args[1][0][0] in [336081077, ]:
# raise
raise
else:
if bytes_sent:
break
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
pass
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
raise
self._clients[client] = _ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
ret = conn.send(data.raw)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
conn.close()
except Exception as e_drop:
pass