From 77a50c73820996115044dfedd287ca8648323317 Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Tue, 27 Nov 2012 21:33:09 -0800 Subject: [PATCH] Implement and turn on the osnet demux This change introduces a demux that uses the kernel's network stack for UDP datagram-to-socket assignment based on packet source address (as opposed to the forwarding strategy of the routing demux). The osnet demux is used by default on non-Windows platforms. When possible, use of the osnet demux is preferred over the routing demux, since it can be expected to perform better. The unit test suite has been extended to run all tests first with the demux selected by default for the current platform, and then with the routing demux, if the latter differs from the former. Tests were already being run twice, first with IPv4 and then with IPv6, and thus we now run each test four times on Linux, twice on Windows. All unit tests pass with both demux types. --- dtls/__init__.py | 1 + dtls/demux/__init__.py | 26 ++++++++++++++-- dtls/demux/osnet.py | 62 +++++++++++++++++++++++++++++++------ dtls/sslconnection.py | 69 ++++++++++++++++++++++++++++-------------- dtls/test/unit.py | 20 +++++++----- 5 files changed, 134 insertions(+), 44 deletions(-) diff --git a/dtls/__init__.py b/dtls/__init__.py index 93f6688..6d40c37 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -17,3 +17,4 @@ wrap_socket's parameters and their semantics have been maintained. from patch import do_patch from sslconnection import SSLConnection +from demux import force_routing_demux, reset_default_demux diff --git a/dtls/demux/__init__.py b/dtls/demux/__init__.py index ee53ba6..7875fd3 100644 --- a/dtls/demux/__init__.py +++ b/dtls/demux/__init__.py @@ -20,8 +20,28 @@ import sys if sys.platform.startswith('win') or sys.platform.startswith('cygwin'): from router import UDPDemux + _routing = True else: - #from osnet import UDPDemux - from router import UDPDemux + from osnet import UDPDemux + _routing = False +_default_demux = None -__all__ = ["UDPDemux"] +def force_routing_demux(): + global _routing + if _routing: + return False # no change - already loaded + global UDPDemux, _default_demux + import router + _default_demux = UDPDemux + UDPDemux = router.UDPDemux + _routing = True + return True # new router loaded and switched + +def reset_default_demux(): + global UDPDemux, _routing, _default_demux + if _default_demux: + UDPDemux = _default_demux + _default_demux = None + _routing = not _routing + +__all__ = ["UDPDemux", "force_routing_demux", "reset_default_demux"] diff --git a/dtls/demux/osnet.py b/dtls/demux/osnet.py index cb7e607..f1f8fe7 100644 --- a/dtls/demux/osnet.py +++ b/dtls/demux/osnet.py @@ -17,9 +17,16 @@ Classes: Exceptions: + InvalidSocketError -- exception raised for improper socket objects KeyError -- raised for unknown peer addresses """ +import socket +from logging import getLogger +from ..err import InvalidSocketError + +_logger = getLogger(__name__) + class UDPDemux(object): """OS network stack configuring demux @@ -31,10 +38,34 @@ class UDPDemux(object): Methods: get_connection -- create a new connection or retrieve an existing one - remove_connection -- remove an existing connection service -- this method does nothing for this type of demux """ + def __init__(self, datagram_socket): + """Constructor + + Arguments: + datagram_socket -- the root socket; this must be a bound, unconnected + datagram socket + """ + + if datagram_socket.type != socket.SOCK_DGRAM: + raise InvalidSocketError("datagram_socket is not of " + + "type SOCK_DGRAM") + try: + datagram_socket.getsockname() + except: + raise InvalidSocketError("datagram_socket is unbound") + try: + datagram_socket.getpeername() + except: + pass + else: + raise InvalidSocketError("datagram_socket is connected") + + datagram_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._datagram_socket = datagram_socket + def get_connection(self, address): """Create or retrieve a muxed connection @@ -47,16 +78,27 @@ class UDPDemux(object): in case address was None """ - def remove_connection(self, address): - """Remove a muxed connection + if not address: + return self._datagram_socket - Arguments: - address -- an address for which a muxed connection was previously - retrieved through get_connection, which has not yet - been removed + # Create a new datagram socket bound to the same interface and port as + # the root socket, but connected to the given peer + conn = socket.socket(self._datagram_socket.family, + self._datagram_socket.type, + self._datagram_socket.proto) + conn.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + conn.bind(self._datagram_socket.getsockname()) + conn.connect(address) + _logger.debug("Created new connection for address: %s", address) + return conn - Return: - the socket object whose connection has been removed + @staticmethod + def service(): + """Service the root socket + + This type of demux performs no servicing work on the root socket, + and instead advises the caller to proceed to listening on the root + socket. """ - return self.connections.pop(address) + return True diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index d2323b5..7fbcb9d 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -116,14 +116,21 @@ class SSLConnection(object): def _init_server(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") - if peer_address: - raise InvalidSocketError("server-side socket must be unconnected") - from demux import UDPDemux - self._udp_demux = UDPDemux(self._sock) - self._rsock = self._udp_demux.get_connection(None) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) - self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + if peer_address: + # Connect directly to this client peer, bypassing the demux + rsock = self._sock + BIO_dgram_set_connected(self._wbio.value, peer_address) + else: + from demux import UDPDemux + self._udp_demux = UDPDemux(self._sock) + rsock = self._udp_demux.get_connection(None) + if rsock is self._sock: + self._rbio = self._wbio + else: + self._rsock = rsock + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._cert_reqs == CERT_NONE: @@ -133,16 +140,20 @@ class SSLConnection(object): else: verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE | \ SSL_VERIFY_FAIL_IF_NO_PEER_CERT - self._listening = False - self._listening_peer_address = None - self._pending_peer_address = None self._config_ssl_ctx(verify_mode) - self._cb_keepalive = SSL_CTX_set_cookie_cb( - self._ctx.value, - _CallbackProxy(self._generate_cookie_cb), - _CallbackProxy(self._verify_cookie_cb)) + if not peer_address: + # Configure UDP listening socket + self._listening = False + self._listening_peer_address = None + self._pending_peer_address = None + self._cb_keepalive = SSL_CTX_set_cookie_cb( + self._ctx.value, + _CallbackProxy(self._generate_cookie_cb), + _CallbackProxy(self._verify_cookie_cb)) self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(self._ssl.value) + if peer_address and self._do_handshake_on_connect: + return lambda: self.do_handshake() def _init_client(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: @@ -179,19 +190,27 @@ class SSLConnection(object): def _copy_server(self): source = self._sock - self._sock = source._sock self._udp_demux = source._udp_demux - self._rsock = self._udp_demux.get_connection( - source._pending_peer_address) - self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) - self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) - BIO_dgram_set_peer(self._wbio.value, source._pending_peer_address) + rsock = self._udp_demux.get_connection(source._pending_peer_address) self._ctx = source._ctx self._ssl = source._ssl new_source_wbio = _BIO(BIO_new_dgram(source._sock.fileno(), BIO_NOCLOSE)) - new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), - BIO_NOCLOSE)) + if hasattr(source, "_rsock"): + self._sock = source._sock + self._rsock = rsock + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(rsock.fileno(), BIO_NOCLOSE)) + new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), + BIO_NOCLOSE)) + BIO_dgram_set_peer(self._wbio.value, source._pending_peer_address) + else: + self._sock = rsock + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = self._wbio + new_source_rbio = new_source_wbio + BIO_dgram_set_connected(self._wbio.value, + source._pending_peer_address) source._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(source._ssl.value) source._rbio = new_source_rbio @@ -349,6 +368,9 @@ class SSLConnection(object): encountered, None if a datagram for a known peer was forwarded """ + if not hasattr(self, "_listening"): + raise InvalidSocketError("listen called on non-listening socket") + self._pending_peer_address = None try: peer_address = self._udp_demux.service() @@ -521,12 +543,13 @@ class SSLConnection(object): lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) else: raise - if hasattr(self, "_udp_demux"): + if hasattr(self, "_rsock"): # Return wrapped connected server socket (non-listening) return _UnwrappedSocket(self._sock, self._rsock, self._udp_demux, self._ctx, BIO_dgram_get_peer(self._wbio.value)) - # Return unwrapped client-side socket + # Return unwrapped client-side socket or unwrapped server-side socket + # for single-socket servers return self._sock def getpeercert(self, binary_form=False): diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 8cedc39..0f4a449 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -22,7 +22,7 @@ from SimpleHTTPServer import SimpleHTTPRequestHandler from collections import OrderedDict import ssl -from dtls import do_patch +from dtls import do_patch, force_routing_demux, reset_default_demux HOST = "localhost" CONNECTION_TIMEOUT = datetime.timedelta(seconds=30) @@ -1347,14 +1347,18 @@ def test_main(verbose=True): raise Exception("Can't read certificate files!") TestSupport.verbose = verbose + reset_default_demux() do_patch() - AF_INET4_6 = socket.AF_INET - res = unittest.main(exit=False).result.wasSuccessful() - if not res: - print "IPv4 test suite failed; not proceeding to IPv6" - sys.exit(not res) - AF_INET4_6 = socket.AF_INET6 - unittest.main() + for demux in "platform-native", "routing": + for AF_INET4_6 in socket.AF_INET, socket.AF_INET6: + print "Suite run: demux: %s, protocol: %d" % (demux, AF_INET4_6) + res = unittest.main(exit=False).result.wasSuccessful() + if not res: + print "Suite run failed: demux: %s, protocol: %d" % ( + demux, AF_INET4_6) + sys.exit(True) + if not force_routing_demux(): + break if __name__ == "__main__": verbose = True if len(sys.argv) > 1 and sys.argv[1] == "-v" else False