Source code for replay_trajectory_classification.discrete_state_transitions
"""Classes to generate transitions between categories."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import xarray as xr
[docs]
@dataclass
class DiagonalDiscrete:
"""Transition matrix with `diagonal_value` on the value for n_states
Off-diagonals are probability: (1 - `diagonal_value`) / (`n_states` - 1)
Attributes
----------
diagonal_value : float, optional
"""
diagonal_value: float = 0.98
[docs]
def make_state_transition(self, n_states: int) -> np.ndarray:
"""Makes discrete state transition matrix.
Parameters
----------
n_states : int
Returns
-------
discrete_state_transition : np.ndarray, shape (n_states, n_states)
"""
strong_diagonal = np.identity(n_states) * self.diagonal_value
is_off_diag = ~np.identity(n_states, dtype=bool)
strong_diagonal[is_off_diag] = (1 - self.diagonal_value) / (n_states - 1)
self.state_transition_ = strong_diagonal
return self.state_transition_
[docs]
@dataclass
class RandomDiscrete:
"""All state transitions are random"""
[docs]
def make_state_transition(self, n_states: int) -> np.ndarray:
"""Makes discrete state transition matrix.
Parameters
----------
n_states : int
Returns
-------
discrete_state_transition : np.ndarray, shape (n_states, n_states)
"""
state_transition = np.random.random_sample(n_states, n_states)
state_transition /= state_transition.sum(axis=1, keepdims=True)
self.state_transition_ = state_transition
return self.state_transition_
[docs]
@dataclass
class UserDefinedDiscrete:
"""State transitions are provided by user.
Attributes
----------
state_transition : np.ndarray, shape (n_states, n_states)
"""
state_transition_: np.ndarray
[docs]
def make_state_transition(self, n_states: int) -> np.ndarray:
"""Makes discrete state transition matrix.
Parameters
----------
n_states : int
Returns
-------
discrete_state_transition : np.ndarray, shape (n_states, n_states)
"""
return self.state_transition_
[docs]
def expected_duration(
discrete_state_transition: np.ndarray, sampling_frequency: int = 1
):
"""The average duration of each discrete state if it follows
a geometric distribution.
Use `sampling_frequency` to convert duration to seconds. Time is in
number of samples by default.
Parameters
----------
discrete_state_transition : np.ndarray, shape (n_states, n_states)
sampling_frequency : int, optional
Returns
-------
duration : np.ndarray, shape (n_states)
"""
self_transitions = np.diag(discrete_state_transition)
return (1 / (1 - self_transitions)) / sampling_frequency
[docs]
def estimate_discrete_state_transition(
classifier,
results: xr.Dataset,
) -> np.ndarray:
"""Estimate a new discrete transition matrix given the old one and updated smoother results.
Parameters
----------
classifier : ClusterlessClassifier or SortedSpikesClassifier instance
results : xr.Dataset
Returns
-------
new_transition_matrix : np.ndarray, shape (n_states, n_states)
"""
likelihood = results.likelihood.sum("position").values
causal_prob = results.causal_posterior.sum("position").values
acausal_prob = results.acausal_posterior.sum("position").values
transition_matrix = classifier.discrete_state_transition_
n_time, n_states = causal_prob.shape
# probability of state 1 in time t and state 2 in time t+1
xi = np.zeros((n_time - 1, n_states, n_states))
for from_state in range(n_states):
for to_state in range(n_states):
xi[:, from_state, to_state] = (
causal_prob[:-1, from_state]
* likelihood[1:, to_state]
* acausal_prob[1:, to_state]
* transition_matrix[from_state, to_state]
/ (causal_prob[1:, to_state] + np.spacing(1))
)
xi = xi / xi.sum(axis=(1, 2), keepdims=True)
summed_xi = xi.sum(axis=0)
new_transition_matrix = summed_xi / summed_xi.sum(axis=1, keepdims=True)
return new_transition_matrix