Added a wrapper for a DTLS-Socket either as client or server - including unit tests

* dtls/__init__.py: Import SSLContext() and SSL() for external use
* dtls/wrapper.py: Added class DtlsSocket() to be used as client or server
* dtls/test/unit_wrapper.py: unit test for DtlsSocket()
incoming
mcfreis 2017-03-20 16:53:55 +01:00
parent ff509e0724
commit 26634280a5
4 changed files with 1024 additions and 1 deletions

View File

@ -1,3 +1,11 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added a wrapper for a DTLS-Socket either as client or server - including unit tests
* dtls/__init__.py: Import SSLContext() and SSL() for external use
* dtls/wrapper.py: Added class DtlsSocket() to be used as client or server
* dtls/test/unit_wrapper.py: unit test for DtlsSocket()
2017-03-17 Björn Freise <mcfreis@gmx.net> 2017-03-17 Björn Freise <mcfreis@gmx.net>
Added more on error evaluation and a method to get the peer certificate chain Added more on error evaluation and a method to get the peer certificate chain

View File

@ -59,6 +59,6 @@ def _prep_bins():
_prep_bins() # prepare before module imports _prep_bins() # prepare before module imports
from patch import do_patch from patch import do_patch
from sslconnection import SSLConnection from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux from demux import force_routing_demux, reset_default_demux
import err as error_codes import err as error_codes

View File

