From 80d05b7d82b334e05892beb01d0265650e2f9eaa Mon Sep 17 00:00:00 2001 From: Jason Youzwak Date: Wed, 26 Apr 2017 20:04:45 -0400 Subject: [PATCH] Add optional parameter to DtlsSocket: client_timeout (seconds) If client_timeout is specified, clients that have not communicated within the time frame will be dropped. --- dtls/wrapper.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/dtls/wrapper.py b/dtls/wrapper.py index 3a53d26..cdb267c 100644 --- a/dtls/wrapper.py +++ b/dtls/wrapper.py @@ -72,13 +72,26 @@ class DtlsSocket(object): class _ClientSession(object): - def __init__(self, host, port, handshake_done=False): + def __init__(self, host, port, handshake_done=False, timeout=None): self.host = host self.port = int(port) self.handshake_done = handshake_done + self.timeout = timeout + self.updateTimestamp() def getAddr(self): return self.host, self.port + + def updateTimestamp(self): + if self.timeout != None: + self.last_update = time.time() + + def expired(self): + if self.timeout == None: + return False + else: + return (time.time() - self.last_update) > self.timeout + def __init__(self, sock=None, @@ -95,7 +108,8 @@ class DtlsSocket(object): sigalgs=None, user_mtu=None, server_key_exchange_curve=None, - server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): + server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE, + client_timeout=None): if server_cert_options is None: server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE @@ -108,6 +122,7 @@ class DtlsSocket(object): self._user_mtu = user_mtu self._server_key_exchange_curve = server_key_exchange_curve self._server_cert_options = server_cert_options + self._client_timeout = client_timeout # Default socket creation _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -205,6 +220,7 @@ class DtlsSocket(object): else: buf = self._clientRead(conn, bufsize) if buf: + self._clients[conn].updateTimestamp() if conn in self._clients: return buf, self._clients[conn].getAddr() else: @@ -220,6 +236,10 @@ class DtlsSocket(object): ret = conn.handle_timeout() _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret)) + if self._clients[conn].expired() == True: + _logger.debug('Found expired session') + self._clientDrop(conn) + except Exception as e: raise e