From 4464d0bd84caf0fc9101bf28963cd49e426a98ce Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Thu, 8 Nov 2012 12:04:40 -0800 Subject: [PATCH] Certificate formatting and retrieval This change introduces the implementation of the SSLConnection methods getpeercert and cipher. The following has been added: * dtls.util: utility elements shared by other modules in this package * dtls.x509: a module for X509-certificate-related functionality, including formatting a certificate into a Python dictionary as prescribed by the Python standard library's ssl module; functionality for testing with PEM-encoded certificates in the file system is included * yahoo-cert.pem: the current certificate of www.yahoo.com: this is a good testing certificate, since it contains the subject alternate name extension Other notable changes: * sslconnection: private attributes are now preceded by "_" * openssl: null-ness in opaque FuncParam-derived return values is now properly detected and an exception is raised as expected --- .gitignore | 7 +- dtls/openssl.py | 320 ++++++++++++++++++++++++++++++--- dtls/sslconnection.py | 283 ++++++++++++++++------------- dtls/test/certs/yahoo-cert.pem | 29 +++ dtls/test/echo_seq.py | 2 +- dtls/test/rl.py | 5 + dtls/util.py | 38 ++++ dtls/x509.py | 124 +++++++++++++ 8 files changed, 656 insertions(+), 152 deletions(-) create mode 100644 dtls/test/certs/yahoo-cert.pem create mode 100644 dtls/util.py create mode 100644 dtls/x509.py diff --git a/.gitignore b/.gitignore index f24cd99..91fe3a9 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,11 @@ pip-log.txt .coverage .tox -#Translations +# Translations *.mo -#Mr Developer +# Mr Developer .mr.developer.cfg + +# Emacs temp files +*~ diff --git a/dtls/openssl.py b/dtls/openssl.py index 2192d52..222d199 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -3,19 +3,16 @@ """OpenSSL Wrapper This module provides run-time access to the OpenSSL cryptographic and -protocols libraries. +protocols libraries. It is designed for use with "from openssl import *". For +this reason, the module variable __all__ contains all of this module's +integer constants, OpenSSL library functions, and wrapper functions. + +Constants and functions are not documented here. See the OpenSSL library +documentation. Exceptions: OpenSSLError -- exception raised when errors occur in the OpenSSL library - -Functions: - -Integer constants: - - BIO_NOCLOSE -- don't destroy encapsulated resource when closing BIO - BIO_CLOSE -- do destroy encapsulated resource when closing BIO - """ import sys @@ -25,14 +22,15 @@ from logging import getLogger from os import path from err import OpenSSLError from err import SSL_ERROR_NONE +from util import _BIO import ctypes from ctypes import CDLL from ctypes import CFUNCTYPE -from ctypes import c_void_p, c_int, c_uint, c_ulong, c_char_p, c_size_t +from ctypes import c_void_p, c_int, c_long, c_uint, c_ulong, c_char_p, c_size_t from ctypes import c_short, c_ushort, c_ubyte, c_char -from ctypes import byref, POINTER +from ctypes import byref, POINTER, addressof from ctypes import Structure, Union -from ctypes import create_string_buffer, sizeof, memmove +from ctypes import create_string_buffer, sizeof, memmove, cast # # Module initialization @@ -44,10 +42,19 @@ _logger = getLogger(__name__) # if sys.platform.startswith('win'): dll_path = path.abspath(path.dirname(__file__)) - #libcrypto = CDLL(path.join(dll_path, "libeay32.dll")) - #libssl = CDLL(path.join(dll_path, "ssleay32.dll")) - libcrypto = CDLL(path.join(dll_path, "cygcrypto-1.0.0.dll")) - libssl = CDLL(path.join(dll_path, "cygssl-1.0.0.dll")) + debug_cryptodll_path = path.join(dll_path, "cygcrypto-1.0.0.dll") + debug_ssldll_path = path.join(dll_path, "cygssl-1.0.0.dll") + release_cryptodll_path = path.join(dll_path, "libeay32.dll") + release_ssldll_path = path.join(dll_path, "ssleay32.dll") + if path.exists(path.join(dll_path, "use_debug_openssl")) and \ + path.exists(debug_cryptodll_path) and \ + path.exists(debug_ssldll_path): + libcrypto = CDLL(debug_cryptodll_path) + libssl = CDLL(debug_ssldll_path) + else: + # If these don't exist, then let the exception propagate + libcrypto = CDLL(release_cryptodll_path) + libssl = CDLL(release_ssldll_path) else: libcrypto = CDLL("libcrypto.so.1.0.0") libssl = CDLL("libssl.so.1.0.0") @@ -71,22 +78,27 @@ SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200 SSL_SESS_CACHE_NO_INTERNAL = \ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE SSL_FILE_TYPE_PEM = 1 +GEN_DIRNAME = 4 +NID_subject_alt_name = 85 # # Integer constants - internal # SSL_CTRL_SET_SESS_CACHE_MODE = 44 SSL_CTRL_SET_READ_AHEAD = 41 +BIO_CTRL_INFO = 3 BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 BIO_CTRL_DGRAM_SET_PEER = 44 BIO_C_SET_NBIO = 102 DTLS_CTRL_LISTEN = 75 +X509_NAME_MAXLEN = 256 +GETS_MAXLEN = 2048 # # Parameter data types # -class c_long(object): +class c_long_parm(object): """Long integer paramter class c_long must be distinguishable from c_int, as the latter is associated @@ -107,12 +119,20 @@ class FuncParam(object): def __init__(self, value): self._as_parameter = value + def __nonzero__(self): + return bool(self._as_parameter) + class DTLSv1Method(FuncParam): def __init__(self, value): super(DTLSv1Method, self).__init__(value) +class BIO_METHOD(FuncParam): + def __init__(self, value): + super(BIO_METHOD, self).__init__(value) + + class SSLCTX(FuncParam): def __init__(self, value): super(SSLCTX, self).__init__(value) @@ -128,6 +148,106 @@ class BIO(FuncParam): super(BIO, self).__init__(value) +class X509(FuncParam): + def __init__(self, value): + super(X509, self).__init__(value) + + +class X509_val_st(Structure): + _fields_ = [("notBefore", c_void_p), + ("notAfter", c_void_p)] + + +class X509_cinf_st(Structure): + _fields_ = [("version", c_void_p), + ("serialNumber", c_void_p), + ("signature", c_void_p), + ("issuer", c_void_p), + ("validity", POINTER(X509_val_st))] # remaining fields omitted + + +class X509_st(Structure): + _fields_ = [("cert_info", POINTER(X509_cinf_st),)] # remainder omitted + + +class X509_name_st(Structure): + _fields_ = [("entries", c_void_p)] # remaining fields omitted + + +class ASN1_OBJECT(FuncParam): + def __init__(self, value): + super(ASN1_OBJECT, self).__init__(value) + + +class ASN1_STRING(FuncParam): + def __init__(self, value): + super(ASN1_STRING, self).__init__(value) + + +class ASN1_TIME(FuncParam): + def __init__(self, value): + super(ASN1_TIME, self).__init__(value) + + +class SSL_CIPHER(FuncParam): + def __init__(self, value): + super(SSL_CIPHER, self).__init__(value) + + +class GENERAL_NAME_union_d(Union): + _fields_ = [("ptr", c_char_p), + # entries omitted + ("directoryName", POINTER(X509_name_st))] + # remaining fields omitted + + +class STACK(FuncParam): + def __init__(self, value): + super(STACK, self).__init__(value) + + +class GENERAL_NAME(Structure): + _fields_ = [("type", c_int), + ("d", GENERAL_NAME_union_d)] + + +class GENERAL_NAMES(STACK): + stack_element_type = GENERAL_NAME + + def __init__(self, value): + super(GENERAL_NAMES, self).__init__(value) + + +class X509_NAME_ENTRY(Structure): + _fields_ = [("object", c_void_p), + ("value", c_void_p), + ("set", c_int), + ("size", c_int)] + + +class ASN1_OCTET_STRING(Structure): + _fields_ = [("length", c_int), + ("type", c_int), + ("data", POINTER(c_ubyte)), + ("flags", c_long)] + + +class X509_EXTENSION(Structure): + _fields_ = [("object", c_void_p), + ("critical", c_int), + ("value", POINTER(ASN1_OCTET_STRING))] + + +class X509V3_EXT_METHOD(Structure): + _fields_ = [("ext_nid", c_int), + ("ext_flags", c_int), + ("it", c_void_p), + ("ext_new", c_int), + ("ext_free", c_int), + ("d2i", c_int), + ("i2d", c_int)] # remaining fields omitted + + # # Socket address conversions # @@ -293,7 +413,7 @@ def _make_function(name, lib, args, export=True, errcheck="default"): if args[0][0] in (c_int,): errcheck = errcheck_ord elif args[0][0] in (c_void_p, c_char_p) or \ - isinstance(args[0][0], FuncParam): + isinstance(args[0][0], type) and issubclass(args[0][0], FuncParam): errcheck = errcheck_p else: errcheck = None @@ -301,7 +421,7 @@ def _make_function(name, lib, args, export=True, errcheck="default"): func.errcheck = errcheck globals()[glbl_name] = func -_subst = {c_long: ctypes.c_long} +_subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", @@ -311,13 +431,21 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", "SSL_FILE_TYPE_PEM", + "GEN_DIRNAME", "NID_subject_alt_name", "DTLSv1_listen", + "BIO_gets", "BIO_read", "BIO_get_mem_data", "BIO_dgram_set_connected", "BIO_dgram_get_peer", "BIO_dgram_set_peer", "BIO_set_nbio", "SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead", "SSL_read", "SSL_write", - "SSL_CTX_set_cookie_cb"] + "SSL_CTX_set_cookie_cb", + "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", + "X509_get_notAfter", + "ASN1_item_d2i", "GENERAL_NAME_print", + "sk_value", + "sk_pop_free", + "i2d_X509"] # note: the following map adds to this list map(lambda x: _make_function(*x), ( ("SSL_library_init", libssl, ((c_int, "ret"),)), @@ -335,19 +463,28 @@ map(lambda x: _make_function(*x), ( ("SSL_free", libssl, ((None, "ret"), (SSL, "ssl"))), ("SSL_set_bio", libssl, ((None, "ret"), (SSL, "ssl"), (BIO, "rbio"), (BIO, "wbio"))), + ("BIO_new", libcrypto, ((BIO, "ret"), (BIO_METHOD, "type"))), + ("BIO_s_mem", libcrypto, ((BIO_METHOD, "ret"),)), + ("BIO_new_file", libcrypto, + ((BIO, "ret"), (c_char_p, "filename"), (c_char_p, "mode"))), ("BIO_new_dgram", libcrypto, ((BIO, "ret"), (c_int, "fd"), (c_int, "close_flag"))), ("BIO_free", libcrypto, ((c_int, "ret"), (BIO, "a"))), + ("BIO_gets", libcrypto, + ((c_int, "ret"), (BIO, "b"), (POINTER(c_char), "buf"), (c_int, "size")), + False), + ("BIO_read", libcrypto, + ((c_int, "ret"), (BIO, "b"), (c_void_p, "buf"), (c_int, "len")), False), ("SSL_CTX_ctrl", libssl, - ((c_long, "ret"), (SSLCTX, "ctx"), (c_int, "cmd"), (c_long, "larg"), + ((c_long_parm, "ret"), (SSLCTX, "ctx"), (c_int, "cmd"), (c_long, "larg"), (c_void_p, "parg")), False), ("BIO_ctrl", libcrypto, - ((c_long, "ret"), (BIO, "bp"), (c_int, "cmd"), (c_long, "larg"), + ((c_long_parm, "ret"), (BIO, "bp"), (c_int, "cmd"), (c_long, "larg"), (c_void_p, "parg")), False), ("SSL_ctrl", libssl, - ((c_long, "ret"), (SSL, "ssl"), (c_int, "cmd"), (c_long, "larg"), + ((c_long_parm, "ret"), (SSL, "ssl"), (c_int, "cmd"), (c_long, "larg"), (c_void_p, "parg")), False), - ("ERR_get_error", libcrypto, ((c_long, "ret"),), False), + ("ERR_get_error", libcrypto, ((c_long_parm, "ret"),), False), ("ERR_error_string_n", libcrypto, ((None, "ret"), (c_ulong, "e"), (c_char_p, "buf"), (c_size_t, "len")), False), @@ -372,6 +509,7 @@ map(lambda x: _make_function(*x), ( ("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))), ("SSL_set_accept_state", libssl, ((None, "ret"), (SSL, "ssl"))), ("SSL_do_handshake", libssl, ((c_int, "ret"), (SSL, "ssl"))), + ("SSL_get_peer_certificate", libssl, ((X509, "ret"), (SSL, "ssl"))), ("SSL_read", libssl, ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), ("SSL_write", libssl, @@ -379,6 +517,59 @@ map(lambda x: _make_function(*x), ( ("SSL_shutdown", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_set_read_ahead", libssl, ((None, "ret"), (SSL, "ssl"), (c_int, "yes"))), + ("X509_free", libcrypto, ((None, "ret"), (X509, "a"))), + ("PEM_read_bio_X509_AUX", libcrypto, + ((X509, "ret"), (BIO, "bp"), (c_void_p, "x", 1, None), + (c_void_p, "cb", 1, None), (c_void_p, "u", 1, None))), + ("OBJ_obj2txt", libcrypto, + ((c_int, "ret"), (POINTER(c_char), "buf"), (c_int, "buf_len"), + (ASN1_OBJECT, "a"), (c_int, "no_name")), False), + ("CRYPTO_free", libcrypto, ((None, "ret"), (c_void_p, "ptr"))), + ("ASN1_STRING_to_UTF8", libcrypto, + ((c_int, "ret"), (POINTER(POINTER(c_ubyte)), "out"), (ASN1_STRING, "in")), + False), + ("X509_NAME_entry_count", libcrypto, + ((c_int, "ret"), (POINTER(X509_name_st), "name")), True, None), + ("X509_NAME_get_entry", libcrypto, + ((POINTER(X509_NAME_ENTRY), "ret"), (POINTER(X509_name_st), "name"), + (c_int, "loc")), True, errcheck_p), + ("X509_NAME_ENTRY_get_object", libcrypto, + ((ASN1_OBJECT, "ret"), (POINTER(X509_NAME_ENTRY), "ne"))), + ("X509_NAME_ENTRY_get_data", libcrypto, + ((ASN1_STRING, "ret"), (POINTER(X509_NAME_ENTRY), "ne"))), + ("X509_get_subject_name", libcrypto, + ((POINTER(X509_name_st), "ret"), (X509, "a")), True, errcheck_p), + ("ASN1_TIME_print", libcrypto, + ((c_int, "ret"), (BIO, "fp"), (ASN1_TIME, "a")), False), + ("X509_get_ext_by_NID", libcrypto, + ((c_int, "ret"), (X509, "x"), (c_int, "nid"), (c_int, "lastpos")), + True, None), + ("X509_get_ext", libcrypto, + ((POINTER(X509_EXTENSION), "ret"), (X509, "x"), (c_int, "loc")), + True, errcheck_p), + ("X509V3_EXT_get", libcrypto, + ((POINTER(X509V3_EXT_METHOD), "ret"), (POINTER(X509_EXTENSION), "ext")), + True, errcheck_p), + ("ASN1_item_d2i", libcrypto, + ((c_void_p, "ret"), (c_void_p, "val"), (POINTER(POINTER(c_ubyte)), "in"), + (c_long, "len"), (c_void_p, "it")), False, None), + ("sk_num", libcrypto, ((c_int, "ret"), (STACK, "stack")), True, None), + ("sk_value", libcrypto, + ((c_void_p, "ret"), (STACK, "stack"), (c_int, "loc")), False), + ("GENERAL_NAME_print", libcrypto, + ((c_int, "ret"), (BIO, "out"), (POINTER(GENERAL_NAME), "gen")), False), + ("sk_pop_free", libcrypto, + ((None, "ret"), (STACK, "st"), (c_void_p, "func")), False), + ("i2d_X509_bio", libcrypto, ((c_int, "ret"), (BIO, "bp"), (X509, "x")), + False), + ("SSL_get_current_cipher", libssl, ((SSL_CIPHER, "ret"), (SSL, "ssl"))), + ("SSL_CIPHER_get_name", libssl, + ((c_char_p, "ret"), (SSL_CIPHER, "cipher"))), + ("SSL_CIPHER_get_version", libssl, + ((c_char_p, "ret"), (SSL_CIPHER, "cipher"))), + ("SSL_CIPHER_get_bits", libssl, + ((c_int, "ret"), (SSL_CIPHER, "cipher"), + (POINTER(c_int), "alg_bits", 1, None)), True, None), )) # @@ -447,9 +638,88 @@ def DTLSv1_listen(ssl): def SSL_read(ssl, length): buf = create_string_buffer(length) - res_len = _SSL_read(ssl, buf, length) + res_len = _SSL_read(ssl, buf, sizeof(buf)) return buf.raw[:res_len] def SSL_write(ssl, data): str_data = str(data) return _SSL_write(ssl, str_data, len(str_data)) + +def OBJ_obj2txt(asn1_object, no_name): + buf = create_string_buffer(X509_NAME_MAXLEN) + res_len = _OBJ_obj2txt(buf, sizeof(buf), asn1_object, 1 if no_name else 0) + return buf.raw[:res_len] + +def decode_ASN1_STRING(asn1_string): + utf8_buf_ptr = POINTER(c_ubyte)() + res_len = _ASN1_STRING_to_UTF8(byref(utf8_buf_ptr), asn1_string) + try: + return unicode(''.join([chr(i) for i in utf8_buf_ptr[:res_len]]), + 'utf-8') + finally: + CRYPTO_free(utf8_buf_ptr) + +def X509_get_notAfter(x509): + x509_raw = X509.from_param(x509) + x509_ptr = cast(x509_raw, POINTER(X509_st)) + notAfter = x509_ptr.contents.cert_info.contents.validity.contents.notAfter + return ASN1_TIME(notAfter) + +def BIO_gets(bio): + buf = create_string_buffer(GETS_MAXLEN) + res_len = _BIO_gets(bio, buf, sizeof(buf) - 1) + return buf.raw[:res_len] + +def BIO_read(bio, length): + buf = create_string_buffer(length) + res_len = _BIO_read(bio, buf, sizeof(buf)) + return buf.raw[:res_len] + +def BIO_get_mem_data(bio): + buf = POINTER(c_ubyte)() + res_len = _BIO_ctrl(bio, BIO_CTRL_INFO, 0, byref(buf)) + return ''.join([chr(i) for i in buf[:res_len]]) + +def ASN1_TIME_print(asn1_time): + bio = _BIO(BIO_new(BIO_s_mem())) + _ASN1_TIME_print(bio.value, asn1_time) + return BIO_gets(bio.value) + +_rvoidp = CFUNCTYPE(c_void_p) + +def _ASN1_ITEM_ptr(item): + if sys.platform.startswith('win'): + func_ptr = _rvoidp(item) + return func_ptr() + return item + +_rvoidp_voidp_ubytepp_long = CFUNCTYPE(c_void_p, c_void_p, + POINTER(POINTER(c_ubyte)), c_long) + +def ASN1_item_d2i(method, asn1_octet_string): + data_in = POINTER(c_ubyte)(asn1_octet_string.data.contents) + if method.it: + return GENERAL_NAMES(_ASN1_item_d2i(None, byref(data_in), + asn1_octet_string.length, + _ASN1_ITEM_ptr(method.it))) + func_ptr = _rvoidp_voidp_ubytepp_long(method.d2i) + return GENERAL_NAMES(func_ptr(None, byref(data_in), + asn1_octet_string.length)) + +def sk_value(stack, loc): + return cast(_sk_value(stack, loc), POINTER(stack.stack_element_type)) + +def GENERAL_NAME_print(general_name): + bio = _BIO(BIO_new(BIO_s_mem())) + _GENERAL_NAME_print(bio.value, general_name) + return BIO_gets(bio.value) + +_free_func = addressof(c_void_p.in_dll(libcrypto, "sk_free")) + +def sk_pop_free(stack): + _sk_pop_free(stack, _free_func) + +def i2d_X509(x509): + bio = _BIO(BIO_new(BIO_s_mem())) + _i2d_X509_bio(bio.value, x509) + return BIO_get_mem_data(bio.value) diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index c1b909f..a0947da 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -32,7 +32,9 @@ from weakref import proxy from err import OpenSSLError, InvalidSocketError from err import raise_ssl_error from err import SSL_ERROR_WANT_READ, ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from x509 import _X509, decode_cert from openssl import * +from util import _Rsrc, _BIO _logger = getLogger(__name__) @@ -48,16 +50,6 @@ SSL_library_init() SSL_load_error_strings() -class _Rsrc(object): - """Wrapper base for library-owned resources""" - def __init__(self, value): - self._value = value - - @property - def value(self): - return self._value - - class _CTX(_Rsrc): """SSL_CTX wrapper""" def __init__(self, value): @@ -69,23 +61,6 @@ class _CTX(_Rsrc): self._value = None -class _BIO(_Rsrc): - """BIO wrapper""" - def __init__(self, value): - super(_BIO, self).__init__(value) - self.owned = True - - def disown(self): - self.owned = False - - def __del__(self): - if self.owned: - _logger.debug("Freeing BIO: %d", self._value._as_parameter) - BIO_free(self._value) - self.owned = False - self._value = None - - class _SSL(_Rsrc): """SSL structure wrapper""" def __init__(self, value): @@ -124,97 +99,98 @@ class SSLConnection(object): _rnd_key = urandom(16) def _init_server(self): - if self.sock.type != socket.SOCK_DGRAM: + if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") from demux import UDPDemux - self.udp_demux = UDPDemux(self.sock) - self.rsock = self.udp_demux.get_connection(None) - self.wbio = _BIO(BIO_new_dgram(self.sock.fileno(), BIO_NOCLOSE)) - self.rbio = _BIO(BIO_new_dgram(self.rsock.fileno(), BIO_NOCLOSE)) - self.ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) - SSL_CTX_set_session_cache_mode(self.ctx.value, SSL_SESS_CACHE_OFF) - if self.cert_reqs == CERT_NONE: + self._udp_demux = UDPDemux(self._sock) + self._rsock = self._udp_demux.get_connection(None) + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) + SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) + if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE - elif self.cert_reqs == CERT_OPTIONAL: + elif self._cert_reqs == CERT_OPTIONAL: verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE else: verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE | \ SSL_VERIFY_FAIL_IF_NO_PEER_CERT - self.listening = False - self.listening_peer_address = None - self.pending_peer_address = None + self._listening = False + self._listening_peer_address = None + self._pending_peer_address = None self._config_ssl_ctx(verify_mode) - self.cb_keepalive = SSL_CTX_set_cookie_cb( - self.ctx.value, + self._cb_keepalive = SSL_CTX_set_cookie_cb( + self._ctx.value, _CallbackProxy(self._generate_cookie_cb), _CallbackProxy(self._verify_cookie_cb)) - self.ssl = _SSL(SSL_new(self.ctx.value)) - SSL_set_accept_state(self.ssl.value) + self._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_accept_state(self._ssl.value) def _init_client(self): - if self.sock.type != socket.SOCK_DGRAM: + if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") - self.wbio = _BIO(BIO_new_dgram(self.sock.fileno(), BIO_NOCLOSE)) - self.rbio = self.wbio - self.ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) - if self.cert_reqs == CERT_NONE: + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = self._wbio + self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method())) + if self._cert_reqs == CERT_NONE: verify_mode = SSL_VERIFY_NONE else: verify_mode = SSL_VERIFY_PEER self._config_ssl_ctx(verify_mode) - self.ssl = _SSL(SSL_new(self.ctx.value)) - SSL_set_connect_state(self.ssl.value) + self._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_connect_state(self._ssl.value) def _config_ssl_ctx(self, verify_mode): - SSL_CTX_set_verify(self.ctx.value, verify_mode) - SSL_CTX_set_read_ahead(self.ctx.value, 1) - if self.certfile: - SSL_CTX_use_certificate_chain_file(self.ctx.value, self.certfile) - if self.keyfile: - SSL_CTX_use_PrivateKey_file(self.ctx.value, self.keyfile, + SSL_CTX_set_verify(self._ctx.value, verify_mode) + SSL_CTX_set_read_ahead(self._ctx.value, 1) + if self._certfile: + SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile) + if self._keyfile: + SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile, SSL_FILE_TYPE_PEM) - if self.ca_certs: - SSL_CTX_load_verify_locations(self.ctx.value, self.ca_certs, None) - if self.ciphers: - SSL_CTX_set_cipher_list(self.ctx.value, self.ciphers) + if self._ca_certs: + SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None) + if self._ciphers: + SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) def _copy_server(self): - source = self.sock - self.sock = source.sock - self.udp_demux = source.udp_demux - self.rsock = self.udp_demux.get_connection(source.pending_peer_address) - self.wbio = _BIO(BIO_new_dgram(self.sock.fileno(), BIO_NOCLOSE)) - self.rbio = _BIO(BIO_new_dgram(self.rsock.fileno(), BIO_NOCLOSE)) - BIO_dgram_set_peer(self.wbio.value, source.pending_peer_address) - self.ctx = source.ctx - self.ssl = source.ssl - new_source_wbio = _BIO(BIO_new_dgram(source.sock.fileno(), + source = self._sock + self._sock = source._sock + self._udp_demux = source._udp_demux + self._rsock = self._udp_demux.get_connection( + source._pending_peer_address) + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + BIO_dgram_set_peer(self._wbio.value, source._pending_peer_address) + self._ctx = source._ctx + self._ssl = source._ssl + new_source_wbio = _BIO(BIO_new_dgram(source._sock.fileno(), BIO_NOCLOSE)) - new_source_rbio = _BIO(BIO_new_dgram(source.rsock.fileno(), + new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), BIO_NOCLOSE)) - source.ssl = _SSL(SSL_new(self.ctx.value)) - source.rbio = new_source_rbio - source.wbio = new_source_wbio - SSL_set_bio(source.ssl.value, + source._ssl = _SSL(SSL_new(self._ctx.value)) + source._rbio = new_source_rbio + source._wbio = new_source_wbio + SSL_set_bio(source._ssl.value, new_source_rbio.value, new_source_wbio.value) new_source_rbio.disown() new_source_wbio.disown() def _check_nbio(self): - BIO_set_nbio(self.wbio.value, self.sock.gettimeout() is not None) - if self.wbio is not self.rbio: - BIO_set_nbio(self.rbio.value, self.rsock.gettimeout() is not None) + BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) + if self._wbio is not self._rbio: + BIO_set_nbio(self._rbio.value, self._rsock.gettimeout() is not None) def _get_cookie(self, ssl): - assert self.listening - assert self.ssl.value._as_parameter == ssl._as_parameter - if self.listening_peer_address: - peer_address = self.listening_peer_address + assert self._listening + assert self._ssl.value._as_parameter == ssl._as_parameter + if self._listening_peer_address: + peer_address = self._listening_peer_address else: - peer_address = BIO_dgram_get_peer(self.rbio.value) + peer_address = BIO_dgram_get_peer(self._rbio.value) cookie_hmac = hmac.new(self._rnd_key, str(peer_address)) return cookie_hmac.digest() @@ -247,14 +223,15 @@ class SSLConnection(object): if not ciphers: ciphers = "DEFAULT" - self.sock = sock - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ca_certs = ca_certs - self.do_handshake_on_connect = do_handshake_on_connect - self.suppress_ragged_eofs = suppress_ragged_eofs - self.ciphers = ciphers + self._sock = sock + self._keyfile = keyfile + self._certfile = certfile + self._cert_reqs = cert_reqs + self._ca_certs = ca_certs + self._do_handshake_on_connect = do_handshake_on_connect + self._suppress_ragged_eofs = suppress_ragged_eofs + self._ciphers = ciphers + self._handshake_done = False if isinstance(sock, SSLConnection): self._copy_server() @@ -263,9 +240,24 @@ class SSLConnection(object): else: self._init_client() - SSL_set_bio(self.ssl.value, self.rbio.value, self.wbio.value) - self.rbio.disown() - self.wbio.disown() + SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) + self._rbio.disown() + self._wbio.disown() + + def get_socket(self, inbound): + """Retrieve a socket used by this connection + + When inbound is True, then the socket from which this connection reads + data is retrieved. Otherwise the socket to which this connection writes + data is retrieved. + + Read and write sockets differ depending on whether this is a server- or + a client-side connection, and on whether a routing demux is in use. + """ + + if inbound and hasattr(self, "_rsock"): + return self._rsock + return self._sock def listen(self): """Server-side cookie exchange @@ -285,9 +277,9 @@ class SSLConnection(object): encountered, None if a datagram for a known peer was forwarded """ - self.pending_peer_address = None + self._pending_peer_address = None try: - peer_address = self.udp_demux.service() + peer_address = self._udp_demux.service() except socket.timeout: peer_address = None except socket.error as sock_err: @@ -303,16 +295,16 @@ class SSLConnection(object): # The demux advises that a datagram from a new peer may have arrived if type(peer_address) is tuple: # For this type of demux, the write BIO must be pointed at the peer - BIO_dgram_set_peer(self.wbio.value, peer_address) - self.udp_demux.forward() - self.listening_peer_address = peer_address + BIO_dgram_set_peer(self._wbio.value, peer_address) + self._udp_demux.forward() + self._listening_peer_address = peer_address self._check_nbio() - self.listening = True + self._listening = True try: _logger.debug("Invoking DTLSv1_listen for ssl: %d", - self.ssl.value._as_parameter) - dtls_peer_address = DTLSv1_listen(self.ssl.value) + self._ssl.value._as_parameter) + dtls_peer_address = DTLSv1_listen(self._ssl.value) except OpenSSLError as err: if err.ssl_error == SSL_ERROR_WANT_READ: # This method must be called again to forward the next datagram @@ -324,15 +316,15 @@ class SSLConnection(object): _logger.exception("Unexpected error in DTLSv1_listen") raise finally: - self.listening = False - self.listening_peer_address = None + self._listening = False + self._listening_peer_address = None if type(peer_address) is tuple: _logger.debug("New local peer: %s", dtls_peer_address) - self.pending_peer_address = peer_address + self._pending_peer_address = peer_address else: - self.pending_peer_address = dtls_peer_address - _logger.debug("New peer: %s", self.pending_peer_address) - return self.pending_peer_address + self._pending_peer_address = dtls_peer_address + _logger.debug("New peer: %s", self._pending_peer_address) + return self._pending_peer_address def accept(self): """Server-side UDP connection establishment @@ -345,16 +337,16 @@ class SSLConnection(object): forwarding only to an existing peer occurred. """ - if not self.pending_peer_address: + if not self._pending_peer_address: if not self.listen(): _logger.debug("Accept returning without connection") return - new_conn = SSLConnection(self, self.keyfile, self.certfile, True, - self.cert_reqs, PROTOCOL_DTLSv1, - self.ca_certs, self.do_handshake_on_connect, - self.suppress_ragged_eofs, self.ciphers) - self.pending_peer_address = None - if self.do_handshake_on_connect: + new_conn = SSLConnection(self, self._keyfile, self._certfile, True, + self._cert_reqs, PROTOCOL_DTLSv1, + self._ca_certs, self._do_handshake_on_connect, + self._suppress_ragged_eofs, self._ciphers) + self._pending_peer_address = None + if self._do_handshake_on_connect: # Note that since that connection's socket was just created in its # constructor, the following operation must be blocking; hence # handshake-on-connect can only be used with a routing demux if @@ -375,10 +367,10 @@ class SSLConnection(object): peer_address - address tuple of server peer """ - self.sock.connect(peer_address) - BIO_dgram_set_connected(self.wbio.value, peer_address) - assert self.wbio is self.rbio - if self.do_handshake_on_connect: + self._sock.connect(peer_address) + BIO_dgram_set_connected(self._wbio.value, peer_address) + assert self._wbio is self._rbio + if self._do_handshake_on_connect: self.do_handshake() def do_handshake(self): @@ -390,7 +382,8 @@ class SSLConnection(object): _logger.debug("Initiating handshake...") self._check_nbio() - SSL_do_handshake(self.ssl.value) + SSL_do_handshake(self._ssl.value) + self._handshake_done = True _logger.debug("...completed handshake") def read(self, len=1024): @@ -405,7 +398,7 @@ class SSLConnection(object): """ self._check_nbio() - return SSL_read(self.ssl.value, len) + return SSL_read(self._ssl.value, len) def write(self, data): """Write data to connection @@ -420,7 +413,7 @@ class SSLConnection(object): """ self._check_nbio() - return SSL_write(self.ssl.value, data) + return SSL_write(self._ssl.value, data) def shutdown(self): """Shut down the DTLS connection @@ -432,7 +425,7 @@ class SSLConnection(object): self._check_nbio() try: - SSL_shutdown(self.ssl.value) + SSL_shutdown(self._ssl.value) except OpenSSLError as err: if err.result == 0: # close-notify alert was just sent; wait for same from peer @@ -440,6 +433,48 @@ class SSLConnection(object): # with SSL_set_read_ahead here, doing so causes a shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown # initiator side. - SSL_shutdown(self.ssl.value) + SSL_shutdown(self._ssl.value) else: raise + + def getpeercert(self, binary_form=False): + """Retrieve the peer's certificate + + When binary form is requested, the peer's DER-encoded certficate is + returned if it was transmitted during the handshake. + + When binary form is not requested, and the peer's certificate has been + validated, then a certificate dictionary is returned. If the certificate + was not validated, an empty dictionary is returned. + + In all cases, None is returned if no certificate was received from the + peer. + """ + + try: + peer_cert = _X509(SSL_get_peer_certificate(self._ssl.value)) + except OpenSSLError: + return + + if binary_form: + return i2d_X509(peer_cert.value) + if self._cert_reqs == CERT_NONE: + return {} + return decode_cert(peer_cert) + + def cipher(self): + """Retrieve information about the current cipher + + Return a triple consisting of cipher name, SSL protocol version defining + its use, and the number of secret bits. Return None if handshaking + has not been completed. + """ + + if not self._handshake_done: + return + + current_cipher = SSL_get_current_cipher(self._ssl.value) + cipher_name = SSL_CIPHER_get_name(current_cipher) + cipher_version = SSL_CIPHER_get_version(current_cipher) + cipher_bits = SSL_CIPHER_get_bits(current_cipher) + return cipher_name, cipher_version, cipher_bits diff --git a/dtls/test/certs/yahoo-cert.pem b/dtls/test/certs/yahoo-cert.pem new file mode 100644 index 0000000..d2cd76d --- /dev/null +++ b/dtls/test/certs/yahoo-cert.pem @@ -0,0 +1,29 @@ +-----BEGIN CERTIFICATE----- +MIIE6jCCBFOgAwIBAgIDEIGKMA0GCSqGSIb3DQEBBQUAME4xCzAJBgNVBAYTAlVT +MRAwDgYDVQQKEwdFcXVpZmF4MS0wKwYDVQQLEyRFcXVpZmF4IFNlY3VyZSBDZXJ0 +aWZpY2F0ZSBBdXRob3JpdHkwHhcNMTAwNDAxMjMwMDE0WhcNMTUwNzAzMDQ1MDAw +WjCBjzEpMCcGA1UEBRMgMmc4YU81d0kxYktKMlpENTg4VXNMdkRlM2dUYmc4RFUx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRIwEAYDVQQHEwlTdW5u +eXZhbGUxFDASBgNVBAoTC1lhaG9vICBJbmMuMRYwFAYDVQQDEw13d3cueWFob28u +Y29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6ZM1jHCkL8rlEKse +1riTTxyC3WvYQ5m34TlFK7dK4QFI/HPttKGqQm3aVB1Fqi0aiTxe4YQMbd++jnKt +djxcpi7sJlFxjMZs4umr1eGo2KgTgSBAJyhxo23k+VpK1SprdPyM3yEfQVdV7JWC +4Y71CE2nE6+GbsIuhk/to+jJMO7jXx/430jvo8vhNPL6GvWe/D6ObbnxS72ynLSd +mLtaltykOvZEZiXbbFKgIaYYmCgh89FGVvBkUbGM/Wb5Voiz7ttQLLxKOYRj8Mdk +TZtzPkM9scIFG1naECPvCxw0NyMyxY3nFOdjUKJ79twanmfCclX2ZO/rk1CpiOuw +lrrr/QIDAQABo4ICDjCCAgowDgYDVR0PAQH/BAQDAgTwMB0GA1UdDgQWBBSmrfKs +68m+dDUSf+S7xJrQ/FXAlzA6BgNVHR8EMzAxMC+gLaArhilodHRwOi8vY3JsLmdl +b3RydXN0LmNvbS9jcmxzL3NlY3VyZWNhLmNybDCCAVsGA1UdEQSCAVIwggFOgg13 +d3cueWFob28uY29tggl5YWhvby5jb22CDHVzLnlhaG9vLmNvbYIMa3IueWFob28u +Y29tggx1ay55YWhvby5jb22CDGllLnlhaG9vLmNvbYIMZnIueWFob28uY29tggxp +bi55YWhvby5jb22CDGNhLnlhaG9vLmNvbYIMYnIueWFob28uY29tggxkZS55YWhv +by5jb22CDGVzLnlhaG9vLmNvbYIMbXgueWFob28uY29tggxpdC55YWhvby5jb22C +DHNnLnlhaG9vLmNvbYIMaWQueWFob28uY29tggxwaC55YWhvby5jb22CDHFjLnlh +aG9vLmNvbYIMdHcueWFob28uY29tggxoay55YWhvby5jb22CDGNuLnlhaG9vLmNv +bYIMYXUueWFob28uY29tggxhci55YWhvby5jb22CDHZuLnlhaG9vLmNvbTAfBgNV +HSMEGDAWgBRI5mj5K9KylddH2CMgEE8zmJCf1DAdBgNVHSUEFjAUBggrBgEFBQcD +AQYIKwYBBQUHAwIwDQYJKoZIhvcNAQEFBQADgYEAp9WOMtcDMM5T0yfPecGv5QhH +RJZRzgeMPZitLksr1JxxicJrdgv82NWq1bw8aMuRj47ijrtaTEWXaCQCy00yXodD +zoRJVNoYIvY1arYZf5zv9VZjN5I0HqUc39mNMe9XdZtbkWE+K6yVh6OimKLbizna +inu9YTrN/4P/w6KzHho= +-----END CERTIFICATE----- diff --git a/dtls/test/echo_seq.py b/dtls/test/echo_seq.py index c93b78f..99bb9cb 100644 --- a/dtls/test/echo_seq.py +++ b/dtls/test/echo_seq.py @@ -47,7 +47,7 @@ def main(): print "Accepting..." conn = scn.accept() sck.settimeout(5) - conn.rsock.settimeout(5) + conn.get_socket(True).settimeout(5) cnt = 0 while True: diff --git a/dtls/test/rl.py b/dtls/test/rl.py index 05e5ff6..7000694 100644 --- a/dtls/test/rl.py +++ b/dtls/test/rl.py @@ -8,7 +8,9 @@ the IPython shell. import dtls import dtls.err +import dtls.util import dtls.sslconnection +import dtls.x509 import dtls.openssl import dtls.demux import dtls.demux.router @@ -16,11 +18,14 @@ import dtls.demux.router def main(): reload(dtls) reload(dtls.err) + reload(dtls.util) reload(dtls.sslconnection) + reload(dtls.x509) reload(dtls.openssl) reload(dtls.demux) reload(dtls.demux.router) reload(dtls.sslconnection) + reload(dtls.x509) if __name__ == "__main__": main() diff --git a/dtls/util.py b/dtls/util.py new file mode 100644 index 0000000..4dc10c3 --- /dev/null +++ b/dtls/util.py @@ -0,0 +1,38 @@ +# Shared implementation internals. Written by Ray Brown. +"""Utilities + +This module contains private implementation details shared among modules of +the PyDTLS package. +""" + +from logging import getLogger + +_logger = getLogger(__name__) + + +class _Rsrc(object): + """Wrapper base for library-owned resources""" + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + +class _BIO(_Rsrc): + """BIO wrapper""" + def __init__(self, value): + super(_BIO, self).__init__(value) + self.owned = True + + def disown(self): + self.owned = False + + def __del__(self): + if self.owned: + _logger.debug("Freeing BIO: %d", self._value._as_parameter) + from openssl import BIO_free + BIO_free(self._value) + self.owned = False + self._value = None diff --git a/dtls/x509.py b/dtls/x509.py new file mode 100644 index 0000000..e43310c --- /dev/null +++ b/dtls/x509.py @@ -0,0 +1,124 @@ +# X509: certificate support. Written by Ray Brown. +"""X509 Certificate + +This module provides support for X509 certificates through the OpenSSL library. +This support includes mapping certificate data to Python dictionaries in the +manner established by the Python standard library's ssl module. This module is +required because the standard library's ssl module does not provide its support +for certificates from arbitrary sources, but instead only for certificates +retrieved from servers during handshaking or get_server_certificate by its +CPython _ssl implementation module. This author is aware of the latter module's +_test_decode_certificate function, but has decided not to use this function +because it is undocumented, and because its use would tie PyDTLS to the CPython +interpreter. +""" + +from logging import getLogger +from openssl import * +from util import _Rsrc, _BIO + +_logger = getLogger(__name__) + + +class _X509(_Rsrc): + """Wrapper for the cryptographic library's X509 resource""" + def __init__(self, value): + super(_X509, self).__init__(value) + + def __del__(self): + _logger.debug("Freeing X509: %d", self._value._as_parameter) + X509_free(self._value) + self._value = None + + +class _STACK(_Rsrc): + """Wrapper for the cryptographic library's stacks""" + def __init__(self, value): + super(_STACK, self).__init__(value) + + def __del__(self): + _logger.debug("Freeing stack: %d", self._value._as_parameter) + sk_pop_free(self._value) + self._value = None + + +def decode_cert(cert): + """Convert an X509 certificate into a Python dictionary + + This function converts the given X509 certificate into a Python dictionary + in the manner established by the Python standard library's ssl module. + """ + + ret_dict = {} + subject_xname = X509_get_subject_name(cert.value) + ret_dict["subject"] = _create_tuple_for_X509_NAME(subject_xname) + + notAfter = X509_get_notAfter(cert.value) + ret_dict["notAfter"] = ASN1_TIME_print(notAfter) + + peer_alt_names = _get_peer_alt_names(cert) + if peer_alt_names is not None: + ret_dict["subjectAltName"] = peer_alt_names + + return ret_dict + +def _test_decode_cert(cert_filename): + """format_cert testing + + Test the certificate conversion functionality with a PEM-encoded X509 + certificate. + """ + + cert_file = _BIO(BIO_new_file(cert_filename, "rb")) + cert = _X509(PEM_read_bio_X509_AUX(cert_file.value)) + return decode_cert(cert) + +def _create_tuple_for_attribute(name, value): + name_str = OBJ_obj2txt(name, False) + value_str = decode_ASN1_STRING(value) + return name_str, value_str + +def _create_tuple_for_X509_NAME(xname): + distinguished_name = [] + relative_distinguished_name = [] + level = -1 + for ind in range(X509_NAME_entry_count(xname)): + name_entry_ptr = X509_NAME_get_entry(xname, ind) + name_entry = name_entry_ptr.contents + if level >= 0 and level != name_entry.set: + distinguished_name.append(tuple(relative_distinguished_name)) + relative_distinguished_name = [] + level = name_entry.set + asn1_object = X509_NAME_ENTRY_get_object(name_entry_ptr) + asn1_string = X509_NAME_ENTRY_get_data(name_entry_ptr) + attribute_tuple = _create_tuple_for_attribute(asn1_object, asn1_string) + relative_distinguished_name.append(attribute_tuple) + if relative_distinguished_name: + distinguished_name.append(tuple(relative_distinguished_name)) + return tuple(distinguished_name) + +def _get_peer_alt_names(cert): + ret_list = None + ext_index = -1 + while True: + ext_index = X509_get_ext_by_NID(cert.value, NID_subject_alt_name, + ext_index) + if ext_index < 0: + break + if ret_list is None: + ret_list = [] + ext_ptr = X509_get_ext(cert.value, ext_index) + method_ptr = X509V3_EXT_get(ext_ptr) + general_names = _STACK(ASN1_item_d2i(method_ptr.contents, + ext_ptr.contents.value.contents)) + for name_index in range(sk_num(general_names.value)): + name_ptr = sk_value(general_names.value, name_index) + if name_ptr.contents.type == GEN_DIRNAME: + name_tuple = "DirName", \ + _create_tuple_for_X509_NAME(name_ptr.contents.d.directoryName) + else: + name_str = GENERAL_NAME_print(name_ptr) + name_tuple = tuple(name_str.split(':', 1)) + ret_list.append(name_tuple) + + return tuple(ret_list) if ret_list is not None else None