@ -0,0 +1,654 @@
# -*- encoding: utf-8 -*-
# Test the support for DTLS through the SSL module. Adapted from the Python
# standard library's test_ssl.py regression test module by Björn Freise.
import unittest
import threading
import sys
import socket
import os
import pprint
from logging import basicConfig, DEBUG, getLogger
# basicConfig(level=DEBUG, format="%(asctime)s - %(threadName)-10s - %(name)s - %(levelname)s - %(message)s")
_logger = getLogger(__name__)
import ssl
from dtls import do_patch, error_codes
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
HOST = "localhost"
CHATTY = True
CHATTY_CLIENT = True
class ThreadedEchoServer(threading.Thread):
def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None,
ciphers=None, curves=None, sigalgs=None,
mtu=None, server_key_exchange_curve=None, server_cert_options=None,
chatty=True):
if ssl_version is None:
ssl_version = ssl.PROTOCOL_DTLSv1
if certreqs is None:
certreqs = ssl.CERT_NONE
self.certificate = certificate
self.protocol = ssl_version
self.certreqs = certreqs
self.cacerts = cacerts
self.ciphers = ciphers
self.curves = curves
self.sigalgs = sigalgs
self.mtu = mtu
self.server_key_exchange_curve = server_key_exchange_curve
self.server_cert_options = server_cert_options
self.chatty = chatty
self.flag = None
self.sock = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=self.certificate,
certfile=self.certificate,
server_side=True,
cert_reqs=self.certreqs,
ssl_version=self.protocol,
ca_certs=self.cacerts,
ciphers=self.ciphers,
curves=self.curves,
sigalgs=self.sigalgs,
user_mtu=self.mtu,
server_key_exchange_curve=self.server_key_exchange_curve,
server_cert_options=self.server_cert_options)
if self.chatty:
sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
self.sock.bind((HOST, 0))
self.port = self.sock.getsockname()[1]
self.active = False
threading.Thread.__init__(self)
self.daemon = True
def start(self, flag=None):
self.flag = flag
self.starter = threading.current_thread().ident
threading.Thread.start(self)
def run(self):
self.sock.settimeout(0.05)
self.sock.listen(0)
self.active = True
if self.flag:
# signal an event
self.flag.set()
while self.active:
try:
acc_ret = self.sock.recvfrom(4096)
if acc_ret:
newdata, connaddr = acc_ret
if self.chatty:
sys.stdout.write(' server: new data from ' + str(connaddr) + '\n')
self.sock.sendto(newdata.lower(), connaddr)
except socket.timeout:
pass
except KeyboardInterrupt:
self.stop()
except Exception as e:
if self.chatty:
sys.stdout.write(' server: error ' + str(e) + '\n')
pass
if self.chatty:
sys.stdout.write(' server: closing socket as %s\n' % str(self.sock))
self.sock.close()
def stop(self):
self.active = False
if self.starter != threading.current_thread().ident:
return
self.join() # don't allow spawning new handlers after we've checked
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert.pem")
CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert_ec.pem")
ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert.pem")
ISSUER_CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert_ec.pem")
# certfile, protocol, certreqs, cacertsfile,
# ciphers=None, curves=None, sigalgs=None,
tests = [
{'testcase':
{'name': 'standard dtls v1',
'desc': 'Standard DTLS v1 test with out-of-the box configuration and RSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE,
'protocol': ssl.PROTOCOL_DTLSv1,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'standard dtls v1_2',
'desc': 'Standard DTLS v1_2 test with out-of-the box configuration and ECDSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'protocol version mismatch',
'desc': 'Client and server have different protocol versions',
'start_server': True},
'input':
{'certfile': CERTFILE,
'protocol': ssl.PROTOCOL_DTLSv1,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_WRONG_SSL_VERSION,
'exception': None}},
{'testcase':
{'name': 'certificate verify fails',
'desc': 'Server certificate cannot be verified by client',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED,
'exception': None}},
{'testcase':
{'name': 'no matching curve',
'desc': 'Client doesn\'t support curve used by server ECDSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': 'secp384r1',
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching curve',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': 'prime256v1',
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'no host',
'desc': 'No server port is listening',
'start_server': False},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_PORT_UNREACHABLE,
'exception': None}},
{'testcase':
{'name': 'no matching sigalgs',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': "RSA+SHA256"},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching sigalgs',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': "ECDSA+SHA256"},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'no matching cipher',
'desc': 'Server using a ECDSA certificate while client is only able to use RSA encryption',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': "AES256-SHA",
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching cipher',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': "ECDHE-ECDSA-AES256-SHA",
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
]
def params_test(start_server, certfile, protocol, certreqs, cacertsfile,
client_certfile=None, client_protocol=None, client_certreqs=None, client_cacertsfile=None,
ciphers=None, curves=None, sigalgs=None,
client_ciphers=None, client_curves=None, client_sigalgs=None,
mtu=1500, server_key_exchange_curve=None, server_cert_options=None,
indata="FOO\n", chatty=False, connectionchatty=False):
"""
Launch a server, connect a client to it and try various reads
and writes.
"""
server = ThreadedEchoServer(certfile,
ssl_version=protocol,
certreqs=certreqs,
cacerts=cacertsfile,
ciphers=ciphers,
curves=curves,
sigalgs=sigalgs,
mtu=mtu,
server_key_exchange_curve=server_key_exchange_curve,
server_cert_options=server_cert_options,
chatty=chatty)
# should we really run the server?
if start_server:
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
else:
server.sock.close()
# try to connect
if client_protocol is None:
client_protocol = protocol
if client_ciphers is None:
client_ciphers = ciphers
if client_curves is None:
client_curves = curves
if client_sigalgs is None:
client_sigalgs = sigalgs
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=client_certfile,
certfile=client_certfile,
cert_reqs=client_certreqs,
ssl_version=client_protocol,
ca_certs=client_cacertsfile,
ciphers=client_ciphers,
curves=client_curves,
sigalgs=client_sigalgs,
user_mtu=mtu)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
cert = s.getpeercert()
cipher = s.cipher()
if connectionchatty:
sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
return False, e
finally:
if start_server:
server.stop()
return True, None
class TestSequenceMeta(type):
def __new__(mcs, name, bases, dict):
def gen_test(_case, _input, _result):
def test(self):
try:
if CHATTY or CHATTY_CLIENT:
sys.stdout.write("\nTestcase: %s\n" % _case['name'])
ret, e = params_test(_case['start_server'], chatty=CHATTY, connectionchatty=CHATTY_CLIENT, **_input)
if _result['ret_success']:
self.assertEqual(ret, _result['ret_success'])
else:
try:
last_error = e.errqueue[-1][0]
except:
try:
last_error = e.errno
except:
last_error = None
self.assertEqual(last_error, _result['error_code'])
except Exception as e:
raise
return test
for testcase in tests:
_case, _input, _result = testcase.itervalues()
test_name = "test_%s" % _case['name'].lower().replace(' ', '_')
dict[test_name] = gen_test(_case, _input, _result)
return type.__new__(mcs, name, bases, dict)
class WrapperTests(unittest.TestCase):
__metaclass__ = TestSequenceMeta
def setUp(self):
super(WrapperTests, self).setUp()
do_patch()
def test_build_cert_chain(self):
steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT]
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO'
certs = dict()
if chatty or connectionchatty:
sys.stdout.write("\nTestcase: test_build_cert_chain\n")
for step in steps:
server = ThreadedEchoServer(certificate=CERTFILE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
certreqs=ssl.CERT_NONE,
cacerts=ISSUER_CERTFILE,
ciphers=None,
curves=None,
sigalgs=None,
mtu=1500,
server_key_exchange_curve=None,
server_cert_options=step,
chatty=chatty)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=None,
certfile=None,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=ISSUER_CERTFILE,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=1500)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
# cert = s.getpeercert()
# cipher = s.cipher()
# if connectionchatty:
# sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
# sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
certs[step] = s.getpeercertchain()
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
raise
finally:
server.stop()
if chatty:
sys.stdout.write("certs:\n")
for step in steps:
sys.stdout.write("SSL_CTX_build_cert_chain: %s\n%s\n" % (step, pprint.pformat(certs[step])))
self.assertNotEqual(certs[steps[0]], certs[steps[1]])
self.assertEqual(len(certs[steps[0]]) - len(certs[steps[1]]), 1)
def test_set_ecdh_curve(self):
steps = {
# server, client, result
'all auto': (None, None, True), # Auto
'client restricted': (None, "secp256k1:prime256v1", True), # client can handle key curve
'client too restricted': (None, "secp256k1", False), # client _cannot_ handle key curve
'client minimum': (None, "prime256v1", True), # client can only handle key curve
'server restricted': ("secp384r1", None, True), # client can handle key curve
'server one, client two': ("secp384r1", "prime256v1:secp384r1", True), # client can handle key curve
'server one, client one': ("secp384r1", "secp384r1", False), # client _cannot_ handle key curve
}
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO'
certs = dict()
if chatty or connectionchatty:
sys.stdout.write("\nTestcase: test_ecdh_curve\n")
for step, tmp in steps.iteritems():
if chatty or connectionchatty:
sys.stdout.write("\n Subcase: %s\n" % step)
server_curve, client_curve, result = tmp
server = ThreadedEchoServer(certificate=CERTFILE_EC,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
certreqs=ssl.CERT_NONE,
cacerts=ISSUER_CERTFILE_EC,
ciphers=None,
curves=None,
sigalgs=None,
mtu=1500,
server_key_exchange_curve=server_curve,
server_cert_options=None,
chatty=chatty)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=None,
certfile=None,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=ISSUER_CERTFILE_EC,
ciphers=None,
curves=client_curve,
sigalgs=None,
user_mtu=1500)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
if result:
raise
finally:
server.stop()
pass
if __name__ == '__main__':
unittest.main()

