Merge pull request #2 from mcfreis/clean-redo-in-steps

Merge "DTLSv1.2 methods and extensions added" from branch to master
incoming
mcfreis 2017-03-21 08:10:17 +01:00 committed by GitHub
commit d12b23ba9f
24 changed files with 2600 additions and 651 deletions

153
ChangeLog
View File

@ -1,3 +1,156 @@
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added a wrapper for a DTLS-Socket either as client or server - including unit tests
* dtls/__init__.py: Import SSLContext() and SSL() for external use
* dtls/wrapper.py: Added class DtlsSocket() to be used as client or server
* dtls/test/unit_wrapper.py: unit test for DtlsSocket()
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added more on error evaluation and a method to get the peer certificate chain
* dtls/__init__.py: import error codes from err.py as error_codes for external access
* dtls/err.py: Added errors for ERR_WRONG_SSL_VERSION, ERR_CERTIFICATE_VERIFY_FAILED, ERR_NO_SHARED_CIPHER and ERR_SSL_HANDSHAKE_FAILURE
* dtls/openssl.py:
- Added constant SSL_BUILD_CHAIN_FLAG_NONE for SSL_CTX_build_cert_chain()
- Added method SSL_get_peer_cert_chain()
* dtls/patch.py: Added getpeercertchain() as method to ssl.SSLSocket()
* dtls/sslconnection.py:
- Bugfix SSLContext.set_ecdh_curve() returns 1 for success and 0 for failure
- SSLContext.build_cert_chain() changed default flags to SSL_BUILD_CHAIN_FLAG_NONE
- In SSLConnection() the mtu size gets only set if no user config function is given
- SSLConnection.listen() raises an exception for ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER and all other unknown errors
- SSLConnection.read() and write() now can also raise ERR_PORT_UNREACHABLE
- If SSLConnection.write() successfully writes bytes to the peer, then the handshake is assumed to be okay
- Added method SSLConnection.getpeercertchain()
* dtls/test/unit.py: ThreadedEchoServer() with an extra exception branch for the newly raised exceptions in SSLConnection.listen()
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added certificate creation using ECDSA
* dtls/test/makecerts_ec.bat: creates ca-cert_ec.pem, keycert_ec.pem and server-cert_ec.pem
* dtls/test/openssl_ca.cnf and openssl_server.cnf: Added HOME to be able to use the conf file under windows
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added an interface in SSLConnection() to access SSLContext() and SSL() for manipulating settings during creation
* dtls/openssl.py:
- Added utility functions EC_curve_nist2nid() and EC_curve_nid2nist()
* dtls/patch.py:
- Extended wrap_socket() arguments with callbacks for user config functions of ssl context and ssl session values
- Extended SSLSocket() arguments with callbacks for user config functions of ssl context and ssl session values
* dtls/sslconnection.py:
- Extended SSLConnection() arguments with callbacks for user config functions of ssl context and ssl session values
- During the init of client and server the corresponding user config functions are called (if given)
- Added new classes SSLContext() [set_ciphers(), set_sigalgs(), set_curves(), set_ecdh_curve(), build_cert_chain(),
set_ssl_logging()] and SSL() [set_mtu(), set_link_mtu()]
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods getting the curves supported by the runtime openSSL lib
* dtls/openssl.py:
- Added class _EllipticCurve() for easy handling of the builtin curves
- Added wrapper get_elliptic_curves() - which uses _EllipticCurve()
- Added EC_get_builtin_curves(), EC_KEY_new_by_curve_name() and EC_KEY_free()
- Added OBJ_nid2sn() for translating numeric ids to names
* dtls/util.py: Added _EC_KEY() derived from _Rsrc() with own free/del method
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods for setting and getting the curves used during negotiation and encryption
* dtls/openssl.py:
- Added SSL_CTX_set1_curves() and SSL_CTX_set1_curves_list()
- Added SSL_CTX_set_ecdh_auto() and SSL_CTX_set_tmp_ecdh()
- Added SSL_get1_curves(), SSL_get_shared_curve(), SSL_set1_curves() and SSL_set1_curves_list()
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods for setting the signature algorithms
* dtls/openssl.py:
- Added SSL_CTX_set1_client_sigalgs_list(), SSL_CTX_set1_client_sigalgs(), SSL_CTX_set1_sigalgs_list() and SSL_CTX_set1_sigalgs()
- Added SSL_set1_client_sigalgs_list(), SSL_set1_client_sigalgs(), SSL_set1_sigalgs_list() and SSL_set1_sigalgs()
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added method SSL_CTX_build_cert_chain()
* dtls/openssl.py: Added SSL_CTX_build_cert_chain() and corresponding constants
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added methods *_clear_options() and *_get_options()
* dtls/openssl.py:
- Added SSL_CTX_clear_options() and SSL_CTX_get_options()
- Added SSL_clear_options() and SSL_get_options()
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added new methods for DTLSv1.2
* dtls/err.py: Added error code ERR_WRONG_VERSION_NUMBER
* dtls/openssl.py: Added DTLS_server_method(), DTLSv1_2_server_method() and DTLSv1_2_client_method()
* dtls/patch.py: Default protocol DTLS for ssl.wrap_socket() and ssl.SSLSocket()
* dtls/sslconnection.py:
- Introduced PROTOCOL_DTLSv1_2 and PROTOCOL_DTLS (the latter one is a synonym for the "higher" version)
- Updated _init_client() and _init_server() with the new protocol methods
- Default protocol DTLS for SSLConnection()
- Return on ERR_WRONG_VERSION_NUMBER if client and server cannot agree on protocol version
* dtls/test/unit.py:
- Extended test_get_server_certificate() to iterate over the different protocol combinations
- Extended test_protocol_dtlsv1() to try the different protocol combinations between client and server
2017-03-17 Björn Freise <mcfreis@gmx.net>
Updating openSSL libs to v1.0.2l-dev
* dtls/openssl.py: Added mtu-functions SSL_set_mtu() and DTLS_set_link_mtu()
* dtls/prebuilt/win32-*: Updated libs for x86 and x86_64 to version 1.0.2l-dev
* dtls/sslconnection.py: mtu size set hardcoded to 1500 - otherwise the windows implementation has problems
2017-03-17 Björn Freise <mcfreis@gmx.net>
Added interface for SSL_CTX_set_info_callback()
* dtls/openssl.py:
- Added methods SSL_CTX_set_info_callback(), SSL_state_string_long(), SSL_alert_type_string_long() and SSL_alert_desc_string_long()
- Added constants for state and error evaluation during callback
* dtls/sslconnection.py: Added _ssl_logging_cb() as default callback function - only outputs messages when logger is active
2017-03-17 Björn Freise <mcfreis@gmx.net>
SSL_write() extended to handle ctypes.Array as data
* dtls/openssl.py: SSL_write() can handle ctypes.Array data
* dtls/sslconnection.py: Added missing import ERR_BOTH_KEY_CERT_FILES
* dtls/test/simple_client.py: Added basic test client to use with dtls/test/echo_seq.py
2017-03-17 Björn Freise <mcfreis@gmx.net>
Beautified lists and maps, grouped imports for easy merges in the future - no changed functionality!
* dtls/openssl.py:
- Ordered constants according to header file from openSSL
- Beautified __all__-list and map for _make_function() in order to easy merges in the future
- Added a few returns in order to evaluate the success of the called methods
* dtls/patch.py: Grouped imports in the following order - system, local
* dtls/sslconnection.py: ssl protocol not hardcoded anymore for forked objects
* dtls/x509.py: logger messages working again
2017-02-27 Ray Brown <code@liquibits.com>
* dtls/openssl.py: support reading directly into given buffer instead of forcing buffer copy (for ssl module compatibility)
* dtls/sslconnection.py: in-situ receive support, as above
* dtls/patch.py: various changes for compatibility with the ssl module of Python 2.7.12; note that the ssl module's new SSLContext is not supported
* dtls/test/unit.py: changes to support the updated ssl module, including fix of deprecation warnings
* setup.py: increase version to 1.0.2
2014-01-18 Ray Brown <code@liquibits.com> 2014-01-18 Ray Brown <code@liquibits.com>
* setup.py: Increase version to 1.0.1 for release to PyPI * setup.py: Increase version to 1.0.1 for release to PyPI

