diff --git a/dtls/openssl.py b/dtls/openssl.py index 455bed4..519d587 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -82,6 +82,7 @@ SSL_SESS_CACHE_NO_INTERNAL = \ SSL_FILE_TYPE_PEM = 1 GEN_DIRNAME = 4 NID_subject_alt_name = 85 +CRYPTO_LOCK = 1 # # Integer constants - internal @@ -442,6 +443,8 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", "SSL_FILE_TYPE_PEM", "GEN_DIRNAME", "NID_subject_alt_name", + "CRYPTO_LOCK", + "CRYPTO_set_locking_callback", "DTLSv1_get_timeout", "DTLSv1_handle_timeout", "DTLSv1_listen", "BIO_gets", "BIO_read", "BIO_get_mem_data", @@ -463,6 +466,9 @@ map(lambda x: _make_function(*x), ( ("SSL_load_error_strings", libssl, ((None, "ret"),)), ("SSLeay", libcrypto, ((c_long_parm, "ret"),)), ("SSLeay_version", libcrypto, ((c_char_p, "ret"), (c_int, "t"))), + ("CRYPTO_set_locking_callback", libcrypto, + ((None, "ret"), (c_void_p, "func")), False), + ("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), @@ -589,9 +595,18 @@ map(lambda x: _make_function(*x), ( # # Wrappers - functions generally equivalent to OpenSSL library macros # -_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), - POINTER(c_uint)) -_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), c_uint) +_rvoid_int_int_charp_int = CFUNCTYPE(None, c_int, c_int, c_char_p, c_int) + +def CRYPTO_set_locking_callback(locking_function): + def py_locking_function(mode, n, file, line): + try: + locking_function(mode, n, file, line) + except: + _logger.exception("Thread locking failed") + + global _locking_cb # for keep-alive + _locking_cb = _rvoid_int_int_charp_int(py_locking_function) + _CRYPTO_set_locking_callback(_locking_cb) def SSL_CTX_set_session_cache_mode(ctx, mode): # Returns the previous value of mode @@ -601,6 +616,10 @@ def SSL_CTX_set_read_ahead(ctx, m): # Returns the previous value of m _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_READ_AHEAD, m, None) +_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), + POINTER(c_uint)) +_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), c_uint) + def SSL_CTX_set_cookie_cb(ctx, generate, verify): def py_generate_cookie_cb(ssl, cookie, cookie_len): try: diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 8c17de0..3976a62 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -35,6 +35,7 @@ from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from x509 import _X509, decode_cert +from tlock import tlock_init from openssl import * from util import _Rsrc, _BIO @@ -50,6 +51,7 @@ CERT_REQUIRED = 2 # SSL_library_init() SSL_load_error_strings() +tlock_init() DTLS_OPENSSL_VERSION_NUMBER = SSLeay() DTLS_OPENSSL_VERSION = SSLeay_version(SSLEAY_VERSION) DTLS_OPENSSL_VERSION_INFO = ( diff --git a/dtls/tlock.py b/dtls/tlock.py new file mode 100644 index 0000000..4aea6f8 --- /dev/null +++ b/dtls/tlock.py @@ -0,0 +1,36 @@ +# TLock: OpenSSL lock support on thread-enabled systems. Written by Ray Brown. +"""TLock + +This module provides the callbacks required by the OpenSSL library in situations +where it is being entered concurrently by multiple threads. This module is +enagaged automatically by the PyDTLS package on systems that have Python +threading support. It does not have client-visible components. +""" + +from logging import getLogger +from openssl import * + +try: + import threading +except ImportError: + pass + +_logger = getLogger(__name__) +DO_DEBUG_LOG = False + +def tlock_init(): + if not globals().has_key("threading"): + return # nothing to configure + global _locks + num_locks = CRYPTO_num_locks() + _locks = tuple(threading.Lock() for _ in range(num_locks)) + CRYPTO_set_locking_callback(_locking_function) + +def _locking_function(mode, n, file, line): + if DO_DEBUG_LOG: + _logger.debug("Thread lock: mode: %d, n: %d, file: %s, line: %d", + mode, n, file, line) + if mode & CRYPTO_LOCK: + _locks[n].acquire() + else: + _locks[n].release()