Add optional parameter to DtlsSocket: client_timeout (seconds)

If client_timeout is specified, clients that have not communicated
within the time frame will be dropped.
incoming
Jason Youzwak 2017-04-26 20:04:45 -04:00
parent 0d6ee12121
commit 80d05b7d82
1 changed files with 22 additions and 2 deletions

View File

@ -72,13 +72,26 @@ class DtlsSocket(object):
class _ClientSession(object): class _ClientSession(object):
def __init__(self, host, port, handshake_done=False): def __init__(self, host, port, handshake_done=False, timeout=None):
self.host = host self.host = host
self.port = int(port) self.port = int(port)
self.handshake_done = handshake_done self.handshake_done = handshake_done
self.timeout = timeout
self.updateTimestamp()
def getAddr(self): def getAddr(self):
return self.host, self.port return self.host, self.port
def updateTimestamp(self):
if self.timeout != None:
self.last_update = time.time()
def expired(self):
if self.timeout == None:
return False
else:
return (time.time() - self.last_update) > self.timeout
def __init__(self, def __init__(self,
sock=None, sock=None,
@ -95,7 +108,8 @@ class DtlsSocket(object):
sigalgs=None, sigalgs=None,
user_mtu=None, user_mtu=None,
server_key_exchange_curve=None, server_key_exchange_curve=None,
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE): server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE,
client_timeout=None):
if server_cert_options is None: if server_cert_options is None:
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
@ -108,6 +122,7 @@ class DtlsSocket(object):
self._user_mtu = user_mtu self._user_mtu = user_mtu
self._server_key_exchange_curve = server_key_exchange_curve self._server_key_exchange_curve = server_key_exchange_curve
self._server_cert_options = server_cert_options self._server_cert_options = server_cert_options
self._client_timeout = client_timeout
# Default socket creation # Default socket creation
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -205,6 +220,7 @@ class DtlsSocket(object):
else: else:
buf = self._clientRead(conn, bufsize) buf = self._clientRead(conn, bufsize)
if buf: if buf:
self._clients[conn].updateTimestamp()
if conn in self._clients: if conn in self._clients:
return buf, self._clients[conn].getAddr() return buf, self._clients[conn].getAddr()
else: else:
@ -220,6 +236,10 @@ class DtlsSocket(object):
ret = conn.handle_timeout() ret = conn.handle_timeout()
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret)) _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
if self._clients[conn].expired() == True:
_logger.debug('Found expired session')
self._clientDrop(conn)
except Exception as e: except Exception as e:
raise e raise e