361
dtls/wrapper.py 100644
View File

@ -0,0 +1,361 @@
# -*- encoding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# The License is also distributed with this work in the file named "LICENSE."
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS
protocol.
Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server
"""
import select
from logging import getLogger
import ssl
import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
_logger = getLogger(__name__)
class DtlsSocket(object):
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
def __init__(self,
peerOrSock,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None:
server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._peer = None
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options
# Default socket creation
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(peerOrSock, tuple):
# Address tuple
self._peer = peerOrSock
else:
# Socket, use given
sock = peerOrSock
self._sock = ssl.wrap_socket(sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl)
if self._server_side:
self._clients = {}
self._timeout = None
if self._peer:
self._sock.bind(self._peer)
self._sock.listen(0)
else:
if self._peer:
self._sock.connect(self._peer)
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
raise AttributeError
def user_config_ssl_ctx(self, _ctx):
"""
:param SSLContext _ctx:
"""
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl):
"""
:param SSL _ssl:
"""
if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
self._sock.close()
def write(self, data):
# return self._sock.write(data)
return self.sendto(data, self._peer)
def read(self, len=1024):
# return self._sock.read(len=len)
return self.recvfrom(len)[0]
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
try:
for conn in r: # type: ssl.SSLSocket
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
except Exception as e:
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
except Exception as e:
raise e
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise e_ssl
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
else:
if buf:
return buf, self._peer
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
raise e_ssl
else:
if bytes_sent:
break
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
raise e_accept
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
_data = data
if False:
_data = data.raw
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise e_write
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
conn.close()
except Exception as e_drop:
pass