diff --git a/dtls/openssl.py b/dtls/openssl.py index 1331a01..1e26006 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -84,6 +84,7 @@ BIO_NOCLOSE = 0x00 BIO_CLOSE = 0x01 SSLEAY_VERSION = 0 SSL_OP_NO_COMPRESSION = 0x00020000 +SSL_OP_NO_QUERY_MTU = 0x00001000 SSL_VERIFY_NONE = 0x00 SSL_VERIFY_PEER = 0x01 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 @@ -97,17 +98,58 @@ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100 SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200 SSL_SESS_CACHE_NO_INTERNAL = \ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE +SSL_BUILD_CHAIN_FLAG_UNTRUSTED = 0x1 +SSL_BUILD_CHAIN_FLAG_NO_ROOT = 0x2 +SSL_BUILD_CHAIN_FLAG_CHECK = 0x4 +SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = 0x8 +SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = 0x10 SSL_FILE_TYPE_PEM = 1 GEN_DIRNAME = 4 NID_subject_alt_name = 85 CRYPTO_LOCK = 1 +SSL_ST_MASK = 0x0FFF +SSL_ST_CONNECT = 0x1000 +SSL_ST_ACCEPT = 0x2000 +SSL_ST_INIT = (SSL_ST_CONNECT | SSL_ST_ACCEPT) +SSL_ST_BEFORE = 0x4000 +SSL_ST_OK = 0x03 +SSL_ST_RENEGOTIATE = (0x04 | SSL_ST_INIT) +SSL_ST_ERR = 0x05 + +SSL_CB_LOOP = 0x01 +SSL_CB_EXIT = 0x02 +SSL_CB_READ = 0x04 +SSL_CB_WRITE = 0x08 +SSL_CB_ALERT = 0x4000 +SSL_CB_READ_ALERT = (SSL_CB_ALERT | SSL_CB_READ) +SSL_CB_WRITE_ALERT = (SSL_CB_ALERT | SSL_CB_WRITE) +SSL_CB_ACCEPT_LOOP = (SSL_ST_ACCEPT | SSL_CB_LOOP) +SSL_CB_ACCEPT_EXIT = (SSL_ST_ACCEPT | SSL_CB_EXIT) +SSL_CB_CONNECT_LOOP = (SSL_ST_CONNECT | SSL_CB_LOOP) +SSL_CB_CONNECT_EXIT = (SSL_ST_CONNECT | SSL_CB_EXIT) +SSL_CB_HANDSHAKE_START = 0x10 +SSL_CB_HANDSHAKE_DONE = 0x20 + # # Integer constants - internal # SSL_CTRL_SET_SESS_CACHE_MODE = 44 SSL_CTRL_SET_READ_AHEAD = 41 SSL_CTRL_OPTIONS = 32 +SSL_CTRL_CLEAR_OPTIONS = 77 +SSL_CTRL_SET_ECDH_AUTO = 94 +SSL_CTRL_BUILD_CERT_CHAIN = 105 +SSL_CTRL_SET_MTU = 17 +SSL_CTRL_GET_CURVES = 90 +SSL_CTRL_SET_CURVES = 91 +SSL_CTRL_SET_CURVES_LIST = 92 +SSL_CTRL_GET_SHARED_CURVE = 93 +SSL_CTRL_SET_SIGALGS = 97 +SSL_CTRL_SET_SIGALGS_LIST = 98 +SSL_CTRL_SET_CLIENT_SIGALGS = 101 +SSL_CTRL_SET_CLIENT_SIGALGS_LIST = 102 +SSL_CTRL_SET_TMP_ECDH = 4 BIO_CTRL_INFO = 3 BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 @@ -119,6 +161,53 @@ DTLS_CTRL_LISTEN = 75 X509_NAME_MAXLEN = 256 GETS_MAXLEN = 2048 + +class _EllipticCurve(object): + _curves = None + + @classmethod + def _load_elliptic_curves(cls, lib): + num_curves = lib.EC_get_builtin_curves(0, 0) + if num_curves > 0: + builtin_curves = create_string_buffer(sizeof(EC_builtin_curve) * num_curves) + lib.EC_get_builtin_curves(builtin_curves, num_curves) + return set(cls.from_nid(lib, c.nid) for c in cast(builtin_curves, POINTER(EC_builtin_curve))[:num_curves]) + return set() + + @classmethod + def _get_elliptic_curves(cls, lib): + if cls._curves is None: + cls._curves = cls._load_elliptic_curves(lib) + return cls._curves + + @classmethod + def from_nid(cls, lib, nid): + return cls(lib, nid, cast(lib.OBJ_nid2sn(nid), c_char_p).value.decode("ascii")) + + def __init__(self, lib, nid, name): + self._lib = lib + self._nid = nid + self.name = name + + def __repr__(self): + return "" % (self.name,) + + def as_EC_KEY(self): + key = self._lib.EC_KEY_new_by_curve_name(self._nid) + return key + + +def get_elliptic_curves(): + return _EllipticCurve._get_elliptic_curves(libcrypto) + + +def get_elliptic_curve(name): + for curve in get_elliptic_curves(): + if curve.name == name: + return curve + raise ValueError("unknown curve name", name) + + # # Parameter data types # @@ -171,6 +260,11 @@ class SSL(FuncParam): super(SSL, self).__init__(value) +class EC_builtin_curve(Structure): + _fields_ = [("nid", c_int), + ("comment", c_char_p)] + + class BIO(FuncParam): def __init__(self, value): super(BIO, self).__init__(value) @@ -472,14 +566,22 @@ _subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSLEAY_VERSION", - "SSL_OP_NO_COMPRESSION", + "SSL_OP_NO_COMPRESSION", "SSL_OP_NO_QUERY_MTU", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", "SSL_SESS_CACHE_SERVER", "SSL_SESS_CACHE_BOTH", "SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", + "SSL_ST_MASK", "SSL_ST_CONNECT", "SSL_ST_ACCEPT", "SSL_ST_INIT", "SSL_ST_BEFORE", "SSL_ST_OK", + "SSL_ST_RENEGOTIATE", "SSL_ST_ERR", "SSL_CB_LOOP", "SSL_CB_EXIT", "SSL_CB_READ", "SSL_CB_WRITE", + "SSL_CB_ALERT", "SSL_CB_READ_ALERT", "SSL_CB_WRITE_ALERT", + "SSL_CB_ACCEPT_LOOP", "SSL_CB_ACCEPT_EXIT", + "SSL_CB_CONNECT_LOOP", "SSL_CB_CONNECT_EXIT", + "SSL_CB_HANDSHAKE_START", "SSL_CB_HANDSHAKE_DONE", "SSL_FILE_TYPE_PEM", + "SSL_BUILD_CHAIN_FLAG_NO_ROOT", + "SSL_state_string_long", "GEN_DIRNAME", "NID_subject_alt_name", "CRYPTO_LOCK", "CRYPTO_set_locking_callback", @@ -489,15 +591,28 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "BIO_dgram_set_connected", "BIO_dgram_get_peer", "BIO_dgram_set_peer", "BIO_set_nbio", + "SSL_CTX_set_ecdh_auto", + "SSL_CTX_build_cert_chain", "SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead", "SSL_CTX_set_options", + "SSL_set_options", "SSL_clear_options", "SSL_get_options", + "SSL_get1_curves", "SSL_CTX_set1_curves", "SSL_CTX_set1_curves_list", "SSL_set1_curves_list", + "SSL_CTX_set_tmp_ecdh", + "SSL_CTX_set_info_callback", + "SSL_set_mtu", "SSL_read", "SSL_write", "SSL_CTX_set_cookie_cb", + "SSL_set1_client_sigalgs_list", "SSL_set1_client_sigalgs", + "SSL_CTX_set1_client_sigalgs_list", "SSL_CTX_set1_client_sigalgs", + "SSL_set1_sigalgs_list", "SSL_set1_sigalgs", + "SSL_CTX_set1_sigalgs_list", "SSL_CTX_set1_sigalgs", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", "X509_get_notAfter", "ASN1_item_d2i", "GENERAL_NAME_print", + "EC_KEY_free", "sk_value", "sk_pop_free", + "get_elliptic_curves", "i2d_X509"] # note: the following map adds to this list map(lambda x: _make_function(*x), ( @@ -511,6 +626,9 @@ map(lambda x: _make_function(*x), ( ("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), + ("DTLSv1_2_client_method", libssl, ((DTLSv1Method, "ret"),)), + ("DTLSv1_2_server_method", libssl, ((DTLSv1Method, "ret"),)), + ("DTLS_server_method", libssl, ((DTLSv1Method, "ret"),)), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), ("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))), ("SSL_CTX_set_cookie_generate_cb", libssl, @@ -563,6 +681,8 @@ map(lambda x: _make_function(*x), ( ("SSL_CTX_set_verify", libssl, ((None, "ret"), (SSLCTX, "ctx"), (c_int, "mode"), (c_void_p, "verify_callback", 1, None))), + ("SSL_CTX_set_verify_depth", libssl, + ((None, "ret"), (SSLCTX, "ctx"), (c_int, "depth"))), ("SSL_accept", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_connect", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))), @@ -630,6 +750,22 @@ map(lambda x: _make_function(*x), ( ("SSL_CIPHER_get_bits", libssl, ((c_int, "ret"), (SSL_CIPHER, "cipher"), (POINTER(c_int), "alg_bits", 1, None)), True, None), + ("EC_get_builtin_curves", libcrypto, + ((c_int, "ret"), (POINTER(EC_builtin_curve), "r"), (c_int, "nitems"))), + ("EC_KEY_new_by_curve_name", libcrypto, + ((POINTER(c_char), "ret"), (c_int, "nid"))), + ("EC_KEY_free", libcrypto, + ((None, "ret"), (POINTER(c_char), "key")), False), + ("OBJ_nid2sn", libcrypto, + ((c_char_p, "ret"), (c_int, "n"))), + ("EC_curve_nist2nid", libcrypto, + ((c_int, "ret"), (POINTER(c_char), "name")), True, None), + ("EC_curve_nid2nist", libcrypto, + ((c_char_p, "ret"), (c_int, "nid")), True, None), + ("SSL_CTX_set_info_callback", libssl, + ((None, "ret"), (SSLCTX, "ctx"), (c_void_p, "callback")), False), + ("SSL_state_string_long", libssl, + ((c_char_p, "ret"), (SSL, "ssl")), False), )) # @@ -648,6 +784,15 @@ def CRYPTO_set_locking_callback(locking_function): _locking_cb = _rvoid_int_int_charp_int(py_locking_function) _CRYPTO_set_locking_callback(_locking_cb) +def SSL_set_mtu(ssl, mtu): + return _SSL_ctrl(ssl, SSL_CTRL_SET_MTU, mtu, None) + +def SSL_CTX_build_cert_chain(ctx, flags): + return _SSL_CTX_ctrl(ctx, SSL_CTRL_BUILD_CERT_CHAIN, flags, None) + +def SSL_CTX_set_ecdh_auto(ctx, onoff): + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, None) + def SSL_CTX_set_session_cache_mode(ctx, mode): # Returns the previous value of mode _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SESS_CACHE_MODE, mode, None) @@ -658,7 +803,47 @@ def SSL_CTX_set_read_ahead(ctx, m): def SSL_CTX_set_options(ctx, options): # Returns the new option bitmaks after adding the given options - _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None) + +def SSL_set_options(ssl, op): + return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, op, None) + +def SSL_clear_options(ssl, op): + return _SSL_ctrl(ssl, SSL_CTRL_CLEAR_OPTIONS, op, None) + +def SSL_get_options(ssl): + return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, 0, None) + +def SSL_get1_curves(ssl, s): + _s = cast(s, POINTER(c_int)) + if s: + mem = None + cnt = SSL_get1_curves(ssl, 0) + if cnt >= s: + mem = create_string_buffer(sizeof(POINTER(c_int)) * s) + ret = _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, s, mem) + return [x for x in cast(mem, POINTER(c_int))[:s]] + else: + return _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, 0, _s) + +def SSL_CTX_set1_curves(ctx, clist, clistlen): + _curves = (c_int * len(clist))(*clist) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES, len(_curves), _curves) + +def SSL_get_shared_curve(ssl, n): + return _SSL_ctrl(ssl, SSL_CTRL_GET_SHARED_CURVE, n, 0) + +def SSL_CTX_set1_curves_list(ctx, s): + _s = cast(s, POINTER(c_char)) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES_LIST, 0, _s) + +def SSL_set1_curves_list(ssl, s): + _s = cast(s, POINTER(c_char)) + return _SSL_ctrl(ssl, SSL_CTRL_SET_CURVES_LIST, 0, _s) + +def SSL_CTX_set_tmp_ecdh(ctx, ecdh): + # return 1 on success and 0 on failure + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, ecdh) _rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) @@ -745,6 +930,8 @@ def SSL_write(ssl, data): str_data = data elif hasattr(data, "tobytes") and callable(data.tobytes): str_data = data.tobytes() + elif isinstance(data, ctypes.Array): + str_data = data.raw else: str_data = str(data) return _SSL_write(ssl, str_data, len(str_data)) @@ -827,3 +1014,73 @@ def i2d_X509(x509): bio = _BIO(BIO_new(BIO_s_mem())) _i2d_X509_bio(bio.value, x509) return BIO_get_mem_data(bio.value) + +def EC_KEY_free(key): + _EC_KEY_free(cast(key, POINTER(c_char))) + +_rvoid_voidp_int_int = CFUNCTYPE(None, c_void_p, c_int, c_int) + +def SSL_CTX_set_info_callback(ctx, callback): + """ + Set the info callback + + :param callback: The Python callback to use + :return: None + """ + def py_info_callback(ssl, where, ret): + try: + callback(SSL(ssl), where, ret) + except: + pass + return + + global _info_callback + _info_callback = _rvoid_voidp_int_int(py_info_callback) + _SSL_CTX_set_info_callback(ctx, _info_callback) + + +def SSL_state_string_long(ssl): + try: + ret = _SSL_state_string_long(ssl) + except: + pass + return ret + +# sigalgs_list: (Only for DTLS v1.2) +# +# The short or long name values for digests can be used in a string (for example "MD5", "SHA1", "SHA224", "SHA256", "SHA384", "SHA512") +# and the public key algorithm strings "RSA", "DSA" or "ECDSA". +# The use of MD5 as a digest is strongly discouraged due to security weaknesses. +# Example: "ECDSA+SHA256:RSA+SHA256" + +def SSL_CTX_set1_client_sigalgs(ctx, slist, slistlen): + _slist = (c_int * len(slist))(*slist) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist) + +def SSL_CTX_set1_client_sigalgs_list(ctx, s): + _s = cast(s, POINTER(c_char)) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s) + +def SSL_set1_client_sigalgs(ssl, slist, slistlen): + _slist = (c_int * len(slist))(*slist) + return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist) + +def SSL_set1_client_sigalgs_list(ssl, s): + _s = cast(s, POINTER(c_char)) + return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s) + +def SSL_CTX_set1_sigalgs(ctx, slist, slistlen): + _slist = (c_int * len(slist))(*slist) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS, len(_slist), _slist) + +def SSL_CTX_set1_sigalgs_list(ctx, s): + _s = cast(s, POINTER(c_char)) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS_LIST, 0, _s) + +def SSL_set1_sigalgs(ssl, slist, slistlen): + _slist = (c_int * len(slist))(*slist) + return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS, len(_slist), _slist) + +def SSL_set1_sigalgs_list(ssl, s): + _s = cast(s, POINTER(c_char)) + return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS_LIST, 0, _s) diff --git a/dtls/patch.py b/dtls/patch.py index d7d73d4..f3bfeb1 100644 --- a/dtls/patch.py +++ b/dtls/patch.py @@ -34,16 +34,18 @@ has the following effects: PROTOCOL_DTLSv1 for the parameter ssl_version is supported """ -from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error -from socket import AF_INET, SOCK_DGRAM, getaddrinfo -from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE -from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION -from sslconnection import DTLS_OPENSSL_VERSION_INFO -from err import raise_as_ssl_module_error -from types import MethodType +from socket import socket, getaddrinfo, _delegate_methods, error as socket_error +from socket import AF_INET, SOCK_DGRAM +from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE +from types import MethodType, BuiltinMethodType from weakref import proxy import errno +from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2 +from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO +from err import raise_as_ssl_module_error + + def do_patch(): import ssl as _ssl # import to be avoided if ssl module is never patched global _orig_SSLSocket_init, _orig_get_server_certificate @@ -51,8 +53,19 @@ def do_patch(): ssl = _ssl if hasattr(ssl, "PROTOCOL_DTLSv1"): return + _orig_wrap_socket = ssl.wrap_socket + ssl.wrap_socket = _wrap_socket + SSLSocket_ = ssl.SSLSocket + class SSLSocket(SSLSocket_): + def __getattr__(self, item): + if hasattr(self, "_sslobj") and hasattr(self._sslobj, item): + return getattr(self._sslobj, item) + raise AttributeError + ssl.SSLSocket = SSLSocket ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 + ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2 ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" + ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2" ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO @@ -62,8 +75,23 @@ def do_patch(): ssl.get_server_certificate = _get_server_certificate raise_as_ssl_module_error() -PROTOCOL_SSLv3 = 1 -PROTOCOL_SSLv23 = 2 + +def _wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None, + cb_user_ssl_ctx_config=None, cb_user_ssl_config=None): + + return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers, + cb_user_ssl_ctx_config=cb_user_ssl_ctx_config, + cb_user_ssl_config=cb_user_ssl_config) + def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): """Retrieve a server certificate @@ -74,10 +102,10 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): If 'ssl_version' is specified, use it in the connection attempt. """ - if ssl_version != PROTOCOL_DTLSv1: + if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2): return _orig_get_server_certificate(addr, ssl_version, ca_certs) - if (ca_certs is not None): + if ca_certs is not None: cert_reqs = ssl.CERT_REQUIRED else: cert_reqs = ssl.CERT_NONE @@ -94,7 +122,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): + suppress_ragged_eofs=True, ciphers=None, + cb_user_ssl_ctx_config=None, cb_user_ssl_config=None): is_connection = is_datagram = False if isinstance(sock, SSLConnection): is_connection = True @@ -138,7 +167,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, server_side, cert_reqs, ssl_version, ca_certs, do_handshake_on_connect, - suppress_ragged_eofs, ciphers) + suppress_ragged_eofs, ciphers, + cb_user_ssl_ctx_config, cb_user_ssl_config) else: self._sslobj = sock @@ -150,6 +180,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None, self.ciphers = ciphers self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs + self.cb_user_ssl_ctx_config = cb_user_ssl_ctx_config + self.cb_user_ssl_config = cb_user_ssl_config self._makefile_refs = 0 # Perform method substitution and addition (without reference cycle) @@ -169,7 +201,8 @@ def _SSLSocket_listen(self, ignored): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers) + self.suppress_ragged_eofs, self.ciphers, + self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) def _SSLSocket_accept(self): if self._connected: @@ -184,7 +217,8 @@ def _SSLSocket_accept(self): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers) + self.suppress_ragged_eofs, self.ciphers, + self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) return new_ssl_sock, addr def _SSLSocket_real_connect(self, addr, return_errno): @@ -195,7 +229,8 @@ def _SSLSocket_real_connect(self, addr, return_errno): self.cert_reqs, self.ssl_version, self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers) + self.suppress_ragged_eofs, self.ciphers, + self.cb_user_ssl_ctx_config, self.cb_user_ssl_config) try: self._sslobj.connect(addr) except socket_error as e: diff --git a/dtls/prebuilt/win32-x86/libeay32.dll b/dtls/prebuilt/win32-x86/libeay32.dll index 8f8a254..db98e5b 100644 Binary files a/dtls/prebuilt/win32-x86/libeay32.dll and b/dtls/prebuilt/win32-x86/libeay32.dll differ diff --git a/dtls/prebuilt/win32-x86/ssleay32.dll b/dtls/prebuilt/win32-x86/ssleay32.dll index 1a70063..965b609 100644 Binary files a/dtls/prebuilt/win32-x86/ssleay32.dll and b/dtls/prebuilt/win32-x86/ssleay32.dll differ diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 9a5dcef..20447f4 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -48,13 +48,14 @@ from logging import getLogger from os import urandom from select import select from weakref import proxy + from err import openssl_error, InvalidSocketError from err import raise_ssl_error -from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL +from err import SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_SYSCALL from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT -from err import ERR_BOTH_KEY_CERT_FILES_SVR +from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR from x509 import _X509, decode_cert from tlock import tlock_init from openssl import * @@ -63,6 +64,8 @@ from util import _Rsrc, _BIO _logger = getLogger(__name__) PROTOCOL_DTLSv1 = 256 +PROTOCOL_DTLSv1_2 = 258 +PROTOCOL_DTLS = 259 CERT_NONE = 0 CERT_OPTIONAL = 1 CERT_REQUIRED = 2 @@ -122,6 +125,185 @@ class _CallbackProxy(object): return self.ssl_func(self.ssl_connection, *args, **kwargs) +def _ssl_logging_cb(conn, where, return_code): + + def get_alert_desc(return_code): + _alertDesc = { + 0: 'close notify', + 10: 'unexpected message', + 20: 'bad record mac', + 30: 'decompression failure', + 40: 'handshake failure', + 41: 'no certificate', + 42: 'bad certificate', + 43: 'unsupported certificate', + 44: 'certificate revoked', + 45: 'certificate expired', + 46: 'certificate unknown', + 47: 'illegal parameter', + } + _typeDescr = { + 1: 'warning', + 2: 'fatal', + } + + _type = return_code >> 8 + _ret = return_code & 0xFF + + _typeStr = _typeDescr[_type] if _type in _typeDescr else ('unknown (%d)' % _type) + + if _ret in _alertDesc: + return _typeStr, _alertDesc[_ret] + + return _typeStr, str(_ret) + + _state = where & ~SSL_ST_MASK + state = "SSL_undef (%04x)" % _state + + if _state & SSL_ST_INIT == SSL_ST_INIT: + state = "SSL_init" + if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE: + state = "SSL_renew" + elif _state & SSL_ST_CONNECT: + state = "SSL_connect" + elif _state & SSL_ST_ACCEPT: + state = "SSL_accept" + elif _state & SSL_ST_BEFORE: + state = "SSL_before" + + if where & SSL_CB_LOOP: + state += '_loop' + _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) + + elif where & SSL_CB_ALERT: + op = "read" if where & SSL_CB_READ else "write" + state += '_alert' + _logger.debug("%s %s: %s" % (state, op, ' - '.join(get_alert_desc(return_code)))) + + elif where & SSL_CB_EXIT: + state += '_exit' + if return_code == 0: + _logger.debug("%s: failed in %s" % (state, SSL_state_string_long(conn))) + elif return_code < 0: + _logger.debug("%s: error %d in %s" % (state, return_code, SSL_state_string_long(conn))) + else: + _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) + + else: + _logger.debug("%s: %s" % (state, SSL_state_string_long(conn))) + + +class SSLContext(object): + + def __init__(self, ctx): + self._ctx = ctx + + def set_ciphers(self, ciphers): + u''' + s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html + + :param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ... + :return: 1 for success and 0 for failure + ''' + retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers) + return retVal + + def set_sigalgs(self, sigalgs): + u''' + s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html + + :param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ... + :return: 1 for success and 0 for failure + ''' + retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs) + return retVal + + def set_curves(self, curves): + u''' Set supported curves by name, nid or nist. + + :param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ... + :return: 1 for success and 0 for failure + ''' + retVal = None + if isinstance(curves, str): + retVal = SSL_CTX_set1_curves_list(self._ctx, curves) + elif isinstance(curves, tuple): + retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves)) + return retVal + + @staticmethod + def get_ec_nist2nid(nist): + if not isinstance(nist, tuple): + nist = nist.split(":") + nid = tuple(EC_curve_nist2nid(x) for x in nist) + return nid + + @staticmethod + def get_ec_nid2nist(nid): + if not isinstance(nid, tuple): + nid = (nid, ) + nist = ":".join([EC_curve_nid2nist(x) for x in nid]) + return nist + + @staticmethod + def get_ec_available(bAsName=True): + curves = get_elliptic_curves() + return sorted([x.name for x in curves] if bAsName else [x._nid for x in curves]) + + def set_ecdh_curve(self, curve_name=None): + u''' + Used for server only! + + s.a. openssl.exe ecparam -list_curves + + :param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ... + :return: 1 for success and 0 for failure + ''' + if curve_name: + retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0) + avail_curves = get_elliptic_curves() + key = [curve for curve in avail_curves if curve.name == curve_name][0].as_EC_KEY() + retVal = SSL_CTX_set_tmp_ecdh(self._ctx, key) + EC_KEY_free(key) + else: + retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1) + return retVal + + def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT): + u''' + Used for server side only! + + :param flags: + :return: 1 for success and 0 for failure + ''' + retVal = SSL_CTX_build_cert_chain(self._ctx, flags) + return retVal + + def set_ssl_logging(self, enable=False, func=_ssl_logging_cb): + u''' Enable or disable SSL logging + + :param True | False enable: Enable or disable SSL logging + :param func: Callback function for logging + ''' + if enable: + SSL_CTX_set_info_callback(self._ctx, func) + else: + SSL_CTX_set_info_callback(self._ctx, 0) + + +class SSL(object): + + def __init__(self, ssl): + self._ssl = ssl + + def set_mtu(self, mtu=None): + if mtu: + SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU) + SSL_set_mtu(self._ssl, mtu) + else: + SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU) + + class SSLConnection(object): """DTLS peer association @@ -149,7 +331,12 @@ class SSLConnection(object): else: self._rsock = rsock self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) - self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) + server_method = DTLS_server_method + if self._ssl_version == PROTOCOL_DTLSv1_2: + server_method = DTLSv1_2_server_method + elif self._ssl_version == PROTOCOL_DTLSv1: + server_method = DTLSv1_server_method + self._ctx = _CTX(SSL_CTX_new(server_method())) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE @@ -179,7 +366,12 @@ class SSLConnection(object): self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._rbio = self._wbio - self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) + client_method = DTLSv1_client_method + if self._ssl_version == PROTOCOL_DTLSv1_2: + client_method = DTLSv1_2_client_method + elif self._ssl_version == PROTOCOL_DTLSv1: + client_method = DTLSv1_client_method + self._ctx = _CTX(SSL_CTX_new(client_method())) if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE else: @@ -197,7 +389,9 @@ class SSLConnection(object): # corruption when packet loss occurs SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION) if self._certfile: - SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) + # SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) + SSL_CTX_use_certificate_file(self._ctx.value, self._certfile, + SSL_FILE_TYPE_PEM) if self._keyfile: SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile, SSL_FILE_TYPE_PEM) @@ -208,6 +402,8 @@ class SSLConnection(object): SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) except openssl_error() as err: raise_ssl_error(ERR_NO_CIPHER, err) + if self._user_ssl_ctx_config: + self._user_ssl_ctx_config(SSLContext(self._ctx.value)) def _copy_server(self): source = self._sock @@ -241,6 +437,8 @@ class SSLConnection(object): new_source_wbio.value) new_source_rbio.disown() new_source_wbio.disown() + if self._user_ssl_config: + self._user_ssl_config(SSL(source._ssl.value)) def _reconnect_unwrapped(self): source = self._sock @@ -310,9 +508,11 @@ class SSLConnection(object): def __init__(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_DTLSv1, ca_certs=None, + ssl_version=PROTOCOL_DTLS, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): + suppress_ragged_eofs=True, ciphers=None, + cb_user_ssl_ctx_config=None, + cb_user_ssl_config=None): """Constructor Arguments: @@ -330,10 +530,13 @@ class SSLConnection(object): if not ciphers: ciphers = "DEFAULT" + assert ssl_version in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2) + self._sock = sock self._keyfile = keyfile self._certfile = certfile self._cert_reqs = cert_reqs + self._ssl_version = ssl_version self._ca_certs = ca_certs self._do_handshake_on_connect = do_handshake_on_connect self._suppress_ragged_eofs = suppress_ragged_eofs @@ -341,6 +544,9 @@ class SSLConnection(object): self._handshake_done = False self._wbio_nb = self._rbio_nb = False + self._user_ssl_ctx_config = cb_user_ssl_ctx_config + self._user_ssl_config = cb_user_ssl_config + if isinstance(sock, SSLConnection): post_init = self._copy_server() elif isinstance(sock, _UnwrappedSocket): @@ -358,6 +564,8 @@ class SSLConnection(object): SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) self._rbio.disown() self._wbio.disown() + if self._user_ssl_config: + self._user_ssl_config(SSL(self._ssl.value)) if post_init: post_init() @@ -462,9 +670,11 @@ class SSLConnection(object): _logger.debug("Accept returning without connection") return new_conn = SSLConnection(self, self._keyfile, self._certfile, True, - self._cert_reqs, PROTOCOL_DTLSv1, + self._cert_reqs, self._ssl_version, self._ca_certs, self._do_handshake_on_connect, - self._suppress_ragged_eofs, self._ciphers) + self._suppress_ragged_eofs, self._ciphers, + cb_user_ssl_ctx_config=self._user_ssl_ctx_config, + cb_user_ssl_config=self._user_ssl_config) new_peer = self._pending_peer_address self._pending_peer_address = None if self._do_handshake_on_connect: @@ -540,8 +750,12 @@ class SSLConnection(object): number of bytes actually transmitted """ - return self._wrap_socket_library_call( + retVal = self._wrap_socket_library_call( lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) + # for client side ... we want to know when the handshake is completed + if self._wbio is self._rbio and retVal > 0: + self._handshake_done = True + return retVal def shutdown(self): """Shut down the DTLS connection diff --git a/dtls/wrapper.py b/dtls/wrapper.py new file mode 100644 index 0000000..16fa805 --- /dev/null +++ b/dtls/wrapper.py @@ -0,0 +1,286 @@ +# -*- encoding: utf-8 -*- + +import datetime +import select + +from logging import getLogger + +import ssl +import socket +from dtls import do_patch +do_patch() + +_logger = getLogger(__name__) + + +class _ClientSession(object): + + def __init__(self, host, port, handshake_done=False): + self.host = host + self.port = int(port) + self.handshake_done = handshake_done + + def getAddr(self): + return self.host, self.port + + +class DtlsSocket(object): + + def __init__(self, + host, + port, + keyfile=None, + certfile=None, + server_side=False, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1_2, + ca_certs=None, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + ciphers=None, + curves=None, + sigalgs=None, + user_mtu=None): + + self._ssl_logging = False + self._peer = (host, int(port)) + self._server_side = server_side + self._ciphers = ciphers + self._curves = curves + self._sigalgs = sigalgs + self._user_mtu = user_mtu + + self._sock = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + keyfile=keyfile, + certfile=certfile, + server_side=self._server_side, + cert_reqs=cert_reqs, + ssl_version=ssl_version, + ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + ciphers=self._ciphers, + cb_user_ssl_ctx_config=self.user_ssl_ctx_config, + cb_user_ssl_config=self.user_ssl_config) + + if self._server_side: + self._clients = {} + self._timeout = None + + self._sock.bind(self._peer) + self._sock.listen(0) + else: + self._sock.connect(self._peer) + + def user_ssl_ctx_config(self, _ctx): + _ctx.set_ssl_logging(self._ssl_logging) + if self._ciphers: + _ctx.set_ciphers(self._ciphers) + if self._curves: + _ctx.set_curves(self._curves) + if self._sigalgs: + _ctx.set_sigalgs(self._sigalgs) + if self._server_side: + _ctx.build_cert_chain() + _ctx.set_ecdh_curve() # ("secp256k1") + + def user_ssl_config(self, _ssl): + if self._user_mtu: + _ssl.set_mtu(self._user_mtu) + + def settimeout(self, t): + if self._server_side: + self._timeout = t + else: + self._sock.settimeout(t) + + def close(self): + if self._server_side: + for cli in self._clients.keys(): + cli.close() + else: + self._sock.unwrap() + self._sock.close() + + def recvfrom(self, bufsize, flags=0): + if self._server_side: + return self._recvfrom_on_server_side(bufsize, flags=flags) + else: + return self._recvfrom_on_client_side(bufsize, flags=flags) + + def _recvfrom_on_server_side(self, bufsize, flags): + try: + r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) + + except socket.timeout as e_timeout: + raise e_timeout + + else: + for conn in r: # type: ssl.SSLSocket + if self._sockIsServerSock(conn): + # Connect + self._clientAccept(conn) + else: + # Handshake + if not self._clientHandshakeDone(conn): + self._clientDoHandshake(conn) + # Normal read + else: + buf = self._clientRead(conn, bufsize) + if buf and conn in self._clients: + return buf, self._clients[conn].getAddr() + + for conn in self._getClientReadingSockets(): + if conn.get_timeout(): + conn.handle_timeout() + + raise socket.timeout + + def _recvfrom_on_client_side(self, bufsize, flags): + try: + buf = self._sock.recv(bufsize, flags) + + except ssl.SSLError as e_ssl: + if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + return '', self._peer + elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: + raise + else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + pass + + else: + if buf: + return buf, self._peer + + raise socket.timeout + + def sendto(self, buf, address): + if self._server_side: + return self._sendto_from_server_side(buf, address) + else: + return self._sendto_from_client_side(buf, address) + + def _sendto_from_server_side(self, buf, address): + for conn, client in self._clients.iteritems(): + if client.getAddr() == address: + return self._clientWrite(conn, buf) + return 0 + + def _sendto_from_client_side(self, buf, address): + while True: + try: + bytes_sent = self._sock.send(buf) + + except ssl.SSLError as e_ssl: + if str(e_ssl).startswith("503:"): + # The write operation timed out + continue + # elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ]: + # # no ciphers available + # if e_ssl.args[1][0][0] in [336081077, ]: + # raise + raise + + else: + if bytes_sent: + break + + return bytes_sent + + def _getClientReadingSockets(self): + return [x for x in self._clients.keys()] + + def _getAllReadingSockets(self): + return [self._sock] + self._getClientReadingSockets() + + def _sockIsServerSock(self, conn): + return conn is self._sock + + def _clientHandshakeDone(self, conn): + return conn in self._clients and self._clients[conn].handshake_done is True + + def _clientAccept(self, conn): + _logger.debug('+' * 60) + ret = None + + try: + ret = conn.accept() + _logger.debug('Accept returned with ... %s' % (str(ret))) + + except Exception as e_accept: + pass + + else: + if ret: + client, addr = ret + host, port = addr + if client in self._clients: + raise + self._clients[client] = _ClientSession(host=host, port=port) + + self._clientDoHandshake(client) + + def _clientDoHandshake(self, conn): + _logger.debug('-' * 60) + conn.setblocking(False) + + try: + conn.do_handshake() + _logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr()))) + + self._clients[conn].handshake_done = True + + except ssl.SSLError as e_handshake: + if str(e_handshake).startswith("504:"): + pass + elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ: + pass + else: + raise e_handshake + + def _clientRead(self, conn, bufsize=4096): + _logger.debug('*' * 60) + ret = None + + try: + ret = conn.recv(bufsize) + _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) + + except ssl.SSLError as e_read: + if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + self._clientDrop(conn) + elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]: + self._clientDrop(conn, error=e_read) + else: # like in [ssl.SSL_ERROR_WANT_READ, ...] + pass + + return ret + + def _clientWrite(self, conn, data): + _logger.debug('#' * 60) + ret = None + + try: + ret = conn.send(data.raw) + _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) + + except Exception as e_write: + raise + + return ret + + def _clientDrop(self, conn, error=None): + _logger.debug('$' * 60) + + try: + if error: + _logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error)) + else: + _logger.debug('Drop client %s' % str(self._clients[conn].getAddr())) + + if conn in self._clients: + del self._clients[conn] + conn.unwrap() + conn.close() + + except Exception as e_drop: + pass