replay_trajectory_classification.classifier.SortedSpikesClassifier#

class SortedSpikesClassifier(environments: list[Environment] = Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0), observation_models: ObservationModel | None = None, continuous_transition_types: list[list[EmpiricalMovement | RandomWalk | RandomWalkDirection1 | RandomWalkDirection2 | Uniform]] = [[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use_diffusion=False), Uniform(environment_name='', environment2_name=None)], [Uniform(environment_name='', environment2_name=None), Uniform(environment_name='', environment2_name=None)]], discrete_transition_type: DiagonalDiscrete | RandomDiscrete | UniformDiscrete | UserDefinedDiscrete = DiagonalDiscrete(diagonal_value=0.98), initial_conditions_type: UniformInitialConditions | UniformOneEnvironmentInitialConditions = UniformInitialConditions(), infer_track_interior: bool = True, sorted_spikes_algorithm: str = 'spiking_likelihood_kde', sorted_spikes_algorithm_params: dict = {'block_size': None, 'position_std': 6.0, 'use_diffusion': False})[source]#

Bases: _ClassifierBase

Classifies neural population representation of position and trajectory from clustered cells.

Parameters:
  • environments (list of Environment instances, optional) – The spatial environment(s) to fit

  • observation_models (ObservationModel instance, optional) – Links environments and encoding group

  • continuous_transition_types (list of list of transition matrix instances, optional) – Types of transition models, by default _DEFAULT_CONTINUOUS_TRANSITIONS Length correspond to number of discrete states.

  • discrete_transition_type (discrete transition instance, optional)

  • initial_conditions_type (initial conditions instance, optional) – The initial conditions class instance

  • infer_track_interior (bool, optional) – Whether to infer the spatial geometry of track from position

  • sorted_spikes_algorithm (str, optional) – The type of algorithm. See _SORTED_SPIKES_ALGORITHMS for keys

  • sorted_spikes_algorithm_params (dict, optional) – Parameters for the algorithm.

Methods

convert_2D_to_1D_results(results2D, ...)

Projects a 2D position decoding result to a 1D decoding result.

copy()

Makes a copy of the classifier

estimate_parameters(fit_args, predict_args)

Estimate the intial conditions and/or discrete transition matrix of the model.

fit(position, spikes[, is_training, ...])

Fit the spatial grid, initial conditions, place field model, and transition matrices.

fit_continuous_state_transition([...])

Constructs the transition matrices for the continuous states.

fit_discrete_state_transition()

Constructs the transition matrix for the discrete states.

fit_environments(position[, environment_labels])

Fits the Environment class on the position data to get information about the spatial environment.

fit_initial_conditions()

Constructs the initial probability for the state and each spatial bin.

fit_place_fields(position, spikes[, ...])

Fits the place intensity function for each encoding group and environment.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

load_model([filename])

Load the classifier from a file.

plot_discrete_state_transition([...])

Plot heatmap of discrete transition matrix.

plot_place_fields([sampling_frequency, figsize])

Plots place fields for each neuron.

predict(spikes[, time, is_compute_acausal, ...])

Predict the probability of spatial position and category from the spikes.

predict_proba(results)

Predicts the probability of each state.

project_1D_position_to_2D(results[, ...])

Project the 1D most probable position into the 2D track graph space.

save_model([filename])

Save the classifier to a pickled file.

set_fit_request(*[, encoding_group_labels, ...])

Request metadata passed to the fit method.

set_params(**params)

Set the parameters of this estimator.

set_predict_request(*[, is_compute_acausal, ...])

Request metadata passed to the predict method.

Methods

convert_2D_to_1D_results

Projects a 2D position decoding result to a 1D decoding result.

copy

Makes a copy of the classifier

estimate_parameters

Estimate the intial conditions and/or discrete transition matrix of the model.

fit

Fit the spatial grid, initial conditions, place field model, and transition matrices.

fit_continuous_state_transition

Constructs the transition matrices for the continuous states.

fit_discrete_state_transition

Constructs the transition matrix for the discrete states.

fit_environments

Fits the Environment class on the position data to get information about the spatial environment.

fit_initial_conditions

Constructs the initial probability for the state and each spatial bin.

fit_place_fields

Fits the place intensity function for each encoding group and environment.

get_metadata_routing

Get metadata routing of this object.

get_params

Get parameters for this estimator.

load_model

Load the classifier from a file.

plot_discrete_state_transition

Plot heatmap of discrete transition matrix.

plot_place_fields

Plots place fields for each neuron.

predict

Predict the probability of spatial position and category from the spikes.

predict_proba

Predicts the probability of each state.

project_1D_position_to_2D

Project the 1D most probable position into the 2D track graph space.

save_model

Save the classifier to a pickled file.

set_fit_request

Request metadata passed to the fit method.

set_params

Set the parameters of this estimator.

set_predict_request

Request metadata passed to the predict method.

static convert_2D_to_1D_results(results2D: Dataset, environment2D: Environment, environment1D: Environment) Dataset#

Projects a 2D position decoding result to a 1D decoding result.

Parameters:
Returns:

results1D

Return type:

xarray.core.dataset.Dataset

Examples

results = classifier.predict(spikes) environment1D = (

Environment(track_graph=track_graph,

place_bin_size=2.0, edge_order=edge_order, edge_spacing=edge_spacing)

.fit_place_grid())

results1D = convert_2D_to_1D_results(

results, classifier.environments[0], environment1D)

copy()#

Makes a copy of the classifier

estimate_parameters(fit_args: dict, predict_args: dict, tolerance: float = 0.0001, max_iter: int = 10, verbose: bool = True, store_likelihood: bool = True, estimate_initial_conditions: bool = True, estimate_discrete_transition: bool = True) tuple[Dataset, list[float]]#

Estimate the intial conditions and/or discrete transition matrix of the model.

Parameters:
  • fit_args (dict) – Arguments that would be passed to the fit method.

  • predict_args (dict) – Arguments that would be passed to the predict method.

  • tolerance (float, optional) – Smallest change in data log likelihood for there to be no change in likelihood, by default 1e-4

  • max_iter (int, optional) – Maximum number of iterations, by default 10

  • verbose (bool, optional) – Log results of each iteration, by default True

  • store_likelihood (bool, optional) – If True, don’t reestimate the likelihood, by default True

  • estimate_initial_conditions (bool, optional) – If True, estimate the initial conditions, by default True

  • estimate_discrete_transition (bool, optional) – If True, estimate the discrete state transition, by default True

Returns:

  • results (xr.Dataset)

  • data_log_likelihoods (list, len (n_iter,)) – The data log likelihood of each iteration

fit(position: ndarray, spikes: ndarray, is_training: ndarray | None = None, encoding_group_labels: ndarray | None = None, environment_labels: ndarray | None = None)[source]#

Fit the spatial grid, initial conditions, place field model, and transition matrices.

Parameters:
  • position (np.ndarray, shape (n_time, n_position_dims)) – Position of the animal.

  • spikes (np.ndarray, shape (n_time, n_neurons)) – Binary indicator of whether there was a spike in a given time bin for a given neuron.

  • is_training (None or np.ndarray, shape (n_time), optional) – Boolean array to indicate which data should be included in fitting of place fields, by default None

  • encoding_group_labels (None or np.ndarray, shape (n_time,)) – Label for the corresponding encoding group for each time point

  • environment_labels (None or np.ndarray, shape (n_time,)) – Label for the corresponding environment for each time point

Return type:

self

fit_continuous_state_transition(continuous_transition_types: list[list[EmpiricalMovement | RandomWalk | RandomWalkDirection1 | RandomWalkDirection2 | Uniform]] = [[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use_diffusion=False), Uniform(environment_name='', environment2_name=None)], [Uniform(environment_name='', environment2_name=None), Uniform(environment_name='', environment2_name=None)]], position: ndarray | None = None, is_training: ndarray | None = None, encoding_group_labels: ndarray | None = None, environment_labels: ndarray | None = None) None#

Constructs the transition matrices for the continuous states.

Parameters:
  • continuous_transition_types (list of list of transition matrix instances, optional) – Types of transition models, by default _DEFAULT_CONTINUOUS_TRANSITIONS

  • position (np.ndarray, optional) – Position of the animal in the environment, by default None

  • is_training (np.ndarray, optional) – Boolean array that determines what data to train the place fields on, by default None

  • encoding_group_labels (np.ndarray, shape (n_time,), optional) – If place fields should correspond to each state, label each time point with the group name For example, Some points could correspond to inbound trajectories and some outbound, by default None

  • environment_labels (np.ndarray, shape (n_time,), optional) – If there are multiple environments, label each time point with the environment name, by default None

fit_discrete_state_transition()#

Constructs the transition matrix for the discrete states.

fit_environments(position: ndarray, environment_labels: ndarray | None = None) None#

Fits the Environment class on the position data to get information about the spatial environment.

Parameters:
  • position (np.ndarray, shape (n_time, n_position_dims))

  • environment_labels (np.ndarray, optional, shape (n_time,)) – Labels for each time points about which environment it corresponds to, by default None

fit_initial_conditions()#

Constructs the initial probability for the state and each spatial bin.

fit_place_fields(position: ndarray, spikes: ndarray, is_training: ndarray | None = None, encoding_group_labels: ndarray | None = None, environment_labels: ndarray | None = None) None[source]#

Fits the place intensity function for each encoding group and environment.

