Added supoort for EC-Functions and DTLSv1.2 (Updated OpenSSL-Lib to V1.0.2j)

Attention: DLL updates for Win32 only! (Till now)
incoming
mcfreis 2017-02-21 14:48:13 +01:00
parent 34dc9ca9cd
commit 5776e38445
6 changed files with 820 additions and 28 deletions

View File

@ -84,6 +84,7 @@ BIO_NOCLOSE = 0x00
BIO_CLOSE = 0x01
SSLEAY_VERSION = 0
SSL_OP_NO_COMPRESSION = 0x00020000
SSL_OP_NO_QUERY_MTU = 0x00001000
SSL_VERIFY_NONE = 0x00
SSL_VERIFY_PEER = 0x01
SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02
@ -97,17 +98,58 @@ SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = 0x0100
SSL_SESS_CACHE_NO_INTERNAL_STORE = 0x0200
SSL_SESS_CACHE_NO_INTERNAL = \
SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE
SSL_BUILD_CHAIN_FLAG_UNTRUSTED = 0x1
SSL_BUILD_CHAIN_FLAG_NO_ROOT = 0x2
SSL_BUILD_CHAIN_FLAG_CHECK = 0x4
SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = 0x8
SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = 0x10
SSL_FILE_TYPE_PEM = 1
GEN_DIRNAME = 4
NID_subject_alt_name = 85
CRYPTO_LOCK = 1
SSL_ST_MASK = 0x0FFF
SSL_ST_CONNECT = 0x1000
SSL_ST_ACCEPT = 0x2000
SSL_ST_INIT = (SSL_ST_CONNECT | SSL_ST_ACCEPT)
SSL_ST_BEFORE = 0x4000
SSL_ST_OK = 0x03
SSL_ST_RENEGOTIATE = (0x04 | SSL_ST_INIT)
SSL_ST_ERR = 0x05
SSL_CB_LOOP = 0x01
SSL_CB_EXIT = 0x02
SSL_CB_READ = 0x04
SSL_CB_WRITE = 0x08
SSL_CB_ALERT = 0x4000
SSL_CB_READ_ALERT = (SSL_CB_ALERT | SSL_CB_READ)
SSL_CB_WRITE_ALERT = (SSL_CB_ALERT | SSL_CB_WRITE)
SSL_CB_ACCEPT_LOOP = (SSL_ST_ACCEPT | SSL_CB_LOOP)
SSL_CB_ACCEPT_EXIT = (SSL_ST_ACCEPT | SSL_CB_EXIT)
SSL_CB_CONNECT_LOOP = (SSL_ST_CONNECT | SSL_CB_LOOP)
SSL_CB_CONNECT_EXIT = (SSL_ST_CONNECT | SSL_CB_EXIT)
SSL_CB_HANDSHAKE_START = 0x10
SSL_CB_HANDSHAKE_DONE = 0x20
#
# Integer constants - internal
#
SSL_CTRL_SET_SESS_CACHE_MODE = 44
SSL_CTRL_SET_READ_AHEAD = 41
SSL_CTRL_OPTIONS = 32
SSL_CTRL_CLEAR_OPTIONS = 77
SSL_CTRL_SET_ECDH_AUTO = 94
SSL_CTRL_BUILD_CERT_CHAIN = 105
SSL_CTRL_SET_MTU = 17
SSL_CTRL_GET_CURVES = 90
SSL_CTRL_SET_CURVES = 91
SSL_CTRL_SET_CURVES_LIST = 92
SSL_CTRL_GET_SHARED_CURVE = 93
SSL_CTRL_SET_SIGALGS = 97
SSL_CTRL_SET_SIGALGS_LIST = 98
SSL_CTRL_SET_CLIENT_SIGALGS = 101
SSL_CTRL_SET_CLIENT_SIGALGS_LIST = 102
SSL_CTRL_SET_TMP_ECDH = 4
BIO_CTRL_INFO = 3
BIO_CTRL_DGRAM_SET_CONNECTED = 32
BIO_CTRL_DGRAM_GET_PEER = 46
@ -119,6 +161,53 @@ DTLS_CTRL_LISTEN = 75
X509_NAME_MAXLEN = 256
GETS_MAXLEN = 2048
class _EllipticCurve(object):
_curves = None
@classmethod
def _load_elliptic_curves(cls, lib):
num_curves = lib.EC_get_builtin_curves(0, 0)
if num_curves > 0:
builtin_curves = create_string_buffer(sizeof(EC_builtin_curve) * num_curves)
lib.EC_get_builtin_curves(builtin_curves, num_curves)
return set(cls.from_nid(lib, c.nid) for c in cast(builtin_curves, POINTER(EC_builtin_curve))[:num_curves])
return set()
@classmethod
def _get_elliptic_curves(cls, lib):
if cls._curves is None:
cls._curves = cls._load_elliptic_curves(lib)
return cls._curves
@classmethod
def from_nid(cls, lib, nid):
return cls(lib, nid, cast(lib.OBJ_nid2sn(nid), c_char_p).value.decode("ascii"))
def __init__(self, lib, nid, name):
self._lib = lib
self._nid = nid
self.name = name
def __repr__(self):
return "<Curve %r>" % (self.name,)
def as_EC_KEY(self):
key = self._lib.EC_KEY_new_by_curve_name(self._nid)
return key
def get_elliptic_curves():
return _EllipticCurve._get_elliptic_curves(libcrypto)
def get_elliptic_curve(name):
for curve in get_elliptic_curves():
if curve.name == name:
return curve
raise ValueError("unknown curve name", name)
#
# Parameter data types
#
@ -171,6 +260,11 @@ class SSL(FuncParam):
super(SSL, self).__init__(value)
class EC_builtin_curve(Structure):
_fields_ = [("nid", c_int),
("comment", c_char_p)]
class BIO(FuncParam):
def __init__(self, value):
super(BIO, self).__init__(value)
@ -472,14 +566,22 @@ _subst = {c_long_parm: c_long}
_sigs = {}
__all__ = ["BIO_NOCLOSE", "BIO_CLOSE",
"SSLEAY_VERSION",
"SSL_OP_NO_COMPRESSION",
"SSL_OP_NO_COMPRESSION", "SSL_OP_NO_QUERY_MTU",
"SSL_VERIFY_NONE", "SSL_VERIFY_PEER",
"SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE",
"SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT",
"SSL_SESS_CACHE_SERVER", "SSL_SESS_CACHE_BOTH",
"SSL_SESS_CACHE_NO_AUTO_CLEAR", "SSL_SESS_CACHE_NO_INTERNAL_LOOKUP",
"SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL",
"SSL_ST_MASK", "SSL_ST_CONNECT", "SSL_ST_ACCEPT", "SSL_ST_INIT", "SSL_ST_BEFORE", "SSL_ST_OK",
"SSL_ST_RENEGOTIATE", "SSL_ST_ERR", "SSL_CB_LOOP", "SSL_CB_EXIT", "SSL_CB_READ", "SSL_CB_WRITE",
"SSL_CB_ALERT", "SSL_CB_READ_ALERT", "SSL_CB_WRITE_ALERT",
"SSL_CB_ACCEPT_LOOP", "SSL_CB_ACCEPT_EXIT",
"SSL_CB_CONNECT_LOOP", "SSL_CB_CONNECT_EXIT",
"SSL_CB_HANDSHAKE_START", "SSL_CB_HANDSHAKE_DONE",
"SSL_FILE_TYPE_PEM",
"SSL_BUILD_CHAIN_FLAG_NO_ROOT",
"SSL_state_string_long",
"GEN_DIRNAME", "NID_subject_alt_name",
"CRYPTO_LOCK",
"CRYPTO_set_locking_callback",
@ -489,15 +591,28 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE",
"BIO_dgram_set_connected",
"BIO_dgram_get_peer", "BIO_dgram_set_peer",
"BIO_set_nbio",
"SSL_CTX_set_ecdh_auto",
"SSL_CTX_build_cert_chain",
"SSL_CTX_set_session_cache_mode", "SSL_CTX_set_read_ahead",
"SSL_CTX_set_options",
"SSL_set_options", "SSL_clear_options", "SSL_get_options",
"SSL_get1_curves", "SSL_CTX_set1_curves", "SSL_CTX_set1_curves_list", "SSL_set1_curves_list",
"SSL_CTX_set_tmp_ecdh",
"SSL_CTX_set_info_callback",
"SSL_set_mtu",
"SSL_read", "SSL_write",
"SSL_CTX_set_cookie_cb",
"SSL_set1_client_sigalgs_list", "SSL_set1_client_sigalgs",
"SSL_CTX_set1_client_sigalgs_list", "SSL_CTX_set1_client_sigalgs",
"SSL_set1_sigalgs_list", "SSL_set1_sigalgs",
"SSL_CTX_set1_sigalgs_list", "SSL_CTX_set1_sigalgs",
"OBJ_obj2txt", "decode_ASN1_STRING", "ASN1_TIME_print",
"X509_get_notAfter",
"ASN1_item_d2i", "GENERAL_NAME_print",
"EC_KEY_free",
"sk_value",
"sk_pop_free",
"get_elliptic_curves",
"i2d_X509"] # note: the following map adds to this list
map(lambda x: _make_function(*x), (
@ -511,6 +626,9 @@ map(lambda x: _make_function(*x), (
("CRYPTO_num_locks", libcrypto, ((c_int, "ret"),)),
("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_2_client_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLSv1_2_server_method", libssl, ((DTLSv1Method, "ret"),)),
("DTLS_server_method", libssl, ((DTLSv1Method, "ret"),)),
("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))),
("SSL_CTX_free", libssl, ((None, "ret"), (SSLCTX, "ctx"))),
("SSL_CTX_set_cookie_generate_cb", libssl,
@ -563,6 +681,8 @@ map(lambda x: _make_function(*x), (
("SSL_CTX_set_verify", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_int, "mode"),
(c_void_p, "verify_callback", 1, None))),
("SSL_CTX_set_verify_depth", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_int, "depth"))),
("SSL_accept", libssl, ((c_int, "ret"), (SSL, "ssl"))),
("SSL_connect", libssl, ((c_int, "ret"), (SSL, "ssl"))),
("SSL_set_connect_state", libssl, ((None, "ret"), (SSL, "ssl"))),
@ -630,6 +750,22 @@ map(lambda x: _make_function(*x), (
("SSL_CIPHER_get_bits", libssl,
((c_int, "ret"), (SSL_CIPHER, "cipher"),
(POINTER(c_int), "alg_bits", 1, None)), True, None),
("EC_get_builtin_curves", libcrypto,
((c_int, "ret"), (POINTER(EC_builtin_curve), "r"), (c_int, "nitems"))),
("EC_KEY_new_by_curve_name", libcrypto,
((POINTER(c_char), "ret"), (c_int, "nid"))),
("EC_KEY_free", libcrypto,
((None, "ret"), (POINTER(c_char), "key")), False),
("OBJ_nid2sn", libcrypto,
((c_char_p, "ret"), (c_int, "n"))),
("EC_curve_nist2nid", libcrypto,
((c_int, "ret"), (POINTER(c_char), "name")), True, None),
("EC_curve_nid2nist", libcrypto,
((c_char_p, "ret"), (c_int, "nid")), True, None),
("SSL_CTX_set_info_callback", libssl,
((None, "ret"), (SSLCTX, "ctx"), (c_void_p, "callback")), False),
("SSL_state_string_long", libssl,
((c_char_p, "ret"), (SSL, "ssl")), False),
))
#
@ -648,6 +784,15 @@ def CRYPTO_set_locking_callback(locking_function):
_locking_cb = _rvoid_int_int_charp_int(py_locking_function)
_CRYPTO_set_locking_callback(_locking_cb)
def SSL_set_mtu(ssl, mtu):
return _SSL_ctrl(ssl, SSL_CTRL_SET_MTU, mtu, None)
def SSL_CTX_build_cert_chain(ctx, flags):
return _SSL_CTX_ctrl(ctx, SSL_CTRL_BUILD_CERT_CHAIN, flags, None)
def SSL_CTX_set_ecdh_auto(ctx, onoff):
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_ECDH_AUTO, onoff, None)
def SSL_CTX_set_session_cache_mode(ctx, mode):
# Returns the previous value of mode
_SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SESS_CACHE_MODE, mode, None)
@ -658,7 +803,47 @@ def SSL_CTX_set_read_ahead(ctx, m):
def SSL_CTX_set_options(ctx, options):
# Returns the new option bitmaks after adding the given options
_SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_OPTIONS, options, None)
def SSL_set_options(ssl, op):
return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, op, None)
def SSL_clear_options(ssl, op):
return _SSL_ctrl(ssl, SSL_CTRL_CLEAR_OPTIONS, op, None)
def SSL_get_options(ssl):
return _SSL_ctrl(ssl, SSL_CTRL_OPTIONS, 0, None)
def SSL_get1_curves(ssl, s):
_s = cast(s, POINTER(c_int))
if s:
mem = None
cnt = SSL_get1_curves(ssl, 0)
if cnt >= s:
mem = create_string_buffer(sizeof(POINTER(c_int)) * s)
ret = _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, s, mem)
return [x for x in cast(mem, POINTER(c_int))[:s]]
else:
return _SSL_ctrl(ssl, SSL_CTRL_GET_CURVES, 0, _s)
def SSL_CTX_set1_curves(ctx, clist, clistlen):
_curves = (c_int * len(clist))(*clist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES, len(_curves), _curves)
def SSL_get_shared_curve(ssl, n):
return _SSL_ctrl(ssl, SSL_CTRL_GET_SHARED_CURVE, n, 0)
def SSL_CTX_set1_curves_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CURVES_LIST, 0, _s)
def SSL_set1_curves_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_CURVES_LIST, 0, _s)
def SSL_CTX_set_tmp_ecdh(ctx, ecdh):
# return 1 on success and 0 on failure
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TMP_ECDH, 0, ecdh)
_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte),
POINTER(c_uint))
@ -745,6 +930,8 @@ def SSL_write(ssl, data):
str_data = data
elif hasattr(data, "tobytes") and callable(data.tobytes):
str_data = data.tobytes()
elif isinstance(data, ctypes.Array):
str_data = data.raw
else:
str_data = str(data)
return _SSL_write(ssl, str_data, len(str_data))
@ -827,3 +1014,73 @@ def i2d_X509(x509):
bio = _BIO(BIO_new(BIO_s_mem()))
_i2d_X509_bio(bio.value, x509)
return BIO_get_mem_data(bio.value)
def EC_KEY_free(key):
_EC_KEY_free(cast(key, POINTER(c_char)))
_rvoid_voidp_int_int = CFUNCTYPE(None, c_void_p, c_int, c_int)
def SSL_CTX_set_info_callback(ctx, callback):
"""
Set the info callback
:param callback: The Python callback to use
:return: None
"""
def py_info_callback(ssl, where, ret):
try:
callback(SSL(ssl), where, ret)
except:
pass
return
global _info_callback
_info_callback = _rvoid_voidp_int_int(py_info_callback)
_SSL_CTX_set_info_callback(ctx, _info_callback)
def SSL_state_string_long(ssl):
try:
ret = _SSL_state_string_long(ssl)
except:
pass
return ret
# sigalgs_list: (Only for DTLS v1.2)
#
# The short or long name values for digests can be used in a string (for example "MD5", "SHA1", "SHA224", "SHA256", "SHA384", "SHA512")
# and the public key algorithm strings "RSA", "DSA" or "ECDSA".
# The use of MD5 as a digest is strongly discouraged due to security weaknesses.
# Example: "ECDSA+SHA256:RSA+SHA256"
def SSL_CTX_set1_client_sigalgs(ctx, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist)
def SSL_CTX_set1_client_sigalgs_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s)
def SSL_set1_client_sigalgs(ssl, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS, len(_slist), _slist)
def SSL_set1_client_sigalgs_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_CLIENT_SIGALGS_LIST, 0, _s)
def SSL_CTX_set1_sigalgs(ctx, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS, len(_slist), _slist)
def SSL_CTX_set1_sigalgs_list(ctx, s):
_s = cast(s, POINTER(c_char))
return _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_SIGALGS_LIST, 0, _s)
def SSL_set1_sigalgs(ssl, slist, slistlen):
_slist = (c_int * len(slist))(*slist)
return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS, len(_slist), _slist)
def SSL_set1_sigalgs_list(ssl, s):
_s = cast(s, POINTER(c_char))
return _SSL_ctrl(ssl, SSL_CTRL_SET_SIGALGS_LIST, 0, _s)

View File

@ -34,16 +34,18 @@ has the following effects:
PROTOCOL_DTLSv1 for the parameter ssl_version is supported
"""
from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_DGRAM, getaddrinfo
from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION
from sslconnection import DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
from types import MethodType
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE
from types import MethodType, BuiltinMethodType
from weakref import proxy
import errno
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
from err import raise_as_ssl_module_error
def do_patch():
import ssl as _ssl # import to be avoided if ssl module is never patched
global _orig_SSLSocket_init, _orig_get_server_certificate
@ -51,8 +53,19 @@ def do_patch():
ssl = _ssl
if hasattr(ssl, "PROTOCOL_DTLSv1"):
return
_orig_wrap_socket = ssl.wrap_socket
ssl.wrap_socket = _wrap_socket
SSLSocket_ = ssl.SSLSocket
class SSLSocket(SSLSocket_):
def __getattr__(self, item):
if hasattr(self, "_sslobj") and hasattr(self._sslobj, item):
return getattr(self._sslobj, item)
raise AttributeError
ssl.SSLSocket = SSLSocket
ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1
ssl.PROTOCOL_DTLSv1_2 = PROTOCOL_DTLSv1_2
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1"
ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1_2] = "DTLSv1.2"
ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
@ -62,8 +75,23 @@ def do_patch():
ssl.get_server_certificate = _get_server_certificate
raise_as_ssl_module_error()
PROTOCOL_SSLv3 = 1
PROTOCOL_SSLv23 = 2
def _wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None,
cb_user_ssl_ctx_config=None, cb_user_ssl_config=None):
return ssl.SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers,
cb_user_ssl_ctx_config=cb_user_ssl_ctx_config,
cb_user_ssl_config=cb_user_ssl_config)
def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve a server certificate
@ -74,10 +102,10 @@ def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
If 'ssl_version' is specified, use it in the connection attempt.
"""
if ssl_version != PROTOCOL_DTLSv1:
if ssl_version not in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2):
return _orig_get_server_certificate(addr, ssl_version, ca_certs)
if (ca_certs is not None):
if ca_certs is not None:
cert_reqs = ssl.CERT_REQUIRED
else:
cert_reqs = ssl.CERT_NONE
@ -94,7 +122,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
suppress_ragged_eofs=True, ciphers=None,
cb_user_ssl_ctx_config=None, cb_user_ssl_config=None):
is_connection = is_datagram = False
if isinstance(sock, SSLConnection):
is_connection = True
@ -138,7 +167,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
server_side, cert_reqs,
ssl_version, ca_certs,
do_handshake_on_connect,
suppress_ragged_eofs, ciphers)
suppress_ragged_eofs, ciphers,
cb_user_ssl_ctx_config, cb_user_ssl_config)
else:
self._sslobj = sock
@ -150,6 +180,8 @@ def _SSLSocket_init(self, sock, keyfile=None, certfile=None,
self.ciphers = ciphers
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
self.cb_user_ssl_ctx_config = cb_user_ssl_ctx_config
self.cb_user_ssl_config = cb_user_ssl_config
self._makefile_refs = 0
# Perform method substitution and addition (without reference cycle)
@ -169,7 +201,8 @@ def _SSLSocket_listen(self, ignored):
self.cert_reqs, self.ssl_version,
self.ca_certs,
self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers)
self.suppress_ragged_eofs, self.ciphers,
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
def _SSLSocket_accept(self):
if self._connected:
@ -184,7 +217,8 @@ def _SSLSocket_accept(self):
self.cert_reqs, self.ssl_version,
self.ca_certs,
self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers)
self.suppress_ragged_eofs, self.ciphers,
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
return new_ssl_sock, addr
def _SSLSocket_real_connect(self, addr, return_errno):
@ -195,7 +229,8 @@ def _SSLSocket_real_connect(self, addr, return_errno):
self.cert_reqs, self.ssl_version,
self.ca_certs,
self.do_handshake_on_connect,
self.suppress_ragged_eofs, self.ciphers)
self.suppress_ragged_eofs, self.ciphers,
self.cb_user_ssl_ctx_config, self.cb_user_ssl_config)
try:
self._sslobj.connect(addr)
except socket_error as e:

Binary file not shown.

Binary file not shown.

View File

@ -48,13 +48,14 @@ from logging import getLogger
from os import urandom
from select import select
from weakref import proxy
from err import openssl_error, InvalidSocketError
from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_SYSCALL
from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES_SVR
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR
from x509 import _X509, decode_cert
from tlock import tlock_init
from openssl import *
@ -63,6 +64,8 @@ from util import _Rsrc, _BIO
_logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259
CERT_NONE = 0
CERT_OPTIONAL = 1
CERT_REQUIRED = 2
@ -122,6 +125,185 @@ class _CallbackProxy(object):
return self.ssl_func(self.ssl_connection, *args, **kwargs)
def _ssl_logging_cb(conn, where, return_code):
def get_alert_desc(return_code):
_alertDesc = {
0: 'close notify',
10: 'unexpected message',
20: 'bad record mac',
30: 'decompression failure',
40: 'handshake failure',
41: 'no certificate',
42: 'bad certificate',
43: 'unsupported certificate',
44: 'certificate revoked',
45: 'certificate expired',
46: 'certificate unknown',
47: 'illegal parameter',
}
_typeDescr = {
1: 'warning',
2: 'fatal',
}
_type = return_code >> 8
_ret = return_code & 0xFF
_typeStr = _typeDescr[_type] if _type in _typeDescr else ('unknown (%d)' % _type)
if _ret in _alertDesc:
return _typeStr, _alertDesc[_ret]
return _typeStr, str(_ret)
_state = where & ~SSL_ST_MASK
state = "SSL_undef (%04x)" % _state
if _state & SSL_ST_INIT == SSL_ST_INIT:
state = "SSL_init"
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state = "SSL_renew"
elif _state & SSL_ST_CONNECT:
state = "SSL_connect"
elif _state & SSL_ST_ACCEPT:
state = "SSL_accept"
elif _state & SSL_ST_BEFORE:
state = "SSL_before"
if where & SSL_CB_LOOP:
state += '_loop'
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
elif where & SSL_CB_ALERT:
op = "read" if where & SSL_CB_READ else "write"
state += '_alert'
_logger.debug("%s %s: %s" % (state, op, ' - '.join(get_alert_desc(return_code))))
elif where & SSL_CB_EXIT:
state += '_exit'
if return_code == 0:
_logger.debug("%s: failed in %s" % (state, SSL_state_string_long(conn)))
elif return_code < 0:
_logger.debug("%s: error %d in %s" % (state, return_code, SSL_state_string_long(conn)))
else:
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
else:
_logger.debug("%s: %s" % (state, SSL_state_string_long(conn)))
class SSLContext(object):
def __init__(self, ctx):
self._ctx = ctx
def set_ciphers(self, ciphers):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal
def set_sigalgs(self, sigalgs):
u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal
def set_curves(self, curves):
u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure
'''
retVal = None
if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal
@staticmethod
def get_ec_nist2nid(nist):
if not isinstance(nist, tuple):
nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid
@staticmethod
def get_ec_nid2nist(nid):
if not isinstance(nid, tuple):
nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist
@staticmethod
def get_ec_available(bAsName=True):
curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x._nid for x in curves])
def set_ecdh_curve(self, curve_name=None):
u'''
Used for server only!
s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure
'''
if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].as_EC_KEY()
retVal = SSL_CTX_set_tmp_ecdh(self._ctx, key)
EC_KEY_free(key)
else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NO_ROOT):
u'''
Used for server side only!
:param flags:
:return: 1 for success and 0 for failure
'''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging
'''
if enable:
SSL_CTX_set_info_callback(self._ctx, func)
else:
SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object):
def __init__(self, ssl):
self._ssl = ssl
def set_mtu(self, mtu=None):
if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu)
else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object):
"""DTLS peer association
@ -149,7 +331,12 @@ class SSLConnection(object):
else:
self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method()))
server_method = DTLS_server_method
if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method()))
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
@ -179,7 +366,12 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio
self._ctx = _CTX(SSL_CTX_new(DTLSv1_client_method()))
client_method = DTLSv1_client_method
if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method()))
if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE
else:
@ -197,7 +389,9 @@ class SSLConnection(object):
# corruption when packet loss occurs
SSL_CTX_set_options(self._ctx.value, SSL_OP_NO_COMPRESSION)
if self._certfile:
SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile)
# SSL_CTX_use_certificate_chain_file(self._ctx.value, self._certfile)
SSL_CTX_use_certificate_file(self._ctx.value, self._certfile,
SSL_FILE_TYPE_PEM)
if self._keyfile:
SSL_CTX_use_PrivateKey_file(self._ctx.value, self._keyfile,
SSL_FILE_TYPE_PEM)
@ -208,6 +402,8 @@ class SSLConnection(object):
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_ssl_ctx_config:
self._user_ssl_ctx_config(SSLContext(self._ctx.value))
def _copy_server(self):
source = self._sock
@ -241,6 +437,8 @@ class SSLConnection(object):
new_source_wbio.value)
new_source_rbio.disown()
new_source_wbio.disown()
if self._user_ssl_config:
self._user_ssl_config(SSL(source._ssl.value))
def _reconnect_unwrapped(self):
source = self._sock
@ -310,9 +508,11 @@ class SSLConnection(object):
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLSv1, ca_certs=None,
ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
suppress_ragged_eofs=True, ciphers=None,
cb_user_ssl_ctx_config=None,
cb_user_ssl_config=None):
"""Constructor
Arguments:
@ -330,10 +530,13 @@ class SSLConnection(object):
if not ciphers:
ciphers = "DEFAULT"
assert ssl_version in (PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2)
self._sock = sock
self._keyfile = keyfile
self._certfile = certfile
self._cert_reqs = cert_reqs
self._ssl_version = ssl_version
self._ca_certs = ca_certs
self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs
@ -341,6 +544,9 @@ class SSLConnection(object):
self._handshake_done = False
self._wbio_nb = self._rbio_nb = False
self._user_ssl_ctx_config = cb_user_ssl_ctx_config
self._user_ssl_config = cb_user_ssl_config
if isinstance(sock, SSLConnection):
post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket):
@ -358,6 +564,8 @@ class SSLConnection(object):
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown()
self._wbio.disown()
if self._user_ssl_config:
self._user_ssl_config(SSL(self._ssl.value))
if post_init:
post_init()
@ -462,9 +670,11 @@ class SSLConnection(object):
_logger.debug("Accept returning without connection")
return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, PROTOCOL_DTLSv1,
self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers)
self._suppress_ragged_eofs, self._ciphers,
cb_user_ssl_ctx_config=self._user_ssl_ctx_config,
cb_user_ssl_config=self._user_ssl_config)
new_peer = self._pending_peer_address
self._pending_peer_address = None
if self._do_handshake_on_connect:
@ -540,8 +750,12 @@ class SSLConnection(object):
number of bytes actually transmitted
"""
return self._wrap_socket_library_call(
retVal = self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
# for client side ... we want to know when the handshake is completed
if self._wbio is self._rbio and retVal > 0:
self._handshake_done = True
return retVal
def shutdown(self):
"""Shut down the DTLS connection

286
dtls/wrapper.py 100644
View File

@ -0,0 +1,286 @@
# -*- encoding: utf-8 -*-
import datetime
import select
from logging import getLogger
import ssl
import socket
from dtls import do_patch
do_patch()
_logger = getLogger(__name__)
class _ClientSession(object):
def __init__(self, host, port, handshake_done=False):
self.host = host
self.port = int(port)
self.handshake_done = handshake_done
def getAddr(self):
return self.host, self.port
class DtlsSocket(object):
def __init__(self,
host,
port,
keyfile=None,
certfile=None,
server_side=False,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None,
do_handshake_on_connect=False,
suppress_ragged_eofs=True,
ciphers=None,
curves=None,
sigalgs=None,
user_mtu=None):
self._ssl_logging = False
self._peer = (host, int(port))
self._server_side = server_side
self._ciphers = ciphers
self._curves = curves
self._sigalgs = sigalgs
self._user_mtu = user_mtu
self._sock = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
keyfile=keyfile,
certfile=certfile,
server_side=self._server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
ciphers=self._ciphers,
cb_user_ssl_ctx_config=self.user_ssl_ctx_config,
cb_user_ssl_config=self.user_ssl_config)
if self._server_side:
self._clients = {}
self._timeout = None
self._sock.bind(self._peer)
self._sock.listen(0)
else:
self._sock.connect(self._peer)
def user_ssl_ctx_config(self, _ctx):
_ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers:
_ctx.set_ciphers(self._ciphers)
if self._curves:
_ctx.set_curves(self._curves)
if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs)
if self._server_side:
_ctx.build_cert_chain()
_ctx.set_ecdh_curve() # ("secp256k1")
def user_ssl_config(self, _ssl):
if self._user_mtu:
_ssl.set_mtu(self._user_mtu)
def settimeout(self, t):
if self._server_side:
self._timeout = t
else:
self._sock.settimeout(t)
def close(self):
if self._server_side:
for cli in self._clients.keys():
cli.close()
else:
self._sock.unwrap()
self._sock.close()
def recvfrom(self, bufsize, flags=0):
if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags)
else:
return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags):
try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout as e_timeout:
raise e_timeout
else:
for conn in r: # type: ssl.SSLSocket
if self._sockIsServerSock(conn):
# Connect
self._clientAccept(conn)
else:
# Handshake
if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn)
# Normal read
else:
buf = self._clientRead(conn, bufsize)
if buf and conn in self._clients:
return buf, self._clients[conn].getAddr()
for conn in self._getClientReadingSockets():
if conn.get_timeout():
conn.handle_timeout()
raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags):
try:
buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e_ssl:
if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
return '', self._peer
elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
raise
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
else:
if buf:
return buf, self._peer
raise socket.timeout
def sendto(self, buf, address):
if self._server_side:
return self._sendto_from_server_side(buf, address)
else:
return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems():
if client.getAddr() == address:
return self._clientWrite(conn, buf)
return 0
def _sendto_from_client_side(self, buf, address):
while True:
try:
bytes_sent = self._sock.send(buf)
except ssl.SSLError as e_ssl:
if str(e_ssl).startswith("503:"):
# The write operation timed out
continue
# elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ]:
# # no ciphers available
# if e_ssl.args[1][0][0] in [336081077, ]:
# raise
raise
else:
if bytes_sent:
break
return bytes_sent
def _getClientReadingSockets(self):
return [x for x in self._clients.keys()]
def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn):
return conn is self._sock
def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn):
_logger.debug('+' * 60)
ret = None
try:
ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e_accept:
pass
else:
if ret:
client, addr = ret
host, port = addr
if client in self._clients:
raise
self._clients[client] = _ClientSession(host=host, port=port)
self._clientDoHandshake(client)
def _clientDoHandshake(self, conn):
_logger.debug('-' * 60)
conn.setblocking(False)
try:
conn.do_handshake()
_logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True
except ssl.SSLError as e_handshake:
if str(e_handshake).startswith("504:"):
pass
elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
pass
else:
raise e_handshake
def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60)
ret = None
try:
ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e_read:
if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
self._clientDrop(conn)
elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
self._clientDrop(conn, error=e_read)
else: # like in [ssl.SSL_ERROR_WANT_READ, ...]
pass
return ret
def _clientWrite(self, conn, data):
_logger.debug('#' * 60)
ret = None
try:
ret = conn.send(data.raw)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e_write:
raise
return ret
def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60)
try:
if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients:
del self._clients[conn]
conn.unwrap()
conn.close()
except Exception as e_drop:
pass