Source code for abracudabra.device.query

"""Query the device of a numpy/cupy array, series or torch tensor."""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, overload

from .._import import get_library_name
from .._validate import Library, validate_obj_type
from .base import Device, DeviceType

if TYPE_CHECKING:
    from torch import Tensor

    from ..annotations import Array, DataFrame, Series


[docs] def _cupy_get_device(array: object, /) -> Device: """Get the device of a cupy array. Args: array: The array to check. Returns: The device of the array. """ return Device("cuda", array.device.id) # type: ignore[attr-defined]
@overload
[docs] def frame_get_device_type( frame: Series | DataFrame, /, *, raise_if_unknown: Literal[True] = ... ) -> DeviceType: ...
@overload def frame_get_device_type( frame: Series | DataFrame, /, *, raise_if_unknown: bool = ... ) -> DeviceType | None: ... def frame_get_device_type( frame: Series | DataFrame, /, *, raise_if_unknown: bool = True ) -> DeviceType | None: """Get the device type of a pandas or cudf series or dataframe.""" library = get_library_name(frame) if library == "pandas" and validate_obj_type(frame, Library.pandas): return "cpu" if library == "cudf" and validate_obj_type(frame, Library.cudf): return "cuda" if raise_if_unknown: msg = ( "Expected a pandas/cudf index, series or dataframe, " f"but got '{type(frame).__name__}'." ) raise TypeError(msg) return None
[docs] def _torch_get_device(tensor: Tensor, /) -> Device: """Get the device of a torch tensor. Args: tensor: The tensor to check. Returns: The device of the tensor. """ device = tensor.device return Device.validate(device.type, device.index)
@overload
[docs] def get_device( element: Array | Tensor, /, *, raise_if_unknown: Literal[True] = ... ) -> Device: ...
@overload def get_device( element: Array | Tensor, /, *, raise_if_unknown: bool = ... ) -> Device | None: ... def get_device( element: Array | Tensor, /, *, raise_if_unknown: bool = True ) -> Device | None: """Get the device of a numpy/cupy array or series. Args: element: The element to check. raise_if_unknown: Whether to raise an error if the element is not a known array or tensor. Returns: The device of the element. """ library = get_library_name(element) if library == "numpy" and validate_obj_type(element, Library.numpy): return Device("cpu") if library == "cupy" and validate_obj_type(element, Library.cupy): return _cupy_get_device(element) if library == "torch" and validate_obj_type(element, Library.torch): return _torch_get_device(element) if raise_if_unknown: msg = ( "Expected a numpy/cupy array or torch array or tensor, " f"but got '{type(element).__name__}'." ) raise TypeError(msg) return None
[docs] def guess_device(*elements: Array | Tensor, skip_unknown: bool = True) -> Device: """Guess the device of a numpy/cupy array or series. Args: *elements: The elements to check. skip_unknown: Whether to skip elements that are not known arrays or tensors. Returns: The device of the elements. Raises: ValueError: If no elements are given. ValueError: If the elements are on different devices. """ devices = { device for element in elements if (device := get_device(element, raise_if_unknown=not skip_unknown)) is not None } if len(devices) == 0: msg = "Expected at least one element, but got none." raise ValueError(msg) if len(devices) > 1: msg = ( f"Expected all elements to be on the same device, " f"but found {len(devices)} different devices:" + ", ".join(map(repr, devices)) ) raise ValueError(msg) return devices.pop()