Cleanup ndarray validator

Remove outdated comments, rename inner function to match
This commit is contained in:
Jonny Saunders 2024-01-16 02:06:08 -08:00 committed by GitHub
parent eac5ef4c80
commit 4ee97263ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -68,14 +68,9 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
assert value.dtype == dtype or value.dtype.name in allowed_precisions[dtype.__name__], f"Invalid dtype! expected {dtype}, got {value.dtype}"
return value
def validate_array(value: Any) -> np.ndarray:
# not using instancecheck because nwb doesnt actually validate precision
# this step is now just validating shape
# if isinstance(value, np.ndarray):
# assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
# elif isinstance(value, DaskArray):
assert shape is Any or check_shape(value.shape, shape), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
def validate_shape(value: Any) -> np.ndarray:
assert shape is Any or check_shape(value.shape, shape), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
return value
def coerce_list(value: Any) -> np.ndarray:
@ -161,7 +156,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
core_schema.is_instance_schema(cls=NDArrayProxy)
]),
core_schema.no_info_plain_validator_function(validate_dtype),
core_schema.no_info_plain_validator_function(validate_array)
core_schema.no_info_plain_validator_function(validate_shape)
]
),
serialization=core_schema.plain_serializer_function_ser_schema(