Classifying with Clusterless Spikes#

Set Plotting Defaults#

import logging
import os

import matplotlib
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

STATE_COLORS = {
    'stationary': '#9f043a',
    'fragmented': '#ff6944',
    'continuous': '#521b65',
    'stationary-continuous-mix': '#61c5e6',
    'fragmented-continuous-mix': '#2a586a',
    '': '#c7c7c7',
}

# Set background and fontsize
rc_params = {
    'pdf.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'ps.fonttype': 42, # Make fonts editable in Adobe Illustrator
    'axes.labelcolor': '#222222',
    'text.color': '#222222',
    }
sns.set(style="white", context='paper', rc=rc_params,
        font_scale=1.3)

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

Simulate Clusterless Data#

from replay_trajectory_classification.clusterless_simulation import make_simulated_run_data

(time, linear_distance, sampling_frequency,
 multiunits, multiunits_spikes) = make_simulated_run_data()

spike_ind, neuron_ind = np.nonzero(multiunits_spikes)

fig, axes = plt.subplots(7, 1, figsize=(12, 12), constrained_layout=True, sharex=True)
axes[0].plot(time, linear_distance, linewidth=3)
axes[0].set_ylabel('Position (cm)')

axes[1].scatter(time[spike_ind], neuron_ind + 1, color='black', s=2)
axes[1].set_yticks((0, multiunits_spikes.shape[1]))
axes[1].set_ylabel('Tetrode Index')

for tetrode_ind in range(multiunits.shape[-1]):
    axes[2 + tetrode_ind].scatter(time, multiunits[:, 0, tetrode_ind], s=1)
    axes[2 + tetrode_ind].set_ylabel(f'Tetrode {tetrode_ind + 1} \n Channel 1 \n Spike Amplitude')

sns.despine()
axes[-1].set_xlabel('Time (s)')
axes[-1].set_xlim((time.min(), time.max()))
WARNING:replay_trajectory_classification.core:Cupy is not installed or GPU is not detected. Ignore this message if not using GPU
(0.0, 349.999)
../../_images/a6fe3e7f03b5da46d8ea74ae8b302b9d0d0a5b77b3aba94eebffe337d010631a.png
plt.scatter(multiunits[:, 0, 0], multiunits[:, 1, 0])
plt.ylabel('Spike Amplitude 1')
plt.xlabel('Spike Amplitude 2');
../../_images/e2f840eb209b22a94eb0c03b7f52372affab299a6f7ddb8dcd98e5f32336e6a0.png

Fit Clusterless Classifier#

from replay_trajectory_classification import (
    ClusterlessClassifier, Environment, 
    RandomWalk, Uniform, Identity, estimate_movement_var)

movement_var = estimate_movement_var(linear_distance, sampling_frequency)


# If your marks are integers, use this algorithm because it is much faster
clusterless_algorithm = 'multiunit_likelihood'
clusterless_algorithm_params = {
    'mark_std': 1.0,
    'position_std': 12.5,
}

environment = Environment(place_bin_size=np.sqrt(movement_var))

continuous_transition_types = [[RandomWalk(movement_var=movement_var * 120),  Uniform(), Identity()],
                                [Uniform(),                                   Uniform(), Uniform()],
                                [RandomWalk(movement_var=movement_var * 120), Uniform(), Identity()],
                               ]


classifier = ClusterlessClassifier(
    environments=environment,
    continuous_transition_types=continuous_transition_types,
    clusterless_algorithm=clusterless_algorithm,
    clusterless_algorithm_params=clusterless_algorithm_params)
