Merge remote-tracking branch 'refs/remotes/origin/master' into clean-redo-in-steps

incoming
mcfreis 2017-03-28 08:04:57 +02:00
commit 129fd349df
17 changed files with 2311 additions and 2295 deletions

View File

@ -1,3 +1,20 @@
2017-03-28 Björn Freise <mcfreis@gmx.net>
Minor fixes and "hopefully" compatible to Ubuntu 16.04
* dtls/__init__.py: Removed wrapper import
* dtls/openssl.py: Fixed line endings to LF
* dtls/patch.py: Removed PROTOCOL_SSLv3 import and fixed line endings to LF
* dtls/sslconnection.py: Fixed line endings to LF
* dtls/test/certs/*_ec.pem: Fixed line endings to LF
* dtls/test/echo_seq.py: Fixed line endings to LF
* dtls/test/simple_client.py: Fixed line endings to LF
* dtls/test/unit.py: Fixed line endings to LF
* dtls/test/unit_wrapper.py: Corrected wrapper import and fixed line endings to LF
* dtls/util.py: Fixed line endings to LF
* dtls/wrapper.py: Corrected function naming to wrap_client() and wrap_server(); Fixed line endings to LF
* dtls/x509.py: Fixed line endings to LF
2017-03-23 Björn Freise <mcfreis@gmx.net> 2017-03-23 Björn Freise <mcfreis@gmx.net>
Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client() Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()

View File

@ -61,4 +61,3 @@ _prep_bins() # prepare before module imports
from patch import do_patch from patch import do_patch
from sslconnection import SSLContext, SSL, SSLConnection from sslconnection import SSLContext, SSL, SSLConnection
from demux import force_routing_demux, reset_default_demux from demux import force_routing_demux, reset_default_demux
from wrapper import DtlsSocket, client as wrap_client, server as wrap_server

File diff suppressed because it is too large Load Diff

View File

@ -36,7 +36,7 @@ has the following effects:
from socket import socket, getaddrinfo, _delegate_methods, error as socket_error from socket import socket, getaddrinfo, _delegate_methods, error as socket_error
from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM from socket import AF_INET, SOCK_STREAM, SOCK_DGRAM
from ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, CERT_NONE from ssl import PROTOCOL_SSLv23, CERT_NONE
from types import MethodType from types import MethodType
from weakref import proxy from weakref import proxy
import errno import errno

View File

@ -45,30 +45,30 @@ import socket
import hmac import hmac
import datetime import datetime
from logging import getLogger from logging import getLogger
from os import urandom from os import urandom
from select import select from select import select
from weakref import proxy from weakref import proxy
from err import openssl_error, InvalidSocketError from err import openssl_error, InvalidSocketError
from err import raise_ssl_error from err import raise_ssl_error
from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL
from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER from err import ERR_WRONG_VERSION_NUMBER, ERR_COOKIE_MISMATCH, ERR_NO_SHARED_CIPHER
from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE
from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT from err import ERR_READ_TIMEOUT, ERR_WRITE_TIMEOUT
from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS from err import ERR_BOTH_KEY_CERT_FILES, ERR_BOTH_KEY_CERT_FILES_SVR, ERR_NO_CERTS
from x509 import _X509, decode_cert from x509 import _X509, decode_cert
from tlock import tlock_init from tlock import tlock_init
from openssl import * from openssl import *
from util import _Rsrc, _BIO from util import _Rsrc, _BIO
_logger = getLogger(__name__) _logger = getLogger(__name__)
PROTOCOL_DTLSv1 = 256 PROTOCOL_DTLSv1 = 256
PROTOCOL_DTLSv1_2 = 258 PROTOCOL_DTLSv1_2 = 258
PROTOCOL_DTLS = 259 PROTOCOL_DTLS = 259
CERT_NONE = 0 CERT_NONE = 0
CERT_OPTIONAL = 1 CERT_OPTIONAL = 1
CERT_REQUIRED = 2 CERT_REQUIRED = 2
# #
# One-time global OpenSSL library initialization # One-time global OpenSSL library initialization
@ -83,64 +83,64 @@ DTLS_OPENSSL_VERSION_INFO = (
DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor
DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix
DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch
DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status
def _ssl_logging_cb(conn, where, return_code): def _ssl_logging_cb(conn, where, return_code):
_state = where & ~SSL_ST_MASK _state = where & ~SSL_ST_MASK
state = "SSL" state = "SSL"
if _state & SSL_ST_INIT == SSL_ST_INIT: if _state & SSL_ST_INIT == SSL_ST_INIT:
if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE: if _state & SSL_ST_RENEGOTIATE == SSL_ST_RENEGOTIATE:
state += "_renew" state += "_renew"
else: else:
state += "_init" state += "_init"
elif _state & SSL_ST_CONNECT: elif _state & SSL_ST_CONNECT:
state += "_connect" state += "_connect"
elif _state & SSL_ST_ACCEPT: elif _state & SSL_ST_ACCEPT:
state += "_accept" state += "_accept"
elif _state == 0: elif _state == 0:
if where & SSL_CB_HANDSHAKE_START: if where & SSL_CB_HANDSHAKE_START:
state += "_handshake_start" state += "_handshake_start"
elif where & SSL_CB_HANDSHAKE_DONE: elif where & SSL_CB_HANDSHAKE_DONE:
state += "_handshake_done" state += "_handshake_done"
if where & SSL_CB_LOOP: if where & SSL_CB_LOOP:
state += '_loop' state += '_loop'
_logger.debug("%s:%s:%d" % (state, _logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn), SSL_state_string_long(conn),
return_code)) return_code))
elif where & SSL_CB_ALERT: elif where & SSL_CB_ALERT:
state += '_alert' state += '_alert'
state += "_read" if where & SSL_CB_READ else "_write" state += "_read" if where & SSL_CB_READ else "_write"
_logger.debug("%s:%s:%s" % (state, _logger.debug("%s:%s:%s" % (state,
SSL_alert_type_string_long(return_code), SSL_alert_type_string_long(return_code),
SSL_alert_desc_string_long(return_code))) SSL_alert_desc_string_long(return_code)))
elif where & SSL_CB_EXIT: elif where & SSL_CB_EXIT:
state += '_exit' state += '_exit'
if return_code == 0: if return_code == 0:
_logger.debug("%s:%s:%d(failed)" % (state, _logger.debug("%s:%s:%d(failed)" % (state,
SSL_state_string_long(conn), SSL_state_string_long(conn),
return_code)) return_code))
elif return_code < 0: elif return_code < 0:
_logger.debug("%s:%s:%d(error)" % (state, _logger.debug("%s:%s:%d(error)" % (state,
SSL_state_string_long(conn), SSL_state_string_long(conn),
return_code)) return_code))
else: else:
_logger.debug("%s:%s:%d" % (state, _logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn), SSL_state_string_long(conn),
return_code)) return_code))
else: else:
_logger.debug("%s:%s:%d" % (state, _logger.debug("%s:%s:%d" % (state,
SSL_state_string_long(conn), SSL_state_string_long(conn),
return_code)) return_code))
class _CTX(_Rsrc): class _CTX(_Rsrc):
"""SSL_CTX wrapper""" """SSL_CTX wrapper"""
def __init__(self, value): def __init__(self, value):
super(_CTX, self).__init__(value) super(_CTX, self).__init__(value)
def __del__(self): def __del__(self):
@ -174,130 +174,130 @@ class _CallbackProxy(object):
self.ssl_func = cbm.im_func self.ssl_func = cbm.im_func
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ssl_func(self.ssl_connection, *args, **kwargs) return self.ssl_func(self.ssl_connection, *args, **kwargs)
class SSLContext(object): class SSLContext(object):
def __init__(self, ctx): def __init__(self, ctx):
self._ctx = ctx self._ctx = ctx
def set_ciphers(self, ciphers): def set_ciphers(self, ciphers):
u''' u'''
s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html s.a. https://www.openssl.org/docs/man1.1.0/apps/ciphers.html
:param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ... :param str ciphers: Example "AES256-SHA:ECDHE-ECDSA-AES256-SHA", ...
:return: 1 for success and 0 for failure :return: 1 for success and 0 for failure
''' '''
retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers) retVal = SSL_CTX_set_cipher_list(self._ctx, ciphers)
return retVal return retVal
def set_sigalgs(self, sigalgs): def set_sigalgs(self, sigalgs):
u''' u'''
s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html s.a. https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set1_sigalgs_list.html
:param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ... :param str sigalgs: Example "RSA+SHA256", "ECDSA+SHA256", ...
:return: 1 for success and 0 for failure :return: 1 for success and 0 for failure
''' '''
retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs) retVal = SSL_CTX_set1_sigalgs_list(self._ctx, sigalgs)
return retVal return retVal
def set_curves(self, curves): def set_curves(self, curves):
u''' Set supported curves by name, nid or nist. u''' Set supported curves by name, nid or nist.
:param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ... :param str | tuple(int) curves: Example "secp384r1:secp256k1", (715, 714), "P-384", "K-409:B-409:K-571", ...
:return: 1 for success and 0 for failure :return: 1 for success and 0 for failure
''' '''
retVal = None retVal = None
if isinstance(curves, str): if isinstance(curves, str):
retVal = SSL_CTX_set1_curves_list(self._ctx, curves) retVal = SSL_CTX_set1_curves_list(self._ctx, curves)
elif isinstance(curves, tuple): elif isinstance(curves, tuple):
retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves)) retVal = SSL_CTX_set1_curves(self._ctx, curves, len(curves))
return retVal return retVal
@staticmethod @staticmethod
def get_ec_nist2nid(nist): def get_ec_nist2nid(nist):
if not isinstance(nist, tuple): if not isinstance(nist, tuple):
nist = nist.split(":") nist = nist.split(":")
nid = tuple(EC_curve_nist2nid(x) for x in nist) nid = tuple(EC_curve_nist2nid(x) for x in nist)
return nid return nid
@staticmethod @staticmethod
def get_ec_nid2nist(nid): def get_ec_nid2nist(nid):
if not isinstance(nid, tuple): if not isinstance(nid, tuple):
nid = (nid, ) nid = (nid, )
nist = ":".join([EC_curve_nid2nist(x) for x in nid]) nist = ":".join([EC_curve_nid2nist(x) for x in nid])
return nist return nist
@staticmethod @staticmethod
def get_ec_available(bAsName=True): def get_ec_available(bAsName=True):
curves = get_elliptic_curves() curves = get_elliptic_curves()
return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves]) return sorted([x.name for x in curves] if bAsName else [x.nid for x in curves])
def set_ecdh_curve(self, curve_name=None): def set_ecdh_curve(self, curve_name=None):
u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode u''' Select a curve to use for ECDH(E) key exchange or set it to auto mode
Used for server only! Used for server only!
s.a. openssl.exe ecparam -list_curves s.a. openssl.exe ecparam -list_curves
:param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ... :param None | str curve_name: None = Auto-mode, "secp256k1", "secp384r1", ...
:return: 1 for success and 0 for failure :return: 1 for success and 0 for failure
''' '''
if curve_name: if curve_name:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0) retVal = SSL_CTX_set_ecdh_auto(self._ctx, 0)
avail_curves = get_elliptic_curves() avail_curves = get_elliptic_curves()
key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY() key = [curve for curve in avail_curves if curve.name == curve_name][0].to_EC_KEY()
retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key) retVal &= SSL_CTX_set_tmp_ecdh(self._ctx, key)
else: else:
retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1) retVal = SSL_CTX_set_ecdh_auto(self._ctx, 1)
return retVal return retVal
def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE): def build_cert_chain(self, flags=SSL_BUILD_CHAIN_FLAG_NONE):
u''' u'''
Used for server side only! Used for server side only!
:param flags: :param flags:
:return: 1 for success and 0 for failure :return: 1 for success and 0 for failure
''' '''
retVal = SSL_CTX_build_cert_chain(self._ctx, flags) retVal = SSL_CTX_build_cert_chain(self._ctx, flags)
return retVal return retVal
def set_ssl_logging(self, enable=False, func=_ssl_logging_cb): def set_ssl_logging(self, enable=False, func=_ssl_logging_cb):
u''' Enable or disable SSL logging u''' Enable or disable SSL logging
:param True | False enable: Enable or disable SSL logging :param True | False enable: Enable or disable SSL logging
:param func: Callback function for logging :param func: Callback function for logging
''' '''
if enable: if enable:
SSL_CTX_set_info_callback(self._ctx, func) SSL_CTX_set_info_callback(self._ctx, func)
else: else:
SSL_CTX_set_info_callback(self._ctx, 0) SSL_CTX_set_info_callback(self._ctx, 0)
class SSL(object): class SSL(object):
def __init__(self, ssl): def __init__(self, ssl):
self._ssl = ssl self._ssl = ssl
def set_mtu(self, mtu=None): def set_mtu(self, mtu=None):
if mtu: if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU) SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
SSL_set_mtu(self._ssl, mtu) SSL_set_mtu(self._ssl, mtu)
else: else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU) SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
def set_link_mtu(self, mtu=None): def set_link_mtu(self, mtu=None):
if mtu: if mtu:
SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU) SSL_set_options(self._ssl, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl, mtu) DTLS_set_link_mtu(self._ssl, mtu)
else: else:
SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU) SSL_clear_options(self._ssl, SSL_OP_NO_QUERY_MTU)
class SSLConnection(object): class SSLConnection(object):
"""DTLS peer association """DTLS peer association
This class associates two DTLS peer instances, wrapping OpenSSL library This class associates two DTLS peer instances, wrapping OpenSSL library
state including SSL (struct ssl_st), SSL_CTX, and BIO instances. state including SSL (struct ssl_st), SSL_CTX, and BIO instances.
""" """
@ -319,19 +319,19 @@ class SSLConnection(object):
rsock = self._udp_demux.get_connection(None) rsock = self._udp_demux.get_connection(None)
if rsock is self._sock: if rsock is self._sock:
self._rbio = self._wbio self._rbio = self._wbio
else: else:
self._rsock = rsock self._rsock = rsock
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
server_method = DTLS_server_method server_method = DTLS_server_method
if self._ssl_version == PROTOCOL_DTLSv1_2: if self._ssl_version == PROTOCOL_DTLSv1_2:
server_method = DTLSv1_2_server_method server_method = DTLSv1_2_server_method
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
server_method = DTLSv1_server_method server_method = DTLSv1_server_method
self._ctx = _CTX(SSL_CTX_new(server_method())) self._ctx = _CTX(SSL_CTX_new(server_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value) self._intf_ssl_ctx = SSLContext(self._ctx.value)
SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
elif self._cert_reqs == CERT_OPTIONAL: elif self._cert_reqs == CERT_OPTIONAL:
verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE
else: else:
@ -345,37 +345,37 @@ class SSLConnection(object):
self._pending_peer_address = None self._pending_peer_address = None
self._cb_keepalive = SSL_CTX_set_cookie_cb( self._cb_keepalive = SSL_CTX_set_cookie_cb(
self._ctx.value, self._ctx.value,
_CallbackProxy(self._generate_cookie_cb), _CallbackProxy(self._generate_cookie_cb),
_CallbackProxy(self._verify_cookie_cb)) _CallbackProxy(self._verify_cookie_cb))
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value) SSL_set_accept_state(self._ssl.value)
if peer_address and self._do_handshake_on_connect: if peer_address and self._do_handshake_on_connect:
return lambda: self.do_handshake() return lambda: self.do_handshake()
def _init_client(self, peer_address): def _init_client(self, peer_address):
if self._sock.type != socket.SOCK_DGRAM: if self._sock.type != socket.SOCK_DGRAM:
raise InvalidSocketError("sock must be of type SOCK_DGRAM") raise InvalidSocketError("sock must be of type SOCK_DGRAM")
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible) client_method = DTLSv1_2_client_method # no "any" exists, therefore use v1_2 (highest possible)
if self._ssl_version == PROTOCOL_DTLSv1_2: if self._ssl_version == PROTOCOL_DTLSv1_2:
client_method = DTLSv1_2_client_method client_method = DTLSv1_2_client_method
elif self._ssl_version == PROTOCOL_DTLSv1: elif self._ssl_version == PROTOCOL_DTLSv1:
client_method = DTLSv1_client_method client_method = DTLSv1_client_method
self._ctx = _CTX(SSL_CTX_new(client_method())) self._ctx = _CTX(SSL_CTX_new(client_method()))
self._intf_ssl_ctx = SSLContext(self._ctx.value) self._intf_ssl_ctx = SSLContext(self._ctx.value)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
verify_mode = SSL_VERIFY_NONE verify_mode = SSL_VERIFY_NONE
else: else:
verify_mode = SSL_VERIFY_PEER verify_mode = SSL_VERIFY_PEER
self._config_ssl_ctx(verify_mode) self._config_ssl_ctx(verify_mode)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
SSL_set_connect_state(self._ssl.value) SSL_set_connect_state(self._ssl.value)
if peer_address: if peer_address:
return lambda: self.connect(peer_address) return lambda: self.connect(peer_address)
def _config_ssl_ctx(self, verify_mode): def _config_ssl_ctx(self, verify_mode):
SSL_CTX_set_verify(self._ctx.value, verify_mode) SSL_CTX_set_verify(self._ctx.value, verify_mode)
@ -392,14 +392,14 @@ class SSLConnection(object):
SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None) SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None)
if self._ciphers: if self._ciphers:
try: try:
SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers)
except openssl_error() as err: except openssl_error() as err:
raise_ssl_error(ERR_NO_CIPHER, err) raise_ssl_error(ERR_NO_CIPHER, err)
if self._user_config_ssl_ctx: if self._user_config_ssl_ctx:
self._user_config_ssl_ctx(self._intf_ssl_ctx) self._user_config_ssl_ctx(self._intf_ssl_ctx)
def _copy_server(self): def _copy_server(self):
source = self._sock source = self._sock
self._udp_demux = source._udp_demux self._udp_demux = source._udp_demux
rsock = self._udp_demux.get_connection(source._pending_peer_address) rsock = self._udp_demux.get_connection(source._pending_peer_address)
self._ctx = source._ctx self._ctx = source._ctx
@ -419,16 +419,16 @@ class SSLConnection(object):
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = self._wbio self._rbio = self._wbio
new_source_rbio = new_source_wbio new_source_rbio = new_source_wbio
BIO_dgram_set_connected(self._wbio.value, BIO_dgram_set_connected(self._wbio.value,
source._pending_peer_address) source._pending_peer_address)
source._ssl = _SSL(SSL_new(self._ctx.value)) source._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(source._ssl.value) self._intf_ssl = SSL(source._ssl.value)
SSL_set_accept_state(source._ssl.value) SSL_set_accept_state(source._ssl.value)
if self._user_config_ssl: if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl) self._user_config_ssl(self._intf_ssl)
source._rbio = new_source_rbio source._rbio = new_source_rbio
source._wbio = new_source_wbio source._wbio = new_source_wbio
SSL_set_bio(source._ssl.value, SSL_set_bio(source._ssl.value,
new_source_rbio.value, new_source_rbio.value,
new_source_wbio.value) new_source_wbio.value)
new_source_rbio.disown() new_source_rbio.disown()
@ -441,16 +441,16 @@ class SSLConnection(object):
self._rsock = source._rsock self._rsock = source._rsock
self._ctx = source._ctx self._ctx = source._ctx
self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE))
self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE))
BIO_dgram_set_peer(self._wbio.value, source._peer_address) BIO_dgram_set_peer(self._wbio.value, source._peer_address)
self._ssl = _SSL(SSL_new(self._ctx.value)) self._ssl = _SSL(SSL_new(self._ctx.value))
self._intf_ssl = SSL(self._ssl.value) self._intf_ssl = SSL(self._ssl.value)
SSL_set_accept_state(self._ssl.value) SSL_set_accept_state(self._ssl.value)
if self._user_config_ssl: if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl) self._user_config_ssl(self._intf_ssl)
if self._do_handshake_on_connect: if self._do_handshake_on_connect:
return lambda: self.do_handshake() return lambda: self.do_handshake()
def _check_nbio(self): def _check_nbio(self):
timeout = self._sock.gettimeout() timeout = self._sock.gettimeout()
if self._wbio_nb != timeout is not None: if self._wbio_nb != timeout is not None:
@ -502,17 +502,17 @@ class SSLConnection(object):
def _verify_cookie_cb(self, ssl, cookie): def _verify_cookie_cb(self, ssl, cookie):
if self._get_cookie(ssl) != cookie: if self._get_cookie(ssl) != cookie:
raise Exception("DTLS cookie mismatch") raise Exception("DTLS cookie mismatch")
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_DTLS, ca_certs=None, ssl_version=PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None, suppress_ragged_eofs=True, ciphers=None,
cb_user_config_ssl_ctx=None, cb_user_config_ssl_ctx=None,
cb_user_config_ssl=None): cb_user_config_ssl=None):
"""Constructor """Constructor
Arguments: Arguments:
these arguments match the ones of the SSLSocket class in the these arguments match the ones of the SSLSocket class in the
standard library's ssl module standard library's ssl module
""" """
@ -528,46 +528,46 @@ class SSLConnection(object):
ciphers = "DEFAULT" ciphers = "DEFAULT"
self._sock = sock self._sock = sock
self._keyfile = keyfile self._keyfile = keyfile
self._certfile = certfile self._certfile = certfile
self._cert_reqs = cert_reqs self._cert_reqs = cert_reqs
self._ssl_version = ssl_version self._ssl_version = ssl_version
self._ca_certs = ca_certs self._ca_certs = ca_certs
self._do_handshake_on_connect = do_handshake_on_connect self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs self._suppress_ragged_eofs = suppress_ragged_eofs
self._ciphers = ciphers self._ciphers = ciphers
self._handshake_done = False self._handshake_done = False
self._wbio_nb = self._rbio_nb = False self._wbio_nb = self._rbio_nb = False
self._user_config_ssl_ctx = cb_user_config_ssl_ctx self._user_config_ssl_ctx = cb_user_config_ssl_ctx
self._intf_ssl_ctx = None self._intf_ssl_ctx = None
self._user_config_ssl = cb_user_config_ssl self._user_config_ssl = cb_user_config_ssl
self._intf_ssl = None self._intf_ssl = None
if isinstance(sock, SSLConnection): if isinstance(sock, SSLConnection):
post_init = self._copy_server() post_init = self._copy_server()
elif isinstance(sock, _UnwrappedSocket): elif isinstance(sock, _UnwrappedSocket):
post_init = self._reconnect_unwrapped() post_init = self._reconnect_unwrapped()
else: else:
try: try:
peer_address = sock.getpeername() peer_address = sock.getpeername()
except socket.error: except socket.error:
peer_address = None peer_address = None
if server_side: if server_side:
post_init = self._init_server(peer_address) post_init = self._init_server(peer_address)
else: else:
post_init = self._init_client(peer_address) post_init = self._init_client(peer_address)
if self._user_config_ssl: if self._user_config_ssl:
self._user_config_ssl(self._intf_ssl) self._user_config_ssl(self._intf_ssl)
else: else:
SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU) SSL_set_options(self._ssl.value, SSL_OP_NO_QUERY_MTU)
DTLS_set_link_mtu(self._ssl.value, 1500) DTLS_set_link_mtu(self._ssl.value, 1500)
SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value)
self._rbio.disown() self._rbio.disown()
self._wbio.disown() self._wbio.disown()
if post_init: if post_init:
post_init() post_init()
def get_socket(self, inbound): def get_socket(self, inbound):
"""Retrieve a socket used by this connection """Retrieve a socket used by this connection
@ -633,25 +633,25 @@ class SSLConnection(object):
self._ssl.raw) self._ssl.raw)
dtls_peer_address = DTLSv1_listen(self._ssl.value) dtls_peer_address = DTLSv1_listen(self._ssl.value)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_WANT_READ: if err.ssl_error == SSL_ERROR_WANT_READ:
# This method must be called again to forward the next datagram # This method must be called again to forward the next datagram
_logger.debug("DTLSv1_listen must be resumed") _logger.debug("DTLSv1_listen must be resumed")
return return
elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER: elif err.errqueue and err.errqueue[0][0] == ERR_WRONG_VERSION_NUMBER:
_logger.debug("Wrong version number; aborting handshake") _logger.debug("Wrong version number; aborting handshake")
raise raise
elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH: elif err.errqueue and err.errqueue[0][0] == ERR_COOKIE_MISMATCH:
_logger.debug("Mismatching cookie received; aborting handshake") _logger.debug("Mismatching cookie received; aborting handshake")
raise raise
elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER: elif err.errqueue and err.errqueue[0][0] == ERR_NO_SHARED_CIPHER:
_logger.debug("No shared cipher; aborting handshake") _logger.debug("No shared cipher; aborting handshake")
raise raise
_logger.exception("Unexpected error in DTLSv1_listen") _logger.exception("Unexpected error in DTLSv1_listen")
raise raise
finally: finally:
self._listening = False self._listening = False
self._listening_peer_address = None self._listening_peer_address = None
if type(peer_address) is tuple: if type(peer_address) is tuple:
_logger.debug("New local peer: %s", dtls_peer_address) _logger.debug("New local peer: %s", dtls_peer_address)
self._pending_peer_address = peer_address self._pending_peer_address = peer_address
else: else:
@ -672,17 +672,17 @@ class SSLConnection(object):
if not self._pending_peer_address: if not self._pending_peer_address:
if not self.listen(): if not self.listen():
_logger.debug("Accept returning without connection") _logger.debug("Accept returning without connection")
return return
new_conn = SSLConnection(self, self._keyfile, self._certfile, True, new_conn = SSLConnection(self, self._keyfile, self._certfile, True,
self._cert_reqs, self._ssl_version, self._cert_reqs, self._ssl_version,
self._ca_certs, self._do_handshake_on_connect, self._ca_certs, self._do_handshake_on_connect,
self._suppress_ragged_eofs, self._ciphers, self._suppress_ragged_eofs, self._ciphers,
cb_user_config_ssl_ctx=self._user_config_ssl_ctx, cb_user_config_ssl_ctx=self._user_config_ssl_ctx,
cb_user_config_ssl=self._user_config_ssl) cb_user_config_ssl=self._user_config_ssl)
new_peer = self._pending_peer_address new_peer = self._pending_peer_address
self._pending_peer_address = None self._pending_peer_address = None
if self._do_handshake_on_connect: if self._do_handshake_on_connect:
# Note that since that connection's socket was just created in its # Note that since that connection's socket was just created in its
# constructor, the following operation must be blocking; hence # constructor, the following operation must be blocking; hence
# handshake-on-connect can only be used with a routing demux if # handshake-on-connect can only be used with a routing demux if
@ -734,48 +734,48 @@ class SSLConnection(object):
Read up to len bytes and return them. Read up to len bytes and return them.
Arguments: Arguments:
len -- maximum number of bytes to read len -- maximum number of bytes to read
Return value: Return value:
string containing read bytes string containing read bytes
""" """
try: try:
return self._wrap_socket_library_call( return self._wrap_socket_library_call(
lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT) lambda: SSL_read(self._ssl.value, len, buffer), ERR_READ_TIMEOUT)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err) raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise raise
def write(self, data): def write(self, data):
"""Write data to connection """Write data to connection
Write data as string of bytes. Write data as string of bytes.
Arguments: Arguments:
data -- buffer containing data to be written data -- buffer containing data to be written
Return value: Return value:
number of bytes actually transmitted number of bytes actually transmitted
""" """
try: try:
ret = self._wrap_socket_library_call( ret = self._wrap_socket_library_call(
lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT) lambda: SSL_write(self._ssl.value, data), ERR_WRITE_TIMEOUT)
except openssl_error() as err: except openssl_error() as err:
if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: if err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1:
raise_ssl_error(ERR_PORT_UNREACHABLE, err) raise_ssl_error(ERR_PORT_UNREACHABLE, err)
raise raise
if ret: if ret:
self._handshake_done = True self._handshake_done = True
return ret return ret
def shutdown(self): def shutdown(self):
"""Shut down the DTLS connection """Shut down the DTLS connection
This method attemps to complete a bidirectional shutdown between This method attemps to complete a bidirectional shutdown between
peers. For non-blocking sockets, it should be called repeatedly until peers. For non-blocking sockets, it should be called repeatedly until
it no longer raises continuation request exceptions. it no longer raises continuation request exceptions.
""" """
@ -826,33 +826,33 @@ class SSLConnection(object):
return return
if binary_form: if binary_form:
return i2d_X509(peer_cert.value) return i2d_X509(peer_cert.value)
if self._cert_reqs == CERT_NONE: if self._cert_reqs == CERT_NONE:
return {} return {}
return decode_cert(peer_cert) return decode_cert(peer_cert)
peer_certificate = getpeercert # compatibility with _ssl call interface peer_certificate = getpeercert # compatibility with _ssl call interface
def getpeercertchain(self, binary_form=False): def getpeercertchain(self, binary_form=False):
try: try:
stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value) stack, num, certs = SSL_get_peer_cert_chain(self._ssl.value)
except openssl_error(): except openssl_error():
return return
peer_cert_chain = [_Rsrc(cert) for cert in certs] peer_cert_chain = [_Rsrc(cert) for cert in certs]
ret = [] ret = []
if binary_form: if binary_form:
ret = [i2d_X509(x.value) for x in peer_cert_chain] ret = [i2d_X509(x.value) for x in peer_cert_chain]
elif len(peer_cert_chain): elif len(peer_cert_chain):
ret = [decode_cert(x) for x in peer_cert_chain] ret = [decode_cert(x) for x in peer_cert_chain]
return ret return ret
def cipher(self): def cipher(self):
"""Retrieve information about the current cipher """Retrieve information about the current cipher
Return a triple consisting of cipher name, SSL protocol version defining Return a triple consisting of cipher name, SSL protocol version defining
its use, and the number of secret bits. Return None if handshaking its use, and the number of secret bits. Return None if handshaking
has not been completed. has not been completed.
""" """

View File

@ -1,11 +1,11 @@
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIBgzCCASoCCQDdMwvUA/R3lzAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET MIIBgzCCASoCCQDdMwvUA/R3lzAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU3WhcNMjcwMzA1MDgzNjU3WjBKMQsw AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU3WhcNMjcwMzA1MDgzNjU3WjBKMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENB CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENB
IEluYzERMA8GA1UEAwwIUmF5Q0FJbmMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC IEluYzERMA8GA1UEAwwIUmF5Q0FJbmMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AASD4xiQkPryjEwUl/GYeGu1CSA3UC6BUY3TiGED3zrC5Bn/POaVVn9GGOQMZUFi AASD4xiQkPryjEwUl/GYeGu1CSA3UC6BUY3TiGED3zrC5Bn/POaVVn9GGOQMZUFi
rCkuTgfg/qeIzTrTFndiR5C/MAoGCCqGSM49BAMDA0cAMEQCIHpd9qMvZZV6iaB5 rCkuTgfg/qeIzTrTFndiR5C/MAoGCCqGSM49BAMDA0cAMEQCIHpd9qMvZZV6iaB5
HrmlyfmhIuLBxDQra20Uxl2Y8N64AiAmPKqwPPp7z6IT2AzAXyHCPoVxwWA0NfGx HrmlyfmhIuLBxDQra20Uxl2Y8N64AiAmPKqwPPp7z6IT2AzAXyHCPoVxwWA0NfGx
nmXoYpDFlw== nmXoYpDFlw==
-----END CERTIFICATE----- -----END CERTIFICATE-----

View File

@ -1,19 +1,19 @@
-----BEGIN EC PARAMETERS----- -----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw== BggqhkjOPQMBBw==
-----END EC PARAMETERS----- -----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY----- -----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEMWCku4TqKwrQdeECm5LQPCBnr7+cqE4InlRYeObLOxoAoGCCqGSM49 MHcCAQEEIEMWCku4TqKwrQdeECm5LQPCBnr7+cqE4InlRYeObLOxoAoGCCqGSM49
AwEHoUQDQgAEgroFe2fym1V7E3zr/zjuJixpyAjwfig+UTsxxm/04IvXzk2jQCQC AwEHoUQDQgAEgroFe2fym1V7E3zr/zjuJixpyAjwfig+UTsxxm/04IvXzk2jQCQC
TgbDVohJ8dgh4iEENZv2axWye7XCBzbftQ== TgbDVohJ8dgh4iEENZv2axWye7XCBzbftQ==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ== sW3uB0zEuJyqIQ==
-----END CERTIFICATE----- -----END CERTIFICATE-----

View File

@ -1,11 +1,11 @@
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET MIIBhjCCASwCCQCZ3L2TA/e93zAKBggqhkjOPQQDAzBKMQswCQYDVQQGEwJVUzET
MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE MBEGA1UECAwKV2FzaGluZ3RvbjETMBEGA1UECgwKUmF5IENBIEluYzERMA8GA1UE
AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw AwwIUmF5Q0FJbmMwHhcNMTcwMzA3MDgzNjU4WhcNMjcwMzA1MDgzNjU4WjBMMQsw
CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy CQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjEUMBIGA1UECgwLUmF5IFNy
diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH diBJbmMxEjAQBgNVBAMMCVJheVNydkluYzBZMBMGByqGSM49AgEGCCqGSM49AwEH
A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI A0IABIK6BXtn8ptVexN86/847iYsacgI8H4oPlE7McZv9OCL185No0AkAk4Gw1aI
SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz SfHYIeIhBDWb9msVsnu1wgc237UwCgYIKoZIzj0EAwMDSAAwRQIhAK4caAt0QSTz
A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg A1WYlrEAA2AH181P7USiXkqQ5qRyoWQNAiBm3vKaoB+0p4B98HeI+h5V/7loomQg
sW3uB0zEuJyqIQ== sW3uB0zEuJyqIQ==
-----END CERTIFICATE----- -----END CERTIFICATE-----

View File

@ -36,13 +36,13 @@ import socket
from os import path from os import path
from logging import basicConfig, DEBUG from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code basicConfig(level=DEBUG) # set now for dtls import code
from dtls.sslconnection import SSLConnection from dtls.sslconnection import SSLConnection
from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN from dtls.err import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_ZERO_RETURN
def main(): def main():
sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sck = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sck.bind(("127.0.0.1", 28000)) sck.bind(("127.0.0.1", 28000))
sck.settimeout(30) sck.settimeout(30)
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs") cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
scn = SSLConnection( scn = SSLConnection(

View File

@ -1,7 +1,7 @@
HOME = . HOME = .
RANDFILE = $ENV::HOME/.rnd RANDFILE = $ENV::HOME/.rnd
[ req ] [ req ]
distinguished_name = req_distinguished_name distinguished_name = req_distinguished_name
prompt = no prompt = no

View File

@ -1,7 +1,7 @@
HOME = . HOME = .
RANDFILE = $ENV::HOME/.rnd RANDFILE = $ENV::HOME/.rnd
[ req ] [ req ]
distinguished_name = req_distinguished_name distinguished_name = req_distinguished_name
prompt = no prompt = no

View File

@ -1,15 +1,15 @@
from os import path from os import path
import ssl import ssl
from socket import socket, AF_INET, SOCK_DGRAM, SHUT_RDWR from socket import socket, AF_INET, SOCK_DGRAM, SHUT_RDWR
from logging import basicConfig, DEBUG from logging import basicConfig, DEBUG
basicConfig(level=DEBUG) # set now for dtls import code basicConfig(level=DEBUG) # set now for dtls import code
from dtls import do_patch from dtls import do_patch
do_patch() do_patch()
cert_path = path.join(path.abspath(path.dirname(__file__)), "certs") cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
sock = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), cert_reqs=ssl.CERT_REQUIRED, ca_certs=path.join(cert_path, "ca-cert.pem")) sock = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), cert_reqs=ssl.CERT_REQUIRED, ca_certs=path.join(cert_path, "ca-cert.pem"))
sock.connect(('localhost', 28000)) sock.connect(('localhost', 28000))
sock.send('Hi there') sock.send('Hi there')
print sock.recv() print sock.recv()
sock.unwrap() sock.unwrap()
sock.shutdown(SHUT_RDWR) sock.shutdown(SHUT_RDWR)

View File

@ -78,12 +78,12 @@ class BasicSocketTests(unittest.TestCase):
def test_constants(self): def test_constants(self):
ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_TLSv1 ssl.PROTOCOL_TLSv1
ssl.PROTOCOL_DTLSv1 # added ssl.PROTOCOL_DTLSv1 # added
ssl.PROTOCOL_DTLSv1_2 # added ssl.PROTOCOL_DTLSv1_2 # added
ssl.PROTOCOL_DTLS # added ssl.PROTOCOL_DTLS # added
ssl.CERT_NONE ssl.CERT_NONE
ssl.CERT_OPTIONAL ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED ssl.CERT_REQUIRED
def test_dtls_openssl_version(self): def test_dtls_openssl_version(self):
@ -91,22 +91,22 @@ class BasicSocketTests(unittest.TestCase):
t = ssl.DTLS_OPENSSL_VERSION_INFO t = ssl.DTLS_OPENSSL_VERSION_INFO
s = ssl.DTLS_OPENSSL_VERSION s = ssl.DTLS_OPENSSL_VERSION
self.assertIsInstance(n, (int, long)) self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple) self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str) self.assertIsInstance(s, str)
# Some sanity checks follow # Some sanity checks follow
# >= 1.0.2 # >= 1.0.2
self.assertGreaterEqual(n, 0x10002000) self.assertGreaterEqual(n, 0x10002000)
# < 2.0 # < 2.0
self.assertLess(n, 0x20000000) self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 1) self.assertGreaterEqual(major, 1)
self.assertLess(major, 2) self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0) self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256) self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 2) self.assertGreaterEqual(fix, 2)
self.assertLess(fix, 256) self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0) self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26) self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0) self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15) self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change # Version string as returned by OpenSSL, the format might change
@ -299,37 +299,37 @@ class NetworkedTests(unittest.TestCase):
s.close() s.close()
if test_support.verbose: if test_support.verbose:
sys.stdout.write(("\nNeeded %d calls to do_handshake() " + sys.stdout.write(("\nNeeded %d calls to do_handshake() " +
"to establish session.\n") % count) "to establish session.\n") % count)
def test_get_server_certificate(self): def test_get_server_certificate(self):
for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS): for prot in (ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLS):
with test_support.transient_internet() as remote: with test_support.transient_internet() as remote:
pem = ssl.get_server_certificate(remote, pem = ssl.get_server_certificate(remote,
prot) prot)
if not pem: if not pem:
self.fail("No server certificate!") self.fail("No server certificate!")
try: try:
pem = ssl.get_server_certificate(remote, pem = ssl.get_server_certificate(remote,
prot, prot,
ca_certs=OTHER_CERTFILE) ca_certs=OTHER_CERTFILE)
except ssl.SSLError: except ssl.SSLError:
# should fail # should fail
pass pass
else: else:
self.fail("Got server certificate %s!" % pem) self.fail("Got server certificate %s!" % pem)
pem = ssl.get_server_certificate(remote, pem = ssl.get_server_certificate(remote,
prot, prot,
ca_certs=ISSUER_CERTFILE) ca_certs=ISSUER_CERTFILE)
if not pem: if not pem:
self.fail("No server certificate!") self.fail("No server certificate!")
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\nVerified certificate is\n%s\n" % pem) sys.stdout.write("\nVerified certificate is\n%s\n" % pem)
class ThreadedEchoServer(threading.Thread): class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread): class ConnectionHandler(threading.Thread):
"""A mildly complicated class, because we want it to work both """A mildly complicated class, because we want it to work both
with and without the SSL wrapper around the socket connection, so with and without the SSL wrapper around the socket connection, so
@ -532,20 +532,20 @@ class ThreadedEchoServer(threading.Thread):
if acc_ret: if acc_ret:
newconn, connaddr = acc_ret newconn, connaddr = acc_ret
if test_support.verbose and self.chatty: if test_support.verbose and self.chatty:
sys.stdout.write(' server: new connection from ' sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n') + str(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn) handler = self.ConnectionHandler(self, newconn)
handler.start() handler.start()
except socket.timeout: except socket.timeout:
pass pass
except ssl.SSLError: except ssl.SSLError:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
self.stop() self.stop()
self.sock.close() self.sock.close()
def register_handler(self, add): def register_handler(self, add):
with self.num_handlers_lock: with self.num_handlers_lock:
if add: if add:
self.num_handlers += 1 self.num_handlers += 1
else: else:
@ -1042,40 +1042,40 @@ class ThreadedTests(unittest.TestCase):
"certs", "badkey.pem")) "certs", "badkey.pem"))
def test_protocol_dtlsv1(self): def test_protocol_dtlsv1(self):
"""Connecting to a DTLSv1 server with various client options""" """Connecting to a DTLSv1 server with various client options"""
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
# server: 1.0 - client: 1.0 -> ok # server: 1.0 - client: 1.0 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_OPTIONAL) ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
# server: any - client: 1.0 and 1.2(any) -> ok # server: any - client: 1.0 and 1.2(any) -> ok
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True) try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True, try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True) try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True, try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True) try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True)
try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True, try_protocol_combo(ssl.PROTOCOL_DTLS, ssl.PROTOCOL_DTLS, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
# server: 1.0 - client: 1.2 -> fail # server: 1.0 - client: 1.2 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False, try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1_2, False,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.0 -> fail # server: 1.2 - client: 1.0 -> fail
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False) try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False, try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1, False,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
# server: 1.2 - client: 1.2 -> ok # server: 1.2 - client: 1.2 -> ok
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True) try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True)
try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True, try_protocol_combo(ssl.PROTOCOL_DTLSv1_2, ssl.PROTOCOL_DTLSv1_2, True,
ssl.CERT_REQUIRED) ssl.CERT_REQUIRED)
def test_starttls(self): def test_starttls(self):
"""Switching from clear text to encrypted and back again.""" """Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS",
"msg 5", "msg 6") "msg 5", "msg 6")
@ -1088,13 +1088,13 @@ class ThreadedTests(unittest.TestCase):
server.start(flag) server.start(flag)
# wait for it to start # wait for it to start
flag.wait() flag.wait()
# try to connect # try to connect
wrapped = False wrapped = False
try: try:
s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1) s = ssl.wrap_socket(socket.socket(AF_INET4_6, socket.SOCK_DGRAM), ssl_version=ssl.PROTOCOL_DTLSv1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
s = s.unwrap() s = s.unwrap()
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
for indata in msgs: for indata in msgs:
if test_support.verbose: if test_support.verbose:

File diff suppressed because it is too large Load Diff

View File

@ -54,18 +54,18 @@ class _BIO(_Rsrc):
if self.owned: if self.owned:
_logger.debug("Freeing BIO: %d", self.raw) _logger.debug("Freeing BIO: %d", self.raw)
from openssl import BIO_free from openssl import BIO_free
BIO_free(self._value) BIO_free(self._value)
self.owned = False self.owned = False
self._value = None self._value = None
class _EC_KEY(_Rsrc): class _EC_KEY(_Rsrc):
"""EC KEY wrapper""" """EC KEY wrapper"""
def __init__(self, value): def __init__(self, value):
super(_EC_KEY, self).__init__(value) super(_EC_KEY, self).__init__(value)
def __del__(self): def __del__(self):
_logger.debug("Freeing EC_KEY: %d", self.raw) _logger.debug("Freeing EC_KEY: %d", self.raw)
from openssl import EC_KEY_free from openssl import EC_KEY_free
EC_KEY_free(self._value) EC_KEY_free(self._value)
self._value = None self._value = None

View File

@ -1,370 +1,370 @@
# -*- encoding: utf-8 -*- # -*- coding: utf-8 -*-
# DTLS Socket: A wrapper for a server and client using a DTLS connection. # DTLS Socket: A wrapper for a server and client using a DTLS connection.
# Copyright 2017 Björn Freise # Copyright 2017 Björn Freise
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# The License is also distributed with this work in the file named "LICENSE." # The License is also distributed with this work in the file named "LICENSE."
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""DTLS Socket """DTLS Socket
This wrapper encapsulates the state and behavior associated with the connection This wrapper encapsulates the state and behavior associated with the connection
between the OpenSSL library and an individual peer when using the DTLS between the OpenSSL library and an individual peer when using the DTLS
protocol. protocol.
Classes: Classes:
DtlsSocket -- DTLS Socket wrapper for use as a client or server DtlsSocket -- DTLS Socket wrapper for use as a client or server
""" """
import select import select
from logging import getLogger from logging import getLogger
import ssl import ssl
import socket import socket
from patch import do_patch from patch import do_patch
do_patch() do_patch()
from sslconnection import SSLContext, SSL from sslconnection import SSLContext, SSL
import err as err_codes import err as err_codes
_logger = getLogger(__name__) _logger = getLogger(__name__)
def client(sock, keyfile=None, certfile=None, def wrap_client(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None, cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
do_handshake_on_connect=True, suppress_ragged_eofs=True, do_handshake_on_connect=True, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None): ciphers=None, curves=None, sigalgs=None, user_mtu=None):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False, return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu, ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE) server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
def server(sock, keyfile=None, certfile=None, def wrap_server(sock, keyfile=None, certfile=None,
cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None, cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
do_handshake_on_connect=False, suppress_ragged_eofs=True, do_handshake_on_connect=False, suppress_ragged_eofs=True,
ciphers=None, curves=None, sigalgs=None, user_mtu=None, ciphers=None, curves=None, sigalgs=None, user_mtu=None,
server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True, return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu, ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options) server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
class DtlsSocket(object): class DtlsSocket(object):
class _ClientSession(object): class _ClientSession(object):
def __init__(self, host, port, handshake_done=False): def __init__(self, host, port, handshake_done=False):
self.host = host self.host = host
self.port = int(port) self.port = int(port)
self.handshake_done = handshake_done self.handshake_done = handshake_done
def getAddr(self): def getAddr(self):
return self.host, self.port return self.host, self.port
def __init__(self, def __init__(self,
sock=None, sock=None,
keyfile=None, keyfile=None,
certfile=None, certfile=None,
server_side=False, server_side=False,
cert_reqs=ssl.CERT_NONE, cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_DTLSv1_2, ssl_version=ssl.PROTOCOL_DTLSv1_2,
ca_certs=None, ca_certs=None,
do_handshake_on_connect=False, do_handshake_on_connect=False,
suppress_ragged_eofs=True, suppress_ragged_eofs=True,
ciphers=None, ciphers=None,
curves=None, curves=None,
sigalgs=None, sigalgs=None,
user_mtu=None, user_mtu=None,
server_key_exchange_curve=None, server_key_exchange_curve=None,
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
if server_cert_options is None: if server_cert_options is None:
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
self._ssl_logging = False self._ssl_logging = False
self._server_side = server_side self._server_side = server_side
self._ciphers = ciphers self._ciphers = ciphers
self._curves = curves self._curves = curves
self._sigalgs = sigalgs self._sigalgs = sigalgs
self._user_mtu = user_mtu self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options self._server_cert_options = server_cert_options
# Default socket creation # Default socket creation
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if isinstance(sock, socket.socket): if isinstance(sock, socket.socket):
_sock = sock _sock = sock
self._sock = ssl.wrap_socket(_sock, self._sock = ssl.wrap_socket(_sock,
keyfile=keyfile, keyfile=keyfile,
certfile=certfile, certfile=certfile,
server_side=self._server_side, server_side=self._server_side,
cert_reqs=cert_reqs, cert_reqs=cert_reqs,
ssl_version=ssl_version, ssl_version=ssl_version,
ca_certs=ca_certs, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect, do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs, suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=self._ciphers, ciphers=self._ciphers,
cb_user_config_ssl_ctx=self.user_config_ssl_ctx, cb_user_config_ssl_ctx=self.user_config_ssl_ctx,
cb_user_config_ssl=self.user_config_ssl) cb_user_config_ssl=self.user_config_ssl)
if self._server_side: if self._server_side:
self._clients = {} self._clients = {}
self._timeout = None self._timeout = None
def __getattr__(self, item): def __getattr__(self, item):
if hasattr(self, "_sock") and hasattr(self._sock, item): if hasattr(self, "_sock") and hasattr(self._sock, item):
return getattr(self._sock, item) return getattr(self._sock, item)
raise AttributeError raise AttributeError
def user_config_ssl_ctx(self, _ctx): def user_config_ssl_ctx(self, _ctx):
""" """
:param SSLContext _ctx: :param SSLContext _ctx:
""" """
_ctx.set_ssl_logging(self._ssl_logging) _ctx.set_ssl_logging(self._ssl_logging)
if self._ciphers: if self._ciphers:
_ctx.set_ciphers(self._ciphers) _ctx.set_ciphers(self._ciphers)
if self._curves: if self._curves:
_ctx.set_curves(self._curves) _ctx.set_curves(self._curves)
if self._sigalgs: if self._sigalgs:
_ctx.set_sigalgs(self._sigalgs) _ctx.set_sigalgs(self._sigalgs)
if self._server_side: if self._server_side:
_ctx.build_cert_chain(flags=self._server_cert_options) _ctx.build_cert_chain(flags=self._server_cert_options)
_ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve) _ctx.set_ecdh_curve(curve_name=self._server_key_exchange_curve)
def user_config_ssl(self, _ssl): def user_config_ssl(self, _ssl):
""" """
:param SSL _ssl: :param SSL _ssl:
""" """
if self._user_mtu: if self._user_mtu:
_ssl.set_link_mtu(self._user_mtu) _ssl.set_link_mtu(self._user_mtu)
def settimeout(self, t): def settimeout(self, t):
if self._server_side: if self._server_side:
self._timeout = t self._timeout = t
else: else:
self._sock.settimeout(t) self._sock.settimeout(t)
def close(self): def close(self):
if self._server_side: if self._server_side:
for cli in self._clients.keys(): for cli in self._clients.keys():
cli.close() cli.close()
else: else:
try: try:
self._sock.unwrap() self._sock.unwrap()
except: except:
pass pass
self._sock.close() self._sock.close()
def recvfrom(self, bufsize, flags=0): def recvfrom(self, bufsize, flags=0):
if self._server_side: if self._server_side:
return self._recvfrom_on_server_side(bufsize, flags=flags) return self._recvfrom_on_server_side(bufsize, flags=flags)
else: else:
return self._recvfrom_on_client_side(bufsize, flags=flags) return self._recvfrom_on_client_side(bufsize, flags=flags)
def _recvfrom_on_server_side(self, bufsize, flags): def _recvfrom_on_server_side(self, bufsize, flags):
try: try:
r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout) r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
except socket.timeout: except socket.timeout:
# __Nothing__ received from any client # __Nothing__ received from any client
raise socket.timeout raise socket.timeout
try: try:
for conn in r: for conn in r:
_last_peer = conn.getpeername() if conn._connected else None _last_peer = conn.getpeername() if conn._connected else None
if self._sockIsServerSock(conn): if self._sockIsServerSock(conn):
# Connect # Connect
self._clientAccept(conn) self._clientAccept(conn)
else: else:
# Handshake # Handshake
if not self._clientHandshakeDone(conn): if not self._clientHandshakeDone(conn):
self._clientDoHandshake(conn) self._clientDoHandshake(conn)
# Normal read # Normal read
else: else:
buf = self._clientRead(conn, bufsize) buf = self._clientRead(conn, bufsize)
if buf: if buf:
if conn in self._clients: if conn in self._clients:
return buf, self._clients[conn].getAddr() return buf, self._clients[conn].getAddr()
else: else:
_logger.debug('Received data from an already disconnected client!') _logger.debug('Received data from an already disconnected client!')
except Exception as e: except Exception as e:
setattr(e, 'peer', _last_peer) setattr(e, 'peer', _last_peer)
raise e raise e
try: try:
for conn in self._getClientReadingSockets(): for conn in self._getClientReadingSockets():
if conn.get_timeout(): if conn.get_timeout():
ret = conn.handle_timeout() ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret)) _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
except Exception as e: except Exception as e:
raise e raise e
# __No_data__ received from any client # __No_data__ received from any client
raise socket.timeout raise socket.timeout
def _recvfrom_on_client_side(self, bufsize, flags): def _recvfrom_on_client_side(self, bufsize, flags):
try: try:
buf = self._sock.recv(bufsize, flags) buf = self._sock.recv(bufsize, flags)
except ssl.SSLError as e: except ssl.SSLError as e:
if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ: if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass pass
else: else:
raise e raise e
else: else:
if buf: if buf:
return buf, self._sock.getpeername() return buf, self._sock.getpeername()
# __No_data__ received from any client # __No_data__ received from any client
raise socket.timeout raise socket.timeout
def sendto(self, buf, address): def sendto(self, buf, address):
if self._server_side: if self._server_side:
return self._sendto_from_server_side(buf, address) return self._sendto_from_server_side(buf, address)
else: else:
return self._sendto_from_client_side(buf, address) return self._sendto_from_client_side(buf, address)
def _sendto_from_server_side(self, buf, address): def _sendto_from_server_side(self, buf, address):
for conn, client in self._clients.iteritems(): for conn, client in self._clients.iteritems():
if client.getAddr() == address: if client.getAddr() == address:
return self._clientWrite(conn, buf) return self._clientWrite(conn, buf)
return 0 return 0
def _sendto_from_client_side(self, buf, address): def _sendto_from_client_side(self, buf, address):
try: try:
if not self._sock._connected: if not self._sock._connected:
self._sock.connect(address) self._sock.connect(address)
bytes_sent = self._sock.send(buf) bytes_sent = self._sock.send(buf)
except ssl.SSLError as e: except ssl.SSLError as e:
raise e raise e
return bytes_sent return bytes_sent
def _getClientReadingSockets(self): def _getClientReadingSockets(self):
return [x for x in self._clients.keys()] return [x for x in self._clients.keys()]
def _getAllReadingSockets(self): def _getAllReadingSockets(self):
return [self._sock] + self._getClientReadingSockets() return [self._sock] + self._getClientReadingSockets()
def _sockIsServerSock(self, conn): def _sockIsServerSock(self, conn):
return conn is self._sock return conn is self._sock
def _clientHandshakeDone(self, conn): def _clientHandshakeDone(self, conn):
return conn in self._clients and self._clients[conn].handshake_done is True return conn in self._clients and self._clients[conn].handshake_done is True
def _clientAccept(self, conn): def _clientAccept(self, conn):
_logger.debug('+' * 60) _logger.debug('+' * 60)
ret = None ret = None
try: try:
ret = conn.accept() ret = conn.accept()
_logger.debug('Accept returned with ... %s' % (str(ret))) _logger.debug('Accept returned with ... %s' % (str(ret)))
except Exception as e: except Exception as e:
raise e raise e
else: else:
if ret: if ret:
client, addr = ret client, addr = ret
host, port = addr host, port = addr
if client in self._clients: if client in self._clients:
_logger.debug('Client already connected %s' % str(client)) _logger.debug('Client already connected %s' % str(client))
raise ValueError raise ValueError
self._clients[client] = self._ClientSession(host=host, port=port) self._clients[client] = self._ClientSession(host=host, port=port)
self._clientDoHandshake(client) self._clientDoHandshake(client)
def _clientDoHandshake(self, conn): def _clientDoHandshake(self, conn):
_logger.debug('-' * 60) _logger.debug('-' * 60)
conn.setblocking(False) conn.setblocking(False)
try: try:
conn.do_handshake() conn.do_handshake()
_logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr()))) _logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
self._clients[conn].handshake_done = True self._clients[conn].handshake_done = True
except ssl.SSLError as e: except ssl.SSLError as e:
if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ: if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass pass
else: else:
self._clientDrop(conn, error=e) self._clientDrop(conn, error=e)
raise e raise e
def _clientRead(self, conn, bufsize=4096): def _clientRead(self, conn, bufsize=4096):
_logger.debug('*' * 60) _logger.debug('*' * 60)
ret = None ret = None
try: try:
ret = conn.recv(bufsize) ret = conn.recv(bufsize)
_logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret)))) _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
except ssl.SSLError as e: except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ: if e.args[0] == ssl.SSL_ERROR_WANT_READ:
pass pass
else: else:
self._clientDrop(conn, error=e) self._clientDrop(conn, error=e)
return ret return ret
def _clientWrite(self, conn, data): def _clientWrite(self, conn, data):
_logger.debug('#' * 60) _logger.debug('#' * 60)
ret = None ret = None
try: try:
_data = data _data = data
if False: if False:
_data = data.raw _data = data.raw
ret = conn.send(_data) ret = conn.send(_data)
_logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret))) _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
except Exception as e: except Exception as e:
raise e raise e
return ret return ret
def _clientDrop(self, conn, error=None): def _clientDrop(self, conn, error=None):
_logger.debug('$' * 60) _logger.debug('$' * 60)
try: try:
if error: if error:
_logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error)) _logger.debug('Drop client %s ... with error: %s' % (self._clients[conn].getAddr(), error))
else: else:
_logger.debug('Drop client %s' % str(self._clients[conn].getAddr())) _logger.debug('Drop client %s' % str(self._clients[conn].getAddr()))
if conn in self._clients: if conn in self._clients:
del self._clients[conn] del self._clients[conn]
try: try:
conn.unwrap() conn.unwrap()
except: except:
pass pass
conn.close() conn.close()
except Exception as e: except Exception as e:
pass pass

View File

@ -40,23 +40,23 @@ _logger = getLogger(__name__)
class _X509(_Rsrc): class _X509(_Rsrc):
"""Wrapper for the cryptographic library's X509 resource""" """Wrapper for the cryptographic library's X509 resource"""
def __init__(self, value): def __init__(self, value):
super(_X509, self).__init__(value) super(_X509, self).__init__(value)
def __del__(self): def __del__(self):
_logger.debug("Freeing X509: %d", self.raw) _logger.debug("Freeing X509: %d", self.raw)
X509_free(self._value) X509_free(self._value)
self._value = None self._value = None
class _STACK(_Rsrc): class _STACK(_Rsrc):
"""Wrapper for the cryptographic library's stacks""" """Wrapper for the cryptographic library's stacks"""
def __init__(self, value): def __init__(self, value):
super(_STACK, self).__init__(value) super(_STACK, self).__init__(value)
def __del__(self): def __del__(self):
_logger.debug("Freeing stack: %d", self.raw) _logger.debug("Freeing stack: %d", self.raw)
sk_pop_free(self._value) sk_pop_free(self._value)
self._value = None self._value = None
def decode_cert(cert): def decode_cert(cert):
"""Convert an X509 certificate into a Python dictionary """Convert an X509 certificate into a Python dictionary