View File

@ -53,11 +53,12 @@ def _prep_bins():
for prebuilt_file in files: for prebuilt_file in files:
try: try:
copy(path.join(prebuilt_path, prebuilt_file), package_root) copy(path.join(prebuilt_path, prebuilt_file), package_root)
except IOError: except IOError:
pass pass
_prep_bins() # prepare before module imports _prep_bins() # prepare before module imports
from patch import do_patch from patch import do_patch
from sslconnection import SSLConnection from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux from demux import force_routing_demux, reset_default_demux
import err as error_codes

View File

@ -42,18 +42,24 @@ SSL_ERROR_WANT_ACCEPT = 8
ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES = 500
ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_BOTH_KEY_CERT_FILES_SVR = 298
ERR_NO_CERTS = 331 ERR_NO_CERTS = 331
ERR_NO_CIPHER = 501 ERR_NO_CIPHER = 501
ERR_READ_TIMEOUT = 502 ERR_READ_TIMEOUT = 502
ERR_WRITE_TIMEOUT = 503 ERR_WRITE_TIMEOUT = 503
ERR_HANDSHAKE_TIMEOUT = 504 ERR_HANDSHAKE_TIMEOUT = 504
ERR_PORT_UNREACHABLE = 505 ERR_PORT_UNREACHABLE = 505
ERR_COOKIE_MISMATCH = 0x1408A134
ERR_WRONG_SSL_VERSION = 0x1409210A
ERR_WRONG_VERSION_NUMBER = 0x1408A10B
class SSLError(socket_error): ERR_COOKIE_MISMATCH = 0x1408A134
"""This exception is raised by modules in the dtls package.""" ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
def __init__(self, *args): ERR_NO_SHARED_CIPHER = 0x1408A0C1
super(SSLError, self).__init__(*args) ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
class SSLError(socket_error):
"""This exception is raised by modules in the dtls package."""
def __init__(self, *args):
super(SSLError, self).__init__(*args)
class InvalidSocketError(Exception): class InvalidSocketError(Exception):
@ -96,8 +102,8 @@ def raise_ssl_error(code, nested=None):
"""Raise an SSL error with the given error code""" """Raise an SSL error with the given error code"""
err_string = str(code) + ": " + _ssl_errors[code] err_string = str(code) + ": " + _ssl_errors[code]
if nested: if nested:
raise SSLError(err_string, nested) raise SSLError(code, err_string + str(nested))
raise SSLError(err_string) raise SSLError(code, err_string)
_ssl_errors = { _ssl_errors = {
ERR_NO_CERTS: "No root certificates specified for verification " + \ ERR_NO_CERTS: "No root certificates specified for verification " + \

File diff suppressed because it is too large Load Diff

View File

@ -34,16 +34,18 @@ has the following effects:
PROTOCOL_DTLSv1 for the parameter ssl_version is supported PROTOCOL_DTLSv1 for the parameter ssl_version is supported
""" """
from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_DGRAM, getaddrinfo from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION
from sslconnection import DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
from types import MethodType from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
def do_patch(): def do_patch():
import ssl as _ssl # import to be avoided if ssl module is never patched import ssl as _ssl # import to be avoided if ssl module is never patched
global _orig_SSLSocket_init, _orig_get_server_certificate global _orig_SSLSocket_init, _orig_get_server_certificate
@ -51,8 +53,14 @@ def do_patch():
ssl = _ssl ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"): if hasattr(ssl, "PROTOCOL_DTLSv1"):
return return
_orig_wrap_socket = ssl.wrap_socket
ssl.wrap_socket = _wrap_socket
ssl.PROTOCOL_DTLS = PROTOCOL_DTLS
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLS] = "DTLS"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
@ -62,10 +70,25 @@ def do_patch():
ssl.get_server_certificate = _get_server_certificate ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error() raise_as_ssl_module_error()
PROTOCOL_SSLv3 = 1 def _wrap_socket(sock, keyfile=None, certfile=None,
PROTOCOL_SSLv23 = 2 server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve a server certificate """Retrieve a server certificate
Retrieve the certificate from the server at the specified address, Retrieve the certificate from the server at the specified address,
@ -74,10 +97,10 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
If 'ssl_version' is specified, use it in the connection attempt. If 'ssl_version' is specified, use it in the connection attempt.
""" """
if ssl_version != PROTOCOL_DTLSv1: if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2):
return _orig_get_server_certificate(addr, ssl_version, ca_certs) return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if (ca_certs is not None): if ca_certs is not None:
cert_reqs = ssl.CERT_REQUIRED cert_reqs = ssl.CERT_REQUIRED
else: else:
cert_reqs = ssl.CERT_NONE cert_reqs = ssl.CERT_NONE
@ -90,11 +113,16 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
s.close() s.close()
return ssl.DER_cert_to_PEM_cert(dercert) return ssl.DER_cert_to_PEM_cert(dercert)
def _SSLSocket_init(self, sock, keyfile=None, certfile=None, def _SSLSocket_init(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None,
_context=None,
cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
is_connection = is_datagram = False is_connection = is_datagram = False
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
is_connection = True is_connection = True
@ -102,11 +130,19 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
is_datagram = True is_datagram = True
if not is_connection and not is_datagram: if not is_connection and not is_datagram:
# Non-DTLS code path # Non-DTLS code path
return _orig_SSLSocket_init(self, sock, keyfile, certfile, return _orig_SSLSocket_init(self, sock=sock, keyfile=keyfile,
server_side, cert_reqs, certfile=certfile, server_side=server_side,
ssl_version, ca_certs, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=
do_handshake_on_connect, do_handshake_on_connect,
suppress_ragged_eofs, ciphers) family=family, type=type, proto=proto,
fileno=fileno,
suppress_ragged_eofs=suppress_ragged_eofs,
npn_protocols=npn_protocols,
ciphers=ciphers,
server_hostname=server_hostname,
_context=_context)
# DTLS code paths: datagram socket and newly accepted DTLS connection # DTLS code paths: datagram socket and newly accepted DTLS connection
if is_datagram: if is_datagram:
socket.__init__(self, _sock=sock._sock) socket.__init__(self, _sock=sock._sock)
@ -138,10 +174,17 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
server_side, cert_reqs, server_side, cert_reqs,
ssl_version, ca_certs, ssl_version, ca_certs,
do_handshake_on_connect, do_handshake_on_connect,
suppress_ragged_eofs, ciphers) suppress_ragged_eofs, ciphers,
cb_user_config_ssl_ctx=cb_user_config_ssl_ctx,
cb_user_config_ssl=cb_user_config_ssl)
else: else:
self._connected = True
self._sslobj = sock self._sslobj = sock
class FakeContext(object):
check_hostname = False
self._context = FakeContext()
self.keyfile = keyfile self.keyfile = keyfile
self.certfile = certfile self.certfile = certfile
self.cert_reqs = cert_reqs self.cert_reqs = cert_reqs
@ -151,25 +194,35 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
self.do_handshake_on_connect = do_handshake_on_connect self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0 self._makefile_refs = 0
self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._user_config_ssl = cb_user_config_ssl
# Perform method substitution and addition (without reference cycle) # Perform method substitution and addition (without reference cycle)
self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self))
self.listen = MethodType(_SSLSocket_listen, proxy(self)) self.listen = MethodType(_SSLSocket_listen, proxy(self))
self.accept = MethodType(_SSLSocket_accept, proxy(self)) self.accept = MethodType(_SSLSocket_accept, proxy(self))
self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self))
self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self))
def _SSLSocket_listen(self, ignored): # Extra
if self._connected: self.getpeercertchain = MethodType(_getpeercertchain, proxy(self))
raise ValueError("attempt to listen on connected SSLSocket!")
if self._sslobj: def _getpeercertchain(self, binary_form=False):
return return self._sslobj.getpeercertchain(binary_form)
self._sslobj = SSLConnection(socket(_sock=self._sock),
def _SSLSocket_listen(self, ignored):
if self._connected:
raise ValueError("attempt to listen on connected SSLSocket!")
if self._sslobj:
return
self._sslobj = SSLConnection(socket(_sock=self._sock),
self.keyfile, self.certfile, True, self.keyfile, self.certfile, True,
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
def _SSLSocket_accept(self): def _SSLSocket_accept(self):
if self._connected: if self._connected:
@ -184,7 +237,9 @@ def _SSLSocket_accept(self):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
return new_ssl_sock, addr return new_ssl_sock, addr
def _SSLSocket_real_connect(self, addr, return_errno): def _SSLSocket_real_connect(self, addr, return_errno):
@ -195,7 +250,9 @@ def _SSLSocket_real_connect(self, addr, return_errno):
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs, self.ca_certs,
self.do_handshake_on_connect, self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers) self.suppress_ragged_eofs, self.ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl)
try: try:
self._sslobj.connect(addr) self._sslobj.connect(addr)
except socket_error as e: except socket_error as e:

Binary file not shown.

Binary file not shown.

View File

@ -45,27 +45,30 @@ import socket
import hmac import hmac
import datetime import datetime
from logging import getLogger from logging import getLogger
from os import urandom from os import urandom
from select import select from select import select
from weakref import proxy from weakref import proxy
from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import openssl_error, InvalidSocketError
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import raise_ssl_error
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_BOTH_KEY_CERT_FILES_SVR from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from x509 import _X509, decode_cert from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS
from tlock import tlock_init from x509 import _X509, decode_cert
from openssl import * from tlock import tlock_init
from util import _Rsrc, _BIO from openssl import *
from util import _Rsrc, _BIO
_logger = getLogger(__name__)
_logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256
CERT_NONE = 0 PROTOCOL_DTLSv1 = 256
CERT_OPTIONAL = 1 PROTOCOL_DTLSv1_2 = 258
CERT_REQUIRED = 2 PROTOCOL_DTLS = 259
CERT_NONE = 0
CERT_OPTIONAL = 1
CERT_REQUIRED = 2
# #
# One-time global OpenSSL library initialization # One-time global OpenSSL library initialization
@ -80,12 +83,64 @@ DTLS_OPENSSL_VERSION_INFO = (
DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor
DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix
DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch
DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status
class _CTX(_Rsrc): def _ssl_logging_cb(conn, where, return_code):
"""SSL_CTX wrapper""" _state = where & ~SSL_ST_MASK
def __init__(self, value): state = "SSL"
if _state & SSL_ST_INIT == SSL_ST_INIT:
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state += "_renew"
else:
state += "_init"
elif _state & SSL_ST_CONNECT:
state += "_connect"
elif _state & SSL_ST_ACCEPT:
state += "_accept"
elif _state == 0:
if where & SSL_CB_HANDSHAKE_START:
state += "_handshake_start"
elif where & SSL_CB_HANDSHAKE_DONE:
state += "_handshake_done"
if where & SSL_CB_LOOP:
state += '_loop'
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
elif where & SSL_CB_ALERT:
state += '_alert'
state += "_read" if where & SSL_CB_READ else "_write"
_logger.debug("%s:%s:%s" % (state,
SSL_alert_type_string_long(return_code),
SSL_alert_desc_string_long(return_code)))
elif where & SSL_CB_EXIT:
state += '_exit'
if return_code == 0:
_logger.debug("%s:%s:%d(failed)" % (state,
SSL_state_string_long(conn),
return_code))
elif return_code < 0:
_logger.debug("%s:%s:%d(error)" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
else:
_logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn),
return_code))
class _CTX(_Rsrc):
"""SSL_CTX wrapper"""
def __init__(self, value):
super(_CTX, self).__init__(value) super(_CTX, self).__init__(value)
def __del__(self): def __del__(self):
@ -119,12 +174,130 @@ class _CallbackProxy(object):
self.ssl_func = cbm.im_func self.ssl_func = cbm.im_func
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ssl_func(self.ssl_connection, *args, **kwargs) return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLConnection(object): class SSLContext(object):
"""DTLS peer association
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object):
"""DTLS peer association
This class associates two DTLS peer instances, wrapping OpenSSL library This class associates two DTLS peer instances, wrapping OpenSSL library
state including SSL (struct ssl_st), SSL_CTX, and BIO instances. state including SSL (struct ssl_st), SSL_CTX, and BIO instances.
""" """
@ -146,13 +319,19 @@ class SSLConnection(object):
rsock = self._udp_demux.get_connection(None) rsock = self._udp_demux.get_connection(None)
if rsock is self._sock: if rsock is self._sock:
self._rbio = self._wbio self._rbio = self._wbio
else: else:
self._rsock = rsock self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) server_method = DTLS_server_method
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._ssl_version == PROTOCOL_DTLSv1_2:
if self._cert_reqs == CERT_NONE: server_method = DTLSv1_2_server_method
verify_mode = SSL_VERIFY_NONE elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
elif self._cert_reqs == CERT_OPTIONAL: elif self._cert_reqs == CERT_OPTIONAL:
verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE
else: else:
@ -166,29 +345,37 @@ class SSLConnection(object):
self._pending_peer_address = None self._pending_peer_address = None
self._cb_keepalive = SSL_CTX_set_cookie_cb( self._cb_keepalive = SSL_CTX_set_cookie_cb(
self._ctx.value, self._ctx.value,
_CallbackProxy(self._generate_cookie_cb), _CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb)) _CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
if peer_address and self._do_handshake_on_connect: SSL_set_accept_state(self._ssl.value)
return lambda: self.do_handshake() if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _init_client(self, peer_address): def _init_client(self, peer_address):
if self._sock.type != socket.SOCK_DGRAM: if self._sock.type != socket.SOCK_DGRAM:
raise InvalidSocketError("sock must be of type SOCK_DGRAM") raise InvalidSocketError("sock must be of type SOCK_DGRAM")
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._cert_reqs == CERT_NONE: if self._ssl_version == PROTOCOL_DTLSv1_2:
verify_mode = SSL_VERIFY_NONE client_method = DTLSv1_2_client_method
else: elif self._ssl_version == PROTOCOL_DTLSv1:
verify_mode = SSL_VERIFY_PEER client_method = DTLSv1_client_method
self._config_ssl_ctx(verify_mode) self._ctx = _CTX(SSL_CTX_new(client_method()))
self._ssl = _SSL(SSL_new(self._ctx.value)) self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_set_connect_state(self._ssl.value) if self._cert_reqs == CERT_NONE:
if peer_address: verify_mode = SSL_VERIFY_NONE
return lambda: self.connect(peer_address) else:
verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value)
SSL_set_connect_state(self._ssl.value)
if peer_address:
return lambda: self.connect(peer_address)
def _config_ssl_ctx(self, verify_mode): def _config_ssl_ctx(self, verify_mode):
SSL_CTX_set_verify(self._ctx.value, verify_mode) SSL_CTX_set_verify(self._ctx.value, verify_mode)
@ -205,12 +392,14 @@ class SSLConnection(object):
SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None) SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None)
if self._ciphers: if self._ciphers:
try: try:
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err: except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err) raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_config_ssl_ctx:
def _copy_server(self): self._user_config_ssl_ctx(self._intf_ssl_ctx)
source = self._sock
def _copy_server(self):
source = self._sock
self._udp_demux = source._udp_demux self._udp_demux = source._udp_demux
rsock = self._udp_demux.get_connection(source._pending_peer_address) rsock = self._udp_demux.get_connection(source._pending_peer_address)
self._ctx = source._ctx self._ctx = source._ctx
@ -230,13 +419,16 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
new_source_rbio = new_source_wbio new_source_rbio = new_source_wbio
BIO_dgram_set_connected(self._wbio.value, BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address) source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value)) source._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(source._ssl.value) self._intf_ssl = SSL(source._ssl.value)
source._rbio = new_source_rbio SSL_set_accept_state(source._ssl.value)
source._wbio = new_source_wbio if self._user_config_ssl:
SSL_set_bio(source._ssl.value, self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio
source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value,
new_source_rbio.value, new_source_rbio.value,
new_source_wbio.value) new_source_wbio.value)
new_source_rbio.disown() new_source_rbio.disown()
@ -249,13 +441,16 @@ class SSLConnection(object):
self._rsock = source._rsock self._rsock = source._rsock
self._ctx = source._ctx self._ctx = source._ctx
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address) BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
SSL_set_accept_state(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
if self._do_handshake_on_connect: SSL_set_accept_state(self._ssl.value)
return lambda: self.do_handshake() if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect:
return lambda: self.do_handshake()
def _check_nbio(self): def _check_nbio(self):
timeout = self._sock.gettimeout() timeout = self._sock.gettimeout()
if self._wbio_nb != timeout is not None: if self._wbio_nb != timeout is not None:
@ -307,15 +502,17 @@ class SSLConnection(object):
def _verify_cookie_cb(self, ssl, cookie): def _verify_cookie_cb(self, ssl, cookie):
if self._get_cookie(ssl) != cookie: if self._get_cookie(ssl) != cookie:
raise Exception("DTLS cookie mismatch") raise Exception("DTLS cookie mismatch")
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLSv1, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None): suppress_ragged_eofs=True, ciphers=None,
"""Constructor cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None):
Arguments: """Constructor
Arguments:
these arguments match the ones of the SSLSocket class in the these arguments match the ones of the SSLSocket class in the
standard library's ssl module standard library's ssl module
""" """
@ -331,36 +528,46 @@ class SSLConnection(object):
ciphers = "DEFAULT" ciphers = "DEFAULT"
self._sock = sock self._sock = sock
self._keyfile = keyfile self._keyfile = keyfile
self._certfile = certfile self._certfile = certfile
self._cert_reqs = cert_reqs self._cert_reqs = cert_reqs
self._ca_certs = ca_certs self._ssl_version = ssl_version
self._do_handshake_on_connect = do_handshake_on_connect self._ca_certs = ca_certs
self._suppress_ragged_eofs = suppress_ragged_eofs self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs
self._ciphers = ciphers self._ciphers = ciphers
self._handshake_done = False self._handshake_done = False
self._wbio_nb = self._rbio_nb = False self._wbio_nb = self._rbio_nb = False
if isinstance(sock, SSLConnection): self._user_config_ssl_ctx = cb_user_config_ssl_ctx
post_init = self._copy_server() self._intf_ssl_ctx = None
elif isinstance(sock, _UnwrappedSocket): self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None
if isinstance(sock, SSLConnection):
post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket):
post_init = self._reconnect_unwrapped() post_init = self._reconnect_unwrapped()
else: else:
try: try:
peer_address = sock.getpeername() peer_address = sock.getpeername()
except socket.error: except socket.error:
peer_address = None peer_address = None
if server_side: if server_side:
post_init = self._init_server(peer_address) post_init = self._init_server(peer_address)
else: else:
post_init = self._init_client(peer_address) post_init = self._init_client(peer_address)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) if self._user_config_ssl:
self._rbio.disown() self._user_config_ssl(self._intf_ssl)
self._wbio.disown() else:
if post_init: SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
post_init() DTLS_set_link_mtu(self._ssl.value, 1500)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown()
self._wbio.disown()
if post_init:
post_init()
def get_socket(self, inbound): def get_socket(self, inbound):
"""Retrieve a socket used by this connection """Retrieve a socket used by this connection
@ -426,19 +633,25 @@ class SSLConnection(object):
self._ssl.raw) self._ssl.raw)
dtls_peer_address = DTLSv1_listen(self._ssl.value) dtls_peer_address = DTLSv1_listen(self._ssl.value)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ: if err.ssl_error == SSL_ERROR_WANT_READ:
# This method must be called again to forward the next datagram # This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed") _logger.debug("DTLSv1_listen must be resumed")
return return
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Mismatching cookie received; aborting handshake") _logger.debug("Wrong version number; aborting handshake")
return raise
_logger.exception("Unexpected error in DTLSv1_listen") elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
raise _logger.debug("Mismatching cookie received; aborting handshake")
finally: raise
self._listening = False elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER:
self._listening_peer_address = None _logger.debug("No shared cipher; aborting handshake")
if type(peer_address) is tuple: raise
_logger.exception("Unexpected error in DTLSv1_listen")
raise
finally:
self._listening = False
self._listening_peer_address = None
if type(peer_address) is tuple:
_logger.debug("New local peer: %s", dtls_peer_address) _logger.debug("New local peer: %s", dtls_peer_address)
self._pending_peer_address = peer_address self._pending_peer_address = peer_address
else: else:
@ -459,15 +672,17 @@ class SSLConnection(object):
if not self._pending_peer_address: if not self._pending_peer_address:
if not self.listen(): if not self.listen():
_logger.debug("Accept returning without connection") _logger.debug("Accept returning without connection")
return return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True, new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, PROTOCOL_DTLSv1, self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect, self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers) self._suppress_ragged_eofs, self._ciphers,
new_peer = self._pending_peer_address cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
self._pending_peer_address = None cb_user_config_ssl=self._user_config_ssl)
if self._do_handshake_on_connect: new_peer = self._pending_peer_address
self._pending_peer_address = None
if self._do_handshake_on_connect:
# Note that since that connection's socket was just created in its # Note that since that connection's socket was just created in its
# constructor, the following operation must be blocking; hence # constructor, the following operation must be blocking; hence
# handshake-on-connect can only be used with a routing demux if # handshake-on-connect can only be used with a routing demux if
@ -514,40 +729,53 @@ class SSLConnection(object):
self._handshake_done = True self._handshake_done = True
_logger.debug("...completed handshake") _logger.debug("...completed handshake")
def read(self, len=1024): def read(self, len=1024, buffer=None):
"""Read data from connection """Read data from connection
Read up to len bytes and return them. Read up to len bytes and return them.
Arguments: Arguments:
len -- maximum number of bytes to read len -- maximum number of bytes to read
Return value: Return value:
string containing read bytes string containing read bytes
""" """
return self._wrap_socket_library_call( try:
lambda: SSL_read(self._ssl.value, len), ERR_READ_TIMEOUT) return self._wrap_socket_library_call(
lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT)
def write(self, data): except openssl_error() as err:
"""Write data to connection if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
Write data as string of bytes. raise
def write(self, data):
"""Write data to connection
Write data as string of bytes.
Arguments: Arguments:
data -- buffer containing data to be written data -- buffer containing data to be written
Return value: Return value:
number of bytes actually transmitted number of bytes actually transmitted
""" """
return self._wrap_socket_library_call( try:
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) ret = self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
def shutdown(self): except openssl_error() as err:
"""Shut down the DTLS connection if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err)
This method attemps to complete a bidirectional shutdown between raise
peers. For non-blocking sockets, it should be called repeatedly until if ret:
self._handshake_done = True
return ret
def shutdown(self):
"""Shut down the DTLS connection
This method attemps to complete a bidirectional shutdown between
peers. For non-blocking sockets, it should be called repeatedly until
it no longer raises continuation request exceptions. it no longer raises continuation request exceptions.
""" """
@ -598,18 +826,33 @@ class SSLConnection(object):
return return
if binary_form: if binary_form:
return i2d_X509(peer_cert.value) return i2d_X509(peer_cert.value)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
return {} return {}
return decode_cert(peer_cert) return decode_cert(peer_cert)
peer_certificate = getpeercert # compatibility with _ssl call interface peer_certificate = getpeercert # compatibility with _ssl call interface
def cipher(self): def getpeercertchain(self, binary_form=False):
"""Retrieve information about the current cipher try:
stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value)
Return a triple consisting of cipher name, SSL protocol version defining except openssl_error():
its use, and the number of secret bits. Return None if handshaking return
peer_cert_chain = [_Rsrc(cert) for cert in certs]
ret = []
if binary_form:
ret = [i2d_X509(x.value) for x in peer_cert_chain]
elif len(peer_cert_chain):
ret = [decode_cert(x) for x in peer_cert_chain]
return ret
def cipher(self):
"""Retrieve information about the current cipher
Return a triple consisting of cipher name, SSL protocol version defining
its use, and the number of secret bits. Return None if handshaking
has not been completed. has not been completed.
""" """

View File

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBgzCCASoCCQDdMwvUA/R3lzAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU3WhcNMjcwMzA1MDgzNjU3WjBKMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENB
IEluYzERMA8GA1UEAwwIUmF5Q0FJbmMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AASD4xiQkPryjEwUl/GYeGu1CSA3UC6BUY3TiGED3zrC5Bn/POaVVn9GGOQMZUFi
rCkuTgfg/qeIzTrTFndiR5C/MAoGCCqGSM49BAMDA0cAMEQCIHpd9qMvZZV6iaB5
HrmlyfmhIuLBxDQra20Uxl2Y8N64AiAmPKqwPPp7z6IT2AzAXyHCPoVxwWA0NfGx
nmXoYpDFlw==
-----END CERTIFICATE-----

View File

@ -0,0 +1,19 @@
-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEMWCku4TqKwrQdeECm5LQPCBnr7+cqE4InlRYeObLOxoAoGCCqGSM49
AwEHoUQDQgAEgroFe2fym1V7E3zr/zjuJixpyAjwfig+UTsxxm/04IvXzk2jQCQC
TgbDVohJ8dgh4iEENZv2axWye7XCBzbftQ==
-----END EC PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----

View File

@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ==
-----END CERTIFICATE-----

View File

@ -36,18 +36,19 @@ import socket
from os import path from os import path
from logging import basicConfig, DEBUG from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code basicConfig(level=DEBUG) # set now for dtls import code
from dtls.sslconnection import SSLConnection from dtls.sslconnection import SSLConnection
from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN
def main():
sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def main():
sck.bind(("127.0.0.1", 28000)) sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sck.bind(("127.0.0.1", 28000))
sck.settimeout(30) sck.settimeout(30)
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs") cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
scn = SSLConnection( scn = SSLConnection(
sck, sck,
keyfile=path.join(cert_path, "server-key.pem"), keyfile=path.join(cert_path, "keycert.pem"),
certfile=path.join(cert_path, "server-cert.pem"), certfile=path.join(cert_path, "keycert.pem"),
server_side=True, server_side=True,
ca_certs=path.join(cert_path, "ca-cert.pem"), ca_certs=path.join(cert_path, "ca-cert.pem"),
do_handshake_on_connect=False) do_handshake_on_connect=False)
@ -76,7 +77,7 @@ def main():
try: try:
conn.do_handshake() conn.do_handshake()
except SSLError as err: except SSLError as err:
if str(err).startswith("504:"): if err.errno == 504:
continue continue
raise raise
print "Completed handshaking with peer" print "Completed handshaking with peer"
@ -92,7 +93,7 @@ def main():
try: try:
message = conn.read() message = conn.read()
except SSLError as err: except SSLError as err:
if str(err).startswith("502:"): if err.errno == 502:
continue continue
if err.args[0] == SSL_ERROR_ZERO_RETURN: if err.args[0] == SSL_ERROR_ZERO_RETURN:
break break
@ -111,7 +112,7 @@ def main():
s = conn.shutdown() s = conn.shutdown()
s.shutdown(socket.SHUT_RDWR) s.shutdown(socket.SHUT_RDWR)
except SSLError as err: except SSLError as err:
if str(err).startswith("502:"): if err.errno == 502:
continue continue
raise raise
break break

View File

@ -0,0 +1,24 @@
@echo off
set RANDFILE=.rnd
rem # Generate self-signed certificate for the certificate authority
echo Generating CA...
openssl ecparam -name prime256v1 -genkey -out tmp_ca_ec.key
openssl req -config "openssl_ca.cnf" -x509 -new -SHA384 -nodes -key tmp_ca_ec.key -days 3650 -out ca-cert_ec.pem
rem # Generate a certificate request
echo Generating certificate request...
openssl ecparam -name prime256v1 -genkey -out tmp_server_ec.key
openssl req -config "openssl_server.cnf" -new -SHA384 -nodes -key tmp_server_ec.key -out tmp_server_ec.req
rem # Sign the request with the certificate authority's certificate created above
echo Signing certificate request...
openssl req -in tmp_server_ec.req -noout -text
openssl x509 -req -SHA384 -days 3650 -in tmp_server_ec.req -CA ca-cert_ec.pem -CAkey tmp_ca_ec.key -CAcreateserial -out server-cert_ec.pem
rem # Build pem file with private and public keys, ready for unprompted server use
cat tmp_server_ec.key server-cert_ec.pem > keycert_ec.pem
rem # Clean up
rm tmp_ca_ec.key tmp_server_ec.key tmp_server_ec.req ca-cert_ec.srl

View File

@ -1,6 +1,7 @@
RANDFILE = $ENV::HOME/.rnd HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
[ req ]
distinguished_name = req_distinguished_name distinguished_name = req_distinguished_name
prompt = no prompt = no

View File

@ -1,6 +1,7 @@
RANDFILE = $ENV::HOME/.rnd HOME = .
RANDFILE = $ENV::HOME/.rnd
[ req ]
[ req ]
distinguished_name = req_distinguished_name distinguished_name = req_distinguished_name
prompt = no prompt = no

View File

@ -0,0 +1,15 @@
from os import path
import ssl
from socket import socket, AF_INET, SOCK_DGRAM, SHUT_RDWR
from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code
from dtls import do_patch
do_patch()
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
sock = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), cert_reqs=ssl.CERT_REQUIRED, ca_certs=path.join(cert_path, "ca-cert.pem"))
sock.connect(('localhost', 28000))
sock.send('Hi there')
print sock.recv()
sock.unwrap()
sock.shutdown(SHUT_RDWR)

View File

@ -78,11 +78,12 @@ class BasicSocketTests(unittest.TestCase):
def test_constants(self): def test_constants(self):
ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_SSLv3 ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_TLSv1 ssl.PROTOCOL_DTLSv1 # added
ssl.PROTOCOL_DTLSv1 # added ssl.PROTOCOL_DTLSv1_2 # added
ssl.CERT_NONE ssl.PROTOCOL_DTLS # added
ssl.CERT_OPTIONAL ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED ssl.CERT_REQUIRED
def test_dtls_openssl_version(self): def test_dtls_openssl_version(self):
@ -90,22 +91,22 @@ class BasicSocketTests(unittest.TestCase):
t = ssl.DTLS_OPENSSL_VERSION_INFO t = ssl.DTLS_OPENSSL_VERSION_INFO
s = ssl.DTLS_OPENSSL_VERSION s = ssl.DTLS_OPENSSL_VERSION
self.assertIsInstance(n, (int, long)) self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple) self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str) self.assertIsInstance(s, str)
# Some sanity checks follow # Some sanity checks follow
# >= 1.0 # >= 1.0.2
self.assertGreaterEqual(n, 0x10000000) self.assertGreaterEqual(n, 0x10002000)
# < 2.0 # < 2.0
self.assertLess(n, 0x20000000) self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 1) self.assertGreaterEqual(major, 1)
self.assertLess(major, 2) self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0) self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256) self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 0) self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256) self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0) self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26) self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0) self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15) self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change # Version string as returned by OpenSSL, the format might change
@ -298,35 +299,37 @@ class NetworkedTests(unittest.TestCase):
s.close() s.close()
if test_support.verbose: if test_support.verbose:
sys.stdout.write(("\nNeeded %d calls to do_handshake() " + sys.stdout.write(("\nNeeded %d calls to do_handshake() " +
"to establish session.\n") % count) "to establish session.\n") % count)
def test_get_server_certificate(self): def test_get_server_certificate(self):
with test_support.transient_internet() as remote: for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) with test_support.transient_internet() as remote:
if not pem: pem = ssl.get_server_certificate(remote,
self.fail("No server certificate!") prot)
if not pem:
try: self.fail("No server certificate!")
pem = ssl.get_server_certificate(remote,
ssl.PROTOCOL_DTLSv1, try:
ca_certs=OTHER_CERTFILE) pem = ssl.get_server_certificate(remote,
except ssl.SSLError: prot,
#should fail ca_certs=OTHER_CERTFILE)
pass except ssl.SSLError:
else: # should fail
self.fail("Got server certificate %s!" % pem) pass
else:
pem = ssl.get_server_certificate(remote, self.fail("Got server certificate %s!" % pem)
ssl.PROTOCOL_DTLSv1,
ca_certs=ISSUER_CERTFILE) pem = ssl.get_server_certificate(remote,
if not pem: prot,
self.fail("No server certificate!") ca_certs=ISSUER_CERTFILE)
if test_support.verbose: if not pem:
sys.stdout.write("\nVerified certificate is\n%s\n" % pem) self.fail("No server certificate!")
if test_support.verbose:
class ThreadedEchoServer(threading.Thread): sys.stdout.write("\nVerified certificate is\n%s\n" % pem)
class ConnectionHandler(threading.Thread): class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread):
"""A mildly complicated class, because we want it to work both """A mildly complicated class, because we want it to work both
with and without the SSL wrapper around the socket connection, so with and without the SSL wrapper around the socket connection, so
@ -529,18 +532,20 @@ class ThreadedEchoServer(threading.Thread):
if acc_ret: if acc_ret:
newconn, connaddr = acc_ret newconn, connaddr = acc_ret
if test_support.verbose and self.chatty: if test_support.verbose and self.chatty:
sys.stdout.write(' server: new connection from ' sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n') + str(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn) handler = self.ConnectionHandler(self, newconn)
handler.start() handler.start()
except socket.timeout: except socket.timeout:
pass pass
except KeyboardInterrupt: except ssl.SSLError:
self.stop() pass
self.sock.close() except KeyboardInterrupt:
self.stop()
def register_handler(self, add): self.sock.close()
with self.num_handlers_lock:
def register_handler(self, add):
with self.num_handlers_lock:
if add: if add:
self.num_handlers += 1 self.num_handlers += 1
else: else:
@ -574,6 +579,9 @@ class AsyncoreEchoServer(threading.Thread):
# Complete the handshake # Complete the handshake
self.handle_read_event() self.handle_read_event()
def __hash__(self):
return hash(self.socket)
def readable(self): def readable(self):
while self.socket.pending() > 0: while self.socket.pending() > 0:
self.handle_read_event() self.handle_read_event()
@ -1034,17 +1042,40 @@ class ThreadedTests(unittest.TestCase):
"certs", "badkey.pem")) "certs", "badkey.pem"))
def test_protocol_dtlsv1(self): def test_protocol_dtlsv1(self):
"""Connecting to a DTLSv1 server with various client options""" """Connecting to a DTLSv1 server with various client options"""
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) # server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, ssl.CERT_OPTIONAL)
ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
def test_starttls(self): # server: any - client: 1.0 and 1.2(any) -> ok
"""Switching from clear text to encrypted and back again.""" try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS",
"msg 5", "msg 6") "msg 5", "msg 6")
@ -1057,13 +1088,13 @@ class ThreadedTests(unittest.TestCase):
server.start(flag) server.start(flag)
# wait for it to start # wait for it to start
flag.wait() flag.wait()
# try to connect # try to connect
wrapped = False wrapped = False
try: try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM)) s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
s = s.unwrap() s = s.unwrap()
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
for indata in msgs: for indata in msgs:
if test_support.verbose: if test_support.verbose:

View File

@ -0,0 +1,654 @@
# -*- encoding: utf-8 -*-
# Test the support for DTLS through the SSL module. Adapted from the Python
# standard library's test_ssl.py regression test module by Björn Freise.
import unittest
import threading
import sys
import socket
import os
import pprint
from logging import basicConfig, DEBUG, getLogger
# basicConfig(level=DEBUG, format="%(asctime)s - %(threadName)-10s - %(name)s - %(levelname)s - %(message)s")
_logger = getLogger(__name__)
import ssl
from dtls import do_patch, error_codes
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
HOST = "localhost"
CHATTY = True
CHATTY_CLIENT = True
class ThreadedEchoServer(threading.Thread):
def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None,
ciphers=None, curves=None, sigalgs=None,
mtu=None, server_key_exchange_curve=None, server_cert_options=None,
chatty=True):
if ssl_version is None:
ssl_version = ssl.PROTOCOL_DTLSv1
if certreqs is None:
certreqs = ssl.CERT_NONE
self.certificate = certificate
self.protocol = ssl_version
self.certreqs = certreqs
self.cacerts = cacerts
self.ciphers = ciphers
self.curves = curves
self.sigalgs = sigalgs
self.mtu = mtu
self.server_key_exchange_curve = server_key_exchange_curve
self.server_cert_options = server_cert_options
self.chatty = chatty
self.flag = None
self.sock = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=self.certificate,
certfile=self.certificate,
server_side=True,
cert_reqs=self.certreqs,
ssl_version=self.protocol,
ca_certs=self.cacerts,
ciphers=self.ciphers,
curves=self.curves,
sigalgs=self.sigalgs,
user_mtu=self.mtu,
server_key_exchange_curve=self.server_key_exchange_curve,
server_cert_options=self.server_cert_options)
if self.chatty:
sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
self.sock.bind((HOST, 0))
self.port = self.sock.getsockname()[1]
self.active = False
threading.Thread.__init__(self)
self.daemon = True
def start(self, flag=None):
self.flag = flag
self.starter = threading.current_thread().ident
threading.Thread.start(self)
def run(self):
self.sock.settimeout(0.05)
self.sock.listen(0)
self.active = True
if self.flag:
# signal an event
self.flag.set()
while self.active:
try:
acc_ret = self.sock.recvfrom(4096)
if acc_ret:
newdata, connaddr = acc_ret
if self.chatty:
sys.stdout.write(' server: new data from ' + str(connaddr) + '\n')
self.sock.sendto(newdata.lower(), connaddr)
except socket.timeout:
pass
except KeyboardInterrupt:
self.stop()
except Exception as e:
if self.chatty:
sys.stdout.write(' server: error ' + str(e) + '\n')
pass
if self.chatty:
sys.stdout.write(' server: closing socket as %s\n' % str(self.sock))
self.sock.close()
def stop(self):
self.active = False
if self.starter != threading.current_thread().ident:
return
self.join() # don't allow spawning new handlers after we've checked
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert.pem")
CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "keycert_ec.pem")
ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert.pem")
ISSUER_CERTFILE_EC = os.path.join(os.path.dirname(__file__) or os.curdir, "certs", "ca-cert_ec.pem")
# certfile, protocol, certreqs, cacertsfile,
# ciphers=None, curves=None, sigalgs=None,
tests = [
{'testcase':
{'name': 'standard dtls v1',
'desc': 'Standard DTLS v1 test with out-of-the box configuration and RSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE,
'protocol': ssl.PROTOCOL_DTLSv1,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'standard dtls v1_2',
'desc': 'Standard DTLS v1_2 test with out-of-the box configuration and ECDSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'protocol version mismatch',
'desc': 'Client and server have different protocol versions',
'start_server': True},
'input':
{'certfile': CERTFILE,
'protocol': ssl.PROTOCOL_DTLSv1,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_WRONG_SSL_VERSION,
'exception': None}},
{'testcase':
{'name': 'certificate verify fails',
'desc': 'Server certificate cannot be verified by client',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED,
'exception': None}},
{'testcase':
{'name': 'no matching curve',
'desc': 'Client doesn\'t support curve used by server ECDSA certificate',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': 'secp384r1',
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching curve',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': 'prime256v1',
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'no host',
'desc': 'No server port is listening',
'start_server': False},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_PORT_UNREACHABLE,
'exception': None}},
{'testcase':
{'name': 'no matching sigalgs',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': "RSA+SHA256"},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching sigalgs',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': None,
'client_curves': None,
'client_sigalgs': "ECDSA+SHA256"},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
{'testcase':
{'name': 'no matching cipher',
'desc': 'Server using a ECDSA certificate while client is only able to use RSA encryption',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': "AES256-SHA",
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': False,
'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
'exception': None}},
{'testcase':
{'name': 'matching cipher',
'desc': '',
'start_server': True},
'input':
{'certfile': CERTFILE_EC,
'protocol': ssl.PROTOCOL_DTLSv1_2,
'certreqs': None,
'cacertsfile': ISSUER_CERTFILE_EC,
'ciphers': None,
'curves': None,
'sigalgs': None,
'client_certfile': None,
'client_protocol': ssl.PROTOCOL_DTLSv1_2,
'client_certreqs': ssl.CERT_REQUIRED,
'client_cacertsfile': ISSUER_CERTFILE_EC,
'client_ciphers': "ECDHE-ECDSA-AES256-SHA",
'client_curves': None,
'client_sigalgs': None},
'result':
{'ret_success': True,
'error_code': None,
'exception': None}},
]
def params_test(start_server, certfile, protocol, certreqs, cacertsfile,
client_certfile=None, client_protocol=None, client_certreqs=None, client_cacertsfile=None,
ciphers=None, curves=None, sigalgs=None,
client_ciphers=None, client_curves=None, client_sigalgs=None,
mtu=1500, server_key_exchange_curve=None, server_cert_options=None,
indata="FOO\n", chatty=False, connectionchatty=False):
"""
Launch a server, connect a client to it and try various reads
and writes.
"""
server = ThreadedEchoServer(certfile,
ssl_version=protocol,
certreqs=certreqs,
cacerts=cacertsfile,
ciphers=ciphers,
curves=curves,
sigalgs=sigalgs,
mtu=mtu,
server_key_exchange_curve=server_key_exchange_curve,
server_cert_options=server_cert_options,
chatty=chatty)
# should we really run the server?
if start_server:
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
else:
server.sock.close()
# try to connect
if client_protocol is None:
client_protocol = protocol
if client_ciphers is None:
client_ciphers = ciphers
if client_curves is None:
client_curves = curves
if client_sigalgs is None:
client_sigalgs = sigalgs
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=client_certfile,
certfile=client_certfile,
cert_reqs=client_certreqs,
ssl_version=client_protocol,
ca_certs=client_cacertsfile,
ciphers=client_ciphers,
curves=client_curves,
sigalgs=client_sigalgs,
user_mtu=mtu)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
cert = s.getpeercert()
cipher = s.cipher()
if connectionchatty:
sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
return False, e
finally:
if start_server:
server.stop()
return True, None
class TestSequenceMeta(type):
def __new__(mcs, name, bases, dict):
def gen_test(_case, _input, _result):
def test(self):
try:
if CHATTY or CHATTY_CLIENT:
sys.stdout.write("\nTestcase: %s\n" % _case['name'])
ret, e = params_test(_case['start_server'], chatty=CHATTY, connectionchatty=CHATTY_CLIENT, **_input)
if _result['ret_success']:
self.assertEqual(ret, _result['ret_success'])
else:
try:
last_error = e.errqueue[-1][0]
except:
try:
last_error = e.errno
except:
last_error = None
self.assertEqual(last_error, _result['error_code'])
except Exception as e:
raise
return test
for testcase in tests:
_case, _input, _result = testcase.itervalues()
test_name = "test_%s" % _case['name'].lower().replace(' ', '_')
dict[test_name] = gen_test(_case, _input, _result)
return type.__new__(mcs, name, bases, dict)
class WrapperTests(unittest.TestCase):
__metaclass__ = TestSequenceMeta
def setUp(self):
super(WrapperTests, self).setUp()
do_patch()
def test_build_cert_chain(self):
steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT]
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO'
certs = dict()
if chatty or connectionchatty:
sys.stdout.write("\nTestcase: test_build_cert_chain\n")
for step in steps:
server = ThreadedEchoServer(certificate=CERTFILE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
certreqs=ssl.CERT_NONE,
cacerts=ISSUER_CERTFILE,
ciphers=None,
curves=None,
sigalgs=None,
mtu=1500,
server_key_exchange_curve=None,
server_cert_options=step,
chatty=chatty)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=None,
certfile=None,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=ISSUER_CERTFILE,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=1500)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
# cert = s.getpeercert()
# cipher = s.cipher()
# if connectionchatty:
# sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
# sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
certs[step] = s.getpeercertchain()
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
raise
finally:
server.stop()
if chatty:
sys.stdout.write("certs:\n")
for step in steps:
sys.stdout.write("SSL_CTX_build_cert_chain: %s\n%s\n" % (step, pprint.pformat(certs[step])))
self.assertNotEqual(certs[steps[0]], certs[steps[1]])
self.assertEqual(len(certs[steps[0]]) - len(certs[steps[1]]), 1)
def test_set_ecdh_curve(self):
steps = {
# server, client, result
'all auto': (None, None, True), # Auto
'client restricted': (None, "secp256k1:prime256v1", True), # client can handle key curve
'client too restricted': (None, "secp256k1", False), # client _cannot_ handle key curve
'client minimum': (None, "prime256v1", True), # client can only handle key curve
'server restricted': ("secp384r1", None, True), # client can handle key curve
'server one, client two': ("secp384r1", "prime256v1:secp384r1", True), # client can handle key curve
'server one, client one': ("secp384r1", "secp384r1", False), # client _cannot_ handle key curve
}
chatty, connectionchatty = CHATTY, CHATTY_CLIENT
indata = 'FOO'
certs = dict()
if chatty or connectionchatty:
sys.stdout.write("\nTestcase: test_ecdh_curve\n")
for step, tmp in steps.iteritems():
if chatty or connectionchatty:
sys.stdout.write("\n Subcase: %s\n" % step)
server_curve, client_curve, result = tmp
server = ThreadedEchoServer(certificate=CERTFILE_EC,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
certreqs=ssl.CERT_NONE,
cacerts=ISSUER_CERTFILE_EC,
ciphers=None,
curves=None,
sigalgs=None,
mtu=1500,
server_key_exchange_curve=server_curve,
server_cert_options=None,
chatty=chatty)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
try:
s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=None,
certfile=None,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=ISSUER_CERTFILE_EC,
ciphers=None,
curves=client_curve,
sigalgs=None,
user_mtu=1500)
s.connect((HOST, server.port))
if connectionchatty:
sys.stdout.write(" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if connectionchatty:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata), 20)], len(outdata),
indata[:min(len(indata), 20)].lower(), len(indata)))
if connectionchatty:
sys.stdout.write(" client: closing connection.\n")
try:
s.close()
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: error closing connection %s...\n" % (repr(e)))
pass
except Exception as e:
if connectionchatty:
sys.stdout.write(" client: aborting with exception %s...\n" % (repr(e)))
if result:
raise
finally:
server.stop()
pass
if __name__ == '__main__':
unittest.main()

View File

@ -54,6 +54,18 @@ class _BIO(_Rsrc):
if self.owned: if self.owned:
_logger.debug("Freeing BIO: %d", self.raw) _logger.debug("Freeing BIO: %d", self.raw)
from openssl import BIO_free from openssl import BIO_free
BIO_free(self._value) BIO_free(self._value)
self.owned = False self.owned = False
self._value = None self._value = None
class _EC_KEY(_Rsrc):
"""EC KEY wrapper"""
def __init__(self, value):
super(_EC_KEY, self).__init__(value)
def __del__(self):
_logger.debug("Freeing EC_KEY: %d", self.raw)
from openssl import EC_KEY_free
EC_KEY_free(self._value)
self._value = None

361
dtls/wrapper.py 100644
View File

@ -0,0 +1,361 @@
# -*- encoding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# The License is also distributed with this work in the file named "LICENSE."
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS
protocol.
Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server
"""
import select
from logging import getLogger
import ssl
import socket
from patch import do_patch
do_patch()
from sslconnection import SSLContext, SSL
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
_logger = getLogger(__name__)
class DtlsSocket(object):
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
def __init__(self,
peerOrSock,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None,
server_key_exchange_curve=None,
server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None:
server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False
self._peer = None
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options
# Default socket creation
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(peerOrSock, tuple):
# Address tuple
self._peer = peerOrSock
else:
# Socket, use given
sock = peerOrSock
self._sock = ssl.wrap_socket(sock,
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl)
if self._server_side:
self._clients = {}
self._timeout = None
if self._peer:
self._sock.bind(self._peer)
self._sock.listen(0)
else:
if self._peer:
self._sock.connect(self._peer)
def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item)
raise AttributeError
def user_config_ssl_ctx(self, _ctx):
"""
:param SSLContext _ctx:
"""
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl):
"""
:param SSL _ssl:
"""
if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
self._sock.close()
def write(self, data):
# return self._sock.write(data)
return self.sendto(data, self._peer)
def read(self, len=1024):
# return self._sock.read(len=len)
return self.recvfrom(len)[0]
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
try:
for conn in r: # type: ssl.SSLSocket
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
except Exception as e:
raise e
try:
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
except Exception as e:
raise e
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise e_ssl
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
else:
if buf:
return buf, self._peer
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
raise e_ssl
else:
if bytes_sent:
break
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
raise e_accept
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
_data = data
if False:
_data = data.raw
ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise e_write
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
conn.close()
except Exception as e_drop:
pass

