Source code for replay_trajectory_classification.decoder

"""Main classes for decoding trajectories from population spiking"""

from __future__ import annotations

from copy import deepcopy
from logging import getLogger
from typing import Optional, Union

import joblib
import numpy as np
import sklearn
import xarray as xr
from sklearn.base import BaseEstimator

from replay_trajectory_classification.continuous_state_transitions import (
    EmpiricalMovement,
    RandomWalk,
    RandomWalkDirection1,
    RandomWalkDirection2,
    Uniform,
)
from replay_trajectory_classification.core import (
    _acausal_decode,
    _acausal_decode_gpu,
    _causal_decode,
    _causal_decode_gpu,
    atleast_2d,
    get_centers,
    mask,
    scaled_likelihood,
)
from replay_trajectory_classification.environments import Environment
from replay_trajectory_classification.initial_conditions import UniformInitialConditions
from replay_trajectory_classification.likelihoods import (
    _SORTED_SPIKES_ALGORITHMS,
    _ClUSTERLESS_ALGORITHMS,
)

logger = getLogger(__name__)

sklearn.set_config(print_changed_only=False)

_DEFAULT_CLUSTERLESS_MODEL_KWARGS = {
    "mark_std": 24.0,
    "position_std": 6.0,
}

_DEFAULT_SORTED_SPIKES_MODEL_KWARGS = {
    "position_std": 6.0,
    "use_diffusion": False,
    "block_size": None,
}


