diff --git a/ChangeLog b/ChangeLog index 69577c4..b08cbf0 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,11 @@ +2017-03-17 Björn Freise + + Added a wrapper for a DTLS-Socket either as client or server - including unit tests + + * dtls/__init__.py: Import SSLContext() and SSL() for external use + * dtls/wrapper.py: Added class DtlsSocket() to be used as client or server + * dtls/test/unit_wrapper.py: unit test for DtlsSocket() + 2017-03-17 Björn Freise Added more on error evaluation and a method to get the peer certificate chain diff --git a/dtls/__init__.py b/dtls/__init__.py index b5c5517..e263633 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -59,6 +59,6 @@ def _prep_bins(): _prep_bins() # prepare before module imports from patch import do_patch -from sslconnection import SSLConnection +from sslconnection import SSLContext, SSL, SSLConnection from demux import force_routing_demux, reset_default_demux import err as error_codes diff --git a/dtls/test/unit_wrapper.py b/dtls/test/unit_wrapper.py new file mode 100644 index 0000000..582e3e8 --- /dev/null +++ b/dtls/test/unit_wrapper.py @@ -0,0 +1,654 @@ +# -*- encoding: utf-8 -*- + +# Test the support for DTLS through the SSL module. Adapted from the Python +# standard library's test_ssl.py regression test module by Björn Freise. + +import unittest +import threading +import sys +import socket +import os +import pprint + +from logging import basicConfig, DEBUG, getLogger +# basicConfig(level=DEBUG, format="%(asctime)s - %(threadName)-10s - %(name)s - %(levelname)s - %(message)s") +_logger = getLogger(__name__) + +import ssl +from dtls import do_patch, error_codes +from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT + + +HOST = "localhost" +CHATTY = True +CHATTY_CLIENT = True + + +class ThreadedEchoServer(threading.Thread): + + def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None, + ciphers=None, curves=None, sigalgs=None, + mtu=None, server_key_exchange_curve=None, server_cert_options=None, + chatty=True): + + if ssl_version is None: + ssl_version = ssl.PROTOCOL_DTLSv1 + if certreqs is None: + certreqs = ssl.CERT_NONE + + self.certificate = certificate + self.protocol = ssl_version + self.certreqs = certreqs + self.cacerts = cacerts + self.ciphers = ciphers + self.curves = curves + self.sigalgs = sigalgs + self.mtu = mtu + self.server_key_exchange_curve = server_key_exchange_curve + self.server_cert_options = server_cert_options + self.chatty = chatty + + self.flag = None + + self.sock = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + keyfile=self.certificate, + certfile=self.certificate, + server_side=True, + cert_reqs=self.certreqs, + ssl_version=self.protocol, + ca_certs=self.cacerts, + ciphers=self.ciphers, + curves=self.curves, + sigalgs=self.sigalgs, + user_mtu=self.mtu, + server_key_exchange_curve=self.server_key_exchange_curve, + server_cert_options=self.server_cert_options) + + if self.chatty: + sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock)) + self.sock.bind((HOST, 0)) + self.port = self.sock.getsockname()[1] + self.active = False + threading.Thread.__init__(self) + self.daemon = True + + def start(self, flag=None): + self.flag = flag + self.starter = threading.current_thread().ident + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen(0) + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + acc_ret = self.sock.recvfrom(4096) + if acc_ret: + newdata, connaddr = acc_ret + if self.chatty: + sys.stdout.write(' server: new data from ' + str(connaddr) + '\n') + self.sock.sendto(newdata.lower(), connaddr) + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + except Exception as e: + if self.chatty: + sys.stdout.write(' server: error ' + str(e) + '\n') + pass + if self.chatty: + sys.stdout.write(' server: closing socket as %s\n' % str(self.sock)) + self.sock.close() + + def stop(self): + self.active = False + if self.starter != threading.current_thread().ident: + return + self.join() # don't allow spawning new handlers after we've checked + + +CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert.pem") +CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert_ec.pem") +ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert.pem") +ISSUER_CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert_ec.pem") + +# certfile, protocol, certreqs, cacertsfile, +# ciphers=None, curves=None, sigalgs=None, +tests = [ + {'testcase': + {'name': 'standard dtls v1', + 'desc': 'Standard DTLS v1 test with out-of-the box configuration and RSA certificate', + 'start_server': True}, + 'input': + {'certfile': CERTFILE, + 'protocol': ssl.PROTOCOL_DTLSv1, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': True, + 'error_code': None, + 'exception': None}}, + {'testcase': + {'name': 'standard dtls v1_2', + 'desc': 'Standard DTLS v1_2 test with out-of-the box configuration and ECDSA certificate', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': True, + 'error_code': None, + 'exception': None}}, + {'testcase': + {'name': 'protocol version mismatch', + 'desc': 'Client and server have different protocol versions', + 'start_server': True}, + 'input': + {'certfile': CERTFILE, + 'protocol': ssl.PROTOCOL_DTLSv1, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_WRONG_SSL_VERSION, + 'exception': None}}, + {'testcase': + {'name': 'certificate verify fails', + 'desc': 'Server certificate cannot be verified by client', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED, + 'exception': None}}, + {'testcase': + {'name': 'no matching curve', + 'desc': 'Client doesn\'t support curve used by server ECDSA certificate', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': 'secp384r1', + 'client_sigalgs': None}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'exception': None}}, + {'testcase': + {'name': 'matching curve', + 'desc': '', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': 'prime256v1', + 'client_sigalgs': None}, + 'result': + {'ret_success': True, + 'error_code': None, + 'exception': None}}, + {'testcase': + {'name': 'no host', + 'desc': 'No server port is listening', + 'start_server': False}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_PORT_UNREACHABLE, + 'exception': None}}, + {'testcase': + {'name': 'no matching sigalgs', + 'desc': '', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': "RSA+SHA256"}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'exception': None}}, + {'testcase': + {'name': 'matching sigalgs', + 'desc': '', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': None, + 'client_curves': None, + 'client_sigalgs': "ECDSA+SHA256"}, + 'result': + {'ret_success': True, + 'error_code': None, + 'exception': None}}, + {'testcase': + {'name': 'no matching cipher', + 'desc': 'Server using a ECDSA certificate while client is only able to use RSA encryption', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': "AES256-SHA", + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': False, + 'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE, + 'exception': None}}, + {'testcase': + {'name': 'matching cipher', + 'desc': '', + 'start_server': True}, + 'input': + {'certfile': CERTFILE_EC, + 'protocol': ssl.PROTOCOL_DTLSv1_2, + 'certreqs': None, + 'cacertsfile': ISSUER_CERTFILE_EC, + 'ciphers': None, + 'curves': None, + 'sigalgs': None, + 'client_certfile': None, + 'client_protocol': ssl.PROTOCOL_DTLSv1_2, + 'client_certreqs': ssl.CERT_REQUIRED, + 'client_cacertsfile': ISSUER_CERTFILE_EC, + 'client_ciphers': "ECDHE-ECDSA-AES256-SHA", + 'client_curves': None, + 'client_sigalgs': None}, + 'result': + {'ret_success': True, + 'error_code': None, + 'exception': None}}, +] + + +def params_test(start_server, certfile, protocol, certreqs, cacertsfile, + client_certfile=None, client_protocol=None, client_certreqs=None, client_cacertsfile=None, + ciphers=None, curves=None, sigalgs=None, + client_ciphers=None, client_curves=None, client_sigalgs=None, + mtu=1500, server_key_exchange_curve=None, server_cert_options=None, + indata="FOO\n", chatty=False, connectionchatty=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + server = ThreadedEchoServer(certfile, + ssl_version=protocol, + certreqs=certreqs, + cacerts=cacertsfile, + ciphers=ciphers, + curves=curves, + sigalgs=sigalgs, + mtu=mtu, + server_key_exchange_curve=server_key_exchange_curve, + server_cert_options=server_cert_options, + chatty=chatty) + # should we really run the server? + if start_server: + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + else: + server.sock.close() + # try to connect + if client_protocol is None: + client_protocol = protocol + if client_ciphers is None: + client_ciphers = ciphers + if client_curves is None: + client_curves = curves + if client_sigalgs is None: + client_sigalgs = sigalgs + try: + s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + keyfile=client_certfile, + certfile=client_certfile, + cert_reqs=client_certreqs, + ssl_version=client_protocol, + ca_certs=client_cacertsfile, + ciphers=client_ciphers, + curves=client_curves, + sigalgs=client_sigalgs, + user_mtu=mtu) + s.connect((HOST, server.port)) + if connectionchatty: + sys.stdout.write(" client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if connectionchatty: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata), 20)], len(outdata), + indata[:min(len(indata), 20)].lower(), len(indata))) + cert = s.getpeercert() + cipher = s.cipher() + if connectionchatty: + sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n") + sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n") + if connectionchatty: + sys.stdout.write(" client: closing connection.\n") + try: + s.close() + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: error closing connection %s...\n" % (repr(e))) + pass + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e))) + return False, e + finally: + if start_server: + server.stop() + return True, None + + +class TestSequenceMeta(type): + def __new__(mcs, name, bases, dict): + + def gen_test(_case, _input, _result): + def test(self): + try: + if CHATTY or CHATTY_CLIENT: + sys.stdout.write("\nTestcase: %s\n" % _case['name']) + ret, e = params_test(_case['start_server'], chatty=CHATTY, connectionchatty=CHATTY_CLIENT, **_input) + if _result['ret_success']: + self.assertEqual(ret, _result['ret_success']) + else: + try: + last_error = e.errqueue[-1][0] + except: + try: + last_error = e.errno + except: + last_error = None + self.assertEqual(last_error, _result['error_code']) + except Exception as e: + raise + return test + + for testcase in tests: + _case, _input, _result = testcase.itervalues() + test_name = "test_%s" % _case['name'].lower().replace(' ', '_') + dict[test_name] = gen_test(_case, _input, _result) + + return type.__new__(mcs, name, bases, dict) + + +class WrapperTests(unittest.TestCase): + __metaclass__ = TestSequenceMeta + + def setUp(self): + super(WrapperTests, self).setUp() + + do_patch() + + def test_build_cert_chain(self): + steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT] + chatty, connectionchatty = CHATTY, CHATTY_CLIENT + indata = 'FOO' + certs = dict() + + if chatty or connectionchatty: + sys.stdout.write("\nTestcase: test_build_cert_chain\n") + for step in steps: + server = ThreadedEchoServer(certificate=CERTFILE, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + certreqs=ssl.CERT_NONE, + cacerts=ISSUER_CERTFILE, + ciphers=None, + curves=None, + sigalgs=None, + mtu=1500, + server_key_exchange_curve=None, + server_cert_options=step, + chatty=chatty) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + try: + s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + keyfile=None, + certfile=None, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + ca_certs=ISSUER_CERTFILE, + ciphers=None, + curves=None, + sigalgs=None, + user_mtu=1500) + s.connect((HOST, server.port)) + if connectionchatty: + sys.stdout.write(" client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if connectionchatty: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata), 20)], len(outdata), + indata[:min(len(indata), 20)].lower(), len(indata))) + # cert = s.getpeercert() + # cipher = s.cipher() + # if connectionchatty: + # sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n") + # sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n") + certs[step] = s.getpeercertchain() + if connectionchatty: + sys.stdout.write(" client: closing connection.\n") + try: + s.close() + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: error closing connection %s...\n" % (repr(e))) + pass + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e))) + raise + finally: + server.stop() + + if chatty: + sys.stdout.write("certs:\n") + for step in steps: + sys.stdout.write("SSL_CTX_build_cert_chain: %s\n%s\n" % (step, pprint.pformat(certs[step]))) + self.assertNotEqual(certs[steps[0]], certs[steps[1]]) + self.assertEqual(len(certs[steps[0]]) - len(certs[steps[1]]), 1) + + def test_set_ecdh_curve(self): + steps = { + # server, client, result + 'all auto': (None, None, True), # Auto + 'client restricted': (None, "secp256k1:prime256v1", True), # client can handle key curve + 'client too restricted': (None, "secp256k1", False), # client _cannot_ handle key curve + 'client minimum': (None, "prime256v1", True), # client can only handle key curve + 'server restricted': ("secp384r1", None, True), # client can handle key curve + 'server one, client two': ("secp384r1", "prime256v1:secp384r1", True), # client can handle key curve + 'server one, client one': ("secp384r1", "secp384r1", False), # client _cannot_ handle key curve + } + + chatty, connectionchatty = CHATTY, CHATTY_CLIENT + indata = 'FOO' + certs = dict() + + if chatty or connectionchatty: + sys.stdout.write("\nTestcase: test_ecdh_curve\n") + for step, tmp in steps.iteritems(): + if chatty or connectionchatty: + sys.stdout.write("\n Subcase: %s\n" % step) + server_curve, client_curve, result = tmp + server = ThreadedEchoServer(certificate=CERTFILE_EC, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + certreqs=ssl.CERT_NONE, + cacerts=ISSUER_CERTFILE_EC, + ciphers=None, + curves=None, + sigalgs=None, + mtu=1500, + server_key_exchange_curve=server_curve, + server_cert_options=None, + chatty=chatty) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + try: + s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + keyfile=None, + certfile=None, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + ca_certs=ISSUER_CERTFILE_EC, + ciphers=None, + curves=client_curve, + sigalgs=None, + user_mtu=1500) + s.connect((HOST, server.port)) + if connectionchatty: + sys.stdout.write(" client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if connectionchatty: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata), 20)], len(outdata), + indata[:min(len(indata), 20)].lower(), len(indata))) + if connectionchatty: + sys.stdout.write(" client: closing connection.\n") + try: + s.close() + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: error closing connection %s...\n" % (repr(e))) + pass + except Exception as e: + if connectionchatty: + sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e))) + if result: + raise + finally: + server.stop() + + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/dtls/wrapper.py b/dtls/wrapper.py new file mode 100644 index 0000000..ef525c5 --- /dev/null +++ b/dtls/wrapper.py @@ -0,0 +1,361 @@ +# -*- encoding: utf-8 -*- + +# DTLS Socket: A wrapper for a server and client using a DTLS connection. + +# Copyright 2017 Björn Freise +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# The License is also distributed with this work in the file named "LICENSE." +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DTLS Socket + +This wrapper encapsulates the state and behavior associated with the connection +between the OpenSSL library and an individual peer when using the DTLS +protocol. + +Classes: + + DtlsSocket -- DTLS Socket wrapper for use as a client or server +""" + +import select + +from logging import getLogger + +import ssl +import socket +from patch import do_patch +do_patch() +from sslconnection import SSLContext, SSL +from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \ + SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK + +_logger = getLogger(__name__) + + +class DtlsSocket(object): + + class _ClientSession(object): + + def __init__(self, host, port, handshake_done=False): + self.host = host + self.port = int(port) + self.handshake_done = handshake_done + + def getAddr(self): + return self.host, self.port + + def __init__(self, + peerOrSock, + keyfile=None, + certfile=None, + server_side=False, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + ca_certs=None, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + ciphers=None, + curves=None, + sigalgs=None, + user_mtu=None, + server_key_exchange_curve=None, + server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE): + + if server_cert_options is None: + server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE + + self._ssl_logging = False + self._peer = None + self._server_side = server_side + self._ciphers = ciphers + self._curves = curves + self._sigalgs = sigalgs + self._user_mtu = user_mtu + self._server_key_exchange_curve = server_key_exchange_curve + self._server_cert_options = server_cert_options + + # Default socket creation + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + if isinstance(peerOrSock, tuple): + # Address tuple + self._peer = peerOrSock + else: + # Socket, use given + sock = peerOrSock + + self._sock = ssl.wrap_socket(sock, + keyfile=keyfile, + certfile=certfile, + server_side=self._server_side, + cert_reqs=cert_reqs, + ssl_version=ssl_version, + ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=self._ciphers, + cb_user_config_ssl_ctx=self.user_config_ssl_ctx, + cb_user_config_ssl=self.user_config_ssl) + + if self._server_side: + self._clients = {} + self._timeout = None + + if self._peer: + self._sock.bind(self._peer) + self._sock.listen(0) + else: + if self._peer: + self._sock.connect(self._peer) + + def __getattr__(self, item): + if hasattr(self, "_sock") and hasattr(self._sock, item): + return getattr(self._sock, item) + raise AttributeError + + def user_config_ssl_ctx(self, _ctx): + """ + + :param SSLContext _ctx: + """ + _ctx.set_ssl_logging(self._ssl_logging) + if self._ciphers: + _ctx.set_ciphers(self._ciphers) + if self._curves: + _ctx.set_curves(self._curves) + if self._sigalgs: + _ctx.set_sigalgs(self._sigalgs) + if self._server_side: + _ctx.build_cert_chain(flags=self._server_cert_options) + _ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve) + + def user_config_ssl(self, _ssl): + """ + + :param SSL _ssl: + """ + if self._user_mtu: + _ssl.set_link_mtu(self._user_mtu) + + def settimeout(self, t): + if self._server_side: + self._timeout = t + else: + self._sock.settimeout(t) + + def close(self): + if self._server_side: + for cli in self._clients.keys(): + cli.close() + else: + self._sock.unwrap() + self._sock.close() + + def write(self, data): + # return self._sock.write(data) + return self.sendto(data, self._peer) + + def read(self, len=1024): + # return self._sock.read(len=len) + return self.recvfrom(len)[0] + + def recvfrom(self, bufsize, flags=0): + if self._server_side: + return self._recvfrom_on_server_side(bufsize, flags=flags) + else: + return self._recvfrom_on_client_side(bufsize, flags=flags) + + def _recvfrom_on_server_side(self, bufsize, flags): + try: + r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) + + except socket.timeout as e_timeout: + raise e_timeout + + try: + for conn in r: # type: ssl.SSLSocket + if self._sockIsServerSock(conn): + # Connect + self._clientAccept(conn) + else: + # Handshake + if not self._clientHandshakeDone(conn): + self._clientDoHandshake(conn) + # Normal read + else: + buf = self._clientRead(conn, bufsize) + if buf and conn in self._clients: + return buf, self._clients[conn].getAddr() + + except Exception as e: + raise e + + try: + for conn in self._getClientReadingSockets(): + if conn.get_timeout(): + conn.handle_timeout() + + except Exception as e: + raise e + + raise socket.timeout + + def _recvfrom_on_client_side(self, bufsize, flags): + try: + buf = self._sock.recv(bufsize, flags) + + except ssl.SSLError as e_ssl: + if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + return '', self._peer + elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: + raise e_ssl + else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + pass + + else: + if buf: + return buf, self._peer + + raise socket.timeout + + def sendto(self, buf, address): + if self._server_side: + return self._sendto_from_server_side(buf, address) + else: + return self._sendto_from_client_side(buf, address) + + def _sendto_from_server_side(self, buf, address): + for conn, client in self._clients.iteritems(): + if client.getAddr() == address: + return self._clientWrite(conn, buf) + return 0 + + def _sendto_from_client_side(self, buf, address): + while True: + try: + bytes_sent = self._sock.send(buf) + + except ssl.SSLError as e_ssl: + if str(e_ssl).startswith("503:"): + # The write operation timed out + continue + raise e_ssl + + else: + if bytes_sent: + break + + return bytes_sent + + def _getClientReadingSockets(self): + return [x for x in self._clients.keys()] + + def _getAllReadingSockets(self): + return [self._sock] + self._getClientReadingSockets() + + def _sockIsServerSock(self, conn): + return conn is self._sock + + def _clientHandshakeDone(self, conn): + return conn in self._clients and self._clients[conn].handshake_done is True + + def _clientAccept(self, conn): + _logger.debug('+' * 60) + ret = None + + try: + ret = conn.accept() + _logger.debug('Accept returned with ... %s' % (str(ret))) + + except Exception as e_accept: + raise e_accept + + else: + if ret: + client, addr = ret + host, port = addr + if client in self._clients: + raise ValueError + self._clients[client] = self._ClientSession(host=host, port=port) + + self._clientDoHandshake(client) + + def _clientDoHandshake(self, conn): + _logger.debug('-' * 60) + conn.setblocking(False) + + try: + conn.do_handshake() + _logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr()))) + + self._clients[conn].handshake_done = True + + except ssl.SSLError as e_handshake: + if str(e_handshake).startswith("504:"): + pass + elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ: + pass + else: + raise e_handshake + + def _clientRead(self, conn, bufsize=4096): + _logger.debug('*' * 60) + ret = None + + try: + ret = conn.recv(bufsize) + _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) + + except ssl.SSLError as e_read: + if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + self._clientDrop(conn) + elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: + self._clientDrop(conn, error=e_read) + else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + pass + + return ret + + def _clientWrite(self, conn, data): + _logger.debug('#' * 60) + ret = None + + try: + _data = data + if False: + _data = data.raw + ret = conn.send(_data) + _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) + + except Exception as e_write: + raise e_write + + return ret + + def _clientDrop(self, conn, error=None): + _logger.debug('$' * 60) + + try: + if error: + _logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error)) + else: + _logger.debug('Drop client %s' % str(self._clients[conn].getAddr())) + + if conn in self._clients: + del self._clients[conn] + conn.unwrap() + conn.close() + + except Exception as e_drop: + pass