Source code for abracudabra.device.library
"""Library import functions for device types."""
from types import ModuleType
from .._import import import_library
from .base import DeviceType
[docs]
_DEVICE_TO_LIBRARY: dict[str, dict[DeviceType, str]] = {
"array": {"cpu": "numpy", "cuda": "cupy"},
"frame": {"cpu": "pandas", "cuda": "cudf"},
}
"""A collection of mappings from device types to library names."""
[docs]
_DEFAULT_DEVICE_TYPE: DeviceType = "cpu"
"""The default device type, if the device type is not specified."""
[docs]
def _import_library(obj_name: str, device_type: DeviceType | None = None) -> ModuleType:
if device_type is None:
device_type = _DEFAULT_DEVICE_TYPE
library_name = _DEVICE_TO_LIBRARY[obj_name][device_type]
return import_library(library_name)
[docs]
def get_np_or_cp(device_type: DeviceType | None = None) -> ModuleType:
"""Get the numpy or cupy library based on the device type."""
return _import_library("array", device_type)
[docs]
def get_pd_or_cudf(device_type: DeviceType | None = None) -> ModuleType:
"""Get the pandas or cudf library based on the device type."""
return _import_library("frame", device_type)