diff --git a/dtls/__init__.py b/dtls/__init__.py index d110877..93f6688 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -1,11 +1,19 @@ # PyDTLS: datagram TLS for Python. Written by Ray Brown. """PyDTLS package -This package exports OpenSSL's DTLS support to Python. Importing it will add -the constant PROTOCOL_DTLSv1 to the Python standard library's ssl module. -Passing a datagram socket to that module's wrap_socket function (or -instantiating its SSLSocket class with a datagram socket) will activate this -module's DTLS implementation for the returned SSLSocket instance. +This package exports OpenSSL's DTLS support to Python. Calling its patch +function will add the constant PROTOCOL_DTLSv1 to the Python standard library's +ssl module. Subsequently passing a datagram socket to that module's +wrap_socket function (or instantiating its SSLSocket class with a datagram +socket) will activate this module's DTLS implementation for the returned +SSLSocket instance. + +Instead of or in addition to invoking the patch functionality, the +SSLConnection class can be used directly for secure communication over datagram +sockets. wrap_socket's parameters and their semantics have been maintained. """ + +from patch import do_patch +from sslconnection import SSLConnection diff --git a/dtls/err.py b/dtls/err.py index 2e3fcb5..9dbe563 100644 --- a/dtls/err.py +++ b/dtls/err.py @@ -25,7 +25,9 @@ SSL_ERROR_WANT_ACCEPT = 8 ERR_BOTH_KEY_CERT_FILES = 500 ERR_BOTH_KEY_CERT_FILES_SVR = 298 ERR_NO_CERTS = 331 - +ERR_NO_CIPHER = 501 +ERR_HANDSHAKE_TIMEOUT = 502 +ERR_PORT_UNREACHABLE = 503 ERR_COOKIE_MISMATCH = 0x1408A134 @@ -35,27 +37,48 @@ class SSLError(socket_error): super(SSLError, self).__init__(*args) -class OpenSSLError(SSLError): - """This exception is raised when an error occurs in the OpenSSL library""" - def __init__(self, ssl_error, errqueue, result, func, args): - self.ssl_error = ssl_error - self.errqueue = errqueue - self.result = result - self.func = func - self.args = args - super(OpenSSLError, self).__init__(ssl_error, errqueue, - result, func, args) - - class InvalidSocketError(Exception): """There is a problem with a socket passed to the dtls package.""" def __init__(self, *args): super(InvalidSocketError, self).__init__(*args) -def raise_ssl_error(code): +def _make_opensslerror_class(): + global _OpenSSLError + class __OpenSSLError(SSLError): + """ + This exception is raised when an error occurs in the OpenSSL library + """ + def __init__(self, ssl_error, errqueue, result, func, args): + self.ssl_error = ssl_error + self.errqueue = errqueue + self.result = result + self.func = func + self.args = args + SSLError.__init__(self, ssl_error, errqueue, + result, func, args) + + _OpenSSLError = __OpenSSLError + +_make_opensslerror_class() + +def openssl_error(): + """Return the OpenSSL error type for use in exception clauses""" + return _OpenSSLError + +def raise_as_ssl_module_error(): + """Exceptions raised from this module are instances of ssl.SSLError""" + import ssl + global SSLError + SSLError = ssl.SSLError + _make_opensslerror_class() + +def raise_ssl_error(code, nested=None): """Raise an SSL error with the given error code""" - raise SSLError(str(code) + ": " + _ssl_errors[code]) + err_string = str(code) + ": " + _ssl_errors[code] + if nested: + raise SSLError(err_string, nested) + raise SSLError(err_string) _ssl_errors = { ERR_NO_CERTS: "No root certificates specified for verification " + \ @@ -63,5 +86,8 @@ _ssl_errors = { ERR_BOTH_KEY_CERT_FILES: "Both the key & certificate files " + \ "must be specified", ERR_BOTH_KEY_CERT_FILES_SVR: "Both the key & certificate files must be " + \ - "specified for server-side operation" + "specified for server-side operation", + ERR_NO_CIPHER: "No cipher can be selected.", + ERR_HANDSHAKE_TIMEOUT: "The handshake operation timed out", + ERR_PORT_UNREACHABLE: "The peer address is not reachable", } diff --git a/dtls/openssl.py b/dtls/openssl.py index 222d199..455bed4 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -20,7 +20,8 @@ import array import socket from logging import getLogger from os import path -from err import OpenSSLError +from datetime import timedelta +from err import openssl_error from err import SSL_ERROR_NONE from util import _BIO import ctypes @@ -64,6 +65,7 @@ else: # BIO_NOCLOSE = 0x00 BIO_CLOSE = 0x01 +SSLEAY_VERSION = 0 SSL_VERIFY_NONE = 0x00 SSL_VERIFY_PEER = 0x01 SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02 @@ -91,6 +93,8 @@ BIO_CTRL_DGRAM_SET_CONNECTED = 32 BIO_CTRL_DGRAM_GET_PEER = 46 BIO_CTRL_DGRAM_SET_PEER = 44 BIO_C_SET_NBIO = 102 +DTLS_CTRL_GET_TIMEOUT = 73 +DTLS_CTRL_HANDLE_TIMEOUT = 74 DTLS_CTRL_LISTEN = 75 X509_NAME_MAXLEN = 256 GETS_MAXLEN = 2048 @@ -248,6 +252,11 @@ class X509V3_EXT_METHOD(Structure): ("i2d", c_int)] # remaining fields omitted +class TIMEVAL(Structure): + _fields_ = [("tv_sec", c_long), + ("tv_usec", c_long)] + + # # Socket address conversions # @@ -366,7 +375,7 @@ def raise_ssl_error(result, func, args, ssl): _logger.debug("SSL error raised: ssl_error: %d, result: %d, " + "errqueue: %s, func_name: %s", ssl_error, result, errqueue, func.func_name) - raise OpenSSLError(ssl_error, errqueue, result, func, args) + raise openssl_error()(ssl_error, errqueue, result, func, args) def find_ssl_arg(args): for arg in args: @@ -424,6 +433,7 @@ def _make_function(name, lib, args, export=True, errcheck="default"): _subst = {c_long_parm: c_long} _sigs = {} __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", + "SSLEAY_VERSION", "SSL_VERIFY_NONE", "SSL_VERIFY_PEER", "SSL_VERIFY_FAIL_IF_NO_PEER_CERT", "SSL_VERIFY_CLIENT_ONCE", "SSL_SESS_CACHE_OFF", "SSL_SESS_CACHE_CLIENT", @@ -432,6 +442,7 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", "SSL_SESS_CACHE_NO_INTERNAL_STORE", "SSL_SESS_CACHE_NO_INTERNAL", "SSL_FILE_TYPE_PEM", "GEN_DIRNAME", "NID_subject_alt_name", + "DTLSv1_get_timeout", "DTLSv1_handle_timeout", "DTLSv1_listen", "BIO_gets", "BIO_read", "BIO_get_mem_data", "BIO_dgram_set_connected", @@ -450,6 +461,8 @@ __all__ = ["BIO_NOCLOSE", "BIO_CLOSE", map(lambda x: _make_function(*x), ( ("SSL_library_init", libssl, ((c_int, "ret"),)), ("SSL_load_error_strings", libssl, ((None, "ret"),)), + ("SSLeay", libcrypto, ((c_long_parm, "ret"),)), + ("SSLeay_version", libcrypto, ((c_char_p, "ret"), (c_int, "t"))), ("DTLSv1_server_method", libssl, ((DTLSv1Method, "ret"),)), ("DTLSv1_client_method", libssl, ((DTLSv1Method, "ret"),)), ("SSL_CTX_new", libssl, ((SSLCTX, "ret"), (DTLSv1Method, "meth"))), @@ -514,6 +527,7 @@ map(lambda x: _make_function(*x), ( ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), ("SSL_write", libssl, ((c_int, "ret"), (SSL, "ssl"), (c_void_p, "buf"), (c_int, "num")), False), + ("SSL_pending", libssl, ((c_int, "ret"), (SSL, "ssl")), True, None), ("SSL_shutdown", libssl, ((c_int, "ret"), (SSL, "ssl"))), ("SSL_set_read_ahead", libssl, ((None, "ret"), (SSL, "ssl"), (c_int, "yes"))), @@ -630,6 +644,28 @@ def BIO_dgram_set_peer(bio, peer_address): def BIO_set_nbio(bio, n): _BIO_ctrl(bio, BIO_C_SET_NBIO, 1 if n else 0, None) +def DTLSv1_get_timeout(ssl): + tv = TIMEVAL() + ret = _SSL_ctrl(ssl, DTLS_CTRL_GET_TIMEOUT, 0, byref(tv)) + if ret != 1: + return + return timedelta(seconds=tv.tv_sec, microseconds=tv.tv_usec) + +def DTLSv1_handle_timeout(ssl): + ret = _SSL_ctrl(ssl, DTLS_CTRL_HANDLE_TIMEOUT, 0, None) + if ret == 0: + # It was too early to call: no timer had yet expired + return False + if ret == 1: + # Buffered messages were retransmitted + return True + # There was an error: either too many timeouts have occurred or a + # retransmission failed + assert ret < 0 + if ret > 0: + ret = -10 + errcheck_p(ret, _SSL_ctrl, (ssl, DTLS_CTRL_HANDLE_TIMEOUT, 0, None)) + def DTLSv1_listen(ssl): su = sockaddr_u() ret = _SSL_ctrl(ssl, DTLS_CTRL_LISTEN, 0, byref(su)) @@ -642,7 +678,10 @@ def SSL_read(ssl, length): return buf.raw[:res_len] def SSL_write(ssl, data): - str_data = str(data) + if hasattr(data, "tobytes") and callable(data.tobytes): + str_data = data.tobytes() + else: + str_data = str(data) return _SSL_write(ssl, str_data, len(str_data)) def OBJ_obj2txt(asn1_object, no_name): diff --git a/dtls/patch.py b/dtls/patch.py new file mode 100644 index 0000000..910bdf4 --- /dev/null +++ b/dtls/patch.py @@ -0,0 +1,199 @@ +# Patch: patching of the Python stadard library's ssl module for transparent +# use of datagram sockets. Written by Ray Brown. +"""Patch + +This module is used to patch the Python standard library's ssl module. Patching +has the following effects: + + * The constant PROTOCOL_DTLSv1 is added at ssl module level + * DTLSv1's protocol name is added to the ssl module's id-to-name dictionary + * The constants DTLS_OPENSSL_VERSION* are added at the ssl module level + * Instntiation of ssl.SSLSocket with sock.type == socket.SOCK_DGRAM is + supported and leads to substitution of this module's DTLS code paths for + that SSLSocket instance + * Direct instantiation of SSLSocket as well as instantiation through + ssl.wrap_socket are supported + * Invocation of the function get_server_certificate with a value of + PROTOCOL_DTLSv1 for the parameter ssl_version is supported +""" + +from socket import SOCK_DGRAM, socket, _delegate_methods, error as socket_error +from socket import AF_INET, SOCK_DGRAM +from sslconnection import SSLConnection, PROTOCOL_DTLSv1, CERT_NONE +from sslconnection import DTLS_OPENSSL_VERSION_NUMBER, DTLS_OPENSSL_VERSION +from sslconnection import DTLS_OPENSSL_VERSION_INFO +from err import raise_as_ssl_module_error +from types import MethodType +from weakref import proxy +import errno + +def do_patch(): + import ssl as _ssl # import to be avoided if ssl module is never patched + global _orig_SSLSocket_init, _orig_get_server_certificate + global ssl + ssl = _ssl + ssl.PROTOCOL_DTLSv1 = PROTOCOL_DTLSv1 + ssl._PROTOCOL_NAMES[PROTOCOL_DTLSv1] = "DTLSv1" + ssl.DTLS_OPENSSL_VERSION_NUMBER = DTLS_OPENSSL_VERSION_NUMBER + ssl.DTLS_OPENSSL_VERSION = DTLS_OPENSSL_VERSION + ssl.DTLS_OPENSSL_VERSION_INFO = DTLS_OPENSSL_VERSION_INFO + _orig_SSLSocket_init = ssl.SSLSocket.__init__ + _orig_get_server_certificate = ssl.get_server_certificate + ssl.SSLSocket.__init__ = _SSLSocket_init + ssl.get_server_certificate = _get_server_certificate + raise_as_ssl_module_error() + +PROTOCOL_SSLv3 = 1 +PROTOCOL_SSLv23 = 2 + +def _get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): + """Retrieve a server certificate + + Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt. + """ + + if ssl_version != PROTOCOL_DTLSv1: + return _orig_get_server_certificate(addr, ssl_version, ca_certs) + + host, port = addr + if (ca_certs is not None): + cert_reqs = ssl.CERT_REQUIRED + else: + cert_reqs = ssl.CERT_NONE + s = ssl.wrap_socket(socket(AF_INET, SOCK_DGRAM), + ssl_version=ssl_version, + cert_reqs=cert_reqs, ca_certs=ca_certs) + s.connect(addr) + dercert = s.getpeercert(True) + s.close() + return ssl.DER_cert_to_PEM_cert(dercert) + +def _SSLSocket_init(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, ciphers=None): + is_connection = is_datagram = False + if isinstance(sock, SSLConnection): + is_connection = True + elif hasattr(sock, "type") and sock.type == SOCK_DGRAM: + is_datagram = True + if not is_connection and not is_datagram: + # Non-DTLS code path + return _orig_SSLSocket_init(self, sock, keyfile, certfile, + server_side, cert_reqs, + ssl_version, ca_certs, + do_handshake_on_connect, + suppress_ragged_eofs, ciphers) + # DTLS code paths: datagram socket and newly accepted DTLS connection + if is_datagram: + socket.__init__(self, _sock=sock._sock) + else: + socket.__init__(self, _sock=sock.get_socket(True)._sock) + # Copy instance initialization from SSLSocket class + for attr in _delegate_methods: + try: + delattr(self, attr) + except AttributeError: + pass + + if certfile and not keyfile: + keyfile = certfile + if is_datagram: + # see if it's connected + try: + socket.getpeername(self) + except socket_error, e: + if e.errno != errno.ENOTCONN: + raise + # no, no connection yet + self._connected = False + self._sslobj = None + else: + # yes, create the SSL object + self._connected = True + self._sslobj = SSLConnection(sock, keyfile, certfile, + server_side, cert_reqs, + ssl_version, ca_certs, + do_handshake_on_connect, + suppress_ragged_eofs, ciphers) + else: + self._sslobj = sock + + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 + + # Perform method substitution and addition (without reference cycle) + self._real_connect = MethodType(_SSLSocket_real_connect, proxy(self)) + self.listen = MethodType(_SSLSocket_listen, proxy(self)) + self.accept = MethodType(_SSLSocket_accept, proxy(self)) + self.get_timeout = MethodType(_SSLSocket_get_timeout, proxy(self)) + self.handle_timeout = MethodType(_SSLSocket_handle_timeout, proxy(self)) + +def _SSLSocket_listen(self, ignored): + if self._connected: + raise ValueError("attempt to listen on connected SSLSocket!") + if self._sslobj: + return + self._sslobj = SSLConnection(socket(_sock=self._sock), + self.keyfile, self.certfile, True, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + +def _SSLSocket_accept(self): + if self._connected: + raise ValueError("attempt to accept on connected SSLSocket!") + if not self._sslobj: + raise ValueError("attempt to accept on SSLSocket prior to listen!") + acc_ret = self._sslobj.accept() + if not acc_ret: + return + new_conn, addr = acc_ret + new_ssl_sock = ssl.SSLSocket(new_conn, self.keyfile, self.certfile, True, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + return new_ssl_sock, addr + +def _SSLSocket_real_connect(self, addr, return_errno): + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._sslobj = SSLConnection(socket(_sock=self._sock), + self.keyfile, self.certfile, False, + self.cert_reqs, self.ssl_version, + self.ca_certs, + self.do_handshake_on_connect, + self.suppress_ragged_eofs, self.ciphers) + try: + self._sslobj.connect(addr) + except socket_error as e: + if return_errno: + return e.errno + else: + self._sslobj = None + raise e + self._connected = True + return 0 + + +if __name__ == "__main__": + do_patch() + +def _SSLSocket_get_timeout(self): + return self._sslobj.get_timeout() + +def _SSLSocket_handle_timeout(self): + return self._sslobj.handle_timeout() diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index a0947da..8c17de0 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -29,9 +29,11 @@ import hmac from logging import getLogger from os import urandom from weakref import proxy -from err import OpenSSLError, InvalidSocketError +from err import openssl_error, InvalidSocketError from err import raise_ssl_error -from err import SSL_ERROR_WANT_READ, ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import SSL_ERROR_WANT_READ, SSL_ERROR_SYSCALL +from err import ERR_COOKIE_MISMATCH, ERR_NO_CERTS +from err import ERR_NO_CIPHER, ERR_HANDSHAKE_TIMEOUT, ERR_PORT_UNREACHABLE from x509 import _X509, decode_cert from openssl import * from util import _Rsrc, _BIO @@ -48,6 +50,14 @@ CERT_REQUIRED = 2 # SSL_library_init() SSL_load_error_strings() +DTLS_OPENSSL_VERSION_NUMBER = SSLeay() +DTLS_OPENSSL_VERSION = SSLeay_version(SSLEAY_VERSION) +DTLS_OPENSSL_VERSION_INFO = ( + DTLS_OPENSSL_VERSION_NUMBER >> 28 & 0xFF, # major + DTLS_OPENSSL_VERSION_NUMBER >> 20 & 0xFF, # minor + DTLS_OPENSSL_VERSION_NUMBER >> 12 & 0xFF, # fix + DTLS_OPENSSL_VERSION_NUMBER >> 4 & 0xFF, # patch + DTLS_OPENSSL_VERSION_NUMBER & 0xF) # status class _CTX(_Rsrc): @@ -98,9 +108,11 @@ class SSLConnection(object): _rnd_key = urandom(16) - def _init_server(self): + def _init_server(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") + if peer_address: + raise InvalidSocketError("server-side socket must be unconnected") from demux import UDPDemux self._udp_demux = UDPDemux(self._sock) @@ -127,7 +139,7 @@ class SSLConnection(object): self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(self._ssl.value) - def _init_client(self): + def _init_client(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") @@ -141,6 +153,8 @@ class SSLConnection(object): self._config_ssl_ctx(verify_mode) self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_connect_state(self._ssl.value) + if peer_address: + return lambda: self.connect(peer_address) def _config_ssl_ctx(self, verify_mode): SSL_CTX_set_verify(self._ctx.value, verify_mode) @@ -153,7 +167,10 @@ class SSLConnection(object): if self._ca_certs: SSL_CTX_load_verify_locations(self._ctx.value, self._ca_certs, None) if self._ciphers: - SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) + try: + SSL_CTX_set_cipher_list(self._ctx.value, self._ciphers) + except openssl_error() as err: + raise_ssl_error(ERR_NO_CIPHER, err) def _copy_server(self): source = self._sock @@ -171,6 +188,7 @@ class SSLConnection(object): new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), BIO_NOCLOSE)) source._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_accept_state(source._ssl.value) source._rbio = new_source_rbio source._wbio = new_source_wbio SSL_set_bio(source._ssl.value, @@ -179,6 +197,20 @@ class SSLConnection(object): new_source_rbio.disown() new_source_wbio.disown() + def _reconnect_unwrapped(self): + source = self._sock + self._sock = source._wsock + self._udp_demux = source._demux + self._rsock = source._rsock + self._ctx = source._ctx + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + BIO_dgram_set_peer(self._wbio.value, source._peer_address) + self._ssl = _SSL(SSL_new(self._ctx.value)) + SSL_set_accept_state(self._ssl.value) + if self._do_handshake_on_connect: + return lambda: self.do_handshake() + def _check_nbio(self): BIO_set_nbio(self._wbio.value, self._sock.gettimeout() is not None) if self._wbio is not self._rbio: @@ -234,15 +266,24 @@ class SSLConnection(object): self._handshake_done = False if isinstance(sock, SSLConnection): - self._copy_server() - elif server_side: - self._init_server() + post_init = self._copy_server() + elif isinstance(sock, _UnwrappedSocket): + post_init = self._reconnect_unwrapped() else: - self._init_client() + try: + peer_address = sock.getpeername() + except socket.error: + peer_address = None + if server_side: + post_init = self._init_server(peer_address) + else: + post_init = self._init_client(peer_address) SSL_set_bio(self._ssl.value, self._rbio.value, self._wbio.value) self._rbio.disown() self._wbio.disown() + if post_init: + post_init() def get_socket(self, inbound): """Retrieve a socket used by this connection @@ -305,7 +346,7 @@ class SSLConnection(object): _logger.debug("Invoking DTLSv1_listen for ssl: %d", self._ssl.value._as_parameter) dtls_peer_address = DTLSv1_listen(self._ssl.value) - except OpenSSLError as err: + except openssl_error() as err: if err.ssl_error == SSL_ERROR_WANT_READ: # This method must be called again to forward the next datagram _logger.debug("DTLSv1_listen must be resumed") @@ -345,6 +386,7 @@ class SSLConnection(object): self._cert_reqs, PROTOCOL_DTLSv1, self._ca_certs, self._do_handshake_on_connect, self._suppress_ragged_eofs, self._ciphers) + new_peer = self._pending_peer_address self._pending_peer_address = None if self._do_handshake_on_connect: # Note that since that connection's socket was just created in its @@ -354,7 +396,7 @@ class SSLConnection(object): # will hang in this call new_conn.do_handshake() _logger.debug("Accept returning new connection for new peer") - return new_conn + return new_conn, new_peer def connect(self, peer_address): """Client-side UDP connection establishment @@ -368,6 +410,7 @@ class SSLConnection(object): """ self._sock.connect(peer_address) + peer_address = self._sock.getpeername() # substituted host addrinfo BIO_dgram_set_connected(self._wbio.value, peer_address) assert self._wbio is self._rbio if self._do_handshake_on_connect: @@ -382,7 +425,15 @@ class SSLConnection(object): _logger.debug("Initiating handshake...") self._check_nbio() - SSL_do_handshake(self._ssl.value) + try: + SSL_do_handshake(self._ssl.value) + except openssl_error() as err: + if err.ssl_error == SSL_ERROR_WANT_READ and \ + self.get_socket(True).gettimeout(): + raise_ssl_error(ERR_HANDSHAKE_TIMEOUT, err) + elif err.ssl_error == SSL_ERROR_SYSCALL and err.result == -1: + raise_ssl_error(ERR_PORT_UNREACHABLE, err) + raise self._handshake_done = True _logger.debug("...completed handshake") @@ -423,19 +474,30 @@ class SSLConnection(object): it no longer raises continuation request exceptions. """ + if hasattr(self, "_listening"): + # Listening server-side sockets cannot be shut down + return + self._check_nbio() try: SSL_shutdown(self._ssl.value) - except OpenSSLError as err: + except openssl_error() as err: if err.result == 0: # close-notify alert was just sent; wait for same from peer # Note: while it might seem wise to suppress further read-aheads # with SSL_set_read_ahead here, doing so causes a shutdown # failure (ret: -1, SSL_ERROR_SYSCALL) on the DTLS shutdown - # initiator side. + # initiator side. And test_starttls does pass. SSL_shutdown(self._ssl.value) else: raise + if hasattr(self, "_udp_demux"): + # Return wrapped connected server socket (non-listening) + return _UnwrappedSocket(self._sock, self._rsock, self._udp_demux, + self._ctx, + BIO_dgram_get_peer(self._wbio.value)) + # Return unwrapped client-side socket + return self._sock def getpeercert(self, binary_form=False): """Retrieve the peer's certificate @@ -453,7 +515,7 @@ class SSLConnection(object): try: peer_cert = _X509(SSL_get_peer_certificate(self._ssl.value)) - except OpenSSLError: + except openssl_error(): return if binary_form: @@ -462,6 +524,8 @@ class SSLConnection(object): return {} return decode_cert(peer_cert) + peer_certificate = getpeercert # compatibility with _ssl call interface + def cipher(self): """Retrieve information about the current cipher @@ -478,3 +542,95 @@ class SSLConnection(object): cipher_version = SSL_CIPHER_get_version(current_cipher) cipher_bits = SSL_CIPHER_get_bits(current_cipher) return cipher_name, cipher_version, cipher_bits + + def pending(self): + """Retrieve number of buffered bytes + + Return the number of bytes that have been read from the socket and + buffered by this connection. Return 0 if no bytes have been buffered. + """ + + return SSL_pending(self._ssl.value) + + def get_timeout(self): + """Retrieve the retransmission timedelta + + Since datagrams are subject to packet loss, DTLS will perform + packet retransmission if a response is not received after a certain + time interval during the handshaking phase. When using non-blocking + sockets, the application must call back after that time interval to + allow for the retransmission to occur. This method returns the + timedelta after which to perform the call to handle_timeout, or None + if no such callback is needed given the current handshake state. + """ + + return DTLSv1_get_timeout(self._ssl.value) + + def handle_timeout(self): + """Perform datagram retransmission, if required + + This method should be called after the timedelta retrieved from + get_timeout has expired, and no datagrams were received in the + meantime. If datagrams were received, a new timeout needs to be + requested. + + Return value: + True -- retransmissions were performed successfully + False -- a timeout was not in effect or had not yet expired + + Exceptions: + Raised when retransmissions fail or too many timeouts occur. + """ + + return DTLSv1_handle_timeout(self._ssl.value) + + +class _UnwrappedSocket(socket.socket): + """Unwrapped server-side socket + + Depending on UDP demux implementation, there may not be single socket + that can be used for both reading and writing to the client socket with + which it is associated. An object of this type is therefore returned from + the SSLSocket's unwrap method to allow for unencrypted communication over + the established channels, including the demux. + """ + + def __init__(self, wsock, rsock, demux, ctx, peer_address): + socket.socket.__init__(self, _sock=rsock._sock) + for attr in "send", "sendto", "sendall": + try: + delattr(self, attr) + except AttributeError: + pass + self._wsock = wsock + self._rsock = rsock # continue to reference to hold in demux map + self._demux = demux + self._ctx = ctx + self._peer_address = peer_address + + def send(self, data, flags=0): + __doc__ = self._wsock.send.__doc__ + return self._wsock.sendto(data, flags, self._peer_address) + + def sendto(self, data, flags_or_addr, addr=None): + __doc__ = self._wsock.sendto.__doc__ + return self._wsock.sendto(data, flags_or_addr, addr) + + def sendall(self, data, flags=0): + __doc__ = self._wsock.sendall.__doc__ + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:], flags) + count += v + return amount + + def getpeername(self): + __doc__ = self._wsock.getpeername.__doc__ + return self._peer_address + + def connect(self, addr): + __doc__ = self._wsock.connect.__doc__ + raise ValueError("Cannot connect already connected unwrapped socket") + + connect_ex = connect diff --git a/dtls/test/certs/badcert.pem b/dtls/test/certs/badcert.pem new file mode 100644 index 0000000..c419146 --- /dev/null +++ b/dtls/test/certs/badcert.pem @@ -0,0 +1,36 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L +opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH +fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB +AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU +D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA +IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM +oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0 +ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/ +loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j +oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA +z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq +ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV +q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +Just bad cert data +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L +opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH +fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB +AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU +D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA +IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM +oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0 +ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/ +loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j +oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA +z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq +ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV +q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +Just bad cert data +-----END CERTIFICATE----- diff --git a/dtls/test/certs/badkey.pem b/dtls/test/certs/badkey.pem new file mode 100644 index 0000000..1c8a955 --- /dev/null +++ b/dtls/test/certs/badkey.pem @@ -0,0 +1,40 @@ +-----BEGIN RSA PRIVATE KEY----- +Bad Key, though the cert should be OK +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD +VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x +IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT +U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1 +NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl +bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj +aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh +m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8 +M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn +fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC +AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb +08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx +CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/ +iHkC6gGdBJhogs4= +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +Bad Key, though the cert should be OK +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD +VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x +IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT +U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1 +NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl +bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj +aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh +m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8 +M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn +fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC +AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb +08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx +CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/ +iHkC6gGdBJhogs4= +-----END CERTIFICATE----- diff --git a/dtls/test/certs/keycert.pem b/dtls/test/certs/keycert.pem new file mode 100644 index 0000000..05ee34c --- /dev/null +++ b/dtls/test/certs/keycert.pem @@ -0,0 +1,21 @@ +-----BEGIN PRIVATE KEY----- +MIIBVAIBADANBgkqhkiG9w0BAQEFAASCAT4wggE6AgEAAkEAuPd3JmydJfXhyii0 +agsVgRMOUcOyuldbaf/Lu4bZ+U0zH0OSoYkv0Ahbz7ehK+oGMeUy/SuGVAn7JLyj +zlYi8QIDAQABAkAygtnV82lC2Y/Mbis+nkJEGlkZuRCQ1JRRMRqI3n2eF6CviqF3 +PiBXIEEExzKihC9bvbHKTAkYDLr+/4YpbiQBAiEA7JLS5Lp7KI/ayWwEzl2r5XXu +k/cbH++A4zZz6A9XIsECIQDIJ8ciDa5/VGyQnYMzBNgKnwaFDDBOiEUFDaU/9ZN8 +MQIgCG3Gw819G9ncQrbtiOi/eiJ0iKMSPVYMMow7HvaE9UECIQCLyQwPwlJd5s4z +aW4ZkYZ4VHuvK8YI8q6RSuhf9Nhd4QIgFbRNdEeehgrzGzGug2yVCMzVzS3MQNBJ +6LqBZaPlFsM= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIBgDCCASoCAQEwDQYJKoZIhvcNAQEEBQAwSjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCldhc2hpbmd0b24xEzARBgNVBAoTClJheSBDQSBJbmMxETAPBgNVBAMTCFJh +eUNBSW5jMB4XDTEyMDkyMTIxMTYxOFoXDTEzMDkyMTIxMTYxOFowTDELMAkGA1UE +BhMCVVMxEzARBgNVBAgTCldhc2hpbmd0b24xFDASBgNVBAoTC1JheSBTcnYgSW5j +MRIwEAYDVQQDEwlSYXlTcnZJbmMwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAuPd3 +JmydJfXhyii0agsVgRMOUcOyuldbaf/Lu4bZ+U0zH0OSoYkv0Ahbz7ehK+oGMeUy +/SuGVAn7JLyjzlYi8QIDAQABMA0GCSqGSIb3DQEBBAUAA0EAEkxVF8HEGV8N4mYA +hDciYpttnnb9pYL1okHGrhaIFqu9D10LfP1SKps/6s/qNSk3YaIVjydWOHEf6xr4 +zJkiFw== +-----END CERTIFICATE----- diff --git a/dtls/test/certs/nullcert.pem b/dtls/test/certs/nullcert.pem new file mode 100644 index 0000000..e69de29 diff --git a/dtls/test/certs/wrongcert.pem b/dtls/test/certs/wrongcert.pem new file mode 100644 index 0000000..5f92f9b --- /dev/null +++ b/dtls/test/certs/wrongcert.pem @@ -0,0 +1,32 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH +FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T +f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB +AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq +1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW +7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg +SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe +19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg +ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666 +lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs +ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv +frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk +2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6 ++bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK +24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G +A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9 +OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt +U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ +fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E +usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+ +43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw +eSHj5jpC8iZKjCHBn+mAi4cQ514= +-----END CERTIFICATE----- diff --git a/dtls/test/echo_seq.py b/dtls/test/echo_seq.py index 99bb9cb..7a3bf3b 100644 --- a/dtls/test/echo_seq.py +++ b/dtls/test/echo_seq.py @@ -45,7 +45,7 @@ def main(): break print "Accepting..." - conn = scn.accept() + conn = scn.accept()[0] sck.settimeout(5) conn.get_socket(True).settimeout(5) @@ -59,7 +59,7 @@ def main(): try: conn.do_handshake() except SSLError as err: - if err.args[0] == SSL_ERROR_WANT_READ: + if len(err.args) > 1 and err.args[1].args[0] == SSL_ERROR_WANT_READ: continue raise print "Completed handshaking with peer" diff --git a/dtls/test/unit.py b/dtls/test/unit.py new file mode 100644 index 0000000..c8d7c96 --- /dev/null +++ b/dtls/test/unit.py @@ -0,0 +1,1313 @@ +# Test the support for DTLS through the SSL module. Adapted from the Python +# standard library's test_ssl.py regression test module by Ray Brown. + +import sys +import unittest +import asyncore +import socket +import select +import gc +import os +import errno +import pprint +import urllib, urlparse +import traceback +import weakref +import platform +import threading +import datetime +import SocketServer +from SimpleHTTPServer import SimpleHTTPRequestHandler + +import ssl +from dtls import do_patch + +HOST = "localhost" + +class TestSupport(object): + verbose = True + + class Ctx(object): + def __enter__(self): + self.server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + self.server.start(flag) + flag.wait() + return self.server.sockname + + def __exit__(self, exc_type, exc_value, traceback): + self.server.stop() + self.server.join() + self.server = None + + def transient_internet(self): + return self.Ctx() + +test_support = TestSupport() + +def handle_error(prefix): + exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) + if test_support.verbose: + sys.stdout.write(prefix + exc_format) + + +class BasicTests(unittest.TestCase): + + def test_sslwrap_simple(self): + # A crude test for the legacy API + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + try: + ssl.sslwrap_simple(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)._sock) + except IOError, e: + if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that + pass + else: + raise + + +class BasicSocketTests(unittest.TestCase): + + def test_constants(self): + ssl.PROTOCOL_SSLv23 + ssl.PROTOCOL_SSLv3 + ssl.PROTOCOL_TLSv1 + ssl.PROTOCOL_DTLSv1 # added + ssl.CERT_NONE + ssl.CERT_OPTIONAL + ssl.CERT_REQUIRED + + def test_dtls_openssl_version(self): + n = ssl.DTLS_OPENSSL_VERSION_NUMBER + t = ssl.DTLS_OPENSSL_VERSION_INFO + s = ssl.DTLS_OPENSSL_VERSION + self.assertIsInstance(n, (int, long)) + self.assertIsInstance(t, tuple) + self.assertIsInstance(s, str) + # Some sanity checks follow + # >= 1.0 + self.assertGreaterEqual(n, 0x10000000) + # < 2.0 + self.assertLess(n, 0x20000000) + major, minor, fix, patch, status = t + self.assertGreaterEqual(major, 1) + self.assertLess(major, 2) + self.assertGreaterEqual(minor, 0) + self.assertLess(minor, 256) + self.assertGreaterEqual(fix, 0) + self.assertLess(fix, 256) + self.assertGreaterEqual(patch, 0) + self.assertLessEqual(patch, 26) + self.assertGreaterEqual(status, 0) + self.assertLessEqual(status, 15) + # Version string as returned by OpenSSL, the format might change + self.assertTrue( + s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), + (s, t)) + + def test_ciphers(self): + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + flag.wait() + remote = (HOST, server.port) + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, ciphers="ALL") + s.connect(remote) + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") + s.connect(remote) + # Error checking occurs when connecting, because the SSL context + # isn't created before. + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE, + ciphers="^$:,;?*'dorothyx") + with self.assertRaisesRegexp(ssl.SSLError, + "No cipher can be selected"): + s.connect(remote) + finally: + server.stop() + server.join() + # repeat with AF_INET6? + + @unittest.skipIf(platform.python_implementation() != "CPython", + "Reference cycle test feasible under CPython only") + def test_refcycle(self): + # Issue #7943: an SSL object doesn't create reference cycles with + # itself. + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ss = ssl.wrap_socket(s) + wr = weakref.ref(ss) + del ss + self.assertEqual(wr(), None) + + def test_wrapped_unconnected(self): + # The _delegate_methods in socket.py are correctly delegated to by an + # unconnected SSLSocket, so they will raise a socket.error rather than + # something unexpected like TypeError. + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ss = ssl.wrap_socket(s) + self.assertRaises(socket.error, ss.recv, 1) + self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) + self.assertRaises(socket.error, ss.recvfrom, 1) + self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(socket.error, ss.send, b'x') + self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + + +class NetworkedTests(unittest.TestCase): + + def test_connect(self): + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_NONE) + s.connect(remote) + c = s.getpeercert() + if c: + self.fail("Peer cert %s shouldn't be here!") + s.close() + + # this should fail because we have no verification certs + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED) + try: + s.connect(remote) + except ssl.SSLError: + pass + finally: + s.close() + + # this should succeed because we specify the root cert + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE) + try: + s.connect(remote) + finally: + s.close() + + def test_connect_ex(self): + # Issue #11326: check connect_ex() implementation + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE) + try: + self.assertEqual(0, s.connect_ex(remote)) + self.assertTrue(s.getpeercert()) + finally: + s.close() + + def test_non_blocking_connect_ex(self): + # Issue #11326: non-blocking connect_ex() should allow handshake + # to proceed after the socket gets ready. + with test_support.transient_internet() as remote: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=ISSUER_CERTFILE, + do_handshake_on_connect=False) + try: + s.setblocking(False) + rc = s.connect_ex(remote) + # EWOULDBLOCK under Windows, EINPROGRESS elsewhere + self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK)) + # Non-blocking handshake + while True: + try: + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + while True: + to = s.get_timeout() + to = to.total_seconds() if to else 5.0 + sel = select.select([s], [], [], to) + if sel[0]: + break + s.handle_timeout() + else: + raise + # SSL established + self.assertTrue(s.getpeercert()) + finally: + s.close() + + @unittest.skipIf(os.name == "nt", + "Can't use a socket as a file under Windows") + def test_makefile_close(self): + # Issue #5238: creating a file-like object with makefile() shouldn't + # delay closing the underlying "real socket" (here tested with its + # file descriptor, hence skipping the test under Windows). + with test_support.transient_internet() as remote: + ss = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + ss.connect(remote) + fd = ss.fileno() + f = ss.makefile() + f.close() + # The fd is still open + os.read(fd, 0) + # Closing the SSL socket should close the fd too + ss.close() + gc.collect() + with self.assertRaises(OSError) as e: + os.read(fd, 0) + self.assertEqual(e.exception.errno, errno.EBADF) + + def test_non_blocking_handshake(self): + with test_support.transient_internet() as remote: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(remote) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + while True: + to = s.get_timeout() + if to: + sel = select.select([s], [], [], + to.total_seconds()) + if sel[0]: + break + s.handle_timeout() + continue + select.select([s], [], []) + break + else: + raise + s.close() + if test_support.verbose: + sys.stdout.write(("\nNeeded %d calls to do_handshake() " + + "to establish session.\n") % count) + + def test_get_server_certificate(self): + with test_support.transient_internet() as remote: + pem = ssl.get_server_certificate(remote, ssl.PROTOCOL_DTLSv1) + if not pem: + self.fail("No server certificate!") + + try: + pem = ssl.get_server_certificate(remote, + ssl.PROTOCOL_DTLSv1, + ca_certs=OTHER_CERTFILE) + except ssl.SSLError: + #should fail + pass + else: + self.fail("Got server certificate %s!" % pem) + + pem = ssl.get_server_certificate(remote, + ssl.PROTOCOL_DTLSv1, + ca_certs=ISSUER_CERTFILE) + if not pem: + self.fail("No server certificate!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate is\n%s\n" % pem) + +class ThreadedEchoServer(threading.Thread): + + class ConnectionHandler(threading.Thread): + + """A mildly complicated class, because we want it to work both + with and without the SSL wrapper around the socket connection, so + that we can test the STARTTLS functionality.""" + + def __init__(self, server, connsock): + self.server = server + self.running = False + self.sock = connsock + self.sock.setblocking(True) + self.sslconn = connsock + threading.Thread.__init__(self) + self.daemon = True + + def show_conn_details(self): + if self.server.certreqs == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" client cert is " + + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if test_support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + + str(cipher) + "\n") + + def wrap_conn(self): + try: + self.sslconn = ssl.wrap_socket( + self.sock, server_side=True, + certfile=self.server.certificate, + ssl_version=self.server.protocol, + ca_certs=self.server.cacerts, + cert_reqs=self.server.certreqs, + ciphers=self.server.ciphers) + except ssl.SSLError: + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more + # discriminating. + if self.server.chatty: + handle_error("\n server: bad connection attempt " + + "from " + + str(self.sock.getpeername()) + ":\n") + self.close() + self.running = False + self.server.stop() + return False + else: + return True + + def read(self): + if self.sslconn: + return self.sslconn.read() + else: + return self.sock.recv(1024) + + def write(self, bytes): + if self.sslconn: + return self.sslconn.write(bytes) + else: + return self.sock.send(bytes) + + def close(self): + if self.sslconn: + self.sslconn.close() + else: + self.sock._sock.close() + + def run(self): + self.running = True + # Complete the handshake + try: + self.sock.do_handshake() + except ssl.SSLError: + if self.server.chatty: + handle_error("\n server: failed to handshake with " + + str(self.sock.getpeername()) + ":\n") + self.close() + self.running = False + self.server.stop() + return + if self.server.starttls_server: + self.sock = self.sock.unwrap() + self.sslconn = None + else: + self.show_conn_details() + while self.running: + try: + msg = self.read() + if not msg: + # eof, so quit this handler + self.running = False + self.close() + elif msg.strip() == 'over': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: client closed " + + "connection\n") + self.close() + return + elif self.server.starttls_server and not self.sslconn \ + and msg.strip() == 'STARTTLS': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: read STARTTLS " + + "from client, sending OK...\n") + self.write("OK\n") + if not self.wrap_conn(): + return + elif self.server.starttls_server and self.sslconn and \ + msg.strip() == 'ENDTLS': + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from " + + "client, sending OK...\n") + self.write("OK\n") + self.sslconn.unwrap() + self.sslconn = None + if test_support.verbose and \ + self.server.connectionchatty: + sys.stdout.write(" server: connection is now " + + "unencrypted...\n") + else: + if test_support.verbose and \ + self.server.connectionchatty: + ctype = (self.sslconn and "encrypted") or \ + "unencrypted" + sys.stdout.write((" server: read %s (%s), " + + "sending back %s (%s)...\n") + % (repr(msg), ctype, + repr(msg.lower()), ctype)) + self.write(msg.lower()) + except ssl.SSLError: + if self.server.chatty: + handle_error("Test server failure:\n") + self.close() + self.running = False + # normally, we'd just stop here, but for the test + # harness, we want to stop the server + self.server.stop() + + def __init__(self, certificate, ssl_version=None, + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + ciphers=None): + + if ssl_version is None: + ssl_version = ssl.PROTOCOL_DTLSv1 + if certreqs is None: + certreqs = ssl.CERT_NONE + self.certificate = certificate + self.protocol = ssl_version + self.certreqs = certreqs + self.cacerts = cacerts + self.ciphers = ciphers + self.chatty = chatty + self.connectionchatty = connectionchatty + self.starttls_server = starttls_server + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.flag = None + self.sock = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.certificate, + cert_reqs=self.certreqs, + ca_certs=self.cacerts, + ssl_version=self.protocol, + do_handshake_on_connect=False, + ciphers=self.ciphers) + if test_support.verbose and self.chatty: + sys.stdout.write(' server: wrapped server ' + + 'socket as %s\n' % str(self.sock)) + self.sock.bind((HOST, 0)) + self.port = self.sock.getsockname()[1] + self.active = False + threading.Thread.__init__(self) + self.daemon = True + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen(5) + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + acc_ret = self.sock.accept() + if acc_ret: + newconn, connaddr = acc_ret + if test_support.verbose and self.chatty: + sys.stdout.write(' server: new connection from ' + + str(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn) + handler.start() + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() + + def stop(self): + self.active = False + +class AsyncoreEchoServer(threading.Thread): + + class EchoServer(asyncore.dispatcher): + + class ConnectionHandler(asyncore.dispatcher): + + def __init__(self, conn, timeout_tracker): + asyncore.dispatcher.__init__(self, conn) + self._timeout_tracker = timeout_tracker + self._ssl_accepting = True + # Complete the handshake + self.handle_read_event() + + def readable(self): + while self.socket.pending() > 0: + self.handle_read_event() + if self._timeout_tracker.has_key(self) and \ + datetime.datetime.now() >= self._timeout_tracker[self]: + self._timeout_tracker.pop(self) + try: + self.socket.handle_timeout() + except: + self.handle_close() + return False + return True + + def writable(self): + return False + + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError, err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SSL): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + except socket.error, err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self._ssl_accepting = False + + def handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + else: + data = self.recv(1024) + if data and data.strip() != 'over': + self.send(data.lower()) + delta = self.socket.get_timeout() + if delta: + self._timeout_tracker[self] = \ + datetime.datetime.now() + delta + + def handle_close(self): + if self._timeout_tracker.has_key(self): + self._timeout_tracker.pop(self) + self.close() + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % + self.socket) + + def handle_error(self): + raise + + def __init__(self, certfile, timeout_tracker): + asyncore.dispatcher.__init__(self) + self._timeout_tracker = timeout_tracker + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind((HOST, 0)) + self.sockname = sock.getsockname() + self.port = self.sockname[1] + self.set_socket(ssl.wrap_socket(sock, server_side=True, + certfile=certfile, + do_handshake_on_connect=False)) + self.listen(5) + + def writable(self): + return False + + def handle_accept(self): + acc_ret = self.accept() + if acc_ret: + sock_obj, addr = acc_ret + if test_support.verbose: + sys.stdout.write(" server: new connection from " + + "%s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self._timeout_tracker) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.flag = None + self.active = False + self.timeout_tracker = {} + self.server = self.EchoServer(certfile, self.timeout_tracker) + self.sockname = self.server.sockname + self.port = self.server.port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + now = datetime.datetime.now() + future_timeouts = filter(lambda x: x > now, + self.timeout_tracker.values()) + future_timeouts.append(now + datetime.timedelta(seconds=0.05)) + first_timeout = min(future_timeouts) - now + asyncore.loop(first_timeout.total_seconds(), count=1) + + def stop(self): + self.active = False + self.server.close() + +# Note that this HTTP-over-UDP server does not implement packet recovery and +# reordering, but it's good enough for testing on a loopback interface +class SocketServerHTTPSServer(threading.Thread): + + class HTTPSServerUDP(SocketServer.ThreadingTCPServer): + + def __init__(self, server_address, RequestHandlerClass, certfile): + SocketServer.ThreadingTCPServer.__init__(self, server_address, + RequestHandlerClass, False) + # account for dealing with a datagram socket + self.socket = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + self.server_bind() + self.server_activate() + + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + + def server_bind(self): + """Override server_bind to store the server name.""" + SocketServer.ThreadingTCPServer.server_bind(self) + host, port = self.socket.getsockname()[:2] + self.server_name = socket.getfqdn(host) + self.server_port = port + + def get_request(self): + # account for the fact that accept can return nothing, and + # according to BaseServer documentation, we should not block here + acc_ret = self.socket.accept() + if not acc_ret: + raise socket.error("No new connection") + return acc_ret + + def shutdown_request(self, request): + # Notify client of termination + request.unwrap() + + class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): + # need to override translate_path to get a known root, + # instead of using os.curdir, since the test could be + # run from anywhere + + server_version = "TestHTTPS-UDP/1.0" + + root = None + + def translate_path(self, path): + """Translate a /-separated PATH to the local filename syntax. + + Components that mean special things to the local file system + (e.g. drive or directory names) are ignored. (XXX They should + probably be diagnosed.) + + """ + # abandon query parameters + path = urlparse.urlparse(path)[2] + path = os.path.normpath(urllib.unquote(path)) + words = path.split('/') + words = filter(None, words) + path = self.root + for word in words: + drive, word = os.path.splitdrive(word) + head, word = os.path.split(word) + if word in self.root: continue + path = os.path.join(path, word) + return path + + def log_message(self, format, *args): + # we override this to suppress logging unless "verbose" + if test_support.verbose: + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, + self.server.server_port, + self.request.cipher(), + self.log_date_time_string(), + format%args)) + + + def __init__(self, certfile): + self.flag = None + self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0] + self.server = self.HTTPSServerUDP( + (HOST, 0), self.RootedHTTPRequestHandler, certfile) + self.port = self.server.server_port + threading.Thread.__init__(self) + self.daemon = True + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + if self.flag: + self.flag.set() + self.server.serve_forever(0.05) + + def stop(self): + self.server.shutdown() + + +def bad_cert_test(certfile): + """ + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with the given client certificate fails. + """ + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_REQUIRED, + cacerts=ISSUER_CERTFILE, chatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=certfile, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + except ssl.SSLError, x: + if test_support.verbose: + sys.stdout.write("\nSSLError is %s\n" % x[1]) + except socket.error, x: + if test_support.verbose: + sys.stdout.write("\nsocket.error is %s\n" % x[1]) + else: + raise AssertionError("Use of invalid cert should have failed!") + finally: + server.stop() + server.join() + +def server_params_test(certfile, protocol, certreqs, cacertsfile, + client_certfile, client_protocol=None, + indata="FOO\n", ciphers=None, chatty=True, + connectionchatty=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + server = ThreadedEchoServer(certfile, + certreqs=certreqs, + ssl_version=protocol, + cacerts=cacertsfile, + ciphers=ciphers, + chatty=chatty, + connectionchatty=connectionchatty) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + if client_protocol is None: + client_protocol = protocol + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=client_certfile, + ca_certs=cacertsfile, + ciphers=ciphers, + cert_reqs=certreqs, + ssl_version=client_protocol) + s.connect((HOST, server.port)) + for arg in [indata, bytearray(indata), memoryview(indata)]: + if connectionchatty: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(arg))) + s.write(arg) + outdata = s.read() + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise AssertionError( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + server.join() + +def try_protocol_combo(server_protocol, + client_protocol, + expect_success, + certsreqs=None): + if certsreqs is None: + certsreqs = ssl.CERT_NONE + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] + if test_support.verbose: + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" + sys.stdout.write(formatstr % + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) + try: + # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client + # will send an SSLv3 hello (rather than SSLv2) starting from + # OpenSSL 1.0.0 (see issue #8322). + server_params_test(CERTFILE, server_protocol, certsreqs, + ISSUER_CERTFILE, CERTFILE, client_protocol, + ciphers="ALL", chatty=False) + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: + if expect_success: + raise + except socket.error as e: + if expect_success or e.errno != errno.ECONNRESET: + raise + else: + if not expect_success: + raise AssertionError( + "Client protocol %s succeeded with server protocol %s!" + % (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol))) + + +class ThreadedTests(unittest.TestCase): + + def test_unreachable(self): + server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + server.bind((HOST, 0)) + port = server.getsockname()[1] + server.close() + s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) + self.assertRaisesRegexp(ssl.SSLError, + "The peer address is not reachable", + s.connect, (HOST, port)) + + def test_echo(self): + """Basic test of an SSL client connecting to a server""" + if test_support.verbose: + sys.stdout.write("\n") + server_params_test(CERTFILE, ssl.PROTOCOL_DTLSv1, ssl.CERT_NONE, + CERTFILE, CERTFILE, ssl.PROTOCOL_DTLSv1, + chatty=True, connectionchatty=True) + + def test_getpeercert(self): + if test_support.verbose: + sys.stdout.write("\n") + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1, + cacerts=CERTFILE, + chatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM), + certfile=CERTFILE, + ca_certs=ISSUER_CERTFILE, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if test_support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Ray Srv Inc'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in " + "certificate subject; should be 'Ray Srv Inc'.") + s.close() + finally: + server.stop() + server.join() + + def test_empty_cert(self): + """Connecting with an empty cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "nullcert.pem")) + def test_malformed_cert(self): + """Connecting with a badly formatted certificate (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "badcert.pem")) + def test_nonexisting_cert(self): + """Connecting with a non-existing cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "wrongcert.pem")) + def test_malformed_key(self): + """Connecting with a badly formatted key (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "badkey.pem")) + + def test_protocol_dtlsv1(self): + """Connecting to a DTLSv1 server with various client options""" + if test_support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_DTLSv1, ssl.PROTOCOL_DTLSv1, True, + ssl.CERT_REQUIRED) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", + "msg 5", "msg 6") + + server = ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_DTLSv1, + starttls_server=True, + chatty=True, + connectionchatty=True) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + wrapped = False + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + s = s.unwrap() + if test_support.verbose: + sys.stdout.write("\n") + for indata in msgs: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % repr(indata)) + if wrapped: + conn.write(indata) + outdata = conn.read() + else: + s.send(indata) + outdata = s.recv(1024) + if (indata == "STARTTLS" and + outdata.strip().lower().startswith("ok")): + # STARTTLS ok, switch to secure mode + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, starting TLS...\n" + % repr(outdata)) + conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_DTLSv1) + wrapped = True + elif (indata == "ENDTLS" and + outdata.strip().lower().startswith("ok")): + # ENDTLS ok, switch back to clear text + if test_support.verbose: + sys.stdout.write( + " client: read %s from server, ending TLS...\n" + % repr(outdata)) + s = conn.unwrap() + wrapped = False + else: + if test_support.verbose: + sys.stdout.write( + " client: read %s from server\n" % repr(outdata)) + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write("over\n") + else: + s.send("over\n") + s.close() + finally: + server.stop() + server.join() + + def test_socketserver(self): + """Using a SocketServer to create and manage SSL connections.""" + server = SocketServerHTTPSServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + if test_support.verbose: + sys.stdout.write('\n') + with open(CERTFILE, 'rb') as f: + d1 = f.read() + d2 = [] + # now fetch the same data from the HTTPS-UDP server + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + fl = "/" + os.path.split(CERTFILE)[1] + s.write("GET " + fl + " HTTP/1.1\r\n" + + "Host: " + HOST + "\r\n\r\n") + content = False + last_buf = "" + while True: + try: + buf = last_buf + s.read() + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_ZERO_RETURN: + s = s.unwrap() # complete shutdown protocol with server + break + raise + if test_support.verbose: + sys.stdout.write( + " client: read %d bytes from remote server '%s'\n" + % (len(buf), server)) + if content: + d2.append(buf) + continue + ind = buf.find("\r\n\r\n") + if ind < 0: + last_buf = buf[-3:] # find double-newline across buffers + continue + d2.append(buf[ind + 4:]) + content = True + last_buf = "" + s.close() + self.assertEqual(d1, ''.join(d2)) + finally: + server.stop() + server.join() + + def test_asyncore_server(self): + """Check the example asyncore integration.""" + indata = "TEST MESSAGE of mixed case\n" + + if test_support.verbose: + sys.stdout.write("\n") + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + s = ssl.wrap_socket(socket.socket(socket.AF_INET, + socket.SOCK_DGRAM)) + s.connect((HOST, server.port)) + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + self.fail( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + # wait for server thread to end + server.join() + + def test_recv_send(self): + """Test recv(), send() and friends.""" + if test_support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + s = ssl.wrap_socket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_DTLSv1) + s.connect((HOST, server.port)) + try: + # helper methods for standardising recv* method signatures + def _recv_into(): + b = bytearray("\0"*100) + count = s.recv_into(b) + return b[:count] + + def _recvfrom_into(): + b = bytearray("\0"*100) + count, addr = s.recvfrom_into(b) + return b[:count] + + # (name, method, whether to expect success, *args) + send_methods = [ + ('send', s.send, True, []), + ('sendto', s.sendto, False, ["some.address"]), + ('sendall', s.sendall, True, []), + ] + recv_methods = [ + ('recv', s.recv, True, []), + ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recv_into', _recv_into, True, []), + ('recvfrom_into', _recvfrom_into, False, []), + ] + data_prefix = u"PREFIX_" + + for meth_name, send_meth, expect_success, args in send_methods: + indata = data_prefix + meth_name + try: + send_meth(indata.encode('ASCII', 'strict'), *args) + outdata = s.read() + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While sending with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to send with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + + for meth_name, recv_meth, expect_success, args in recv_methods: + indata = data_prefix + meth_name + try: + s.send(indata.encode('ASCII', 'strict')) + outdata = recv_meth(*args) + outdata = outdata.decode('ASCII', 'strict') + if outdata != indata.lower(): + self.fail( + "While receiving with <<%s>> bad data " + "<<%r>> (%d) received; " + "expected <<%r>> (%d)\n" % ( + meth_name, outdata[:20], len(outdata), + indata[:20], len(indata) + ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to receive with method <<%s>>; " + "expected to succeed.\n" % (meth_name,) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<%s>> failed with unexpected " + "exception message: %s\n" % ( + meth_name, e + ) + ) + # consume data + s.read() + + s.write("over\n".encode("ASCII", "strict")) + s.close() + finally: + server.stop() + server.join() + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout + server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + server.bind((HOST, 0)) + port = server.getsockname()[1] + + try: + try: + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(0.2) + c.connect((HOST, port)) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + ssl.wrap_socket, c) + finally: + c.close() + try: + c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + c.settimeout(0.2) + c = ssl.wrap_socket(c) + # Will attempt handshake and time out + self.assertRaisesRegexp(ssl.SSLError, "timed out", + c.connect, (HOST, port)) + finally: + c.close() + finally: + server.close() + + +def test_main(verbose=True): + global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE + CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "keycert.pem") + ISSUER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "ca-cert.pem") + OTHER_CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, + "certs", "yahoo-cert.pem") + + for fl in CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE: + if not os.path.exists(fl): + raise Exception("Can't read certificate files!") + + TestSupport.verbose = verbose + do_patch() + unittest.main() + +if __name__ == "__main__": + verbose = True if len(sys.argv) > 1 and sys.argv[1] == "-v" else False + test_main(verbose)