Source code for replay_trajectory_classification.initial_conditions
"""Classes for constructing the initial conditions for the state space models."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from replay_trajectory_classification.environments import Environment
[docs]
@dataclass
class UniformInitialConditions:
"""Initial conditions where all discrete states and position bins are
equally likely."""
[docs]
def make_initial_conditions(
self, environments: tuple[Environment], environment_names_to_state: tuple[str]
) -> list[np.ndarray]:
"""Creates initial conditions array
Parameters
----------
environments : tuple[Environment]
Spatial environments in the model
environment_names_to_state : tuple[str]
Mapping of environment names to state
Returns
-------
initial_conditions : list of arrays
"""
n_total_place_bins = 0
initial_conditions = []
for environment_name in environment_names_to_state:
is_track_interior = environments[
environments.index(environment_name)
].is_track_interior_.ravel(order="F")
n_total_place_bins += is_track_interior.sum()
initial_conditions.append(is_track_interior)
return [ic / n_total_place_bins for ic in initial_conditions]
[docs]
@dataclass
class UniformOneEnvironmentInitialConditions:
"""Initial conditions where all position bins are
equally likely for one environment and zero for other environments."""
environment_name: str = ""
[docs]
def make_initial_conditions(
self, environments: tuple[Environment], environment_names_to_state: tuple[str]
) -> list[np.ndarray]:
"""Creates initial conditions array
Parameters
----------
environments : tuple[Environment]
Spatial environments in the model
environment_names_to_state : tuple[str]
Mapping of environment names to state
Returns
-------
initial_conditions : list of arrays
"""
n_total_place_bins = 0
initial_conditions = []
for environment_name in environment_names_to_state:
is_track_interior = environments[
environments.index(environment_name)
].is_track_interior_.ravel(order="F")
if self.environment_name == environment_name:
initial_conditions.append(is_track_interior)
n_total_place_bins += is_track_interior.sum()
else:
initial_conditions.append(np.zeros_like(is_track_interior))
return [ic / n_total_place_bins for ic in initial_conditions]