Parameters:
  • position (np.ndarray, shape (n_time, n_position_dims)) – Position of the animal.

  • spikes (np.ndarray, (n_time, n_neurons)) – Binary indicator of whether there was a spike in a given time bin for a given neuron.

  • is_training (np.ndarray, shape (n_time,), optional) – Boolean array to indicate which data should be included in fitting of place fields, by default None

  • encoding_group_labels (np.ndarray, shape (n_time,), optional) – Label for the corresponding encoding group for each time point

  • environment_labels (np.ndarray, shape (n_time,), optional) – Label for the corresponding environment for each time point

get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

Get parameters for this estimator.

Parameters:

deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:

params – Parameter names mapped to their values.

Return type:

dict

static load_model(filename: str = 'model.pkl')#

Load the classifier from a file.

Parameters:

filename (str, optional)

Return type:

classifier instance

plot_discrete_state_transition(state_names: list[str] | None = None, cmap: str = 'Oranges', ax: Axes | None = None, convert_to_seconds: bool = False, sampling_frequency: int = 1) None#

Plot heatmap of discrete transition matrix.

Parameters:
  • state_names (list[str], optional) – Names corresponding to each discrete state, by default None

  • cmap (str, optional) – matplotlib colormap, by default “Oranges”

  • ax (matplotlib.axes.Axes, optional) – Plotting axis, by default plots to current axis

  • convert_to_seconds (bool, optional) – Convert the probabilities of state to expected duration of state, by default False

  • sampling_frequency (int, optional) – Number of samples per second, by default 1

plot_place_fields(sampling_frequency: int = 1, figsize: tuple[float, float] = (10.0, 7.0))[source]#

Plots place fields for each neuron.

Parameters:
  • sampling_frequency (int, optional) – samples per second, by default 1

  • figsize (tuple, optional) – figure dimensions, by default (10, 7)

predict(spikes: ndarray, time: ndarray | None = None, is_compute_acausal: bool = True, use_gpu: bool = False, state_names: list[str] | None = None, store_likelihood: bool = False) Dataset[source]#

Predict the probability of spatial position and category from the spikes.

Parameters:
  • spikes (np.ndarray, shape (n_time, n_neurons)) – Binary indicator of whether there was a spike in a given time bin for a given neuron.

  • time (np.ndarray or None, shape (n_time,), optional) – Label the time axis with these values.

  • is_compute_acausal (bool, optional) – If True, compute the acausal posterior.

  • use_gpu (bool, optional) – Use GPU for the state space part of the model, not the likelihood.

  • state_names (None or array_like, shape (n_states,)) – Label the discrete states.

  • store_likelihood (bool, optional) – Store the likelihood to reuse in next computation.

Returns:

results

Return type:

xarray.Dataset

static predict_proba(results: Dataset) Dataset#

Predicts the probability of each state.

Parameters:

results (xr.Dataset)

Returns:

results

Return type:

xr.Dataset

project_1D_position_to_2D(results: Dataset, posterior_type='acausal_posterior') ndarray#

Project the 1D most probable position into the 2D track graph space.

Only works for single environment.

Parameters:
  • results (xr.Dataset)

  • posterior_type (causal_posterior | acausal_posterior | likelihood)

Returns:

map_position2D

Return type:

np.ndarray

save_model(filename: str = 'model.pkl') None#

Save the classifier to a pickled file.

Parameters:

filename (str, optional)

set_fit_request(*, encoding_group_labels: bool | None | str = '$UNCHANGED$', environment_labels: bool | None | str = '$UNCHANGED$', is_training: bool | None | str = '$UNCHANGED$', position: bool | None | str = '$UNCHANGED$', spikes: bool | None | str = '$UNCHANGED$') SortedSpikesClassifier#

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
  • encoding_group_labels (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for encoding_group_labels parameter in fit.

  • environment_labels (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for environment_labels parameter in fit.

  • is_training (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for is_training parameter in fit.

  • position (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for position parameter in fit.

  • spikes (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for spikes parameter in fit.

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:

**params (dict) – Estimator parameters.

Returns:

self – Estimator instance.

Return type:

estimator instance

set_predict_request(*, is_compute_acausal: bool | None | str = '$UNCHANGED$', spikes: bool | None | str = '$UNCHANGED$', state_names: bool | None | str = '$UNCHANGED$', store_likelihood: bool | None | str = '$UNCHANGED$', time: bool | None | str = '$UNCHANGED$', use_gpu: bool | None | str = '$UNCHANGED$') SortedSpikesClassifier#

Request metadata passed to the predict method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
  • is_compute_acausal (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for is_compute_acausal parameter in predict.

  • spikes (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for spikes parameter in predict.

  • state_names (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for state_names parameter in predict.

  • store_likelihood (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for store_likelihood parameter in predict.

  • time (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for time parameter in predict.

  • use_gpu (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for use_gpu parameter in predict.

Returns:

self – The updated object.

Return type:

object