Source code for abracudabra.testing
"""Assertions for testing purposes."""
from __future__ import annotations
from typing import TYPE_CHECKING
from ._import import raise_library_not_found
if TYPE_CHECKING:
from torch import Tensor
from abracudabra.annotations import Array, DataFrame, Series
[docs]
def assert_tensors_equal(tensor1: Tensor, tensor2: Tensor, /) -> None:
"""Assert that two torch tensors are equal."""
try:
import torch
except ImportError: # pragma: no cover
raise_library_not_found("torch")
# Type
for tensor in (tensor1, tensor2):
if not isinstance(tensor, torch.Tensor):
msg = f"Expected a torch tensor, but got {type(tensor)!r}."
raise TypeError(msg)
# Device
if tensor1.device != tensor2.device:
msg = (
"The two tensors are on different devices: "
f"{tensor1.device} and {tensor2.device}."
)
raise AssertionError(msg)
# Values
if not tensor1.equal(tensor2):
msg = "The two tensors are not equal."
raise AssertionError(msg)
[docs]
def assert_arrays_equal(array1: Array, array2: Array, /) -> None:
"""Assert that two arrays are equal."""
# Type
if type(array1) is not type(array2):
msg = (
"Expected arrays of the same type, "
f"but got {type(array1)!r} and {type(array2)!r}."
)
raise AssertionError(msg)
# Shape
if array1.shape != array2.shape:
msg = (
"Expected arrays of the same shape, "
f"but got {array1.shape} and {array2.shape}."
)
raise AssertionError(msg)
# Values
if not (array1 == array2).all():
msg = "The two arrays are not equal."
raise AssertionError(msg)
[docs]
def assert_frames_equal(frame1: Series | DataFrame, frame2: Series | DataFrame) -> None:
"""Assert that two pandas or cudf series or dataframes are equal."""
# Type
if type(frame1) is not type(frame2):
msg = (
"Expected frames of the same type, "
f"but got {type(frame1)!r} and {type(frame2)!r}."
)
raise AssertionError(msg)
# Values
if not frame1.equals(frame2): # type: ignore[arg-type]
msg = "The two frames are not equal."
raise AssertionError(msg)