class _DecoderBase(BaseEstimator):
    """Base class for decoder objects.

    Parameters
    ----------
    environment : Environment, optional
        The spatial environment to fit
    transition_type : EmpiricalMovement | RandomWalk | RandomWalkDirection1 | RandomWalkDirection2 | Uniform
        The continuous state transition matrix
    initial_conditions_type : UniformInitialConditions, optional
        The initial conditions class instance
    infer_track_interior : bool, optional
        Whether to infer the spatial geometry of track from position

    """

    def __init__(
        self,
        environment: Environment = Environment(environment_name=""),
        transition_type: Union[
            EmpiricalMovement,
            RandomWalk,
            RandomWalkDirection1,
            RandomWalkDirection2,
            Uniform,
        ] = RandomWalk(),
        initial_conditions_type: UniformInitialConditions = UniformInitialConditions(),
        infer_track_interior: bool = True,
    ):
        self.environment = environment
        self.transition_type = transition_type
        self.initial_conditions_type = initial_conditions_type
        self.infer_track_interior = infer_track_interior

    def fit_environment(self, position: np.ndarray) -> None:
        """Discretize the spatial environment into bins. Determine valid track positions.

        Parameters
        ----------
        position : np.ndarray, shape (n_time, n_position_dims)
            Position of the animal in the environment.
        """
        self.environment.fit_place_grid(
            position, infer_track_interior=self.infer_track_interior
        )

    def fit_initial_conditions(self):
        """Set the initial probability of position."""
        logger.info("Fitting initial conditions...")
        self.initial_conditions_ = self.initial_conditions_type.make_initial_conditions(
            [self.environment], [self.environment.environment_name]
        )[0]

    def fit_state_transition(
        self,
        position: np.ndarray,
        is_training: Optional[np.ndarray] = None,
        transition_type: Optional[
            Union[
                EmpiricalMovement,
                RandomWalk,
                RandomWalkDirection1,
                RandomWalkDirection2,
                Uniform,
            ]
        ] = None,
    ):
        logger.info("Fitting state transition...")

        if transition_type is not None:
            self.transition_type = transition_type

        if isinstance(self.transition_type, EmpiricalMovement):
            if is_training is None:
                is_training = np.ones((position.shape[0],), dtype=bool)
            is_training = np.asarray(is_training).squeeze()
            n_time = position.shape[0]
            encoding_group_labels = np.zeros((n_time,), dtype=np.int32)
            self.state_transition_ = self.transition_type.make_state_transition(
                (self.environment,),
                position,
                is_training,
                encoding_group_labels,
                environment_labels=None,
            )
        else:
            self.state_transition_ = self.transition_type.make_state_transition(
                (self.environment,)
            )

    def fit(self):
        """To be implemented by inheriting class"""
        raise NotImplementedError

    def predict(self):
        """To be implemented by inheriting class"""
        raise NotImplementedError

    def save_model(self, filename="model.pkl"):
        """Save the classifier to a pickled file.

        Parameters
        ----------
        filename : str, optional

        """
        joblib.dump(self, filename)

    @staticmethod
    def load_model(filename="model.pkl"):
        """Load the classifier from a file.

        Parameters
        ----------
        filename : str, optional

        Returns
        -------
        classifier instance

        """
        return joblib.load(filename)

    def copy(self):
        """Makes a copy of the classifier"""
        return deepcopy(self)

    def project_1D_position_to_2D(
        self, results: xr.Dataset, posterior_type="acausal_posterior"
    ) -> np.ndarray:
        """Project the 1D most probable position into the 2D track graph space.

        Only works for single environment.

        Parameters
        ----------
        results : xr.Dataset
        posterior_type : causal_posterior | acausal_posterior | likelihood

        Returns
        -------
        map_position2D : np.ndarray

        """
        map_position_ind = (
            results[posterior_type].sum("state").argmax("position").to_numpy().squeeze()
        )

        return self.environment.place_bin_centers_nodes_df_.iloc[map_position_ind][
            ["x_position", "y_position"]
        ].to_numpy()

    def _get_results(
        self,
        results: dict,
        n_time: int,
        time: Optional[np.ndarray] = None,
        is_compute_acausal: bool = True,
        use_gpu: bool = False,
    ):
        """Converts the results dict into a collection of labeled arrays.

        Parameters
        ----------
        results : dict
        n_time : int
        time : np.ndarray
        is_compute_acausal : bool
        use_gpu : bool

        Returns
        -------
        results : xr.Dataset

        """
        is_track_interior = self.environment.is_track_interior_.ravel(order="F")
        n_position_bins = is_track_interior.shape[0]
        st_interior_ind = np.ix_(is_track_interior, is_track_interior)

        logger.info("Estimating causal posterior...")
        if not use_gpu:
            results["causal_posterior"] = np.full(
                (n_time, n_position_bins), np.nan, dtype=float
            )
            (
                results["causal_posterior"][:, is_track_interior],
                data_log_likelihood,
            ) = _causal_decode(
                self.initial_conditions_[is_track_interior].astype(float),
                self.state_transition_[st_interior_ind].astype(float),
                results["likelihood"][:, is_track_interior].astype(float),
            )
        else:
            results["causal_posterior"] = np.full(
                (n_time, n_position_bins), np.nan, dtype=np.float32
            )
            (
                results["causal_posterior"][:, is_track_interior],
                data_log_likelihood,
            ) = _causal_decode_gpu(
                self.initial_conditions_[is_track_interior],
                self.state_transition_[st_interior_ind],
                results["likelihood"][:, is_track_interior],
            )

        if is_compute_acausal:
            logger.info("Estimating acausal posterior...")
            if not use_gpu:
                results["acausal_posterior"] = np.full(
                    (n_time, n_position_bins, 1), np.nan, dtype=float
                )
                results["acausal_posterior"][:, is_track_interior] = _acausal_decode(
                    results["causal_posterior"][
                        :, is_track_interior, np.newaxis
                    ].astype(float),
                    self.state_transition_[st_interior_ind].astype(float),
                )
            else:
                results["acausal_posterior"] = np.full(
                    (n_time, n_position_bins, 1), np.nan, dtype=np.float32
                )
                results["acausal_posterior"][:, is_track_interior] = (
                    _acausal_decode_gpu(
                        results["causal_posterior"][:, is_track_interior, np.newaxis],
                        self.state_transition_[st_interior_ind],
                    )
                )

        if time is None:
            time = np.arange(n_time)

        return self.convert_results_to_xarray(results, time, data_log_likelihood)

    def convert_results_to_xarray(
        self, results: dict, time: np.ndarray, data_log_likelihood: float
    ) -> xr.Dataset:
        """Converts the results dict into a collection of labeled arrays.

        Parameters
        ----------
        results : dict
        time : np.ndarray
        data_log_likelihood : float

        Returns
        -------
        results : xr.Dataset

        """
        n_position_dims = self.environment.place_bin_centers_.shape[1]
        n_time = time.shape[0]

        attrs = {"data_log_likelihood": data_log_likelihood}

        if n_position_dims > 1:
            dims = ["time", "x_position", "y_position"]
            coords = dict(
                time=time,
                x_position=get_centers(self.environment.edges_[0]),
                y_position=get_centers(self.environment.edges_[1]),
            )
        else:
            dims = ["time", "position"]
            coords = dict(
                time=time,
                position=get_centers(self.environment.edges_[0]),
            )
        new_shape = (n_time, *self.environment.centers_shape_)
        is_track_interior = self.environment.is_track_interior_.ravel(order="F")
        try:
            results = xr.Dataset(
                {
                    key: (
                        dims,
                        mask(value, is_track_interior).reshape(new_shape, order="F"),
                    )
                    for key, value in results.items()
                },
                coords=coords,
                attrs=attrs,
            )
        except ValueError:
            results = xr.Dataset(
                {
                    key: (
                        dims,
                        mask(value, is_track_interior).reshape(new_shape),
                    )
                    for key, value in results.items()
                },
                coords=coords,
                attrs=attrs,
            )

        return results


