diff --git a/ChangeLog b/ChangeLog index 1b6899f..7717a03 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,14 @@ +2017-03-17 Björn Freise + + Added methods getting the curves supported by the runtime openSSL lib + + * dtls/openssl.py: + - Added class _EllipticCurve() for easy handling of the builtin curves + - Added wrapper get_elliptic_curves() - which uses _EllipticCurve() + - Added EC_get_builtin_curves(), EC_KEY_new_by_curve_name() and EC_KEY_free() + - Added OBJ_nid2sn() for translating numeric ids to names + * dtls/util.py: Added _EC_KEY() derived from _Rsrc() with own free/del method + 2017-03-17 Björn Freise Added methods for setting and getting the curves used during negotiation and encryption diff --git a/dtls/openssl.py b/dtls/openssl.py index 4323cd6..31fb244 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -37,13 +37,13 @@ import array import socket from logging import getLogger from os import path -from datetime import timedelta -from err import openssl_error -from err import SSL_ERROR_NONE -from util import _BIO -import ctypes -from ctypes import CDLL -from ctypes import CFUNCTYPE +from datetime import timedelta +from err import openssl_error +from err import SSL_ERROR_NONE +from util import _EC_KEY, _BIO +import ctypes +from ctypes import CDLL +from ctypes import CFUNCTYPE from ctypes import c_void_p, c_int, c_long, c_uint, c_ulong, c_char_p, c_size_t from ctypes import c_short, c_ushort, c_ubyte, c_char from ctypes import byref, POINTER, addressof @@ -169,6 +169,56 @@ X509_NAME_MAXLEN = 256 GETS_MAXLEN = 2048 +class _EllipticCurve(object): + _curves = None + + @classmethod + def _get_elliptic_curves(cls): + if cls._curves is None: + # Load once + cls._curves = cls._load_elliptic_curves() + return cls._curves + + @classmethod + def _load_elliptic_curves(cls): + num_curves = EC_get_builtin_curves(None, 0) + if num_curves > 0: + builtin_curves = create_string_buffer(sizeof(EC_builtin_curve) * num_curves) + EC_get_builtin_curves(cast(builtin_curves, POINTER(EC_builtin_curve)), num_curves) + return [cls(c.nid, OBJ_nid2sn(c.nid)) for c in cast(builtin_curves, POINTER(EC_builtin_curve))[:num_curves]] + return [] + + def __init__(self, nid, name): + self.nid = nid + self.name = name + + def __repr__(self): + return "" % (self.nid, self.name) + + def to_EC_KEY(self): + key = _EC_KEY(EC_KEY_new_by_curve_name(self.nid)) + return key if bool(key.value) else None + + +def get_elliptic_curves(): + u''' Return the available curves. If not yet loaded, then load them once. + + :rtype: list + ''' + return _EllipticCurve._get_elliptic_curves() + + +def get_elliptic_curve(name): + u''' Return the curve from the given name. + + :rtype: _EllipticCurve + ''' + for curve in get_elliptic_curves(): + if curve.name == name: + return curve + raise ValueError("unknown curve name", name) + + # # Parameter data types # @@ -223,12 +273,17 @@ class SSL(FuncParam): class BIO(FuncParam): def __init__(self, value): - super(BIO, self).__init__(value) - - -class X509(FuncParam): - def __init__(self, value): - super(X509, self).__init__(value) + super(BIO, self).__init__(value) + + +class EC_KEY(FuncParam): + def __init__(self, value): + super(EC_KEY, self).__init__(value) + + +class X509(FuncParam): + def __init__(self, value): + super(X509, self).__init__(value) class X509_val_st(Structure): @@ -328,12 +383,17 @@ class X509V3_EXT_METHOD(Structure): class TIMEVAL(Structure): _fields_ = [("tv_sec", c_long), - ("tv_usec", c_long)] - - -# -# Socket address conversions -# + ("tv_usec", c_long)] + + +class EC_builtin_curve(Structure): + _fields_ = [("nid", c_int), + ("comment", c_char_p)] + + +# +# Socket address conversions +# class sockaddr_storage(Structure): _fields_ = [("ss_family", c_short), ("pad", c_char * 126)] @@ -570,11 +630,13 @@ __all__ = [ "SSL_state_string_long", "SSL_alert_type_string_long", "SSL_alert_desc_string_long", "SSL_CTX_set_cookie_cb", "OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print", + "OBJ_nid2sn", "X509_get_notAfter", "ASN1_item_d2i", "GENERAL_NAME_print", "sk_value", "sk_pop_free", "i2d_X509", + "get_elliptic_curves", ] # note: the following map adds to this list map(lambda x: _make_function(*x), ( @@ -690,6 +752,8 @@ map(lambda x: _make_function(*x), ( ((X509, "ret"), (BIO, "bp"), (c_void_p, "x", 1, None), (c_void_p, "cb", 1, None), (c_void_p, "u", 1, None))), ("OBJ_obj2txt", libcrypto, ((c_int, "ret"), (POINTER(c_char), "buf"), (c_int, "buf_len"), (ASN1_OBJECT, "a"), (c_int, "no_name")), False), + ("OBJ_nid2sn", libcrypto, + ((c_char_p, "ret"), (c_int, "n")), False), ("CRYPTO_free", libcrypto, ((None, "ret"), (c_void_p, "ptr"))), ("ASN1_STRING_to_UTF8", libcrypto, @@ -733,6 +797,12 @@ map(lambda x: _make_function(*x), ( ((c_char_p, "ret"), (SSL_CIPHER, "cipher"))), ("SSL_CIPHER_get_bits", libssl, ((c_int, "ret"), (SSL_CIPHER, "cipher"), (POINTER(c_int), "alg_bits", 1, None)), True, None), + ("EC_get_builtin_curves", libcrypto, + ((c_int, "ret"), (POINTER(EC_builtin_curve), "r"), (c_int, "nitems"))), + ("EC_KEY_new_by_curve_name", libcrypto, + ((EC_KEY, "ret"), (c_int, "nid"))), + ("EC_KEY_free", libcrypto, + ((None, "ret"), (EC_KEY, "key"))), )) # @@ -821,9 +891,10 @@ def SSL_CTX_build_cert_chain(ctx, flags): def SSL_CTX_set_ecdh_auto(ctx, onoff): return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, None) -def SSL_CTX_set_tmp_ecdh(ctx, ecdh): +def SSL_CTX_set_tmp_ecdh(ctx, ec_key): # return 1 on success and 0 on failure - return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, ecdh) + _ec_key_p = cast(ec_key.raw, c_void_p) + return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, _ec_key_p) _rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) @@ -1001,11 +1072,15 @@ def SSL_alert_desc_string_long(value): def OBJ_obj2txt(asn1_object, no_name): buf = create_string_buffer(X509_NAME_MAXLEN) res_len = _OBJ_obj2txt(buf, sizeof(buf), asn1_object, 1 if no_name else 0) - return buf.raw[:res_len] - -def decode_ASN1_STRING(asn1_string): - utf8_buf_ptr = POINTER(c_ubyte)() - res_len = _ASN1_STRING_to_UTF8(byref(utf8_buf_ptr), asn1_string) + return buf.raw[:res_len] + +def OBJ_nid2sn(nid): + _name = _OBJ_nid2sn(nid) + return cast(_name, c_char_p).value.decode("ascii") + +def decode_ASN1_STRING(asn1_string): + utf8_buf_ptr = POINTER(c_ubyte)() + res_len = _ASN1_STRING_to_UTF8(byref(utf8_buf_ptr), asn1_string) try: return unicode(''.join([chr(i) for i in utf8_buf_ptr[:res_len]]), 'utf-8') diff --git a/dtls/util.py b/dtls/util.py index 13e20f5..3a1e64d 100644 --- a/dtls/util.py +++ b/dtls/util.py @@ -54,6 +54,18 @@ class _BIO(_Rsrc): if self.owned: _logger.debug("Freeing BIO: %d", self.raw) from openssl import BIO_free - BIO_free(self._value) - self.owned = False - self._value = None + BIO_free(self._value) + self.owned = False + self._value = None + + +class _EC_KEY(_Rsrc): + """EC KEY wrapper""" + def __init__(self, value): + super(_EC_KEY, self).__init__(value) + + def __del__(self): + _logger.debug("Freeing EC_KEY: %d", self.raw) + from openssl import EC_KEY_free + EC_KEY_free(self._value) + self._value = None