diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 825c4c1..3dfc596 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -154,13 +154,23 @@ class Interface(ABC, Generic[T]): """ Find the interface that should be used for this array based on its input type """ - matches = [i for i in cls.interfaces() if i.check(array)] + # first try and find a non-numpy interface, since the numpy interface + # will try and load the array into memory in its check method + interfaces = cls.interfaces() + non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] + np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] + + matches = [i for i in non_np_interfaces if i.check(array)] if len(matches) > 1: msg = f"More than one interface matches input {array}:\n" msg += "\n".join([f" - {i}" for i in matches]) raise ValueError(msg) elif len(matches) == 0: - raise ValueError(f"No matching interfaces found for input {array}") + # now try the numpy interface + if np_interface.check(array): + return np_interface + else: + raise ValueError(f"No matching interfaces found for input {array}") else: return matches[0] diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index d7f6676..5ee988a 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -25,7 +25,7 @@ class NumpyInterface(Interface): input_types = (ndarray, list) return_type = ndarray - priority = -1 + priority = -999 """ The numpy interface is usually the interface of last resort. We want to use any more specific interface that we might have, @@ -45,7 +45,7 @@ class NumpyInterface(Interface): try: _ = np.array(array) return True - except TypeError: + except Exception: return False def before_validation(self, array: Any) -> ndarray: diff --git a/tests/test_interface/test_interface.py b/tests/test_interface/test_interface.py index bbafb7a..337ec93 100644 --- a/tests/test_interface/test_interface.py +++ b/tests/test_interface/test_interface.py @@ -51,7 +51,7 @@ def test_interface_match_error(interfaces): assert "Interface2" in e with pytest.raises(ValueError) as e: - Interface.match("hey") + Interface.match([[1, 2, 3], ["hey"]]) assert "No matching interfaces" in e with pytest.raises(ValueError) as e: