Source code for CADETProcess.dataStructure.aggregator
import numpy as np
from .dataStructure import Aggregator
[docs]
class SizedAggregator(Aggregator):
"""Aggregator for sized parameters."""
def _parameter_shape(self, instance):
values = self._get_parameter_values_from_container(instance)
shapes = [np.array(el, ndmin=1).shape for el in values]
if len(set(shapes)) > 1:
raise ValueError("Inconsistent parameter shapes.")
if len(shapes) == 0:
return ()
return shapes[0]
def _expected_shape(self, instance):
return (self._n_instances(instance), ) + self._parameter_shape(instance)
def _get_parameter_values_from_container(self, instance):
value = super()._get_parameter_values_from_container(instance)
if value is None or len(value) == 0:
return
value = np.array(value, ndmin=2).T
return value
def _check(self, instance, value, recursive=False):
expected_shape = self._expected_shape(instance)
if value.shape != expected_shape:
raise ValueError(
f"Expected a array with shape {expected_shape}, got {value.shape}"
)
if recursive:
super()._check(instance, value, recursive)
def _prepare(self, instance, value, recursive=False):
value = np.array(value)
if recursive:
value = super()._prepare(instance, value, recursive)
return value
[docs]
class ClassDependentAggregator(Aggregator):
"""Aggregator where parameter name changes depending on instance type."""
def __init__(self, *args, mapping, **kwargs):
"""
Initialize the Aggregator descriptor.
Parameters
----------
mapping : dict
Mapping of instance types and parameter names.
*args : tuple, optional
Additional positional arguments.
**kwargs : dict, optional
Additional keyword arguments.
"""
self.mapping = mapping
super().__init__(*args, **kwargs)
def _get_parameter_values_from_container(self, instance):
container = self._container_obj(instance)
values = []
for el in container:
if type(el) in self.mapping:
attr = self.mapping[type(el)]
else:
attr = self.mapping[None]
if attr is None:
continue
value = getattr(el, attr)
values.append(value)
if len(values) == 0:
return
return values
[docs]
class SizedClassDependentAggregator(SizedAggregator, ClassDependentAggregator):
pass