classifier.fit(linear_distance, multiunits)
INFO:replay_trajectory_classification.classifier:Fitting initial conditions...
INFO:replay_trajectory_classification.classifier:Fitting continuous state transition...
INFO:replay_trajectory_classification.classifier:Fitting discrete state transition
INFO:replay_trajectory_classification.classifier:Fitting multiunits...
ClusterlessClassifier(clusterless_algorithm='multiunit_likelihood',
                      clusterless_algorithm_params={'mark_std': 1.0,
                                                    'position_std': 12.5},
                      continuous_transition_types=[[RandomWalk(environment_name='', movement_var=33.31010499585128, movement_mean=0.0, use_diffusion=False),
                                                    Uniform(environment_name='', environment2_name=None),
                                                    Identity(environment_name='')],
                                                   [Un...
                      environments=(Environment(environment_name='', place_bin_size=0.5268626085601071, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),),
                      infer_track_interior=True,
                      initial_conditions_type=UniformInitialConditions(),
                      observation_models=(ObservationModel(environment_name='', encoding_group=0),
                                          ObservationModel(environment_name='', encoding_group=0),
                                          ObservationModel(environment_name='', encoding_group=0)))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
fig, axes = plt.subplots(3, 3, figsize=(9, 9),
                         sharex=True, sharey=True,
                         constrained_layout=True)
continuous_transition_types = np.asarray(classifier.continuous_transition_types)
x, y = np.meshgrid(environment.place_bin_edges_, environment.place_bin_edges_)
state_names = ['continuous', 'fragmented', 'stationary']
bin_min, bin_max = linear_distance.min(), linear_distance.max()

for j in range(3):
    for k in range(3):
        im = axes[j, k].pcolormesh(x, y, classifier.continuous_state_transition_[j, k], cmap='Blues',
                                   vmin=0.0, vmax=classifier.continuous_state_transition_[j, k].max())
        # axes[j, k].text(1, 175, continuous_transition_types[j, k].replace('_', ' '), color='grey',
        #                 va='top', fontsize=14)
        axes[j, k].set_title(f'{state_names[j]}{state_names[k]}')
        axes[j, k].set_xticks((bin_min, bin_max))
        axes[j, k].set_yticks((bin_min, bin_max))
        
plt.xlim((bin_min, bin_max))
plt.ylim((bin_min, bin_max))
axes[1, 0].set_ylabel(r'$position_{t-1}$')
axes[-1, 1].set_xlabel(r'$position_{t}$')
plt.suptitle('Continuous State Transition', y=1.04, fontsize=22)
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), ticks=[0, 1], label='Probability')
cbar.ax.set_yticklabels(['0', 'Max'])
sns.despine()
../../_images/a795e84c0183985cd1ec5070a0392b4f0b9b2103271e0262acee56389b169546.png
state_names = ['continuous', 'fragmented', 'stationary']
classifier.plot_discrete_state_transition(state_names)
../../_images/e453dc04bf232798d3d297321866ff300a86e3360cee0c080e657f3a930c8057.png

Test classifier on different replay types#

def plot_classification(replay_time, test_multiunits, results):
    fig, axes = plt.subplots(3, 1, figsize=(12, 10), constrained_layout=True, sharex=True)
    test_multiunit_spikes = np.any(~np.isnan(test_multiunits), axis=1)
    spike_time_ind, neuron_ind = np.nonzero(test_multiunit_spikes)
    axes[0].scatter(replay_time[spike_time_ind], neuron_ind, color='black')
    axes[0].set_yticks((0, test_multiunit_spikes.shape[1]))
    axes[0].set_ylabel('Tetrode Index')
    
    replay_probabilities = results.acausal_posterior.sum('position')
    for state, probability in replay_probabilities.groupby('state'):
        probability.plot(x='time', color=STATE_COLORS[state], linewidth=3,
                         ax=axes[1], label=state)
    axes[1].set_title('')
    axes[1].set_ylabel('Probability')
    axes[1].set_ylim((0.0, 1.05))
    axes[1].legend()
    
    results.acausal_posterior.sum('state').plot(
            x='time', y='position', robust=True, vmin=0.0, ax=axes[2])
    plt.xlim((replay_time.min(), replay_time.max()))
    sns.despine()

Continuous#

from replay_trajectory_classification.clusterless_simulation import make_continuous_replay

replay_time, test_multiunits = make_continuous_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/5c221f1294fe9937be944124b500ddd7fbb8571103564433ebdf33b3daf2695c.png

Stationary#

from replay_trajectory_classification.clusterless_simulation import make_hover_replay

replay_time, test_multiunits = make_hover_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/b1e08a251afb7b6e59519f7d8b7e1eac81bb4ac5ee604487803e994ddcfb2661.png

Fragmented#

from replay_trajectory_classification.clusterless_simulation import make_fragmented_replay

replay_time, test_multiunits = make_fragmented_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/109d824b6dd56513c09b4f81d51937b106479e22869261c3c1d3c801989937fe.png

Stationary-Continuous-Stationary#

from replay_trajectory_classification.clusterless_simulation import make_hover_continuous_hover_replay

replay_time, test_multiunits = make_hover_continuous_hover_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/2e782a3448783837d47b743f4b27adb6caf3f2feab412522122a0734ee3d6f09.png

Fragmented-Hover-Fragmented#

from replay_trajectory_classification.clusterless_simulation import make_fragmented_hover_fragmented_replay

replay_time, test_multiunits = make_fragmented_hover_fragmented_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/40c9723eaaf9e427831f3b0c69578dcff0c34b4b5b3e33ca1be0592087968f14.png

Fragmented-Continuous-Fragmented#

from replay_trajectory_classification.clusterless_simulation import make_fragmented_continuous_fragmented_replay

replay_time, test_multiunits = make_fragmented_continuous_fragmented_replay()

results = classifier.predict(test_multiunits, time=replay_time, state_names=state_names)
plot_classification(replay_time, test_multiunits, results)
INFO:replay_trajectory_classification.classifier:Estimating likelihood...
INFO:replay_trajectory_classification.classifier:Estimating causal posterior...
INFO:replay_trajectory_classification.classifier:Estimating acausal posterior...
../../_images/0a8187555d88e8f5a6db15d9ff43a12c1b14182c8b7d676aa36d052ab6852b83.png