from __future__ import annotations
from typing import Any, Callable, Iterable, Mapping, Optional
import numpy as np
import numpy.typing as npt
from .dataStructure import Aggregator
__all__ = [
"NumpyProxyArray",
"SizedAggregator",
"ClassDependentAggregator",
"SizedClassDependentAggregator"
]
[docs]
class NumpyProxyArray(np.ndarray):
"""A numpy array that dynamically updates attributes of container elements."""
def __new__(
cls, aggregator: SizedAggregator, instance: Any
) -> Optional[NumpyProxyArray]:
"""
Create a new NumpyProxyArray instance using data from an aggregator.
Parameters
----------
aggregator : SizedAggregator
The aggregator from which to obtain values.
instance : Any
The instance associated with the values.
Returns
-------
Optional[NumpyProxyArray]
A new instance of the class if values are available, otherwise None.
"""
values = aggregator._get_values_from_container(instance, transpose=True)
if values is None:
return
obj = values.view(cls)
obj.aggregator = aggregator
obj.instance = instance
return obj
def _get_values_from_aggregator(self) -> Any:
"""Refresh data from the underlying container."""
return self.aggregator._get_values_from_container(self.instance, transpose=True, check=True)
[docs]
def __getitem__(self, index: int) -> Any:
"""Retrieve an item from the aggregated parameter array."""
return self._get_values_from_aggregator()[index]
def __setitem__(self, index: int, value: Any) -> None:
"""
Modify an individual element in the aggregated parameter list.
This ensures changes are propagated back to the objects.
"""
current_value = self._get_values_from_aggregator()
current_value[index] = value
self.aggregator.__set__(self.instance, current_value)
def __array_finalize__(self, obj: NumpyProxyArray) -> Optional[np.ndarray]:
"""Ensure attributes are copied when creating a new view or slice."""
if obj is None:
self.aggregator = None
self.instance = None
return
if not isinstance(obj, NumpyProxyArray):
return np.asarray(obj)
self.aggregator = getattr(obj, "aggregator", None)
self.instance = getattr(obj, "instance", None)
def __array_function__(
self,
func: Callable,
types: Iterable[type],
*args: Iterable[Any],
**kwargs: Mapping[str, Any],
) -> np.ndarray:
"""Ensure that high-level NumPy functions return a normal np.ndarray."""
result = super().__array_function__(func, types, *args, **kwargs)
return np.asarray(result)
def __repr__(self) -> str:
"""Return a fresh representation that reflects live data."""
return f"NumpyProxyA{self._get_values_from_aggregator().__repr__()[1:]}"
[docs]
class SizedAggregator(Aggregator):
"""Aggregator for sized parameters."""
def __init__(self, *args: Any, transpose: bool = False, **kwargs: Any) -> None:
"""
Initialize a SizedAggregator instance.
Parameters
----------
*args : Any
Variable length argument list.
transpose : bool, options
If False, the parameter shape will be ((n_instances, ) + parameter_shape).
Else, it will be (parameter_shape + (n_instances, ))
The default is False.
**kwargs : Any
Arbitrary keyword arguments.
"""
self.transpose = transpose
super().__init__(*args, **kwargs)
def _parameter_shape(self, instance: Any) -> tuple[int, ...]:
values = self._get_values_from_container(instance, transpose=False)
shapes = [el.shape for el in values]
if len(set(shapes)) > 1:
raise ValueError("Inconsistent parameter shapes.")
if len(shapes) == 0:
return ()
return shapes[0]
def _expected_shape(self, instance: Any) -> tuple[int, ...]:
if self.transpose:
return self._parameter_shape(instance) + (self._n_instances(instance),)
else:
return (self._n_instances(instance),) + self._parameter_shape(instance)
def _get_values_from_container(
self, instance: Any, transpose: bool = False, check: bool = False
) -> np.ndarray:
value = super()._get_values_from_container(instance, check=False)
if value is None or len(value) == 0:
return
value = np.array(value, ndmin=2)
if check:
value = self._prepare(instance, value, transpose=False, recursive=True)
self._check(instance, value, transpose=True, recursive=True)
if transpose and self.transpose:
value = value.T
return value
def _check(
self,
instance: Any,
value: npt.ArrayLike,
transpose: bool = True,
recursive: bool = False,
) -> None:
value_array = np.array(value, ndmin=2)
if transpose and self.transpose:
value_array = value_array.T
value_shape = value_array.shape
expected_shape = self._expected_shape(instance)
if value_shape != expected_shape:
raise ValueError(f"Expected a array with shape {expected_shape}, got {value_shape}")
if recursive:
super()._check(instance, value, recursive)
def _prepare(
self,
instance: Any,
value: npt.ArrayLike,
transpose: bool = False,
recursive: bool = False,
) -> np.ndarray:
value = np.array(value, ndmin=2)
if transpose and self.transpose:
value = value.T
if recursive:
value = super()._prepare(instance, value, recursive)
return value
def __get__(self, instance: Any, cls: type) -> NumpyProxyArray:
"""
Retrieve the descriptor value for the given instance.
Parameters
----------
instance : Any
Instance to retrieve the descriptor value for.
cls : Type[Any], optional
Class to which the descriptor belongs.
Returns
-------
NumpyProxyArray
A numpy-like array that modifies the underlying objects when changed.
"""
if instance is None:
return self
return NumpyProxyArray(self, instance)
def __set__(self, instance: Any, value: Any) -> None:
"""
Set the descriptor value for the given instance.
Parameters
----------
instance : Any
Instance to set the descriptor value for.
value : Any
Value to set.
"""
value = self._prepare(instance, value, transpose=True)
super().__set__(instance, value)
[docs]
class ClassDependentAggregator(Aggregator):
"""Aggregator where parameter name changes depending on instance type."""
def __init__(self, *args: Any, mapping: dict, **kwargs: Any) -> None:
"""
Initialize the Aggregator descriptor.
Parameters
----------
mapping : dict
Mapping of instance types and parameter names.
*args : tuple, optional
Additional positional arguments.
**kwargs : dict, optional
Additional keyword arguments.
"""
self.mapping = mapping
super().__init__(*args, **kwargs)
def _get_values_from_container(self, instance: Any, check: bool = False) -> list:
container = self._container_obj(instance)
values = []
for el in container:
if type(el) in self.mapping:
attr = self.mapping[type(el)]
else:
attr = self.mapping[None]
if attr is None:
continue
value = getattr(el, attr)
values.append(value)
if len(values) == 0:
return
if check:
values = self._prepare(instance, values, transpose=False, recursive=True)
self._check(instance, values, transpose=True, recursive=True)
return values
[docs]
class SizedClassDependentAggregator(SizedAggregator, ClassDependentAggregator):
"""Aggregator where parameter name and size changes depending on instance type."""
pass