Add optional parameter to DtlsSocket: client_timeout (seconds)
If client_timeout is specified, clients that have not communicated within the time frame will be dropped.incoming
parent
0d6ee12121
commit
80d05b7d82
|
@ -72,13 +72,26 @@ class DtlsSocket(object):
|
|||
|
||||
class _ClientSession(object):
|
||||
|
||||
def __init__(self, host, port, handshake_done=False):
|
||||
def __init__(self, host, port, handshake_done=False, timeout=None):
|
||||
self.host = host
|
||||
self.port = int(port)
|
||||
self.handshake_done = handshake_done
|
||||
self.timeout = timeout
|
||||
self.updateTimestamp()
|
||||
|
||||
def getAddr(self):
|
||||
return self.host, self.port
|
||||
|
||||
def updateTimestamp(self):
|
||||
if self.timeout != None:
|
||||
self.last_update = time.time()
|
||||
|
||||
def expired(self):
|
||||
if self.timeout == None:
|
||||
return False
|
||||
else:
|
||||
return (time.time() - self.last_update) > self.timeout
|
||||
|
||||
|
||||
def __init__(self,
|
||||
sock=None,
|
||||
|
@ -95,7 +108,8 @@ class DtlsSocket(object):
|
|||
sigalgs=None,
|
||||
user_mtu=None,
|
||||
server_key_exchange_curve=None,
|
||||
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
|
||||
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE,
|
||||
client_timeout=None):
|
||||
|
||||
if server_cert_options is None:
|
||||
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
|
||||
|
@ -108,6 +122,7 @@ class DtlsSocket(object):
|
|||
self._user_mtu = user_mtu
|
||||
self._server_key_exchange_curve = server_key_exchange_curve
|
||||
self._server_cert_options = server_cert_options
|
||||
self._client_timeout = client_timeout
|
||||
|
||||
# Default socket creation
|
||||
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
@ -205,6 +220,7 @@ class DtlsSocket(object):
|
|||
else:
|
||||
buf = self._clientRead(conn, bufsize)
|
||||
if buf:
|
||||
self._clients[conn].updateTimestamp()
|
||||
if conn in self._clients:
|
||||
return buf, self._clients[conn].getAddr()
|
||||
else:
|
||||
|
@ -220,6 +236,10 @@ class DtlsSocket(object):
|
|||
ret = conn.handle_timeout()
|
||||
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
|
||||
|
||||
if self._clients[conn].expired() == True:
|
||||
_logger.debug('Found expired session')
|
||||
self._clientDrop(conn)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
|
Loading…
Reference in New Issue