Add optional parameter to DtlsSocket: client_timeout (seconds)
If client_timeout is specified, clients that have not communicated within the time frame will be dropped.incoming
parent
0d6ee12121
commit
80d05b7d82
|
@ -72,14 +72,27 @@ class DtlsSocket(object):
|
||||||
|
|
||||||
class _ClientSession(object):
|
class _ClientSession(object):
|
||||||
|
|
||||||
def __init__(self, host, port, handshake_done=False):
|
def __init__(self, host, port, handshake_done=False, timeout=None):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = int(port)
|
self.port = int(port)
|
||||||
self.handshake_done = handshake_done
|
self.handshake_done = handshake_done
|
||||||
|
self.timeout = timeout
|
||||||
|
self.updateTimestamp()
|
||||||
|
|
||||||
def getAddr(self):
|
def getAddr(self):
|
||||||
return self.host, self.port
|
return self.host, self.port
|
||||||
|
|
||||||
|
def updateTimestamp(self):
|
||||||
|
if self.timeout != None:
|
||||||
|
self.last_update = time.time()
|
||||||
|
|
||||||
|
def expired(self):
|
||||||
|
if self.timeout == None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return (time.time() - self.last_update) > self.timeout
|
||||||
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
sock=None,
|
sock=None,
|
||||||
keyfile=None,
|
keyfile=None,
|
||||||
|
@ -95,7 +108,8 @@ class DtlsSocket(object):
|
||||||
sigalgs=None,
|
sigalgs=None,
|
||||||
user_mtu=None,
|
user_mtu=None,
|
||||||
server_key_exchange_curve=None,
|
server_key_exchange_curve=None,
|
||||||
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
|
server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE,
|
||||||
|
client_timeout=None):
|
||||||
|
|
||||||
if server_cert_options is None:
|
if server_cert_options is None:
|
||||||
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
|
server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
|
||||||
|
@ -108,6 +122,7 @@ class DtlsSocket(object):
|
||||||
self._user_mtu = user_mtu
|
self._user_mtu = user_mtu
|
||||||
self._server_key_exchange_curve = server_key_exchange_curve
|
self._server_key_exchange_curve = server_key_exchange_curve
|
||||||
self._server_cert_options = server_cert_options
|
self._server_cert_options = server_cert_options
|
||||||
|
self._client_timeout = client_timeout
|
||||||
|
|
||||||
# Default socket creation
|
# Default socket creation
|
||||||
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
@ -205,6 +220,7 @@ class DtlsSocket(object):
|
||||||
else:
|
else:
|
||||||
buf = self._clientRead(conn, bufsize)
|
buf = self._clientRead(conn, bufsize)
|
||||||
if buf:
|
if buf:
|
||||||
|
self._clients[conn].updateTimestamp()
|
||||||
if conn in self._clients:
|
if conn in self._clients:
|
||||||
return buf, self._clients[conn].getAddr()
|
return buf, self._clients[conn].getAddr()
|
||||||
else:
|
else:
|
||||||
|
@ -220,6 +236,10 @@ class DtlsSocket(object):
|
||||||
ret = conn.handle_timeout()
|
ret = conn.handle_timeout()
|
||||||
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
|
_logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
|
||||||
|
|
||||||
|
if self._clients[conn].expired() == True:
|
||||||
|
_logger.debug('Found expired session')
|
||||||
|
self._clientDrop(conn)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue