Source code for abracudabra.testing
"""Assertions for testing purposes.
These functions are used in the unitary tests. The equality checks assume no
floating-point errors.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from ._import import raise_library_not_found
if TYPE_CHECKING:
from torch import Tensor
from ._annotations import Array, DataFrame, Series
[docs]
def assert_tensors_equal(tensor1: Tensor, tensor2: Tensor, /) -> None:
"""Assert that two Torch tensors are equal."""
try:
import torch
except ImportError: # pragma: no cover
raise_library_not_found("torch")
# Type
for tensor in (tensor1, tensor2):
if not isinstance(tensor, torch.Tensor):
msg = f"Expected a Torch tensor, but got {type(tensor)!r}."
raise TypeError(msg)
# Device
if tensor1.device != tensor2.device:
msg = (
"The two tensors are on different devices: "
f"{tensor1.device} and {tensor2.device}."
)
raise AssertionError(msg)
# dtype
if tensor1.dtype != tensor2.dtype:
msg = (
"The two tensors have different data types: "
f"{tensor1.dtype} and {tensor2.dtype}."
)
raise AssertionError(msg)
# Values
if not tensor1.equal(tensor2):
msg = "The two tensors are not equal."
raise AssertionError(msg)
[docs]
def assert_arrays_equal(array1: Array, array2: Array, /) -> None:
"""Assert that two arrays are equal."""
# Type
if type(array1) is not type(array2):
msg = (
"Expected arrays of the same type, "
f"but got {type(array1)!r} and {type(array2)!r}."
)
raise AssertionError(msg)
# Shape
if array1.shape != array2.shape:
msg = (
"Expected arrays of the same shape, "
f"but got {array1.shape} and {array2.shape}."
)
raise AssertionError(msg)
# dtype
if array1.dtype != array2.dtype:
msg = (
"The two arrays have different data types: "
f"{array1.dtype} and {array2.dtype}."
)
raise AssertionError(msg)
# Values
if not (array1 == array2).all():
msg = "The two arrays are not equal."
raise AssertionError(msg)
[docs]
def assert_frames_equal(frame1: Series | DataFrame, frame2: Series | DataFrame) -> None:
"""Assert that two Pandas/cuDF series or dataframes are equal."""
# Type
if type(frame1) is not type(frame2):
msg = (
"Expected frames of the same type, "
f"but got {type(frame1)!r} and {type(frame2)!r}."
)
raise AssertionError(msg)
# dtypes
dtypes_equal = frame1.dtypes == frame2.dtypes
if not isinstance(dtypes_equal, bool):
dtypes_equal = dtypes_equal.all()
if not dtypes_equal:
msg = "The two frames have different data types."
raise AssertionError(msg)
# Values
if not frame1.equals(frame2): # type: ignore[arg-type]
msg = "The two frames are not equal."
raise AssertionError(msg)