diff --git a/dtls/__init__.py b/dtls/__init__.py index 93f6688..6d40c37 100644 --- a/dtls/__init__.py +++ b/dtls/__init__.py @@ -17,3 +17,4 @@ wrap_socket's parameters and their semantics have been maintained. from patch import do_patch from sslconnection import SSLConnection +from demux import force_routing_demux, reset_default_demux diff --git a/dtls/demux/__init__.py b/dtls/demux/__init__.py index ee53ba6..7875fd3 100644 --- a/dtls/demux/__init__.py +++ b/dtls/demux/__init__.py @@ -20,8 +20,28 @@ import sys if sys.platform.startswith('win') or sys.platform.startswith('cygwin'): from router import UDPDemux + _routing = True else: - #from osnet import UDPDemux - from router import UDPDemux + from osnet import UDPDemux + _routing = False +_default_demux = None -__all__ = ["UDPDemux"] +def force_routing_demux(): + global _routing + if _routing: + return False # no change - already loaded + global UDPDemux, _default_demux + import router + _default_demux = UDPDemux + UDPDemux = router.UDPDemux + _routing = True + return True # new router loaded and switched + +def reset_default_demux(): + global UDPDemux, _routing, _default_demux + if _default_demux: + UDPDemux = _default_demux + _default_demux = None + _routing = not _routing + +__all__ = ["UDPDemux", "force_routing_demux", "reset_default_demux"] diff --git a/dtls/demux/osnet.py b/dtls/demux/osnet.py index cb7e607..f1f8fe7 100644 --- a/dtls/demux/osnet.py +++ b/dtls/demux/osnet.py @@ -17,9 +17,16 @@ Classes: Exceptions: + InvalidSocketError -- exception raised for improper socket objects KeyError -- raised for unknown peer addresses """ +import socket +from logging import getLogger +from ..err import InvalidSocketError + +_logger = getLogger(__name__) + class UDPDemux(object): """OS network stack configuring demux @@ -31,10 +38,34 @@ class UDPDemux(object): Methods: get_connection -- create a new connection or retrieve an existing one - remove_connection -- remove an existing connection service -- this method does nothing for this type of demux """ + def __init__(self, datagram_socket): + """Constructor + + Arguments: + datagram_socket -- the root socket; this must be a bound, unconnected + datagram socket + """ + + if datagram_socket.type != socket.SOCK_DGRAM: + raise InvalidSocketError("datagram_socket is not of " + + "type SOCK_DGRAM") + try: + datagram_socket.getsockname() + except: + raise InvalidSocketError("datagram_socket is unbound") + try: + datagram_socket.getpeername() + except: + pass + else: + raise InvalidSocketError("datagram_socket is connected") + + datagram_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._datagram_socket = datagram_socket + def get_connection(self, address): """Create or retrieve a muxed connection @@ -47,16 +78,27 @@ class UDPDemux(object): in case address was None """ - def remove_connection(self, address): - """Remove a muxed connection + if not address: + return self._datagram_socket - Arguments: - address -- an address for which a muxed connection was previously - retrieved through get_connection, which has not yet - been removed + # Create a new datagram socket bound to the same interface and port as + # the root socket, but connected to the given peer + conn = socket.socket(self._datagram_socket.family, + self._datagram_socket.type, + self._datagram_socket.proto) + conn.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + conn.bind(self._datagram_socket.getsockname()) + conn.connect(address) + _logger.debug("Created new connection for address: %s", address) + return conn - Return: - the socket object whose connection has been removed + @staticmethod + def service(): + """Service the root socket + + This type of demux performs no servicing work on the root socket, + and instead advises the caller to proceed to listening on the root + socket. """ - return self.connections.pop(address) + return True diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index d2323b5..7fbcb9d 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -116,14 +116,21 @@ class SSLConnection(object): def _init_server(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: raise InvalidSocketError("sock must be of type SOCK_DGRAM") - if peer_address: - raise InvalidSocketError("server-side socket must be unconnected") - from demux import UDPDemux - self._udp_demux = UDPDemux(self._sock) - self._rsock = self._udp_demux.get_connection(None) self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) - self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) + if peer_address: + # Connect directly to this client peer, bypassing the demux + rsock = self._sock + BIO_dgram_set_connected(self._wbio.value, peer_address) + else: + from demux import UDPDemux + self._udp_demux = UDPDemux(self._sock) + rsock = self._udp_demux.get_connection(None) + if rsock is self._sock: + self._rbio = self._wbio + else: + self._rsock = rsock + self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) self._ctx = _CTX(SSL_CTX_new(DTLSv1_server_method())) SSL_CTX_set_session_cache_mode(self._ctx.value, SSL_SESS_CACHE_OFF) if self._cert_reqs == CERT_NONE: @@ -133,16 +140,20 @@ class SSLConnection(object): else: verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE | \ SSL_VERIFY_FAIL_IF_NO_PEER_CERT - self._listening = False - self._listening_peer_address = None - self._pending_peer_address = None self._config_ssl_ctx(verify_mode) - self._cb_keepalive = SSL_CTX_set_cookie_cb( - self._ctx.value, - _CallbackProxy(self._generate_cookie_cb), - _CallbackProxy(self._verify_cookie_cb)) + if not peer_address: + # Configure UDP listening socket + self._listening = False + self._listening_peer_address = None + self._pending_peer_address = None + self._cb_keepalive = SSL_CTX_set_cookie_cb( + self._ctx.value, + _CallbackProxy(self._generate_cookie_cb), + _CallbackProxy(self._verify_cookie_cb)) self._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(self._ssl.value) + if peer_address and self._do_handshake_on_connect: + return lambda: self.do_handshake() def _init_client(self, peer_address): if self._sock.type != socket.SOCK_DGRAM: @@ -179,19 +190,27 @@ class SSLConnection(object): def _copy_server(self): source = self._sock - self._sock = source._sock self._udp_demux = source._udp_demux - self._rsock = self._udp_demux.get_connection( - source._pending_peer_address) - self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) - self._rbio = _BIO(BIO_new_dgram(self._rsock.fileno(), BIO_NOCLOSE)) - BIO_dgram_set_peer(self._wbio.value, source._pending_peer_address) + rsock = self._udp_demux.get_connection(source._pending_peer_address) self._ctx = source._ctx self._ssl = source._ssl new_source_wbio = _BIO(BIO_new_dgram(source._sock.fileno(), BIO_NOCLOSE)) - new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), - BIO_NOCLOSE)) + if hasattr(source, "_rsock"): + self._sock = source._sock + self._rsock = rsock + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = _BIO(BIO_new_dgram(rsock.fileno(), BIO_NOCLOSE)) + new_source_rbio = _BIO(BIO_new_dgram(source._rsock.fileno(), + BIO_NOCLOSE)) + BIO_dgram_set_peer(self._wbio.value, source._pending_peer_address) + else: + self._sock = rsock + self._wbio = _BIO(BIO_new_dgram(self._sock.fileno(), BIO_NOCLOSE)) + self._rbio = self._wbio + new_source_rbio = new_source_wbio + BIO_dgram_set_connected(self._wbio.value, + source._pending_peer_address) source._ssl = _SSL(SSL_new(self._ctx.value)) SSL_set_accept_state(source._ssl.value) source._rbio = new_source_rbio @@ -349,6 +368,9 @@ class SSLConnection(object): encountered, None if a datagram for a known peer was forwarded """ + if not hasattr(self, "_listening"): + raise InvalidSocketError("listen called on non-listening socket") + self._pending_peer_address = None try: peer_address = self._udp_demux.service() @@ -521,12 +543,13 @@ class SSLConnection(object): lambda: SSL_shutdown(self._ssl.value), ERR_READ_TIMEOUT) else: raise - if hasattr(self, "_udp_demux"): + if hasattr(self, "_rsock"): # Return wrapped connected server socket (non-listening) return _UnwrappedSocket(self._sock, self._rsock, self._udp_demux, self._ctx, BIO_dgram_get_peer(self._wbio.value)) - # Return unwrapped client-side socket + # Return unwrapped client-side socket or unwrapped server-side socket + # for single-socket servers return self._sock def getpeercert(self, binary_form=False): diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 8cedc39..0f4a449 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -22,7 +22,7 @@ from SimpleHTTPServer import SimpleHTTPRequestHandler from collections import OrderedDict import ssl -from dtls import do_patch +from dtls import do_patch, force_routing_demux, reset_default_demux HOST = "localhost" CONNECTION_TIMEOUT = datetime.timedelta(seconds=30) @@ -1347,14 +1347,18 @@ def test_main(verbose=True): raise Exception("Can't read certificate files!") TestSupport.verbose = verbose + reset_default_demux() do_patch() - AF_INET4_6 = socket.AF_INET - res = unittest.main(exit=False).result.wasSuccessful() - if not res: - print "IPv4 test suite failed; not proceeding to IPv6" - sys.exit(not res) - AF_INET4_6 = socket.AF_INET6 - unittest.main() + for demux in "platform-native", "routing": + for AF_INET4_6 in socket.AF_INET, socket.AF_INET6: + print "Suite run: demux: %s, protocol: %d" % (demux, AF_INET4_6) + res = unittest.main(exit=False).result.wasSuccessful() + if not res: + print "Suite run failed: demux: %s, protocol: %d" % ( + demux, AF_INET4_6) + sys.exit(True) + if not force_routing_demux(): + break if __name__ == "__main__": verbose = True if len(sys.argv) > 1 and sys.argv[1] == "-v" else False