View File

@ -40,24 +40,23 @@ _logger = getLogger(__name__)
class _X509(_Rsrc): class _X509(_Rsrc):
"""Wrapper for the cryptographic library's X509 resource""" """Wrapper for the cryptographic library's X509 resource"""
def __init__(self, value): def __init__(self, value):
super(_X509, self).__init__(value) super(_X509, self).__init__(value)
def __del__(self): def __del__(self):
_logger.debug("Freeing X509: %d", self._value._as_parameter) _logger.debug("Freeing X509: %d", self.raw)
X509_free(self._value) X509_free(self._value)
self._value = None self._value = None
class _STACK(_Rsrc): class _STACK(_Rsrc):
"""Wrapper for the cryptographic library's stacks""" """Wrapper for the cryptographic library's stacks"""
def __init__(self, value): def __init__(self, value):
super(_STACK, self).__init__(value) super(_STACK, self).__init__(value)
def __del__(self): def __del__(self):
_logger.debug("Freeing stack: %d", self._value._as_parameter) _logger.debug("Freeing stack: %d", self.raw)
sk_pop_free(self._value) sk_pop_free(self._value)
self._value = None self._value = None
def decode_cert(cert): def decode_cert(cert):
"""Convert an X509 certificate into a Python dictionary """Convert an X509 certificate into a Python dictionary

View File

@ -33,7 +33,7 @@ for scheme in INSTALL_SCHEMES.values():
scheme['data'] = scheme['purelib'] scheme['data'] = scheme['purelib']
NAME = "Dtls" NAME = "Dtls"
VERSION = "1.0.1" VERSION = "1.0.2"
DIST_DIR = "dist" DIST_DIR = "dist"
FORMAT_TO_SUFFIX = { "zip": ".zip", FORMAT_TO_SUFFIX = { "zip": ".zip",