diff --git a/dtls/openssl.py b/dtls/openssl.py index 0c63d1f..bea985c 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -122,11 +122,15 @@ class FuncParam(object): return value._as_parameter def __init__(self, value): - self._as_parameter = value + self._as_parameter = c_void_p(value) def __nonzero__(self): return bool(self._as_parameter) + @property + def raw(self): + return self._as_parameter.value + class DTLSv1Method(FuncParam): def __init__(self, value): @@ -268,15 +272,15 @@ class sockaddr_storage(Structure): class sockaddr_in(Structure): _fields_ = [("sin_family", c_short), ("sin_port", c_ushort), - ("sin_addr", c_ulong * 1), + ("sin_addr", c_uint * 1), ("sin_zero", c_char * 8)] class sockaddr_in6(Structure): _fields_ = [("sin6_family", c_short), ("sin6_port", c_ushort), - ("sin6_flowinfo", c_ulong), - ("sin6_addr", c_ulong * 4), - ("sin6_scope_id", c_ulong)] + ("sin6_flowinfo", c_uint), + ("sin6_addr", c_uint * 4), + ("sin6_scope_id", c_uint)] class sockaddr_u(Union): _fields_ = [("ss", sockaddr_storage), @@ -302,27 +306,27 @@ if not py_inet_pton: def inet_ntop(address_family, packed_ip): if py_inet_ntop: return py_inet_ntop(address_family, - array.array('L', packed_ip).tostring()) + array.array('I', packed_ip).tostring()) if wsa_inet_ntop: string_buf = create_string_buffer(47) wsa_inet_ntop(address_family, packed_ip, string_buf, sizeof(string_buf)) if not string_buf.value: raise ValueError("wsa_inet_ntop failed with: %s" % - array.array('L', packed_ip).tostring()) + array.array('I', packed_ip).tostring()) return string_buf.value if address_family == socket.AF_INET6: raise ValueError("Platform does not support IPv6") - return socket.inet_ntoa(array.array('L', packed_ip).tostring()) + return socket.inet_ntoa(array.array('I', packed_ip).tostring()) def inet_pton(address_family, string_ip): if address_family == socket.AF_INET6: - ret_packed_ip = (c_ulong * 4)() + ret_packed_ip = (c_uint * 4)() else: - ret_packed_ip = (c_ulong * 1)() + ret_packed_ip = (c_uint * 1)() if py_inet_pton: ret_string = py_inet_pton(address_family, string_ip) - ret_packed_ip[:] = array.array('L', ret_string) + ret_packed_ip[:] = array.array('I', ret_string) elif wsa_inet_pton: if wsa_inet_pton(address_family, string_ip, ret_packed_ip) != 1: raise ValueError("wsa_inet_pton failed with: %s" % string_ip) @@ -330,7 +334,7 @@ def inet_pton(address_family, string_ip): if address_family == socket.AF_INET6: raise ValueError("Platform does not support IPv6") ret_string = socket.inet_aton(string_ip) - ret_packed_ip[:] = array.array('L', ret_string) + ret_packed_ip[:] = array.array('I', ret_string) return ret_packed_ip def addr_tuple_from_sockaddr_u(su): @@ -393,6 +397,11 @@ def errcheck_p(result, func, args): raise_ssl_error(result, func, args, None) return args +def errcheck_FuncParam(result, func, args): + if not result: + raise_ssl_error(result, func, args, None) + return func.ret_type(result) + # # Function prototypes # @@ -405,6 +414,12 @@ def _make_function(name, lib, args, export=True, errcheck="default"): return map_type sig = tuple(type_subst(i[0]) for i in args) + # Handle pointer return values (width is architecture-dependent) + if isinstance(sig[0], type) and issubclass(sig[0], FuncParam): + sig = (c_void_p,) + sig[1:] + pointer_return = True + else: + pointer_return = False if not _sigs.has_key(sig): _sigs[sig] = CFUNCTYPE(*sig) if export: @@ -418,13 +433,16 @@ def _make_function(name, lib, args, export=True, errcheck="default"): [:3 if len(i) > 3 else 2] for i in args[1:])) func.func_name = name + if pointer_return: + func.ret_type = args[0][0] # for fix-up during error checking protocol if errcheck == "default": # Assign error checker based on return type if args[0][0] in (c_int,): errcheck = errcheck_ord - elif args[0][0] in (c_void_p, c_char_p) or \ - isinstance(args[0][0], type) and issubclass(args[0][0], FuncParam): + elif args[0][0] in (c_void_p, c_char_p): errcheck = errcheck_p + elif pointer_return: + errcheck = errcheck_FuncParam else: errcheck = None if errcheck: @@ -616,9 +634,9 @@ def SSL_CTX_set_read_ahead(ctx, m): # Returns the previous value of m _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_READ_AHEAD, m, None) -_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_int, POINTER(c_ubyte), +_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) -_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_int, POINTER(c_ubyte), c_uint) +_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), c_uint) def SSL_CTX_set_cookie_cb(ctx, generate, verify): def py_generate_cookie_cb(ssl, cookie, cookie_len): diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 7fbcb9d..c088302 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -71,7 +71,7 @@ class _CTX(_Rsrc): super(_CTX, self).__init__(value) def __del__(self): - _logger.debug("Freeing SSL CTX: %d", self._value._as_parameter) + _logger.debug("Freeing SSL CTX: %d", self.raw) SSL_CTX_free(self._value) self._value = None @@ -82,7 +82,7 @@ class _SSL(_Rsrc): super(_SSL, self).__init__(value) def __del__(self): - _logger.debug("Freeing SSL: %d", self._value._as_parameter) + _logger.debug("Freeing SSL: %d", self.raw) SSL_free(self._value) self._value = None @@ -268,7 +268,7 @@ class SSLConnection(object): def _get_cookie(self, ssl): assert self._listening - assert self._ssl.value._as_parameter == ssl._as_parameter + assert self._ssl.raw == ssl.raw if self._listening_peer_address: peer_address = self._listening_peer_address else: @@ -397,7 +397,7 @@ class SSLConnection(object): self._listening = True try: _logger.debug("Invoking DTLSv1_listen for ssl: %d", - self._ssl.value._as_parameter) + self._ssl.raw) dtls_peer_address = DTLSv1_listen(self._ssl.value) except openssl_error() as err: if err.ssl_error == SSL_ERROR_WANT_READ: diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 0f4a449..678d315 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -1333,6 +1333,22 @@ class ThreadedTests(unittest.TestCase): server.close() +def hostname_for_protocol(protocol): + global HOST + # We can't quite predict the content of the hosts file, but we prefer names + # to numbers in order to test name resolution; if we can't find a name, + # then fall back to a number for the given protocol + for name in HOST, "localhost", "ip6-localhost", "127.0.0.1", "::1": + try: + socket.getaddrinfo(name, 0, protocol) + except socket.error: + pass + else: + HOST = name + return + # Is the loopback interface enabled along with ipv6 for that interface? + raise Exception("Failed to select hostname for protocol %d" % protocol) + def test_main(verbose=True): global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE, AF_INET4_6 CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, @@ -1352,6 +1368,7 @@ def test_main(verbose=True): for demux in "platform-native", "routing": for AF_INET4_6 in socket.AF_INET, socket.AF_INET6: print "Suite run: demux: %s, protocol: %d" % (demux, AF_INET4_6) + hostname_for_protocol(AF_INET4_6) res = unittest.main(exit=False).result.wasSuccessful() if not res: print "Suite run failed: demux: %s, protocol: %d" % ( diff --git a/dtls/util.py b/dtls/util.py index 4dc10c3..df938b9 100644 --- a/dtls/util.py +++ b/dtls/util.py @@ -19,6 +19,10 @@ class _Rsrc(object): def value(self): return self._value + @property + def raw(self): + return self._value.raw + class _BIO(_Rsrc): """BIO wrapper""" @@ -31,7 +35,7 @@ class _BIO(_Rsrc): def __del__(self): if self.owned: - _logger.debug("Freeing BIO: %d", self._value._as_parameter) + _logger.debug("Freeing BIO: %d", self.raw) from openssl import BIO_free BIO_free(self._value) self.owned = False