diff --git a/dtls/wrapper.py b/dtls/wrapper.py index 3a53d26..cdb267c 100644 --- a/dtls/wrapper.py +++ b/dtls/wrapper.py @@ -72,13 +72,26 @@ class DtlsSocket(object): class _ClientSession(object): - def __init__(self, host, port, handshake_done=False): + def __init__(self, host, port, handshake_done=False, timeout=None): self.host = host self.port = int(port) self.handshake_done = handshake_done + self.timeout = timeout + self.updateTimestamp() def getAddr(self): return self.host, self.port + + def updateTimestamp(self): + if self.timeout != None: + self.last_update = time.time() + + def expired(self): + if self.timeout == None: + return False + else: + return (time.time() - self.last_update) > self.timeout + def __init__(self, sock=None, @@ -95,7 +108,8 @@ class DtlsSocket(object): sigalgs=None, user_mtu=None, server_key_exchange_curve=None, - server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): + server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE, + client_timeout=None): if server_cert_options is None: server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE @@ -108,6 +122,7 @@ class DtlsSocket(object): self._user_mtu = user_mtu self._server_key_exchange_curve = server_key_exchange_curve self._server_cert_options = server_cert_options + self._client_timeout = client_timeout # Default socket creation _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -205,6 +220,7 @@ class DtlsSocket(object): else: buf = self._clientRead(conn, bufsize) if buf: + self._clients[conn].updateTimestamp() if conn in self._clients: return buf, self._clients[conn].getAddr() else: @@ -220,6 +236,10 @@ class DtlsSocket(object): ret = conn.handle_timeout() _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret)) + if self._clients[conn].expired() == True: + _logger.debug('Found expired session') + self._clientDrop(conn) + except Exception as e: raise e