[docs] class SortedSpikesDecoder(_DecoderBase): def __init__( self, environment: Environment = Environment(environment_name=""), transition_type: Union[ EmpiricalMovement, RandomWalk, RandomWalkDirection1, RandomWalkDirection2, Uniform, ] = RandomWalk(), initial_conditions_type: UniformInitialConditions = UniformInitialConditions(), infer_track_interior: bool = True, sorted_spikes_algorithm: str = "spiking_likelihood_kde", sorted_spikes_algorithm_params: dict = _DEFAULT_SORTED_SPIKES_MODEL_KWARGS, ): """Decodes neural population representation of position from clustered cells. Parameters ---------- environment : Environment instance, optional The spatial environment to fit transition_type : transition matrix instance, optional Movement model for the continuous state discrete_transition_type : discrete transition instance, optional initial_conditions_type : initial conditions instance, optional The initial conditions class instance infer_track_interior : bool, optional Whether to infer the spatial geometry of track from position sorted_spikes_algorithm : str, optional The type of algorithm. See _SORTED_SPIKES_ALGORITHMS for keys sorted_spikes_algorithm_params : dict, optional Parameters for the algorithm. """ super().__init__( environment, transition_type, initial_conditions_type, infer_track_interior ) self.sorted_spikes_algorithm = sorted_spikes_algorithm self.sorted_spikes_algorithm_params = sorted_spikes_algorithm_params
[docs] def fit_place_fields( self, position: np.ndarray, spikes: np.ndarray, is_training: Optional[np.ndarray] = None, ): """Fits the place intensity function. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) Position of the animal. spikes : np.ndarray, (n_time, n_neurons) Binary indicator of whether there was a spike in a given time bin for a given neuron. is_training : np.ndarray, shape (n_time,), optional Boolean array to indicate which data should be included in fitting of place fields, by default None """ logger.info("Fitting place fields...") if is_training is None: is_training = np.ones((position.shape[0],), dtype=bool) is_training = np.asarray(is_training).squeeze() kwargs = self.sorted_spikes_algorithm_params if kwargs is None: kwargs = {} self.place_fields_ = _SORTED_SPIKES_ALGORITHMS[self.sorted_spikes_algorithm][0]( position[is_training], spikes[is_training], place_bin_centers=self.environment.place_bin_centers_, place_bin_edges=self.environment.place_bin_edges_, edges=self.environment.edges_, is_track_interior=self.environment.is_track_interior_, is_track_boundary=self.environment.is_track_boundary_, **kwargs, )
[docs] def plot_place_fields( self, sampling_frequency: int = 1, col_wrap: int = 5 ) -> xr.plot.FacetGrid: """Plots the fitted place fields for each neuron. Parameters ---------- sampling_frequency : int, optional Number of samples per second col_wrap : int, optional Number of columns in the subplot. Returns ------- g : xr.plot.FacetGrid instance """ try: g = ( self.place_fields_.unstack("position").where( self.environment.is_track_interior_ ) * sampling_frequency ).plot(x="x_position", y="y_position", col="neuron", col_wrap=col_wrap) except ValueError: g = (self.place_fields_ * sampling_frequency).plot( x="position", col="neuron", col_wrap=col_wrap ) return g
[docs] def fit( self, position: np.ndarray, spikes: np.ndarray, is_training: Optional[np.ndarray] = None, ): """Fit the spatial grid, initial conditions, place field model, and transition matrix. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) Position of the animal. spikes : np.ndarray, shape (n_time, n_neurons) Binary indicator of whether there was a spike in a given time bin for a given neuron. is_training : None or np.ndarray, shape (n_time), optional Boolean array to indicate which data should be included in fitting of place fields, by default None Returns ------- self """ position = atleast_2d(np.asarray(position)) spikes = np.asarray(spikes) self.fit_environment(position) self.fit_initial_conditions() self.fit_state_transition( position, is_training, transition_type=self.transition_type ) self.fit_place_fields(position, spikes, is_training) return self
[docs] def predict( self, spikes: np.ndarray, time: Optional[np.ndarray] = None, is_compute_acausal: bool = True, use_gpu: bool = False, ) -> xr.Dataset: """Predict the probability of spatial position from the spikes. Parameters ---------- spikes : np.ndarray, shape (n_time, n_neurons) Binary indicator of whether there was a spike in a given time bin for a given neuron. time : np.ndarray or None, shape (n_time,), optional Label the time axis with these values. is_compute_acausal : bool, optional If True, compute the acausal posterior. use_gpu : bool, optional Use GPU for the state space part of the model, not the likelihood. Returns ------- results : xarray.Dataset """ spikes = np.asarray(spikes) n_time = spikes.shape[0] logger.info("Estimating likelihood...") results = {} results["likelihood"] = scaled_likelihood( _SORTED_SPIKES_ALGORITHMS[self.sorted_spikes_algorithm][1]( spikes, np.asarray(self.place_fields_) ) ) return self._get_results(results, n_time, time, is_compute_acausal, use_gpu)
[docs] class ClusterlessDecoder(_DecoderBase): """Classifies neural population representation of position from multiunit spikes and waveforms. Parameters ---------- environment : Environment, optional The spatial environment to fit transition_type : EmpiricalMovement | RandomWalk | RandomWalkDirection1 | RandomWalkDirection2 | Uniform The continuous state transition matrix initial_conditions_type : UniformInitialConditions, optional The initial conditions class instance infer_track_interior : bool, optional Whether to infer the spatial geometry of track from position clusterless_algorithm : str The type of clusterless algorithm. See _ClUSTERLESS_ALGORITHMS for keys clusterless_algorithm_params : dict Parameters for the clusterless algorithms. """ def __init__( self, environment: Environment = Environment(environment_name=""), transition_type: Union[ EmpiricalMovement, RandomWalk, RandomWalkDirection1, RandomWalkDirection2, Uniform, ] = RandomWalk(), initial_conditions_type: UniformInitialConditions = UniformInitialConditions(), infer_track_interior: bool = True, clusterless_algorithm: str = "multiunit_likelihood", clusterless_algorithm_params: dict = _DEFAULT_CLUSTERLESS_MODEL_KWARGS, ): super().__init__( environment, transition_type, initial_conditions_type, infer_track_interior ) self.clusterless_algorithm = clusterless_algorithm self.clusterless_algorithm_params = clusterless_algorithm_params
[docs] def fit_multiunits( self, position: np.ndarray, multiunits: np.ndarray, is_training: Optional[np.ndarray] = None, ): """ Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) multiunits : np.ndarray, shape (n_time, n_marks, n_electrodes) is_training : None or array_like, shape (n_time,) """ logger.info("Fitting multiunits...") if is_training is None: is_training = np.ones((position.shape[0],), dtype=bool) is_training = np.asarray(is_training).squeeze() kwargs = self.clusterless_algorithm_params if kwargs is None: kwargs = {} self.encoding_model_ = _ClUSTERLESS_ALGORITHMS[self.clusterless_algorithm][0]( position=position[is_training], multiunits=multiunits[is_training], place_bin_centers=self.environment.place_bin_centers_, is_track_interior=self.environment.is_track_interior_.ravel(order="F"), **kwargs, )
[docs] def fit( self, position: np.ndarray, multiunits: np.ndarray, is_training: Optional[np.ndarray] = None, ): """Fit the spatial grid, initial conditions, place field model, and transition matrices. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) Position of the animal. multiunits : np.ndarray, shape (n_time, n_marks, n_electrodes) Array where spikes are indicated by non-Nan values that correspond to the waveform features for each electrode. is_training : None or np.ndarray, shape (n_time), optional Boolean array to indicate which data should be included in fitting of place fields, by default None Returns ------- self """ position = atleast_2d(np.asarray(position)) multiunits = np.asarray(multiunits) self.fit_environment(position) self.fit_initial_conditions() self.fit_state_transition( position, is_training, transition_type=self.transition_type ) self.fit_multiunits(position, multiunits, is_training) return self
[docs] def predict( self, multiunits: np.ndarray, time: Optional[np.ndarray] = None, is_compute_acausal: bool = True, use_gpu: bool = False, ) -> xr.Dataset: """Predict the probability of spatial position and category from the multiunit spikes and waveforms. Parameters ---------- multiunits : np.ndarray, shape (n_time, n_marks, n_electrodes) Array where spikes are indicated by non-Nan values that correspond to the waveform features for each electrode. time : np.ndarray or None, shape (n_time,), optional Label the time axis with these values. is_compute_acausal : bool, optional If True, compute the acausal posterior. use_gpu : bool, optional Use GPU for the state space part of the model, not the likelihood. Returns ------- results : xarray.Dataset """ multiunits = np.asarray(multiunits) is_track_interior = self.environment.is_track_interior_.ravel(order="F") n_time = multiunits.shape[0] logger.info("Estimating likelihood...") results = {} results["likelihood"] = scaled_likelihood( _ClUSTERLESS_ALGORITHMS[self.clusterless_algorithm][1]( multiunits=multiunits, place_bin_centers=self.environment.place_bin_centers_, is_track_interior=is_track_interior, **self.encoding_model_, ) ) return self._get_results(results, n_time, time, is_compute_acausal, use_gpu)