Merge pull request #3 from mcfreis/master
Update "Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()"
This commit is contained in:
		
						commit
						418655bed0
					
				
							
								
								
									
										14
									
								
								ChangeLog
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								ChangeLog
									
									
									
									
									
								
							@ -1,3 +1,17 @@
 | 
				
			|||||||
 | 
					2017-03-23  Björn Freise  <mcfreis@gmx.net>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Patched ssl-Module with SSL_BUILD_*- and ERR_*- constants and added aliases for wrap_server() and wrap_client()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						* dtls/__init__.py: Added DtlsSocket() from wrapper and aliases for wrap_server() and wrap_client()
 | 
				
			||||||
 | 
						* dtls/err.py: Added patch_ssl_errors() to patch ssl-Module with ERR_* constants
 | 
				
			||||||
 | 
						* dtls/patch.py: Patched ssl-Module with SSL_BUILD_* constants and added call to patch_ssl_errors()
 | 
				
			||||||
 | 
						* dtls/wrapper.py:
 | 
				
			||||||
 | 
						    - Added a server and client function to alias/wrap DtlsSocket() creation
 | 
				
			||||||
 | 
						    - Cleanup of DtlsSocket.__init__()
 | 
				
			||||||
 | 
						    - Cleanup of exception handling in all member methods
 | 
				
			||||||
 | 
						    - Cleanup sendto() from client: no endless loop and first do a connect if not already connected
 | 
				
			||||||
 | 
						* dtls/test/unit_wrapper.py: Adopt the changes made described above
 | 
				
			||||||
 | 
					
 | 
				
			||||||
2017-03-17  Björn Freise  <mcfreis@gmx.net>
 | 
					2017-03-17  Björn Freise  <mcfreis@gmx.net>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Added a wrapper for a DTLS-Socket either as client or server - including unit tests
 | 
						Added a wrapper for a DTLS-Socket either as client or server - including unit tests
 | 
				
			||||||
 | 
				
			|||||||
@ -61,4 +61,4 @@ _prep_bins()  # prepare before module imports
 | 
				
			|||||||
from patch import do_patch
 | 
					from patch import do_patch
 | 
				
			||||||
from sslconnection import SSLContext, SSL, SSLConnection
 | 
					from sslconnection import SSLContext, SSL, SSLConnection
 | 
				
			||||||
from demux import force_routing_demux, reset_default_demux
 | 
					from demux import force_routing_demux, reset_default_demux
 | 
				
			||||||
import err as error_codes
 | 
					from wrapper import DtlsSocket, client as wrap_client, server as wrap_server
 | 
				
			||||||
 | 
				
			|||||||
@ -54,7 +54,14 @@ ERR_COOKIE_MISMATCH = 0x1408A134
 | 
				
			|||||||
ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
 | 
					ERR_CERTIFICATE_VERIFY_FAILED = 0x14090086
 | 
				
			||||||
ERR_NO_SHARED_CIPHER = 0x1408A0C1
 | 
					ERR_NO_SHARED_CIPHER = 0x1408A0C1
 | 
				
			||||||
ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
 | 
					ERR_SSL_HANDSHAKE_FAILURE = 0x1410C0E5
 | 
				
			||||||
 | 
					ERR_TLSV1_ALERT_UNKNOWN_CA = 0x14102418
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def patch_ssl_errors():
 | 
				
			||||||
 | 
					    import ssl
 | 
				
			||||||
 | 
					    errors = [i for i in globals().iteritems() if type(i[1]) == int and str(i[0]).startswith('ERR_')]
 | 
				
			||||||
 | 
					    for k, v in errors:
 | 
				
			||||||
 | 
					        if not hasattr(ssl, k):
 | 
				
			||||||
 | 
					            setattr(ssl, k, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SSLError(socket_error):
 | 
					class SSLError(socket_error):
 | 
				
			||||||
    """This exception is raised by modules in the dtls package."""
 | 
					    """This exception is raised by modules in the dtls package."""
 | 
				
			||||||
 | 
				
			|||||||
@ -43,7 +43,9 @@ import errno
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
 | 
					from sslconnection import SSLConnection, PROTOCOL_DTLS, PROTOCOL_DTLSv1, PROTOCOL_DTLSv1_2
 | 
				
			||||||
from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
 | 
					from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION, DTLS_OPENSSL_VERSION_INFO
 | 
				
			||||||
from err import raise_as_ssl_module_error
 | 
					from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
 | 
				
			||||||
 | 
					    SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK, SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR, SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
 | 
				
			||||||
 | 
					from err import raise_as_ssl_module_error, patch_ssl_errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def do_patch():
 | 
					def do_patch():
 | 
				
			||||||
@ -64,10 +66,17 @@ def do_patch():
 | 
				
			|||||||
    ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
 | 
					    ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER
 | 
				
			||||||
    ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
 | 
					    ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION
 | 
				
			||||||
    ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
 | 
					    ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_NONE = SSL_BUILD_CHAIN_FLAG_NONE
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_UNTRUSTED = SSL_BUILD_CHAIN_FLAG_UNTRUSTED
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT = SSL_BUILD_CHAIN_FLAG_NO_ROOT
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_CHECK = SSL_BUILD_CHAIN_FLAG_CHECK
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR = SSL_BUILD_CHAIN_FLAG_IGNORE_ERROR
 | 
				
			||||||
 | 
					    ssl.SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR = SSL_BUILD_CHAIN_FLAG_CLEAR_ERROR
 | 
				
			||||||
    _orig_SSLSocket_init = ssl.SSLSocket.__init__
 | 
					    _orig_SSLSocket_init = ssl.SSLSocket.__init__
 | 
				
			||||||
    _orig_get_server_certificate = ssl.get_server_certificate
 | 
					    _orig_get_server_certificate = ssl.get_server_certificate
 | 
				
			||||||
    ssl.SSLSocket.__init__ = _SSLSocket_init
 | 
					    ssl.SSLSocket.__init__ = _SSLSocket_init
 | 
				
			||||||
    ssl.get_server_certificate = _get_server_certificate
 | 
					    ssl.get_server_certificate = _get_server_certificate
 | 
				
			||||||
 | 
					    patch_ssl_errors()
 | 
				
			||||||
    raise_as_ssl_module_error()
 | 
					    raise_as_ssl_module_error()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _wrap_socket(sock, keyfile=None, certfile=None,
 | 
					def _wrap_socket(sock, keyfile=None, certfile=None,
 | 
				
			||||||
 | 
				
			|||||||
@ -15,8 +15,7 @@ from logging import basicConfig, DEBUG, getLogger
 | 
				
			|||||||
_logger = getLogger(__name__)
 | 
					_logger = getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import ssl
 | 
					import ssl
 | 
				
			||||||
from dtls import do_patch, error_codes
 | 
					from dtls import DtlsSocket
 | 
				
			||||||
from dtls.wrapper import DtlsSocket, SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HOST = "localhost"
 | 
					HOST = "localhost"
 | 
				
			||||||
@ -186,7 +185,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': None},
 | 
					          'client_sigalgs': None},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_WRONG_SSL_VERSION,
 | 
					          'error_code': ssl.ERR_WRONG_SSL_VERSION,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
        {'name': 'certificate verify fails',
 | 
					        {'name': 'certificate verify fails',
 | 
				
			||||||
@ -209,7 +208,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': None},
 | 
					          'client_sigalgs': None},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_CERTIFICATE_VERIFY_FAILED,
 | 
					          'error_code': ssl.ERR_CERTIFICATE_VERIFY_FAILED,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
        {'name': 'no matching curve',
 | 
					        {'name': 'no matching curve',
 | 
				
			||||||
@ -232,7 +231,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': None},
 | 
					          'client_sigalgs': None},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
 | 
					          'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
         {'name': 'matching curve',
 | 
					         {'name': 'matching curve',
 | 
				
			||||||
@ -278,7 +277,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': None},
 | 
					          'client_sigalgs': None},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_PORT_UNREACHABLE,
 | 
					          'error_code': ssl.ERR_PORT_UNREACHABLE,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
        {'name': 'no matching sigalgs',
 | 
					        {'name': 'no matching sigalgs',
 | 
				
			||||||
@ -301,7 +300,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': "RSA+SHA256"},
 | 
					          'client_sigalgs': "RSA+SHA256"},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
 | 
					          'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
        {'name': 'matching sigalgs',
 | 
					        {'name': 'matching sigalgs',
 | 
				
			||||||
@ -347,7 +346,7 @@ tests = [
 | 
				
			|||||||
          'client_sigalgs': None},
 | 
					          'client_sigalgs': None},
 | 
				
			||||||
     'result':
 | 
					     'result':
 | 
				
			||||||
         {'ret_success': False,
 | 
					         {'ret_success': False,
 | 
				
			||||||
          'error_code': error_codes.ERR_SSL_HANDSHAKE_FAILURE,
 | 
					          'error_code': ssl.ERR_SSL_HANDSHAKE_FAILURE,
 | 
				
			||||||
          'exception': None}},
 | 
					          'exception': None}},
 | 
				
			||||||
    {'testcase':
 | 
					    {'testcase':
 | 
				
			||||||
        {'name': 'matching cipher',
 | 
					        {'name': 'matching cipher',
 | 
				
			||||||
@ -493,13 +492,8 @@ class TestSequenceMeta(type):
 | 
				
			|||||||
class WrapperTests(unittest.TestCase):
 | 
					class WrapperTests(unittest.TestCase):
 | 
				
			||||||
    __metaclass__ = TestSequenceMeta
 | 
					    __metaclass__ = TestSequenceMeta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        super(WrapperTests, self).setUp()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        do_patch()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_build_cert_chain(self):
 | 
					    def test_build_cert_chain(self):
 | 
				
			||||||
        steps = [SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_NO_ROOT]
 | 
					        steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT]
 | 
				
			||||||
        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
 | 
					        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
 | 
				
			||||||
        indata = 'FOO'
 | 
					        indata = 'FOO'
 | 
				
			||||||
        certs = dict()
 | 
					        certs = dict()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										145
									
								
								dtls/wrapper.py
									
									
									
									
									
								
							
							
						
						
									
										145
									
								
								dtls/wrapper.py
									
									
									
									
									
								
							@ -38,12 +38,36 @@ import socket
 | 
				
			|||||||
from patch import do_patch
 | 
					from patch import do_patch
 | 
				
			||||||
do_patch()
 | 
					do_patch()
 | 
				
			||||||
from sslconnection import SSLContext, SSL
 | 
					from sslconnection import SSLContext, SSL
 | 
				
			||||||
from sslconnection import SSL_BUILD_CHAIN_FLAG_NONE, SSL_BUILD_CHAIN_FLAG_UNTRUSTED, \
 | 
					import err as err_codes
 | 
				
			||||||
    SSL_BUILD_CHAIN_FLAG_NO_ROOT, SSL_BUILD_CHAIN_FLAG_CHECK
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
_logger = getLogger(__name__)
 | 
					_logger = getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def client(sock, keyfile=None, certfile=None,
 | 
				
			||||||
 | 
					           cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLSv1_2, ca_certs=None,
 | 
				
			||||||
 | 
					           do_handshake_on_connect=True, suppress_ragged_eofs=True,
 | 
				
			||||||
 | 
					           ciphers=None, curves=None, sigalgs=None, user_mtu=None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=False,
 | 
				
			||||||
 | 
					                      cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
 | 
				
			||||||
 | 
					                      do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
 | 
				
			||||||
 | 
					                      ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
 | 
				
			||||||
 | 
					                      server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def server(sock, keyfile=None, certfile=None,
 | 
				
			||||||
 | 
					           cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_DTLS, ca_certs=None,
 | 
				
			||||||
 | 
					           do_handshake_on_connect=False, suppress_ragged_eofs=True,
 | 
				
			||||||
 | 
					           ciphers=None, curves=None, sigalgs=None, user_mtu=None,
 | 
				
			||||||
 | 
					           server_key_exchange_curve=None, server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return DtlsSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=True,
 | 
				
			||||||
 | 
					                      cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs,
 | 
				
			||||||
 | 
					                      do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs,
 | 
				
			||||||
 | 
					                      ciphers=ciphers, curves=curves, sigalgs=sigalgs, user_mtu=user_mtu,
 | 
				
			||||||
 | 
					                      server_key_exchange_curve=server_key_exchange_curve, server_cert_options=server_cert_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DtlsSocket(object):
 | 
					class DtlsSocket(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class _ClientSession(object):
 | 
					    class _ClientSession(object):
 | 
				
			||||||
@ -57,7 +81,7 @@ class DtlsSocket(object):
 | 
				
			|||||||
            return self.host, self.port
 | 
					            return self.host, self.port
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 peerOrSock,
 | 
					                 sock=None,
 | 
				
			||||||
                 keyfile=None,
 | 
					                 keyfile=None,
 | 
				
			||||||
                 certfile=None,
 | 
					                 certfile=None,
 | 
				
			||||||
                 server_side=False,
 | 
					                 server_side=False,
 | 
				
			||||||
@ -71,13 +95,12 @@ class DtlsSocket(object):
 | 
				
			|||||||
                 sigalgs=None,
 | 
					                 sigalgs=None,
 | 
				
			||||||
                 user_mtu=None,
 | 
					                 user_mtu=None,
 | 
				
			||||||
                 server_key_exchange_curve=None,
 | 
					                 server_key_exchange_curve=None,
 | 
				
			||||||
                 server_cert_options=SSL_BUILD_CHAIN_FLAG_NONE):
 | 
					                 server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if server_cert_options is None:
 | 
					        if server_cert_options is None:
 | 
				
			||||||
            server_cert_options = SSL_BUILD_CHAIN_FLAG_NONE
 | 
					            server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._ssl_logging = False
 | 
					        self._ssl_logging = False
 | 
				
			||||||
        self._peer = None
 | 
					 | 
				
			||||||
        self._server_side = server_side
 | 
					        self._server_side = server_side
 | 
				
			||||||
        self._ciphers = ciphers
 | 
					        self._ciphers = ciphers
 | 
				
			||||||
        self._curves = curves
 | 
					        self._curves = curves
 | 
				
			||||||
@ -87,15 +110,11 @@ class DtlsSocket(object):
 | 
				
			|||||||
        self._server_cert_options = server_cert_options
 | 
					        self._server_cert_options = server_cert_options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Default socket creation
 | 
					        # Default socket creation
 | 
				
			||||||
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
					        _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
				
			||||||
        if isinstance(peerOrSock, tuple):
 | 
					        if isinstance(sock, socket.socket):
 | 
				
			||||||
            # Address tuple
 | 
					            _sock = sock
 | 
				
			||||||
            self._peer = peerOrSock
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # Socket, use given
 | 
					 | 
				
			||||||
            sock = peerOrSock
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._sock = ssl.wrap_socket(sock,
 | 
					        self._sock = ssl.wrap_socket(_sock,
 | 
				
			||||||
                                     keyfile=keyfile,
 | 
					                                     keyfile=keyfile,
 | 
				
			||||||
                                     certfile=certfile,
 | 
					                                     certfile=certfile,
 | 
				
			||||||
                                     server_side=self._server_side,
 | 
					                                     server_side=self._server_side,
 | 
				
			||||||
@ -112,13 +131,6 @@ class DtlsSocket(object):
 | 
				
			|||||||
            self._clients = {}
 | 
					            self._clients = {}
 | 
				
			||||||
            self._timeout = None
 | 
					            self._timeout = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self._peer:
 | 
					 | 
				
			||||||
                self._sock.bind(self._peer)
 | 
					 | 
				
			||||||
                self._sock.listen(0)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if self._peer:
 | 
					 | 
				
			||||||
                self._sock.connect(self._peer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __getattr__(self, item):
 | 
					    def __getattr__(self, item):
 | 
				
			||||||
        if hasattr(self, "_sock") and hasattr(self._sock, item):
 | 
					        if hasattr(self, "_sock") and hasattr(self._sock, item):
 | 
				
			||||||
            return getattr(self._sock, item)
 | 
					            return getattr(self._sock, item)
 | 
				
			||||||
@ -159,17 +171,12 @@ class DtlsSocket(object):
 | 
				
			|||||||
            for cli in self._clients.keys():
 | 
					            for cli in self._clients.keys():
 | 
				
			||||||
                cli.close()
 | 
					                cli.close()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
                self._sock.unwrap()
 | 
					                self._sock.unwrap()
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
        self._sock.close()
 | 
					        self._sock.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def write(self, data):
 | 
					 | 
				
			||||||
        # return self._sock.write(data)
 | 
					 | 
				
			||||||
        return self.sendto(data, self._peer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def read(self, len=1024):
 | 
					 | 
				
			||||||
        # return self._sock.read(len=len)
 | 
					 | 
				
			||||||
        return self.recvfrom(len)[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def recvfrom(self, bufsize, flags=0):
 | 
					    def recvfrom(self, bufsize, flags=0):
 | 
				
			||||||
        if self._server_side:
 | 
					        if self._server_side:
 | 
				
			||||||
            return self._recvfrom_on_server_side(bufsize, flags=flags)
 | 
					            return self._recvfrom_on_server_side(bufsize, flags=flags)
 | 
				
			||||||
@ -180,11 +187,13 @@ class DtlsSocket(object):
 | 
				
			|||||||
        try:
 | 
					        try:
 | 
				
			||||||
            r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
 | 
					            r, _, _ = select.select(self._getAllReadingSockets(), [], [], self._timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except socket.timeout as e_timeout:
 | 
					        except socket.timeout:
 | 
				
			||||||
            raise e_timeout
 | 
					            # __Nothing__ received from any client
 | 
				
			||||||
 | 
					            raise socket.timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            for conn in r:  # type: ssl.SSLSocket
 | 
					            for conn in r:
 | 
				
			||||||
 | 
					                _last_peer = conn.getpeername() if conn._connected else None
 | 
				
			||||||
                if self._sockIsServerSock(conn):
 | 
					                if self._sockIsServerSock(conn):
 | 
				
			||||||
                    # Connect
 | 
					                    # Connect
 | 
				
			||||||
                    self._clientAccept(conn)
 | 
					                    self._clientAccept(conn)
 | 
				
			||||||
@ -195,38 +204,43 @@ class DtlsSocket(object):
 | 
				
			|||||||
                    # Normal read
 | 
					                    # Normal read
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        buf = self._clientRead(conn, bufsize)
 | 
					                        buf = self._clientRead(conn, bufsize)
 | 
				
			||||||
                        if buf and conn in self._clients:
 | 
					                        if buf:
 | 
				
			||||||
 | 
					                            if conn in self._clients:
 | 
				
			||||||
                                return buf, self._clients[conn].getAddr()
 | 
					                                return buf, self._clients[conn].getAddr()
 | 
				
			||||||
 | 
					                            else:
 | 
				
			||||||
 | 
					                                _logger.debug('Received data from an already disconnected client!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            setattr(e, 'peer', _last_peer)
 | 
				
			||||||
            raise e
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            for conn in self._getClientReadingSockets():
 | 
					            for conn in self._getClientReadingSockets():
 | 
				
			||||||
                if conn.get_timeout():
 | 
					                if conn.get_timeout():
 | 
				
			||||||
                    conn.handle_timeout()
 | 
					                    ret = conn.handle_timeout()
 | 
				
			||||||
 | 
					                    _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise e
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # __No_data__ received from any client
 | 
				
			||||||
        raise socket.timeout
 | 
					        raise socket.timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _recvfrom_on_client_side(self, bufsize, flags):
 | 
					    def _recvfrom_on_client_side(self, bufsize, flags):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            buf = self._sock.recv(bufsize, flags)
 | 
					            buf = self._sock.recv(bufsize, flags)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ssl.SSLError as e_ssl:
 | 
					        except ssl.SSLError as e:
 | 
				
			||||||
            if e_ssl.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
 | 
					            if e.errno == ssl.ERR_READ_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
				
			||||||
                return '', self._peer
 | 
					 | 
				
			||||||
            elif e_ssl.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
 | 
					 | 
				
			||||||
                raise e_ssl
 | 
					 | 
				
			||||||
            else:  # like in [ssl.SSL_ERROR_WANT_READ, ...]
 | 
					 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if buf:
 | 
					            if buf:
 | 
				
			||||||
                return buf, self._peer
 | 
					                return buf, self._sock.getpeername()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # __No_data__ received from any client
 | 
				
			||||||
        raise socket.timeout
 | 
					        raise socket.timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sendto(self, buf, address):
 | 
					    def sendto(self, buf, address):
 | 
				
			||||||
@ -242,19 +256,13 @@ class DtlsSocket(object):
 | 
				
			|||||||
        return 0
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _sendto_from_client_side(self, buf, address):
 | 
					    def _sendto_from_client_side(self, buf, address):
 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            if not self._sock._connected:
 | 
				
			||||||
 | 
					                self._sock.connect(address)
 | 
				
			||||||
            bytes_sent = self._sock.send(buf)
 | 
					            bytes_sent = self._sock.send(buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            except ssl.SSLError as e_ssl:
 | 
					        except ssl.SSLError as e:
 | 
				
			||||||
                if str(e_ssl).startswith("503:"):
 | 
					            raise e
 | 
				
			||||||
                    # The write operation timed out
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                raise e_ssl
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                if bytes_sent:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return bytes_sent
 | 
					        return bytes_sent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -278,14 +286,15 @@ class DtlsSocket(object):
 | 
				
			|||||||
            ret = conn.accept()
 | 
					            ret = conn.accept()
 | 
				
			||||||
            _logger.debug('Accept returned with ... %s' % (str(ret)))
 | 
					            _logger.debug('Accept returned with ... %s' % (str(ret)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e_accept:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise e_accept
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if ret:
 | 
					            if ret:
 | 
				
			||||||
                client, addr = ret
 | 
					                client, addr = ret
 | 
				
			||||||
                host, port = addr
 | 
					                host, port = addr
 | 
				
			||||||
                if client in self._clients:
 | 
					                if client in self._clients:
 | 
				
			||||||
 | 
					                    _logger.debug('Client already connected %s' % str(client))
 | 
				
			||||||
                    raise ValueError
 | 
					                    raise ValueError
 | 
				
			||||||
                self._clients[client] = self._ClientSession(host=host, port=port)
 | 
					                self._clients[client] = self._ClientSession(host=host, port=port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -297,17 +306,16 @@ class DtlsSocket(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            conn.do_handshake()
 | 
					            conn.do_handshake()
 | 
				
			||||||
            _logger.debug('Connection from %s succesful' % (str(self._clients[conn].getAddr())))
 | 
					            _logger.debug('Connection from %s successful' % (str(self._clients[conn].getAddr())))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self._clients[conn].handshake_done = True
 | 
					            self._clients[conn].handshake_done = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ssl.SSLError as e_handshake:
 | 
					        except ssl.SSLError as e:
 | 
				
			||||||
            if str(e_handshake).startswith("504:"):
 | 
					            if e.errno == err_codes.ERR_HANDSHAKE_TIMEOUT or e.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
				
			||||||
                pass
 | 
					 | 
				
			||||||
            elif e_handshake.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
					 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                raise e_handshake
 | 
					                self._clientDrop(conn, error=e)
 | 
				
			||||||
 | 
					                raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _clientRead(self, conn, bufsize=4096):
 | 
					    def _clientRead(self, conn, bufsize=4096):
 | 
				
			||||||
        _logger.debug('*' * 60)
 | 
					        _logger.debug('*' * 60)
 | 
				
			||||||
@ -317,13 +325,11 @@ class DtlsSocket(object):
 | 
				
			|||||||
            ret = conn.recv(bufsize)
 | 
					            ret = conn.recv(bufsize)
 | 
				
			||||||
            _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
 | 
					            _logger.debug('From client %s ... bytes received %s' % (str(self._clients[conn].getAddr()), str(len(ret))))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ssl.SSLError as e_read:
 | 
					        except ssl.SSLError as e:
 | 
				
			||||||
            if e_read.args[0] == ssl.SSL_ERROR_ZERO_RETURN:
 | 
					            if e.args[0] == ssl.SSL_ERROR_WANT_READ:
 | 
				
			||||||
                self._clientDrop(conn)
 | 
					 | 
				
			||||||
            elif e_read.args[0] in [ssl.SSL_ERROR_SSL, ssl.SSL_ERROR_SYSCALL]:
 | 
					 | 
				
			||||||
                self._clientDrop(conn, error=e_read)
 | 
					 | 
				
			||||||
            else:  # like in [ssl.SSL_ERROR_WANT_READ, ...]
 | 
					 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self._clientDrop(conn, error=e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ret
 | 
					        return ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -338,8 +344,8 @@ class DtlsSocket(object):
 | 
				
			|||||||
            ret = conn.send(_data)
 | 
					            ret = conn.send(_data)
 | 
				
			||||||
            _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
 | 
					            _logger.debug('To client %s ... bytes sent %s' % (str(self._clients[conn].getAddr()), str(ret)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e_write:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise e_write
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ret
 | 
					        return ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -354,8 +360,11 @@ class DtlsSocket(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            if conn in self._clients:
 | 
					            if conn in self._clients:
 | 
				
			||||||
                del self._clients[conn]
 | 
					                del self._clients[conn]
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
                conn.unwrap()
 | 
					                conn.unwrap()
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
            conn.close()
 | 
					            conn.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e_drop:
 | 
					        except Exception as e:
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user