Add optional parameter to DtlsSocket: client_timeout (seconds)
If client_timeout is specified, clients that have not communicated within the time frame will be dropped.
This commit is contained in:
		
							parent
							
								
									0d6ee12121
								
							
						
					
					
						commit
						80d05b7d82
					
				@ -72,14 +72,27 @@ class DtlsSocket(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    class _ClientSession(object):
 | 
					    class _ClientSession(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def __init__(self, host, port, handshake_done=False):
 | 
					        def __init__(self, host, port, handshake_done=False, timeout=None):
 | 
				
			||||||
            self.host = host
 | 
					            self.host = host
 | 
				
			||||||
            self.port = int(port)
 | 
					            self.port = int(port)
 | 
				
			||||||
            self.handshake_done = handshake_done
 | 
					            self.handshake_done = handshake_done
 | 
				
			||||||
 | 
					            self.timeout = timeout
 | 
				
			||||||
 | 
					            self.updateTimestamp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def getAddr(self):
 | 
					        def getAddr(self):
 | 
				
			||||||
            return self.host, self.port
 | 
					            return self.host, self.port
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        def updateTimestamp(self):
 | 
				
			||||||
 | 
					            if self.timeout != None:
 | 
				
			||||||
 | 
					                self.last_update = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def expired(self):
 | 
				
			||||||
 | 
					            if self.timeout == None:
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return (time.time() - self.last_update) > self.timeout
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 sock=None,
 | 
					                 sock=None,
 | 
				
			||||||
                 keyfile=None,
 | 
					                 keyfile=None,
 | 
				
			||||||
@ -95,7 +108,8 @@ class DtlsSocket(object):
 | 
				
			|||||||
                 sigalgs=None,
 | 
					                 sigalgs=None,
 | 
				
			||||||
                 user_mtu=None,
 | 
					                 user_mtu=None,
 | 
				
			||||||
                 server_key_exchange_curve=None,
 | 
					                 server_key_exchange_curve=None,
 | 
				
			||||||
                 server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE):
 | 
					                 server_cert_options=ssl.SSL_BUILD_CHAIN_FLAG_NONE,
 | 
				
			||||||
 | 
					                 client_timeout=None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if server_cert_options is None:
 | 
					        if server_cert_options is None:
 | 
				
			||||||
            server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
 | 
					            server_cert_options = ssl.SSL_BUILD_CHAIN_FLAG_NONE
 | 
				
			||||||
@ -108,6 +122,7 @@ class DtlsSocket(object):
 | 
				
			|||||||
        self._user_mtu = user_mtu
 | 
					        self._user_mtu = user_mtu
 | 
				
			||||||
        self._server_key_exchange_curve = server_key_exchange_curve
 | 
					        self._server_key_exchange_curve = server_key_exchange_curve
 | 
				
			||||||
        self._server_cert_options = server_cert_options
 | 
					        self._server_cert_options = server_cert_options
 | 
				
			||||||
 | 
					        self._client_timeout = client_timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Default socket creation
 | 
					        # Default socket creation
 | 
				
			||||||
        _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
					        _sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
				
			||||||
@ -205,6 +220,7 @@ class DtlsSocket(object):
 | 
				
			|||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        buf = self._clientRead(conn, bufsize)
 | 
					                        buf = self._clientRead(conn, bufsize)
 | 
				
			||||||
                        if buf:
 | 
					                        if buf:
 | 
				
			||||||
 | 
					                            self._clients[conn].updateTimestamp()
 | 
				
			||||||
                            if conn in self._clients:
 | 
					                            if conn in self._clients:
 | 
				
			||||||
                                return buf, self._clients[conn].getAddr()
 | 
					                                return buf, self._clients[conn].getAddr()
 | 
				
			||||||
                            else:
 | 
					                            else:
 | 
				
			||||||
@ -220,6 +236,10 @@ class DtlsSocket(object):
 | 
				
			|||||||
                    ret = conn.handle_timeout()
 | 
					                    ret = conn.handle_timeout()
 | 
				
			||||||
                    _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
 | 
					                    _logger.debug('Retransmission triggered for %s: %d' % (str(self._clients[conn].getAddr()), ret))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self._clients[conn].expired() == True:
 | 
				
			||||||
 | 
					                    _logger.debug('Found expired session')
 | 
				
			||||||
 | 
					                    self._clientDrop(conn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise e
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user