From 7c6a512f94e121730d24af1a214c9b63eeb2e50f Mon Sep 17 00:00:00 2001 From: Ray Brown Date: Sun, 2 Dec 2012 10:39:39 -0800 Subject: [PATCH] 64-bit port On a 64-bit OS, pointer return values needed to be marked as c_void_p instead of a user-defined type, which would result in the transfer of 32 bits only. In order to still return an instance of the user-defined type to the caller, imported functions are now marked with the return type, and the return value is converted to that type by a new error checking function used only with imported functions that create and return user-defined types. On 64-bit Linux, the long type becomes 8 bytes, whereas the int type remains 4 bytes. The various sockaddr_* fields therefore needed to be changed from long to int, as did the type signatures of the packed string to array conversion functions. On an Ubuntu server installation, it was found that the name "localhost" does not resolve to an ipv6 address. A name search has therefore been added to the unit test driver, along with an ip number fallback. Tested on Ubuntu Server 12.04.1 LTS 64-bit. Regression tested on Ubuntu 12.04.1 LTS 32-bit. --- dtls/openssl.py | 50 +++++++++++++++++++++++++++++-------------- dtls/sslconnection.py | 8 +++---- dtls/test/unit.py | 17 +++++++++++++++ dtls/util.py | 6 +++++- 4 files changed, 60 insertions(+), 21 deletions(-) diff --git a/dtls/openssl.py b/dtls/openssl.py index 0c63d1f..bea985c 100644 --- a/dtls/openssl.py +++ b/dtls/openssl.py @@ -122,11 +122,15 @@ class FuncParam(object): return value._as_parameter def __init__(self, value): - self._as_parameter = value + self._as_parameter = c_void_p(value) def __nonzero__(self): return bool(self._as_parameter) + @property + def raw(self): + return self._as_parameter.value + class DTLSv1Method(FuncParam): def __init__(self, value): @@ -268,15 +272,15 @@ class sockaddr_storage(Structure): class sockaddr_in(Structure): _fields_ = [("sin_family", c_short), ("sin_port", c_ushort), - ("sin_addr", c_ulong * 1), + ("sin_addr", c_uint * 1), ("sin_zero", c_char * 8)] class sockaddr_in6(Structure): _fields_ = [("sin6_family", c_short), ("sin6_port", c_ushort), - ("sin6_flowinfo", c_ulong), - ("sin6_addr", c_ulong * 4), - ("sin6_scope_id", c_ulong)] + ("sin6_flowinfo", c_uint), + ("sin6_addr", c_uint * 4), + ("sin6_scope_id", c_uint)] class sockaddr_u(Union): _fields_ = [("ss", sockaddr_storage), @@ -302,27 +306,27 @@ if not py_inet_pton: def inet_ntop(address_family, packed_ip): if py_inet_ntop: return py_inet_ntop(address_family, - array.array('L', packed_ip).tostring()) + array.array('I', packed_ip).tostring()) if wsa_inet_ntop: string_buf = create_string_buffer(47) wsa_inet_ntop(address_family, packed_ip, string_buf, sizeof(string_buf)) if not string_buf.value: raise ValueError("wsa_inet_ntop failed with: %s" % - array.array('L', packed_ip).tostring()) + array.array('I', packed_ip).tostring()) return string_buf.value if address_family == socket.AF_INET6: raise ValueError("Platform does not support IPv6") - return socket.inet_ntoa(array.array('L', packed_ip).tostring()) + return socket.inet_ntoa(array.array('I', packed_ip).tostring()) def inet_pton(address_family, string_ip): if address_family == socket.AF_INET6: - ret_packed_ip = (c_ulong * 4)() + ret_packed_ip = (c_uint * 4)() else: - ret_packed_ip = (c_ulong * 1)() + ret_packed_ip = (c_uint * 1)() if py_inet_pton: ret_string = py_inet_pton(address_family, string_ip) - ret_packed_ip[:] = array.array('L', ret_string) + ret_packed_ip[:] = array.array('I', ret_string) elif wsa_inet_pton: if wsa_inet_pton(address_family, string_ip, ret_packed_ip) != 1: raise ValueError("wsa_inet_pton failed with: %s" % string_ip) @@ -330,7 +334,7 @@ def inet_pton(address_family, string_ip): if address_family == socket.AF_INET6: raise ValueError("Platform does not support IPv6") ret_string = socket.inet_aton(string_ip) - ret_packed_ip[:] = array.array('L', ret_string) + ret_packed_ip[:] = array.array('I', ret_string) return ret_packed_ip def addr_tuple_from_sockaddr_u(su): @@ -393,6 +397,11 @@ def errcheck_p(result, func, args): raise_ssl_error(result, func, args, None) return args +def errcheck_FuncParam(result, func, args): + if not result: + raise_ssl_error(result, func, args, None) + return func.ret_type(result) + # # Function prototypes # @@ -405,6 +414,12 @@ def _make_function(name, lib, args, export=True, errcheck="default"): return map_type sig = tuple(type_subst(i[0]) for i in args) + # Handle pointer return values (width is architecture-dependent) + if isinstance(sig[0], type) and issubclass(sig[0], FuncParam): + sig = (c_void_p,) + sig[1:] + pointer_return = True + else: + pointer_return = False if not _sigs.has_key(sig): _sigs[sig] = CFUNCTYPE(*sig) if export: @@ -418,13 +433,16 @@ def _make_function(name, lib, args, export=True, errcheck="default"): [:3 if len(i) > 3 else 2] for i in args[1:])) func.func_name = name + if pointer_return: + func.ret_type = args[0][0] # for fix-up during error checking protocol if errcheck == "default": # Assign error checker based on return type if args[0][0] in (c_int,): errcheck = errcheck_ord - elif args[0][0] in (c_void_p, c_char_p) or \ - isinstance(args[0][0], type) and issubclass(args[0][0], FuncParam): + elif args[0][0] in (c_void_p, c_char_p): errcheck = errcheck_p + elif pointer_return: + errcheck = errcheck_FuncParam else: errcheck = None if errcheck: @@ -616,9 +634,9 @@ def SSL_CTX_set_read_ahead(ctx, m): # Returns the previous value of m _SSL_CTX_ctrl(ctx, SSL_CTRL_SET_READ_AHEAD, m, None) -_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_int, POINTER(c_ubyte), +_rint_voidp_ubytep_uintp = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), POINTER(c_uint)) -_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_int, POINTER(c_ubyte), c_uint) +_rint_voidp_ubytep_uint = CFUNCTYPE(c_int, c_void_p, POINTER(c_ubyte), c_uint) def SSL_CTX_set_cookie_cb(ctx, generate, verify): def py_generate_cookie_cb(ssl, cookie, cookie_len): diff --git a/dtls/sslconnection.py b/dtls/sslconnection.py index 7fbcb9d..c088302 100644 --- a/dtls/sslconnection.py +++ b/dtls/sslconnection.py @@ -71,7 +71,7 @@ class _CTX(_Rsrc): super(_CTX, self).__init__(value) def __del__(self): - _logger.debug("Freeing SSL CTX: %d", self._value._as_parameter) + _logger.debug("Freeing SSL CTX: %d", self.raw) SSL_CTX_free(self._value) self._value = None @@ -82,7 +82,7 @@ class _SSL(_Rsrc): super(_SSL, self).__init__(value) def __del__(self): - _logger.debug("Freeing SSL: %d", self._value._as_parameter) + _logger.debug("Freeing SSL: %d", self.raw) SSL_free(self._value) self._value = None @@ -268,7 +268,7 @@ class SSLConnection(object): def _get_cookie(self, ssl): assert self._listening - assert self._ssl.value._as_parameter == ssl._as_parameter + assert self._ssl.raw == ssl.raw if self._listening_peer_address: peer_address = self._listening_peer_address else: @@ -397,7 +397,7 @@ class SSLConnection(object): self._listening = True try: _logger.debug("Invoking DTLSv1_listen for ssl: %d", - self._ssl.value._as_parameter) + self._ssl.raw) dtls_peer_address = DTLSv1_listen(self._ssl.value) except openssl_error() as err: if err.ssl_error == SSL_ERROR_WANT_READ: diff --git a/dtls/test/unit.py b/dtls/test/unit.py index 0f4a449..678d315 100644 --- a/dtls/test/unit.py +++ b/dtls/test/unit.py @@ -1333,6 +1333,22 @@ class ThreadedTests(unittest.TestCase): server.close() +def hostname_for_protocol(protocol): + global HOST + # We can't quite predict the content of the hosts file, but we prefer names + # to numbers in order to test name resolution; if we can't find a name, + # then fall back to a number for the given protocol + for name in HOST, "localhost", "ip6-localhost", "127.0.0.1", "::1": + try: + socket.getaddrinfo(name, 0, protocol) + except socket.error: + pass + else: + HOST = name + return + # Is the loopback interface enabled along with ipv6 for that interface? + raise Exception("Failed to select hostname for protocol %d" % protocol) + def test_main(verbose=True): global CERTFILE, ISSUER_CERTFILE, OTHER_CERTFILE, AF_INET4_6 CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, @@ -1352,6 +1368,7 @@ def test_main(verbose=True): for demux in "platform-native", "routing": for AF_INET4_6 in socket.AF_INET, socket.AF_INET6: print "Suite run: demux: %s, protocol: %d" % (demux, AF_INET4_6) + hostname_for_protocol(AF_INET4_6) res = unittest.main(exit=False).result.wasSuccessful() if not res: print "Suite run failed: demux: %s, protocol: %d" % ( diff --git a/dtls/util.py b/dtls/util.py index 4dc10c3..df938b9 100644 --- a/dtls/util.py +++ b/dtls/util.py @@ -19,6 +19,10 @@ class _Rsrc(object): def value(self): return self._value + @property + def raw(self): + return self._value.raw + class _BIO(_Rsrc): """BIO wrapper""" @@ -31,7 +35,7 @@ class _BIO(_Rsrc): def __del__(self): if self.owned: - _logger.debug("Freeing BIO: %d", self._value._as_parameter) + _logger.debug("Freeing BIO: %d", self.raw) from openssl import BIO_free BIO_free(self._value) self.owned = False