Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
583 changes: 583 additions & 0 deletions notebooks/structural_components_dataclass.ipynb

Large diffs are not rendered by default.

261 changes: 261 additions & 0 deletions pymc_extras/statespace/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
from collections.abc import Iterator
from dataclasses import dataclass, fields
from typing import Generic, Self, TypeVar

from pymc_extras.statespace.core import PyMCStateSpace
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
SHOCK_DIM,
)


@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))


T = TypeVar("T", bound=Property)


@dataclass(frozen=True)
class Info(Generic[T]):
items: tuple[T, ...]
key_field: str = "name"
_index: dict[str, T] | None = None

def __post_init__(self):
index = {}
missing_attr = []
for item in self.items:
if not hasattr(item, self.key_field):
missing_attr.append(item)
continue
key = getattr(item, self.key_field)
if key in index:
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
index[key] = item
if missing_attr:
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
object.__setattr__(self, "_index", index)

def _key(self, item: T) -> str:
return getattr(item, self.key_field)

def get(self, key: str, default=None) -> T | None:
return self._index.get(key, default)

def __getitem__(self, key: str) -> T:
try:
return self._index[key]
except KeyError as e:
available = ", ".join(self._index.keys())
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e

def __contains__(self, key: object) -> bool:
return key in self._index

def __iter__(self) -> Iterator[str]:
return iter(self.items)

def __len__(self) -> int:
return len(self.items)

def __str__(self) -> str:
return f"{self.key_field}s: {list(self._index.keys())}"

@property
def names(self) -> tuple[str, ...]:
return tuple(self._index.keys())


@dataclass(frozen=True)
class Parameter(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
constraints: str | None = None


@dataclass(frozen=True)
class ParameterInfo(Info[Parameter]):
def __init__(self, parameters: list[Parameter]):
super().__init__(items=tuple(parameters), key_field="name")

def add(self, parameter: Parameter) -> "ParameterInfo":
# return a new ParameterInfo with parameter appended
return ParameterInfo(parameters=[*list(self.items), parameter])

def merge(self, other: "ParameterInfo") -> "ParameterInfo":
"""Combine parameters from two ParameterInfo objects."""
if not isinstance(other, ParameterInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate parameter names found: {overlapping}")

return ParameterInfo(parameters=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool


@dataclass(frozen=True)
class DataInfo(Info[Data]):
def __init__(self, data: list[Data]):
super().__init__(items=tuple(data), key_field="name")

@property
def needs_exogenous_data(self) -> bool:
return any(d.is_exogenous for d in self.items)

def __str__(self) -> str:
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"

def add(self, data: Data) -> "DataInfo":
# return a new DataInfo with data appended
return DataInfo(data=[*list(self.items), data])

def merge(self, other: "DataInfo") -> "DataInfo":
"""Combine data from two DataInfo objects."""
if not isinstance(other, DataInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate data names found: {overlapping}")

return DataInfo(data=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Coord(Property):
dimension: str
labels: tuple[str, ...]


@dataclass(frozen=True)
class CoordInfo(Info[Coord]):
def __init__(self, coords: list[Coord]):
super().__init__(items=tuple(coords), key_field="dimension")

def __str__(self) -> str:
base = "coordinates:"
for coord in self.items:
coord_str = str(coord)
indented = "\n".join(" " + line for line in coord_str.splitlines())
base += "\n" + indented + "\n"
return base

@classmethod
def default_coords_from_model(
cls, model: PyMCStateSpace
) -> (
Self
): # TODO: Need to figure out how to include Component type was causing circular import issues
states = tuple(model.state_names)
obs_states = tuple(model.observed_state_names)
shocks = tuple(model.shock_names)

dim_to_labels = (
(ALL_STATE_DIM, states),
(ALL_STATE_AUX_DIM, states),
(OBS_STATE_DIM, obs_states),
(OBS_STATE_AUX_DIM, obs_states),
(SHOCK_DIM, shocks),
(SHOCK_AUX_DIM, shocks),
)

coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
return cls(coords)

def to_dict(self):
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}

def add(self, coord: Coord) -> "CoordInfo":
# return a new CoordInfo with data appended
return CoordInfo(coords=[*list(self.items), coord])

def merge(self, other: "CoordInfo") -> "CoordInfo":
"""Combine data from two CoordInfo objects."""
if not isinstance(other, CoordInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate coord names found: {overlapping}")

return CoordInfo(coords=list(self.items) + list(other.items))


@dataclass(frozen=True)
class State(Property):
name: str
observed: bool
shared: bool


@dataclass(frozen=True)
class StateInfo(Info[State]):
def __init__(self, states: list[State]):
super().__init__(items=tuple(states), key_field="name")

def __str__(self) -> str:
return (
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
)

@property
def observed_states(self) -> tuple[State, ...]:
return tuple(s for s in self.items if s.observed)

def add(self, state: State) -> "StateInfo":
# return a new StateInfo with state appended
return StateInfo(states=[*list(self.items), state])

def merge(self, other: "StateInfo") -> "StateInfo":
"""Combine states from two StateInfo objects."""
if not isinstance(other, StateInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate state names found: {overlapping}")

return StateInfo(states=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Shock(Property):
name: str


@dataclass(frozen=True)
class ShockInfo(Info[Shock]):
def __init__(self, shocks: list[Shock]):
super().__init__(items=tuple(shocks), key_field="name")

def add(self, shock: Shock) -> "ShockInfo":
# return a new ShockInfo with shock appended
return ShockInfo(shocks=[*list(self.items), shock])

def merge(self, other: "ShockInfo") -> "ShockInfo":
"""Combine shocks from two ShockInfo objects."""
if not isinstance(other, ShockInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo")

overlapping = set(self.names) & set(other.names)
if overlapping:
raise ValueError(f"Duplicate shock names found: {overlapping}")

return ShockInfo(shocks=list(self.items) + list(other.items))
4 changes: 4 additions & 0 deletions pymc_extras/statespace/models/structural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent
from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError
from pymc_extras.statespace.models.structural.components.regression import RegressionComponent
from pymc_extras.statespace.models.structural.components.regression_dataclass import (
RegressionComponent as RegressionComponentDataClass,
)
from pymc_extras.statespace.models.structural.components.seasonality import (
FrequencySeasonality,
TimeSeasonality,
Expand All @@ -17,5 +20,6 @@
"LevelTrendComponent",
"MeasurementError",
"RegressionComponent",
"RegressionComponentDataClass",
"TimeSeasonality",
]
Loading
Loading