pydtls/dtls/wrapper.py

391 lines
13 KiB
Python
Raw Permalink Normal View History

# -*- coding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# The License is also distributed with this work in the file named "LICENSE."
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS
protocol.
Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server
"""
import select
from logging import getLogger
import ssl
import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
import err as err_codes
_logger = getLogger(__name__)
def wrap_client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def wrap_server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object):
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False, timeout=None):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
self.timeout = timeout
self.updateTimestamp()
def getAddr(self):
return self.host, self.port
def updateTimestamp(self):
if self.timeout != None:
self.last_update = time.time()
def expired(self):
if self.timeout == None:
return False
else:
return (time.time() - self.last_update) > self.timeout
def __init__(self,
sock=None,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE,
client_timeout=None):
if server_cert_options is None:
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options
self._client_timeout = client_timeout
# Default socket creation
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(sock, socket.socket):
_sock = sock
self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl)
if self._server_side:
self._clients = {}
self._timeout = None
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
raise AttributeError
def user_config_ssl_ctx(self, _ctx):
"""
:param SSLContext _ctx:
"""
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl):
"""
:param SSL _ssl:
"""
if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
try:
self._sock.unwrap()
except:
pass
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout:
# __Nothing__ received from any client
raise socket.timeout
try:
for conn in r:
_last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf:
self._clients[conn].updateTimestamp()
if conn in self._clients:
return buf, self._clients[conn].getAddr()
else:
_logger.debug('Received data from an already disconnected client!')
except Exception as e:
setattr(e, 'peer', _last_peer)
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
if self._clients[conn].expired() == True:
_logger.debug('Found expired session')
self._clientDrop(conn)
except Exception as e:
raise e
# __No_data__ received from any client
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e:
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e
else:
if buf:
return buf, self._sock.getpeername()
# __No_data__ received from any client
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
try:
if not self._sock._connected:
self._sock.connect(address)
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e:
raise e
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e:
raise e
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
_logger.debug('Client already connected %s' % str(client))
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e:
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
raise e
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
self._clientDrop(conn, error=e)
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
_data = data
if False:
_data = data.raw
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e:
raise e
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
try:
conn.unwrap()
except:
pass
conn.close()
except Exception as e:
pass