diff --git a/nwb_linkml/src/nwb_linkml/types/ndarray.py b/nwb_linkml/src/nwb_linkml/types/ndarray.py index 825b354..612b5d3 100644 --- a/nwb_linkml/src/nwb_linkml/types/ndarray.py +++ b/nwb_linkml/src/nwb_linkml/types/ndarray.py @@ -68,14 +68,9 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): assert value.dtype == dtype or value.dtype.name in allowed_precisions[dtype.__name__], f"Invalid dtype! expected {dtype}, got {value.dtype}" return value - def validate_array(value: Any) -> np.ndarray: - # not using instancecheck because nwb doesnt actually validate precision - # this step is now just validating shape - # if isinstance(value, np.ndarray): - # assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}' - # elif isinstance(value, DaskArray): + + def validate_shape(value: Any) -> np.ndarray: assert shape is Any or check_shape(value.shape, shape), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}' - return value def coerce_list(value: Any) -> np.ndarray: @@ -161,7 +156,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): core_schema.is_instance_schema(cls=NDArrayProxy) ]), core_schema.no_info_plain_validator_function(validate_dtype), - core_schema.no_info_plain_validator_function(validate_array) + core_schema.no_info_plain_validator_function(validate_shape) ] ), serialization=core_schema.plain_serializer_function_ser_schema(