From 22083e822193718a6a9206e97359338e4718fce0 Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Wed, 21 Nov 2012 11:23:03 -0800 Subject: [PATCH] SSL standard library module wire-in A patch implementation is provided, which augments and alters the Python standard library's ssl module to support passing of datagram sockets, in which case this package's DTLS protocol support will be activated. The ssl module's interface is intended to operate identically regardless of whether the DTLS protocol or another protocol is chosen. The following features of the ssl module are explicitly supported with datagram sockets: * socket wrapping, unwrapping, and re-wrapping * threaded UDP servers * asynchronous UDP servers (asyncore integration) * socket servers (SocketServer integration) The following modules have been added: * dtls.patch: standard library module patching code and substitution functions and methods * unit.py: this is a port of the standard library's testing module test_ssl.py for datagram sockets; all tests pass at this time; a couple of inapplicable tests have been dropped; a few other tests have been added Also note that the err module's exception raising mechanism has been augmented so as to raise exceptions of type ssl.SSLError (as opposed to dtls.err.SSLError) when instructed to do so through activation of the patching mechanism. This allows code written against the standard library module's interface to remain unchanged. In some cases, types derived from ssl.SSLError are raised. --- dtls/__init__.py | 18 +- dtls/err.py | 58 +- dtls/openssl.py | 45 +- dtls/patch.py | 199 +++++ dtls/sslconnection.py | 186 ++++- dtls/test/certs/badcert.pem | 36 + dtls/test/certs/badkey.pem | 40 + dtls/test/certs/keycert.pem | 21 + dtls/test/certs/nullcert.pem | 0 dtls/test/certs/wrongcert.pem | 32 + dtls/test/echo_seq.py | 4 +- dtls/test/unit.py | 1313 +++++++++++++++++++++++++++++++++ 12 files changed, 1911 insertions(+), 41 deletions(-) create mode 100644 dtls/patch.py create mode 100644 dtls/test/certs/badcert.pem create mode 100644 dtls/test/certs/badkey.pem create mode 100644 dtls/test/certs/keycert.pem create mode 100644 dtls/test/certs/nullcert.pem create mode 100644 dtls/test/certs/wrongcert.pem create mode 100644 dtls/test/unit.py diff --git a/dtls/__init__.py b/dtls/__init__.py index d110877..93f6688 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -1,11 +1,19 @@ # PyDTLS: datagram TLS for Python. Written by Ray Brown. """PyDTLS package -This package exports OpenSSL's DTLS support to Python. Importing it will add -the constant PROTOCOL_DTLSv1 to the Python standard library's ssl module. -Passing a datagram socket to that module's wrap_socket function (or -instantiating its SSLSocket class with a datagram socket) will activate this -module's DTLS implementation for the returned SSLSocket instance. +This package exports OpenSSL's DTLS support to Python. Calling its patch +function will add the constant PROTOCOL_DTLSv1 to the Python standard library's +ssl module. Subsequently passing a datagram socket to that module's +wrap_socket function (or instantiating its SSLSocket class with a datagram +socket) will activate this module's DTLS implementation for the returned +SSLSocket instance. + +Instead of or in addition to invoking the patch functionality, the +SSLConnection class can be used directly for secure communication over datagram +sockets. wrap_socket's parameters and their semantics have been maintained. """ + +from patch import do_patch +from sslconnection import SSLConnection diff --git a/dtls/err.py b/dtls/err.py index 2e3fcb5..9dbe563 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -25,7 +25,9 @@ SSL_ERROR_WANT_ACCEPT = 8 ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 - +ERR_NO_CIPHER = 501 +ERR_HANDSHAKE_TIMEOUT = 502 +ERR_PORT_UNREACHABLE = 503 ERR_COOKIE_MISMATCH = 0x1408A134 @@ -35,27 +37,48 @@ class SSLError(socket_error): super(SSLError, self).__init__(*args) -class OpenSSLError(SSLError): - """This exception is raised when an error occurs in the OpenSSL library""" - def __init__(self, ssl_error, errqueue, result, func, args): - self.ssl_error = ssl_error - self.errqueue = errqueue - self.result = result - self.func = func - self.args = args - super(OpenSSLError, self).__init__(ssl_error, errqueue, - result, func, args) - - class InvalidSocketError(Exception): """There is a problem with a socket passed to the dtls package.""" def __init__(self, *args): super(InvalidSocketError, self).__init__(*args) -def raise_ssl_error(code): +def _make_opensslerror_class(): + global _OpenSSLError + class __OpenSSLError(SSLError): + """ + This exception is raised when an error occurs in the OpenSSL library + """ + def __init__(self, ssl_error, errqueue, result, func, args): + self.ssl_error = ssl_error + self.errqueue = errqueue + self.result = result + self.func = func + self.args = args + SSLError.__init__(self, ssl_error, errqueue, + result, func, args) + + _OpenSSLError = __OpenSSLError + +_make_opensslerror_class() + +def openssl_error(): + """Return the OpenSSL error type for use in exception clauses""" + return _OpenSSLError + +def raise_as_ssl_module_error(): + """Exceptions raised from this module are instances of ssl.SSLError""" + import ssl + global SSLError + SSLError = ssl.SSLError + _make_opensslerror_class() + +def raise_ssl_error(code, nested=None): """Raise an SSL error with the given error code""" - raise SSLError(str(code) + ": " + _ssl_errors[code]) + err_string = str(code) + ": " + _ssl_errors[code] + if nested: + raise SSLError(err_string, nested) + raise SSLError(err_string) _ssl_errors = { ERR_NO_CERTS: "No root certificates specified for verification " + \ @@ -63,5 +86,8 @@ _ssl_errors = { ERR_BOTH_KEY_CERT_FILES: "Both the key & certificate files " + \ "must be specified", ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \ - "specified for server-side operation" + "specified for server-side operation", + ERR_NO_CIPHER: "No cipher can be selected.", + ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out", + ERR_PORT_UNREACHABLE: "The peer address is not reachable", } diff --git a/dtls/openssl.py b/dtls/openssl.py index 222d199..455bed4 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -20,7 +20,8 @@ import array import socket from logging import getLogger from os import path -from err import OpenSSLError +from datetime import timedelta +from err import openssl_error from err import SSL_ERROR_NONE from util import _BIO import ctypes @@ -64,6 +65,7 @@ else: # BIO_NOCLOSE = 0x00 BIO_CLOSE = 0x01 +SSLEAY_VERSION = 0 SSL_VERIFY_NONE = 0x00 SSL_VERIFY_PEER = 0x01 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 @@ -91,6 +93,8 @@ BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 BIO_CTRL_DGRAM_SET_PEER = 44 BIO_C_SET_NBIO = 102 +DTLS_CTRL_GET_TIMEOUT = 73 +DTLS_CTRL_HANDLE_TIMEOUT = 74 DTLS_CTRL_LISTEN = 75 X509_NAME_MAXLEN = 256 GETS_MAXLEN = 2048 @@ -248,6 +252,11 @@ class X509V3_EXT_METHOD(Structure): ("i2d", c_int)] # remaining fields omitted +class TIMEVAL(Structure): + _fields_ = [("tv_sec", c_long), + ("tv_usec", c_long)] + + # # Socket address conversions # @@ -366,7 +375,7 @@ def raise_ssl_error(result, func, args, ssl): _logger.debug("SSL error raised: ssl_error: %d, result: %d, " + "errqueue: %s, func_name: %s", ssl_error, result, errqueue, func.func_name) - raise OpenSSLError(ssl_error, errqueue, result, func, args) + raise openssl_error()(ssl_error, errqueue, result, func, args) def find_ssl_arg(args): for arg in args: @@ -424,6 +433,7 @@ def _make_function(name, lib, args, export=True, errcheck="default"): _subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", + "SSLEAY_VERSION", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", @@ -432,6 +442,7 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", "SSL_FILE_TYPE_PEM", "GEN_DIRNAME", "NID_subject_alt_name", + "DTLSv1_get_timeout", "DTLSv1_handle_timeout", "DTLSv1_listen", "BIO_gets", "BIO_read", "BIO_get_mem_data", "BIO_dgram_set_connected", @@ -450,6 +461,8 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", map(lambda x: _make_function(*x), ( ("SSL_library_init", libssl, ((c_int, "ret"),)), ("SSL_load_error_strings", libssl, ((None, "ret"),)), + ("SSLeay", libcrypto, ((c_long_parm, "ret"),)), + ("SSLeay_version", libcrypto, ((c_char_p, "ret"), (c_int, "t"))), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), @@ -514,6 +527,7 @@ map(lambda x: _make_function(*x), ( ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), ("SSL_write", libssl, ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), + ("SSL_pending", libssl, ((c_int, "ret"), (SSL, "ssl")), True, None), ("SSL_shutdown", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_set_read_ahead", libssl, ((None, "ret"), (SSL, "ssl"), (c_int, "yes"))), @@ -630,6 +644,28 @@ def BIO_dgram_set_peer(bio, peer_address): def BIO_set_nbio(bio, n): _BIO_ctrl(bio, BIO_C_SET_NBIO, 1 if n else 0, None) +def DTLSv1_get_timeout(ssl): + tv = TIMEVAL() + ret = _SSL_ctrl(ssl, DTLS_CTRL_GET_TIMEOUT, 0, byref(tv)) + if ret != 1: + return + return timedelta(seconds=tv.tv_sec, microseconds=tv.tv_usec) + +def DTLSv1_handle_timeout(ssl): + ret = _SSL_ctrl(ssl, DTLS_CTRL_HANDLE_TIMEOUT, 0, None) + if ret == 0: + # It was too early to call: no timer had yet expired + return False + if ret == 1: + # Buffered messages were retransmitted + return True + # There was an error: either too many timeouts have occurred or a + # retransmission failed + assert ret < 0 + if ret > 0: + ret = -10 + errcheck_p(ret, _SSL_ctrl, (ssl, DTLS_CTRL_HANDLE_TIMEOUT, 0, None)) + def DTLSv1_listen(ssl): su = sockaddr_u() ret = _SSL_ctrl(ssl, DTLS_CTRL_LISTEN, 0, byref(su)) @@ -642,7 +678,10 @@ def SSL_read(ssl, length): return buf.raw[:res_len] def SSL_write(ssl, data): - str_data = str(data) + if hasattr(data, "tobytes") and callable(data.tobytes): + str_data = data.tobytes() + else: + str_data = str(data) return _SSL_write(ssl, str_data, len(str_data)) def OBJ_obj2txt(asn1_object, no_name): diff --git a/dtls/patch.py b/dtls/patch.py new file mode 100644 index 0000000..910bdf4 --- /dev/null +++ b/dtls/patch.py @@ -0,0 +1,199 @@ +# Patch: patching of the Python stadard library's ssl module for transparent +# use of datagram sockets. Written by Ray Brown. +"""Patch + +This module is used to patch the Python standard library's ssl module. Patching +has the following effects: + + * The constant PROTOCOL_DTLSv1 is added at ssl module level + * DTLSv1's protocol name is added to the ssl module's id-to-name dictionary + * The constants DTLS_OPENSSL_VERSION* are added at the ssl module level + * Instntiation of ssl.SSLSocket with sock.type == socket.SOCK_DGRAM is + supported and leads to substitution of this module's DTLS code paths for + that SSLSocket instance + * Direct instantiation of SSLSocket as well as instantiation through + ssl.wrap_socket are supported + * Invocation of the function get_server_certificate with a value of + PROTOCOL_DTLSv1 for the parameter ssl_version is supported +""" + +from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error +from socket import AF_INET, SOCK_DGRAM +from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE +from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION +from sslconnection import DTLS_OPENSSL_VERSION_INFO +from err import raise_as_ssl_module_error +from types import MethodType +from weakref import proxy +import errno + +def do_patch(): + import ssl as _ssl # import to be avoided if ssl module is never patched + global _orig_SSLSocket_init, _orig_get_server_certificate + global ssl + ssl = _ssl + ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 + ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" + ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER + ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION + ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO + _orig_SSLSocket_init = ssl.SSLSocket.__init__ + _orig_get_server_certificate = ssl.get_server_certificate + ssl.SSLSocket.__init__ = _SSLSocket_init + ssl.get_server_certificate = _get_server_certificate + raise_as_ssl_module_error() + +PROTOCOL_SSLv3 = 1 +PROTOCOL_SSLv23 = 2 + +def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): + """Retrieve a server certificate + + Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt. + """ + + if ssl_version != PROTOCOL_DTLSv1: + return _orig_get_server_certificate(addr, ssl_version, ca_certs) + + host, port = addr + if (ca_certs is not None): + cert_reqs = ssl.CERT_REQUIRED + else: + cert_reqs = ssl.CERT_NONE + s = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), + ssl_version=ssl_version, + cert_reqs=cert_reqs, ca_certs=ca_certs) + s.connect(addr) + dercert = s.getpeercert(True) + s.close() + return ssl.DER_cert_to_PEM_cert(dercert) + +def _SSLSocket_init(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + is_connection = is_datagram = False + if isinstance(sock, SSLConnection): + is_connection = True + elif hasattr(sock, "type") and sock.type == SOCK_DGRAM: + is_datagram = True + if not is_connection and not is_datagram: + # Non-DTLS code path + return _orig_SSLSocket_init(self, sock, keyfile, certfile, + server_side, cert_reqs, + ssl_version, ca_certs, + do_handshake_on_connect, + suppress_ragged_eofs, ciphers) + # DTLS code paths: datagram socket and newly accepted DTLS connection + if is_datagram: + socket.__init__(self, _sock=sock._sock) + else: + socket.__init__(self, _sock=sock.get_socket(True)._sock) + # Copy instance initialization from SSLSocket class + for attr in _delegate_methods: + try: + delattr(self, attr) + except AttributeError: + pass + + if certfile and not keyfile: + keyfile = certfile + if is_datagram: + # see if it's connected + try: + socket.getpeername(self) + except socket_error, e: + if e.errno != errno.ENOTCONN: + raise + # no, no connection yet + self._connected = False + self._sslobj = None + else: + # yes, create the SSL object + self._connected = True + self._sslobj = SSLConnection(sock, keyfile, certfile, + server_side, cert_reqs, + ssl_version, ca_certs, + do_handshake_on_connect, + suppress_ragged_eofs, ciphers) + else: + self._sslobj = sock + + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 + + # Perform method substitution and addition (without reference cycle) + self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) + self.listen = MethodType(_SSLSocket_listen, proxy(self)) + self.accept = MethodType(_SSLSocket_accept, proxy(self)) + self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) + self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) + +def _SSLSocket_listen(self, ignored): + if self._connected: + raise ValueError("attempt to listen on connected SSLSocket!") + if self._sslobj: + return + self._sslobj = SSLConnection(socket(_sock=self._sock), + self.keyfile, self.certfile, True, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + +def _SSLSocket_accept(self): + if self._connected: + raise ValueError("attempt to accept on connected SSLSocket!") + if not self._sslobj: + raise ValueError("attempt to accept on SSLSocket prior to listen!") + acc_ret = self._sslobj.accept() + if not acc_ret: + return + new_conn, addr = acc_ret + new_ssl_sock = ssl.SSLSocket(new_conn, self.keyfile, self.certfile, True, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + return new_ssl_sock, addr + +def _SSLSocket_real_connect(self, addr, return_errno): + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._sslobj = SSLConnection(socket(_sock=self._sock), + self.keyfile, self.certfile, False, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + try: + self._sslobj.connect(addr) + except socket_error as e: + if return_errno: + return e.errno + else: + self._sslobj = None + raise e + self._connected = True + return 0 + + +if __name__ == "__main__": + do_patch() + +def _SSLSocket_get_timeout(self): + return self._sslobj.get_timeout() + +def _SSLSocket_handle_timeout(self): + return self._sslobj.handle_timeout() diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index a0947da..8c17de0 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -29,9 +29,11 @@ import hmac from logging import getLogger from os import urandom from weakref import proxy -from err import OpenSSLError, InvalidSocketError +from err import openssl_error, InvalidSocketError from err import raise_ssl_error -from err import SSL_ERROR_WANT_READ, ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL +from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from x509 import _X509, decode_cert from openssl import * from util import _Rsrc, _BIO @@ -48,6 +50,14 @@ CERT_REQUIRED = 2 # SSL_library_init() SSL_load_error_strings() +DTLS_OPENSSL_VERSION_NUMBER = SSLeay() +DTLS_OPENSSL_VERSION = SSLeay_version(SSLEAY_VERSION) +DTLS_OPENSSL_VERSION_INFO = ( + DTLS_OPENSSL_VERSION_NUMBER >> 28 & 0xFF, # major + DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor + DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix + DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch + DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status class _CTX(_Rsrc): @@ -98,9 +108,11 @@ class SSLConnection(object): _rnd_key = urandom(16) - def _init_server(self): + def _init_server(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") + if peer_address: + raise InvalidSocketError("server-side socket must be unconnected") from demux import UDPDemux self._udp_demux = UDPDemux(self._sock) @@ -127,7 +139,7 @@ class SSLConnection(object): self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(self._ssl.value) - def _init_client(self): + def _init_client(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") @@ -141,6 +153,8 @@ class SSLConnection(object): self._config_ssl_ctx(verify_mode) self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_connect_state(self._ssl.value) + if peer_address: + return lambda: self.connect(peer_address) def _config_ssl_ctx(self, verify_mode): SSL_CTX_set_verify(self._ctx.value, verify_mode) @@ -153,7 +167,10 @@ class SSLConnection(object): if self._ca_certs: SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None) if self._ciphers: - SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) + try: + SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) + except openssl_error() as err: + raise_ssl_error(ERR_NO_CIPHER, err) def _copy_server(self): source = self._sock @@ -171,6 +188,7 @@ class SSLConnection(object): new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), BIO_NOCLOSE)) source._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_accept_state(source._ssl.value) source._rbio = new_source_rbio source._wbio = new_source_wbio SSL_set_bio(source._ssl.value, @@ -179,6 +197,20 @@ class SSLConnection(object): new_source_rbio.disown() new_source_wbio.disown() + def _reconnect_unwrapped(self): + source = self._sock + self._sock = source._wsock + self._udp_demux = source._demux + self._rsock = source._rsock + self._ctx = source._ctx + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + BIO_dgram_set_peer(self._wbio.value, source._peer_address) + self._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_accept_state(self._ssl.value) + if self._do_handshake_on_connect: + return lambda: self.do_handshake() + def _check_nbio(self): BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) if self._wbio is not self._rbio: @@ -234,15 +266,24 @@ class SSLConnection(object): self._handshake_done = False if isinstance(sock, SSLConnection): - self._copy_server() - elif server_side: - self._init_server() + post_init = self._copy_server() + elif isinstance(sock, _UnwrappedSocket): + post_init = self._reconnect_unwrapped() else: - self._init_client() + try: + peer_address = sock.getpeername() + except socket.error: + peer_address = None + if server_side: + post_init = self._init_server(peer_address) + else: + post_init = self._init_client(peer_address) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) self._rbio.disown() self._wbio.disown() + if post_init: + post_init() def get_socket(self, inbound): """Retrieve a socket used by this connection @@ -305,7 +346,7 @@ class SSLConnection(object): _logger.debug("Invoking DTLSv1_listen for ssl: %d", self._ssl.value._as_parameter) dtls_peer_address = DTLSv1_listen(self._ssl.value) - except OpenSSLError as err: + except openssl_error() as err: if err.ssl_error == SSL_ERROR_WANT_READ: # This method must be called again to forward the next datagram _logger.debug("DTLSv1_listen must be resumed") @@ -345,6 +386,7 @@ class SSLConnection(object): self._cert_reqs, PROTOCOL_DTLSv1, self._ca_certs, self._do_handshake_on_connect, self._suppress_ragged_eofs, self._ciphers) + new_peer = self._pending_peer_address self._pending_peer_address = None if self._do_handshake_on_connect: # Note that since that connection's socket was just created in its @@ -354,7 +396,7 @@ class SSLConnection(object): # will hang in this call new_conn.do_handshake() _logger.debug("Accept returning new connection for new peer") - return new_conn + return new_conn, new_peer def connect(self, peer_address): """Client-side UDP connection establishment @@ -368,6 +410,7 @@ class SSLConnection(object): """ self._sock.connect(peer_address) + peer_address = self._sock.getpeername() # substituted host addrinfo BIO_dgram_set_connected(self._wbio.value, peer_address) assert self._wbio is self._rbio if self._do_handshake_on_connect: @@ -382,7 +425,15 @@ class SSLConnection(object): _logger.debug("Initiating handshake...") self._check_nbio() - SSL_do_handshake(self._ssl.value) + try: + SSL_do_handshake(self._ssl.value) + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_WANT_READ and \ + self.get_socket(True).gettimeout(): + raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err) + elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + raise_ssl_error(ERR_PORT_UNREACHABLE, err) + raise self._handshake_done = True _logger.debug("...completed handshake") @@ -423,19 +474,30 @@ class SSLConnection(object): it no longer raises continuation request exceptions. """ + if hasattr(self, "_listening"): + # Listening server-side sockets cannot be shut down + return + self._check_nbio() try: SSL_shutdown(self._ssl.value) - except OpenSSLError as err: + except openssl_error() as err: if err.result == 0: # close-notify alert was just sent; wait for same from peer # Note: while it might seem wise to suppress further read-aheads # with SSL_set_read_ahead here, doing so causes a shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown - # initiator side. + # initiator side. And test_starttls does pass. SSL_shutdown(self._ssl.value) else: raise + if hasattr(self, "_udp_demux"): + # Return wrapped connected server socket (non-listening) + return _UnwrappedSocket(self._sock, self._rsock, self._udp_demux, + self._ctx, + BIO_dgram_get_peer(self._wbio.value)) + # Return unwrapped client-side socket + return self._sock def getpeercert(self, binary_form=False): """Retrieve the peer's certificate @@ -453,7 +515,7 @@ class SSLConnection(object): try: peer_cert = _X509(SSL_get_peer_certificate(self._ssl.value)) - except OpenSSLError: + except openssl_error(): return if binary_form: @@ -462,6 +524,8 @@ class SSLConnection(object): return {} return decode_cert(peer_cert) + peer_certificate = getpeercert # compatibility with _ssl call interface + def cipher(self): """Retrieve information about the current cipher @@ -478,3 +542,95 @@ class SSLConnection(object): cipher_version = SSL_CIPHER_get_version(current_cipher) cipher_bits = SSL_CIPHER_get_bits(current_cipher) return cipher_name, cipher_version, cipher_bits + + def pending(self): + """Retrieve number of buffered bytes + + Return the number of bytes that have been read from the socket and + buffered by this connection. Return 0 if no bytes have been buffered. + """ + + return SSL_pending(self._ssl.value) + + def get_timeout(self): + """Retrieve the retransmission timedelta + + Since datagrams are subject to packet loss, DTLS will perform + packet retransmission if a response is not received after a certain + time interval during the handshaking phase. When using non-blocking + sockets, the application must call back after that time interval to + allow for the retransmission to occur. This method returns the + timedelta after which to perform the call to handle_timeout, or None + if no such callback is needed given the current handshake state. + """ + + return DTLSv1_get_timeout(self._ssl.value) + + def handle_timeout(self): + """Perform datagram retransmission, if required + + This method should be called after the timedelta retrieved from + get_timeout has expired, and no datagrams were received in the + meantime. If datagrams were received, a new timeout needs to be + requested. + + Return value: + True -- retransmissions were performed successfully + False -- a timeout was not in effect or had not yet expired + + Exceptions: + Raised when retransmissions fail or too many timeouts occur. + """ + + return DTLSv1_handle_timeout(self._ssl.value) + + +class _UnwrappedSocket(socket.socket): + """Unwrapped server-side socket + + Depending on UDP demux implementation, there may not be single socket + that can be used for both reading and writing to the client socket with + which it is associated. An object of this type is therefore returned from + the SSLSocket's unwrap method to allow for unencrypted communication over + the established channels, including the demux. + """ + + def __init__(self, wsock, rsock, demux, ctx, peer_address): + socket.socket.__init__(self, _sock=rsock._sock) + for attr in "send", "sendto", "sendall": + try: + delattr(self, attr) + except AttributeError: + pass + self._wsock = wsock + self._rsock = rsock # continue to reference to hold in demux map + self._demux = demux + self._ctx = ctx + self._peer_address = peer_address + + def send(self, data, flags=0): + __doc__ = self._wsock.send.__doc__ + return self._wsock.sendto(data, flags, self._peer_address) + + def sendto(self, data, flags_or_addr, addr=None): + __doc__ = self._wsock.sendto.__doc__ + return self._wsock.sendto(data, flags_or_addr, addr) + + def sendall(self, data, flags=0): + __doc__ = self._wsock.sendall.__doc__ + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:], flags) + count += v + return amount + + def getpeername(self): + __doc__ = self._wsock.getpeername.__doc__ + return self._peer_address + + def connect(self, addr): + __doc__ = self._wsock.connect.__doc__ + raise ValueError("Cannot connect already connected unwrapped socket") + + connect_ex = connect diff --git a/dtls/test/certs/badcert.pem b/dtls/test/certs/badcert.pem new file mode 100644 index 0000000..c419146 --- /dev/null +++ b/dtls/test/certs/badcert.pem @@ -0,0 +1,36 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L +opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH +fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB +AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU +D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA +IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM +oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0 +ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/ +loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j +oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA +z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq +ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV +q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +Just bad cert data +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L +opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH +fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB +AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU +D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA +IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM +oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0 +ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/ +loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j +oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA +z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq +ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV +q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +Just bad cert data +-----END CERTIFICATE----- diff --git a/dtls/test/certs/badkey.pem b/dtls/test/certs/badkey.pem new file mode 100644 index 0000000..1c8a955 --- /dev/null +++ b/dtls/test/certs/badkey.pem @@ -0,0 +1,40 @@ +-----BEGIN RSA PRIVATE KEY----- +Bad Key, though the cert should be OK +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD +VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x +IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT +U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1 +NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl +bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj +aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh +m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8 +M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn +fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC +AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb +08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx +CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/ +iHkC6gGdBJhogs4= +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +Bad Key, though the cert should be OK +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD +VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x +IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT +U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1 +NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl +bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj +aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh +m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8 +M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn +fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC +AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb +08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx +CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/ +iHkC6gGdBJhogs4= +-----END CERTIFICATE----- diff --git a/dtls/test/certs/keycert.pem b/dtls/test/certs/keycert.pem new file mode 100644 index 0000000..05ee34c --- /dev/null +++ b/dtls/test/certs/keycert.pem @@ -0,0 +1,21 @@ +-----BEGIN PRIVATE KEY----- +MIIBVAIBADANBgkqhkiG9w0BAQEFAASCAT4wggE6AgEAAkEAuPd3JmydJfXhyii0 +agsVgRMOUcOyuldbaf/Lu4bZ+U0zH0OSoYkv0Ahbz7ehK+oGMeUy/SuGVAn7JLyj +zlYi8QIDAQABAkAygtnV82lC2Y/Mbis+nkJEGlkZuRCQ1JRRMRqI3n2eF6CviqF3 +PiBXIEEExzKihC9bvbHKTAkYDLr+/4YpbiQBAiEA7JLS5Lp7KI/ayWwEzl2r5XXu +k/cbH++A4zZz6A9XIsECIQDIJ8ciDa5/VGyQnYMzBNgKnwaFDDBOiEUFDaU/9ZN8 +MQIgCG3Gw819G9ncQrbtiOi/eiJ0iKMSPVYMMow7HvaE9UECIQCLyQwPwlJd5s4z +aW4ZkYZ4VHuvK8YI8q6RSuhf9Nhd4QIgFbRNdEeehgrzGzGug2yVCMzVzS3MQNBJ +6LqBZaPlFsM= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIBgDCCASoCAQEwDQYJKoZIhvcNAQEEBQAwSjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCldhc2hpbmd0b24xEzARBgNVBAoTClJheSBDQSBJbmMxETAPBgNVBAMTCFJh +eUNBSW5jMB4XDTEyMDkyMTIxMTYxOFoXDTEzMDkyMTIxMTYxOFowTDELMAkGA1UE +BhMCVVMxEzARBgNVBAgTCldhc2hpbmd0b24xFDASBgNVBAoTC1JheSBTcnYgSW5j +MRIwEAYDVQQDEwlSYXlTcnZJbmMwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAuPd3 +JmydJfXhyii0agsVgRMOUcOyuldbaf/Lu4bZ+U0zH0OSoYkv0Ahbz7ehK+oGMeUy +/SuGVAn7JLyjzlYi8QIDAQABMA0GCSqGSIb3DQEBBAUAA0EAEkxVF8HEGV8N4mYA +hDciYpttnnb9pYL1okHGrhaIFqu9D10LfP1SKps/6s/qNSk3YaIVjydWOHEf6xr4 +zJkiFw== +-----END CERTIFICATE----- diff --git a/dtls/test/certs/nullcert.pem b/dtls/test/certs/nullcert.pem new file mode 100644 index 0000000..e69de29 diff --git a/dtls/test/certs/wrongcert.pem b/dtls/test/certs/wrongcert.pem new file mode 100644 index 0000000..5f92f9b --- /dev/null +++ b/dtls/test/certs/wrongcert.pem @@ -0,0 +1,32 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH +FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T +f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB +AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq +1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW +7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg +SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe +19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg +ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666 +lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs +ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv +frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk +2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6 ++bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK +24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G +A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9 +OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt +U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ +fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E +usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+ +43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw +eSHj5jpC8iZKjCHBn+mAi4cQ514= +-----END CERTIFICATE----- diff --git a/dtls/test/echo_seq.py b/dtls/test/echo_seq.py index 99bb9cb..7a3bf3b 100644 --- a/dtls/test/echo_seq.py +++ b/dtls/test/echo_seq.py @@ -45,7 +45,7 @@ def main(): break print "Accepting..." - conn = scn.accept() + conn = scn.accept()[0] sck.settimeout(5) conn.get_socket(True).settimeout(5) @@ -59,7 +59,7 @@ def main(): try: conn.do_handshake() except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ: continue raise print "Completed handshaking with peer" diff --git a/dtls/test/unit.py b/dtls/test/unit.py new file mode 100644 index 0000000..c8d7c96 --- /dev/null +++ b/dtls/test/unit.py @@ -0,0 +1,1313 @@ +# Test the support for DTLS through the SSL module. Adapted from the Python +# standard library's test_ssl.py regression test module by Ray Brown. + +import sys +import unittest +import asyncore +import socket +import select +import gc +import os +import errno +import pprint +import urllib, urlparse +import traceback +import weakref +import platform +import threading +import datetime +import SocketServer +from SimpleHTTPServer import SimpleHTTPRequestHandler + +import ssl +from dtls import do_patch + +HOST = "localhost" + +class TestSupport(object): + verbose = True + + class Ctx(object): + def __enter__(self): + self.server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + self.server.start(flag) + flag.wait() + return self.server.sockname + + def __exit__(self, exc_type, exc_value, traceback): + self.server.stop() + self.server.join() + self.server = None + + def transient_internet(self): + return self.Ctx() + +test_support = TestSupport() + +def handle_error(prefix): + exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) + if test_support.verbose: + sys.stdout.write(prefix + exc_format) + + +class BasicTests(unittest.TestCase): + + def test_sslwrap_simple(self): + # A crude test for the legacy API + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)._sock) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + + +class BasicSocketTests(unittest.TestCase): + + def test_constants(self): + ssl.PROTOCOL_SSLv23 + ssl.PROTOCOL_SSLv3 + ssl.PROTOCOL_TLSv1 + ssl.PROTOCOL_DTLSv1 # added + ssl.CERT_NONE + ssl.CERT_OPTIONAL + ssl.CERT_REQUIRED + + def test_dtls_openssl_version(self): + n = ssl.DTLS_OPENSSL_VERSION_NUMBER + t = ssl.DTLS_OPENSSL_VERSION_INFO + s = ssl.DTLS_OPENSSL_VERSION + self.assertIsInstance(n, (int, long)) + self.assertIsInstance(t, tuple) + self.assertIsInstance(s, str) + # Some sanity checks follow + # >= 1.0 + self.assertGreaterEqual(n, 0x10000000) + # < 2.0 + self.assertLess(n, 0x20000000) + major, minor, fix, patch, status = t + self.assertGreaterEqual(major, 1) + self.assertLess(major, 2) + self.assertGreaterEqual(minor, 0) + self.assertLess(minor, 256) + self.assertGreaterEqual(fix, 0) + self.assertLess(fix, 256) + self.assertGreaterEqual(patch, 0) + self.assertLessEqual(patch, 26) + self.assertGreaterEqual(status, 0) + self.assertLessEqual(status, 15) + # Version string as returned by OpenSSL, the format might change + self.assertTrue( + s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), + (s, t)) + + def test_ciphers(self): + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + flag.wait() + remote = (HOST, server.port) + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, ciphers="ALL") + s.connect(remote) + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") + s.connect(remote) + # Error checking occurs when connecting, because the SSL context + # isn't created before. + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, + ciphers="^$:,;?*'dorothyx") + with self.assertRaisesRegexp(ssl.SSLError, + "No cipher can be selected"): + s.connect(remote) + finally: + server.stop() + server.join() + # repeat with AF_INET6? + + @unittest.skipIf(platform.python_implementation() != "CPython", + "Reference cycle test feasible under CPython only") + def test_refcycle(self): + # Issue #7943: an SSL object doesn't create reference cycles with + # itself. + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ss = ssl.wrap_socket(s) + wr = weakref.ref(ss) + del ss + self.assertEqual(wr(), None) + + def test_wrapped_unconnected(self): + # The _delegate_methods in socket.py are correctly delegated to by an + # unconnected SSLSocket, so they will raise a socket.error rather than + # something unexpected like TypeError. + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ss = ssl.wrap_socket(s) + self.assertRaises(socket.error, ss.recv, 1) + self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) + self.assertRaises(socket.error, ss.recvfrom, 1) + self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(socket.error, ss.send, b'x') + self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + + +class NetworkedTests(unittest.TestCase): + + def test_connect(self): + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE) + s.connect(remote) + c = s.getpeercert() + if c: + self.fail("Peer cert %s shouldn't be here!") + s.close() + + # this should fail because we have no verification certs + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED) + try: + s.connect(remote) + except ssl.SSLError: + pass + finally: + s.close() + + # this should succeed because we specify the root cert + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE) + try: + s.connect(remote) + finally: + s.close() + + def test_connect_ex(self): + # Issue #11326: check connect_ex() implementation + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE) + try: + self.assertEqual(0, s.connect_ex(remote)) + self.assertTrue(s.getpeercert()) + finally: + s.close() + + def test_non_blocking_connect_ex(self): + # Issue #11326: non-blocking connect_ex() should allow handshake + # to proceed after the socket gets ready. + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE, + do_handshake_on_connect=False) + try: + s.setblocking(False) + rc = s.connect_ex(remote) + # EWOULDBLOCK under Windows, EINPROGRESS elsewhere + self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK)) + # Non-blocking handshake + while True: + try: + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + while True: + to = s.get_timeout() + to = to.total_seconds() if to else 5.0 + sel = select.select([s], [], [], to) + if sel[0]: + break + s.handle_timeout() + else: + raise + # SSL established + self.assertTrue(s.getpeercert()) + finally: + s.close() + + @unittest.skipIf(os.name == "nt", + "Can't use a socket as a file under Windows") + def test_makefile_close(self): + # Issue #5238: creating a file-like object with makefile() shouldn't + # delay closing the underlying "real socket" (here tested with its + # file descriptor, hence skipping the test under Windows). + with test_support.transient_internet() as remote: + ss = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + ss.connect(remote) + fd = ss.fileno() + f = ss.makefile() + f.close() + # The fd is still open + os.read(fd, 0) + # Closing the SSL socket should close the fd too + ss.close() + gc.collect() + with self.assertRaises(OSError) as e: + os.read(fd, 0) + self.assertEqual(e.exception.errno, errno.EBADF) + + def test_non_blocking_handshake(self): + with test_support.transient_internet() as remote: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(remote) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + while True: + to = s.get_timeout() + if to: + sel = select.select([s], [], [], + to.total_seconds()) + if sel[0]: + break + s.handle_timeout() + continue + select.select([s], [], []) + break + else: + raise + s.close() + if test_support.verbose: + sys.stdout.write(("\nNeeded %d calls to do_handshake() " + + "to establish session.\n") % count) + + def test_get_server_certificate(self): + with test_support.transient_internet() as remote: + pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) + if not pem: + self.fail("No server certificate!") + + try: + pem = ssl.get_server_certificate(remote, + ssl.PROTOCOL_DTLSv1, + ca_certs=OTHER_CERTFILE) + except ssl.SSLError: + #should fail + pass + else: + self.fail("Got server certificate %s!" % pem) + + pem = ssl.get_server_certificate(remote, + ssl.PROTOCOL_DTLSv1, + ca_certs=ISSUER_CERTFILE) + if not pem: + self.fail("No server certificate!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate is\n%s\n" % pem) + +class ThreadedEchoServer(threading.Thread): + + class ConnectionHandler(threading.Thread): + + """A mildly complicated class, because we want it to work both + with and without the SSL wrapper around the socket connection, so + that we can test the STARTTLS functionality.""" + + def __init__(self, server, connsock): + self.server = server + self.running = False + self.sock = connsock + self.sock.setblocking(True) + self.sslconn = connsock + threading.Thread.__init__(self) + self.daemon = True + + def show_conn_details(self): + if self.server.certreqs == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" client cert is " + + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if test_support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + + str(cipher) + "\n") + + def wrap_conn(self): + try: + self.sslconn = ssl.wrap_socket( + self.sock, server_side=True, + certfile=self.server.certificate, + ssl_version=self.server.protocol, + ca_certs=self.server.cacerts, + cert_reqs=self.server.certreqs, + ciphers=self.server.ciphers) + except ssl.SSLError: + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more + # discriminating. + if self.server.chatty: + handle_error("\n server: bad connection attempt " + + "from " + + str(self.sock.getpeername()) + ":\n") + self.close() + self.running = False + self.server.stop() + return False + else: + return True + + def read(self): + if self.sslconn: + return self.sslconn.read() + else: + return self.sock.recv(1024) + + def write(self, bytes): + if self.sslconn: + return self.sslconn.write(bytes) + else: + return self.sock.send(bytes) + + def close(self): + if self.sslconn: + self.sslconn.close() + else: + self.sock._sock.close() + + def run(self): + self.running = True + # Complete the handshake + try: + self.sock.do_handshake() + except ssl.SSLError: + if self.server.chatty: + handle_error("\n server: failed to handshake with " + + str(self.sock.getpeername()) + ":\n") + self.close() + self.running = False + self.server.stop() + return + if self.server.starttls_server: + self.sock = self.sock.unwrap() + self.sslconn = None + else: + self.show_conn_details() + while self.running: + try: + msg = self.read() + if not msg: + # eof, so quit this handler + self.running = False + self.close() + elif msg.strip() == 'over': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: client closed " + + "connection\n") + self.close() + return + elif self.server.starttls_server and not self.sslconn \ + and msg.strip() == 'STARTTLS': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: read STARTTLS " + + "from client, sending OK...\n") + self.write("OK\n") + if not self.wrap_conn(): + return + elif self.server.starttls_server and self.sslconn and \ + msg.strip() == 'ENDTLS': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from " + + "client, sending OK...\n") + self.write("OK\n") + self.sslconn.unwrap() + self.sslconn = None + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: connection is now " + + "unencrypted...\n") + else: + if test_support.verbose and \ + self.server.connectionchatty: + ctype = (self.sslconn and "encrypted") or \ + "unencrypted" + sys.stdout.write((" server: read %s (%s), " + + "sending back %s (%s)...\n") + % (repr(msg), ctype, + repr(msg.lower()), ctype)) + self.write(msg.lower()) + except ssl.SSLError: + if self.server.chatty: + handle_error("Test server failure:\n") + self.close() + self.running = False + # normally, we'd just stop here, but for the test + # harness, we want to stop the server + self.server.stop() + + def __init__(self, certificate, ssl_version=None, + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + ciphers=None): + + if ssl_version is None: + ssl_version = ssl.PROTOCOL_DTLSv1 + if certreqs is None: + certreqs = ssl.CERT_NONE + self.certificate = certificate + self.protocol = ssl_version + self.certreqs = certreqs + self.cacerts = cacerts + self.ciphers = ciphers + self.chatty = chatty + self.connectionchatty = connectionchatty + self.starttls_server = starttls_server + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.flag = None + self.sock = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.certificate, + cert_reqs=self.certreqs, + ca_certs=self.cacerts, + ssl_version=self.protocol, + do_handshake_on_connect=False, + ciphers=self.ciphers) + if test_support.verbose and self.chatty: + sys.stdout.write(' server: wrapped server ' + + 'socket as %s\n' % str(self.sock)) + self.sock.bind((HOST, 0)) + self.port = self.sock.getsockname()[1] + self.active = False + threading.Thread.__init__(self) + self.daemon = True + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen(5) + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + acc_ret = self.sock.accept() + if acc_ret: + newconn, connaddr = acc_ret + if test_support.verbose and self.chatty: + sys.stdout.write(' server: new connection from ' + + str(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn) + handler.start() + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() + + def stop(self): + self.active = False + +class AsyncoreEchoServer(threading.Thread): + + class EchoServer(asyncore.dispatcher): + + class ConnectionHandler(asyncore.dispatcher): + + def __init__(self, conn, timeout_tracker): + asyncore.dispatcher.__init__(self, conn) + self._timeout_tracker = timeout_tracker + self._ssl_accepting = True + # Complete the handshake + self.handle_read_event() + + def readable(self): + while self.socket.pending() > 0: + self.handle_read_event() + if self._timeout_tracker.has_key(self) and \ + datetime.datetime.now() >= self._timeout_tracker[self]: + self._timeout_tracker.pop(self) + try: + self.socket.handle_timeout() + except: + self.handle_close() + return False + return True + + def writable(self): + return False + + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError, err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SSL): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + except socket.error, err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self._ssl_accepting = False + + def handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + else: + data = self.recv(1024) + if data and data.strip() != 'over': + self.send(data.lower()) + delta = self.socket.get_timeout() + if delta: + self._timeout_tracker[self] = \ + datetime.datetime.now() + delta + + def handle_close(self): + if self._timeout_tracker.has_key(self): + self._timeout_tracker.pop(self) + self.close() + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % + self.socket) + + def handle_error(self): + raise + + def __init__(self, certfile, timeout_tracker): + asyncore.dispatcher.__init__(self) + self._timeout_tracker = timeout_tracker + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind((HOST, 0)) + self.sockname = sock.getsockname() + self.port = self.sockname[1] + self.set_socket(ssl.wrap_socket(sock, server_side=True, + certfile=certfile, + do_handshake_on_connect=False)) + self.listen(5) + + def writable(self): + return False + + def handle_accept(self): + acc_ret = self.accept() + if acc_ret: + sock_obj, addr = acc_ret + if test_support.verbose: + sys.stdout.write(" server: new connection from " + + "%s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self._timeout_tracker) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.flag = None + self.active = False + self.timeout_tracker = {} + self.server = self.EchoServer(certfile, self.timeout_tracker) + self.sockname = self.server.sockname + self.port = self.server.port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + now = datetime.datetime.now() + future_timeouts = filter(lambda x: x > now, + self.timeout_tracker.values()) + future_timeouts.append(now + datetime.timedelta(seconds=0.05)) + first_timeout = min(future_timeouts) - now + asyncore.loop(first_timeout.total_seconds(), count=1) + + def stop(self): + self.active = False + self.server.close() + +# Note that this HTTP-over-UDP server does not implement packet recovery and +# reordering, but it's good enough for testing on a loopback interface +class SocketServerHTTPSServer(threading.Thread): + + class HTTPSServerUDP(SocketServer.ThreadingTCPServer): + + def __init__(self, server_address, RequestHandlerClass, certfile): + SocketServer.ThreadingTCPServer.__init__(self, server_address, + RequestHandlerClass, False) + # account for dealing with a datagram socket + self.socket = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + self.server_bind() + self.server_activate() + + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + + def server_bind(self): + """Override server_bind to store the server name.""" + SocketServer.ThreadingTCPServer.server_bind(self) + host, port = self.socket.getsockname()[:2] + self.server_name = socket.getfqdn(host) + self.server_port = port + + def get_request(self): + # account for the fact that accept can return nothing, and + # according to BaseServer documentation, we should not block here + acc_ret = self.socket.accept() + if not acc_ret: + raise socket.error("No new connection") + return acc_ret + + def shutdown_request(self, request): + # Notify client of termination + request.unwrap() + + class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): + # need to override translate_path to get a known root, + # instead of using os.curdir, since the test could be + # run from anywhere + + server_version = "TestHTTPS-UDP/1.0" + + root = None + + def translate_path(self, path): + """Translate a /-separated PATH to the local filename syntax. + + Components that mean special things to the local file system + (e.g. drive or directory names) are ignored. (XXX They should + probably be diagnosed.) + + """ + # abandon query parameters + path = urlparse.urlparse(path)[2] + path = os.path.normpath(urllib.unquote(path)) + words = path.split('/') + words = filter(None, words) + path = self.root + for word in words: + drive, word = os.path.splitdrive(word) + head, word = os.path.split(word) + if word in self.root: continue + path = os.path.join(path, word) + return path + + def log_message(self, format, *args): + # we override this to suppress logging unless "verbose" + if test_support.verbose: + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, + self.server.server_port, + self.request.cipher(), + self.log_date_time_string(), + format%args)) + + + def __init__(self, certfile): + self.flag = None + self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0] + self.server = self.HTTPSServerUDP( + (HOST, 0), self.RootedHTTPRequestHandler, certfile) + self.port = self.server.server_port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + if self.flag: + self.flag.set() + self.server.serve_forever(0.05) + + def stop(self): + self.server.shutdown() + + +def bad_cert_test(certfile): + """ + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with the given client certificate fails. + """ + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_REQUIRED, + cacerts=ISSUER_CERTFILE, chatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=certfile, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + except ssl.SSLError, x: + if test_support.verbose: + sys.stdout.write("\nSSLError is %s\n" % x[1]) + except socket.error, x: + if test_support.verbose: + sys.stdout.write("\nsocket.error is %s\n" % x[1]) + else: + raise AssertionError("Use of invalid cert should have failed!") + finally: + server.stop() + server.join() + +def server_params_test(certfile, protocol, certreqs, cacertsfile, + client_certfile, client_protocol=None, + indata="FOO\n", ciphers=None, chatty=True, + connectionchatty=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + server = ThreadedEchoServer(certfile, + certreqs=certreqs, + ssl_version=protocol, + cacerts=cacertsfile, + ciphers=ciphers, + chatty=chatty, + connectionchatty=connectionchatty) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + if client_protocol is None: + client_protocol = protocol + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=client_certfile, + ca_certs=cacertsfile, + ciphers=ciphers, + cert_reqs=certreqs, + ssl_version=client_protocol) + s.connect((HOST, server.port)) + for arg in [indata, bytearray(indata), memoryview(indata)]: + if connectionchatty: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(arg))) + s.write(arg) + outdata = s.read() + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + server.join() + +def try_protocol_combo(server_protocol, + client_protocol, + expect_success, + certsreqs=None): + if certsreqs is None: + certsreqs = ssl.CERT_NONE + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] + if test_support.verbose: + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" + sys.stdout.write(formatstr % + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) + try: + # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client + # will send an SSLv3 hello (rather than SSLv2) starting from + # OpenSSL 1.0.0 (see issue #8322). + server_params_test(CERTFILE, server_protocol, certsreqs, + ISSUER_CERTFILE, CERTFILE, client_protocol, + ciphers="ALL", chatty=False) + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: + if expect_success: + raise + except socket.error as e: + if expect_success or e.errno != errno.ECONNRESET: + raise + else: + if not expect_success: + raise AssertionError( + "Client protocol %s succeeded with server protocol %s!" + % (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol))) + + +class ThreadedTests(unittest.TestCase): + + def test_unreachable(self): + server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + server.bind((HOST, 0)) + port = server.getsockname()[1] + server.close() + s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) + self.assertRaisesRegexp(ssl.SSLError, + "The peer address is not reachable", + s.connect, (HOST, port)) + + def test_echo(self): + """Basic test of an SSL client connecting to a server""" + if test_support.verbose: + sys.stdout.write("\n") + server_params_test(CERTFILE, ssl.PROTOCOL_DTLSv1, ssl.CERT_NONE, + CERTFILE, CERTFILE, ssl.PROTOCOL_DTLSv1, + chatty=True, connectionchatty=True) + + def test_getpeercert(self): + if test_support.verbose: + sys.stdout.write("\n") + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1, + cacerts=CERTFILE, + chatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=CERTFILE, + ca_certs=ISSUER_CERTFILE, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if test_support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Ray Srv Inc'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in " + "certificate subject; should be 'Ray Srv Inc'.") + s.close() + finally: + server.stop() + server.join() + + def test_empty_cert(self): + """Connecting with an empty cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "nullcert.pem")) + def test_malformed_cert(self): + """Connecting with a badly formatted certificate (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "badcert.pem")) + def test_nonexisting_cert(self): + """Connecting with a non-existing cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "wrongcert.pem")) + def test_malformed_key(self): + """Connecting with a badly formatted key (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "badkey.pem")) + + def test_protocol_dtlsv1(self): + """Connecting to a DTLSv1 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_REQUIRED) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", + "msg 5", "msg 6") + + server = ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_DTLSv1, + starttls_server=True, + chatty=True, + connectionchatty=True) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + wrapped = False + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + s = s.unwrap() + if test_support.verbose: + sys.stdout.write("\n") + for indata in msgs: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % repr(indata)) + if wrapped: + conn.write(indata) + outdata = conn.read() + else: + s.send(indata) + outdata = s.recv(1024) + if (indata == "STARTTLS" and + outdata.strip().lower().startswith("ok")): + # STARTTLS ok, switch to secure mode + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, starting TLS...\n" + % repr(outdata)) + conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_DTLSv1) + wrapped = True + elif (indata == "ENDTLS" and + outdata.strip().lower().startswith("ok")): + # ENDTLS ok, switch back to clear text + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, ending TLS...\n" + % repr(outdata)) + s = conn.unwrap() + wrapped = False + else: + if test_support.verbose: + sys.stdout.write( + " client: read %s from server\n" % repr(outdata)) + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write("over\n") + else: + s.send("over\n") + s.close() + finally: + server.stop() + server.join() + + def test_socketserver(self): + """Using a SocketServer to create and manage SSL connections.""" + server = SocketServerHTTPSServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + if test_support.verbose: + sys.stdout.write('\n') + with open(CERTFILE, 'rb') as f: + d1 = f.read() + d2 = [] + # now fetch the same data from the HTTPS-UDP server + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + fl = "/" + os.path.split(CERTFILE)[1] + s.write("GET " + fl + " HTTP/1.1\r\n" + + "Host: " + HOST + "\r\n\r\n") + content = False + last_buf = "" + while True: + try: + buf = last_buf + s.read() + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + s = s.unwrap() # complete shutdown protocol with server + break + raise + if test_support.verbose: + sys.stdout.write( + " client: read %d bytes from remote server '%s'\n" + % (len(buf), server)) + if content: + d2.append(buf) + continue + ind = buf.find("\r\n\r\n") + if ind < 0: + last_buf = buf[-3:] # find double-newline across buffers + continue + d2.append(buf[ind + 4:]) + content = True + last_buf = "" + s.close() + self.assertEqual(d1, ''.join(d2)) + finally: + server.stop() + server.join() + + def test_asyncore_server(self): + """Check the example asyncore integration.""" + indata = "TEST MESSAGE of mixed case\n" + + if test_support.verbose: + sys.stdout.write("\n") + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + self.fail( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + # wait for server thread to end + server.join() + + def test_recv_send(self): + """Test recv(), send() and friends.""" + if test_support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + try: + # helper methods for standardising recv* method signatures + def _recv_into(): + b = bytearray("\0"*100) + count = s.recv_into(b) + return b[:count] + + def _recvfrom_into(): + b = bytearray("\0"*100) + count, addr = s.recvfrom_into(b) + return b[:count] + + # (name, method, whether to expect success, *args) + send_methods = [ + ('send', s.send, True, []), + ('sendto', s.sendto, False, ["some.address"]), + ('sendall', s.sendall, True, []), + ] + recv_methods = [ + ('recv', s.recv, True, []), + ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recv_into', _recv_into, True, []), + ('recvfrom_into', _recvfrom_into, False, []), + ] + data_prefix = u"PREFIX_" + + for meth_name, send_meth, expect_success, args in send_methods: + indata = data_prefix + meth_name + try: + send_meth(indata.encode('ASCII', 'strict'), *args) + outdata = s.read() + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While sending with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to send with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + + for meth_name, recv_meth, expect_success, args in recv_methods: + indata = data_prefix + meth_name + try: + s.send(indata.encode('ASCII', 'strict')) + outdata = recv_meth(*args) + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While receiving with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to receive with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + # consume data + s.read() + + s.write("over\n".encode("ASCII", "strict")) + s.close() + finally: + server.stop() + server.join() + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout + server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + server.bind((HOST, 0)) + port = server.getsockname()[1] + + try: + try: + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(0.2) + c.connect((HOST, port)) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + ssl.wrap_socket, c) + finally: + c.close() + try: + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(0.2) + c = ssl.wrap_socket(c) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + c.connect, (HOST, port)) + finally: + c.close() + finally: + server.close() + + +def test_main(verbose=True): + global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE + CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "keycert.pem") + ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "ca-cert.pem") + OTHER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "yahoo-cert.pem") + + for fl in CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE: + if not os.path.exists(fl): + raise Exception("Can't read certificate files!") + + TestSupport.verbose = verbose + do_patch() + unittest.main() + +if __name__ == "__main__": + verbose = True if len(sys.argv) > 1 and sys.argv[1] == "-v" else False + test_main(verbose)