Decoding with Sorted Spikes#

import logging
import os

import matplotlib
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# Make analysis reproducible
np.random.seed(0)

# Enable logging
logging.basicConfig(level=logging.INFO)

Create Simulated Data#

from replay_trajectory_classification.sorted_spikes_simulation import make_simulated_run_data

time, position, sampling_frequency, spikes, place_fields = make_simulated_run_data()
WARNING:replay_trajectory_classification.core:Cupy is not installed or GPU is not detected. Ignore this message if not using GPU

Plot the spikes and the position:

MM_TO_INCHES = 1.0 / 25.4
ONE_COLUMN = 89.0 * MM_TO_INCHES
ONE_AND_HALF_COLUMN = 140.0 * MM_TO_INCHES
TWO_COLUMN = 178.0 * MM_TO_INCHES
PAGE_HEIGHT = 247.0 * MM_TO_INCHES
GOLDEN_RATIO = (np.sqrt(5) - 1.0) / 2.0

STATE_COLORS = {
    'stationary': '#9f043a',
    'fragmented': '#ff6944',
    'continuous': '#521b65',
    'stationary-continuous-mix': '#61c5e6',
    'fragmented-continuous-mix': '#2a586a',
    '': '#c7c7c7',
}


spike_ind, neuron_ind = np.nonzero(spikes)
cmap = plt.get_cmap('tab20')

fig, axes = plt.subplots(3, 1, figsize=(TWO_COLUMN, TWO_COLUMN * GOLDEN_RATIO), constrained_layout=True)
for place_field, color in zip(place_fields.T, cmap.colors):
    axes[0].plot(position, place_field, linewidth=3, color=color)
axes[0].set_xlabel('Position [cm]')
axes[0].set_ylabel('Firing Rate\n[spikes / s]')
axes[0].set_title('True Simulated Place Fields')
axes[0].set_xlim((position.min(), position.max()))
axes[0].set_yticks([0, np.round(place_fields.max())])

axes[1].plot(time, position, linewidth=3)
axes[1].set_ylabel('Position [cm]')
axes[1].set_title('Simulated Position and Spikes')
axes[1].set_yticks([0, np.round(position.max())])
axes[1].set_xticks([])
axes[1].set_xlim((0.0, 90.0))

c = [cmap.colors[ind] for ind in neuron_ind]
axes[2].scatter(time[spike_ind], neuron_ind + 1, c=c, s=0.5)
axes[2].set_yticks((1, spikes.shape[1]))
axes[2].set_ylabel('Cells')

axes[2].set_xlabel('Time [s]')
axes[2].set_xlim((0.0, 90.0))

sns.despine()
../../_images/390a4eb4edd5a7e1f05ea0a6f8363ed053d1b6bb1d68bda266378ce4b93f7529.png

Fitting the Model#

We can fit an encoding model by relating the position to the spikes (aka finding the place fields for each cell). Suppose we wanted to fit the animal’s position while it is running.

First we use dask for parallelizing the fit.

Then we can fit the model. Here we set the movement variance of the random walk (movement_var) to be the same as the animal’s while running. replay_speed acts as a multiplier on the movement variance, so the ultimate variance of the Gaussian random walk is movement_var * replay_speed. We set the place_bin_size, which controls the discretization of position, to be the standard deviation of the movement variance. Finally we set knot_spacing to 10 cm. This controls the smoothness of the place field and corresponds to how fast you expect the firing rate to change over position.

from replay_trajectory_classification import SortedSpikesDecoder, Environment, RandomWalk, estimate_movement_var

movement_var = estimate_movement_var(position, sampling_frequency)

environment = Environment(place_bin_size=np.sqrt(movement_var))
transition_type = RandomWalk(movement_var=movement_var)

decoder = SortedSpikesDecoder(
    environment=environment,
    transition_type=transition_type,
    sorted_spikes_algorithm='spiking_likelihood_kde',
    sorted_spikes_algorithm_params={'block_size': None,
                                    'position_std': [3.0],
                                    'use_diffusion': False},
)
decoder.fit(position, spikes)
INFO:replay_trajectory_classification.decoder:Fitting initial conditions...
INFO:replay_trajectory_classification.decoder:Fitting state transition...
INFO:replay_trajectory_classification.decoder:Fitting place fields...
SortedSpikesDecoder(environment=Environment(environment_name='', place_bin_size=0.5268625668325884, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False),
                    infer_track_interior=True,
                    initial_conditions_type=UniformInitialConditions(),
                    sorted_spikes_algorithm='spiking_likelihood_kde',
                    sorted_spikes_algorithm_params={'block_size': None,
                                                    'position_std': [3.0],
                                                    'use_diffusion': False},
                    transition_type=RandomWalk(environment_name='', movement_var=0.2775841643294236, movement_mean=0.0, use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We can access the fitted place fields:

decoder.place_fields_
<xarray.DataArray (position: 342, neuron: 19)>
array([[0.0146893 , 0.00592344, 0.00021617, ..., 0.        , 0.        ,
        0.        ],
       [0.0144644 , 0.00613825, 0.00024234, ..., 0.        , 0.        ,
        0.        ],
       [0.01420782, 0.00639076, 0.00027595, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.00019678, 0.00588007,
        0.01433882],
       [0.        , 0.        , 0.        , ..., 0.00016516, 0.00562208,
        0.01445369],
       [0.        , 0.        , 0.        , ..., 0.00014009, 0.00540576,
        0.01454125]], dtype=float32)
Coordinates:
  * position  (position) float32 0.2632 0.7895 1.316 1.842 ... 178.7 179.2 179.7
Dimensions without coordinates: neuron

And plot them versus the true place fields:

g = (decoder.place_fields_ * sampling_frequency).plot(
        x="position", col="neuron", col_wrap=5, color="red", linewidth=2, alpha=0.9, zorder=1, label="Predicted")
g.axes[0, 0].set_ylabel("Firing Rate [spikes/s]")
for ax, place_field in zip(g.axes.flat, place_fields.T):
    ax.plot(position, place_field, linewidth=2, color="black", zorder=-1, label="True")
sns.despine()
../../_images/daecc875212b5f3764f3214670461cbdcfda9643b9f40e2129928c8c16c23efb.png

We can also plot the state transition model:

fig, ax = plt.subplots(1, 1, figsize=(10, 10))

edge1, edge2 = np.meshgrid(decoder.environment.place_bin_edges_, decoder.environment.place_bin_edges_)
ax.pcolormesh(edge1, edge2, decoder.state_transition_.T, vmin=0.0, vmax=np.percentile(decoder.state_transition_, 99.9))
ax.set_title("Random Walk State Transition Matrix")
ax.set_ylabel("Position at time t-1")
ax.set_xlabel("Position at time t")
ax.axis("square");
../../_images/de81b4327b4e682a4bed06616ff040f935d926ba5aaf3f1764779f8fc2811da1.png

Predicting Position#

We can predict the first 50,000 time bins, which produces an labeled array called results:

time_ind = slice(0, 50000)

results = decoder.predict(spikes[time_ind], time=time[time_ind])
results
INFO:replay_trajectory_classification.decoder:Estimating likelihood...
INFO:replay_trajectory_classification.decoder:Estimating causal posterior...
INFO:replay_trajectory_classification.decoder:Estimating acausal posterior...
<xarray.Dataset>
Dimensions:            (time: 50000, position: 342)
Coordinates:
  * time               (time) float64 0.0 0.001 0.002 0.003 ... 50.0 50.0 50.0
  * position           (position) float64 0.2632 0.7895 1.316 ... 179.2 179.7
Data variables:
    likelihood         (time, position) float64 0.9993 0.9992 ... 0.9998 1.0
    causal_posterior   (time, position) float64 0.002929 0.002929 ... 4.314e-171
    acausal_posterior  (time, position) float64 0.046 0.06159 ... 4.314e-171
Attributes:
    data_log_likelihood:  -726.859881977992

results has three main variables:

  • likelihood – the probablility of spikes given position

  • causal_posterior: the probability of position given only past spikes

  • acausal_posterior: the probability of position given all past and future spikes

You’ll probably want to use the acausal_posterior, but we can visualize both the acausal and causal and overlay the true position (magenta dashed line):

fig, axes = plt.subplots(3, 1, sharex=True, constrained_layout=True, figsize=(15, 7))

spike_ind, neuron_ind = np.nonzero(spikes[time_ind])
c = [cmap.colors[ind] for ind in neuron_ind]
axes[0].scatter(time[spike_ind], neuron_ind + 1, c=c, s=0.5, clip_on=False)
axes[0].set_yticks((1, spikes.shape[1]))
axes[0].set_ylabel('Cells')

results.causal_posterior.plot(x="time", y="position", ax=axes[1], cmap="bone_r", vmin=0.0, vmax=0.05, clip_on=False)
axes[1].plot(time[time_ind], position[time_ind], color="magenta", linestyle="--", linewidth=3, clip_on=False)
axes[1].set_xlabel("")
results.acausal_posterior.plot(x="time", y="position", ax=axes[2], cmap="bone_r", vmin=0.0, vmax=0.05, clip_on=False)
axes[2].plot(time[time_ind], position[time_ind], color="magenta", linestyle="--", linewidth=3, clip_on=False)
axes[2].set_xlabel('Time [s]')
sns.despine(offset=5)
../../_images/48dbefc96b0b559cfaa33a70e7be9053b4b60883131c9b82217d0fa96d32f709.png

From this, we can see, as expected, that the decoded position (given as a probability of position over time) matches the true position well.

Example using Track Graph#

For anything more complicated than a linear track, you will want to use the track graph to specify the correspondence of 1D and 2D position. Here is an example using a circular track.

Simulated Data Setup#

First we generate some data for a circular track. We also make a track graph (See notebooks/tutorial/01-Introduction_and_Data_Format.ipynb for more information on track graph creation)

from replay_trajectory_classification import make_track_graph, plot_track_graph
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

angle = np.linspace(-np.pi, np.pi, num=24, endpoint=False)
radius = 30

node_positions = np.stack((radius * np.cos(angle), radius * np.sin(angle)), axis=1)

node_ids = np.arange(node_positions.shape[0])
edges = np.stack((node_ids, np.roll(node_ids, shift=1)), axis=1)

track_graph = make_track_graph(node_positions, edges)

position_angles = np.linspace(-np.pi, 31 * np.pi, num=360_000, endpoint=False)
position = np.stack((radius * np.cos(position_angles), radius * np.sin(position_angles)), axis=1)

fig, ax = plt.subplots(figsize=(10, 10))
plot_track_graph(track_graph, ax=ax)
ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
ax.set_xlabel("X-Position")
ax.set_ylabel("Y-Position")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

ax.scatter(position[:, 0], position[:, 1], alpha=0.25, s=10, zorder=11, color="magenta")
<matplotlib.collections.PathCollection at 0x7f90bf73c250>
../../_images/4850ac5c50d1c12ead685e4a118adb14257cd4153353762adafeaa996a1b45ea.png

If we want to linearize the track, we need to map the 2D position to 1D position. We do this by sepcifying the edge_order and edge_spacing:

from replay_trajectory_classification import plot_graph_as_1D

edge_spacing = 0
n_nodes = len(track_graph.nodes)
edge_order = np.stack((np.roll(np.arange(n_nodes-1, -1, -1), 1),
                       np.arange(n_nodes-1, -1, -1)), axis=1)

fig, ax = plt.subplots(figsize=(n_nodes // 2, 1))
plot_graph_as_1D(track_graph,
                 edge_spacing=edge_spacing,
                 edge_order=edge_order,
                 ax=ax)
../../_images/40cb56dba26c9271fe881ae09f0e4a6cad25da147067ae13c1ba2a44f0605656.png

We can use the track_linearization package to linearize the track from 1D to 2D

from track_linearization import get_linearized_position

position_df = get_linearized_position(position, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)
position_df
linear_position track_segment_id projected_x_position projected_y_position
0 0.000000 0 -30.000000 -3.673940e-15
1 187.949411 1 -29.998916 -8.235002e-03
2 187.941104 1 -29.997832 -1.647031e-02
3 187.932798 1 -29.996747 -2.470591e-02
4 187.924491 1 -29.995663 -3.294182e-02
... ... ... ... ...
359995 0.041533 0 -29.994579 4.117803e-02
359996 0.033226 0 -29.995663 3.294182e-02
359997 0.024919 0 -29.996747 2.470591e-02
359998 0.016612 0 -29.997832 1.647031e-02
359999 0.008306 0 -29.998916 8.235002e-03

360000 rows × 4 columns

Now let’s check to see if the linearization looks correct.

plt.figure(figsize=(20, 5))

sampling_frequency = 1000
time = np.arange(position_df.linear_position.size) / sampling_frequency
plt.scatter(time, position_df.linear_position, clip_on=False, s=1, color='magenta')
plt.xlabel('Time')
plt.ylabel('Linear Position')
sns.despine()
../../_images/c40098f2e8336bfaf503deb0247a209c4d2ca101e94cb92984bec121d6389154.png

Now let’s simulate neurons with 2D place fields. Here simulate 20 cells with place fields spaced equally around the track.

from replay_trajectory_classification.simulate import simulate_neuron_with_place_field

angle = np.linspace(-np.pi, np.pi, num=20, endpoint=False)

place_field_centers = np.stack((radius * np.cos(angle), radius * np.sin(angle)), axis=1)

spikes = np.stack([simulate_neuron_with_place_field(center, position, sampling_frequency=sampling_frequency, variance=6.0**2)
                   for center in place_field_centers[::-1]], axis=1)

fig, ax = plt.subplots(figsize=(10, 10))
for spike in spikes.T:
    spike_ind = np.nonzero(spike)[0]
    ax.scatter(position[spike_ind, 0], position[spike_ind, 1])
    
ax.axis("square")
ax.set_title('Spikes at Position')
sns.despine()
../../_images/593c7d669b9f05006ca29b15b8e54d8751018a37a1e07eb9e2f5821cadddab1d.png

Using the track graph with the decoder#

IMPORTANT!!!

In order to get the proper transition matrix and binning of linear position, we need to specify:

  • track_graph

  • edge_order

  • edge_spacing

in the environment.

This is the same as the linearization.

from replay_trajectory_classification import SortedSpikesDecoder

environment = Environment(place_bin_size=0.5,
                         track_graph=track_graph,
            edge_order=edge_order,
            edge_spacing=edge_spacing)
transition_type = RandomWalk(movement_var=0.25)

decoder = SortedSpikesDecoder(
    environment=environment,
    transition_type=transition_type,
)

decoder
SortedSpikesDecoder(environment=Environment(environment_name='', place_bin_size=0.5, track_graph=<networkx.classes.graph.Graph object at 0x7f90bf745970>, edge_order=array([[ 0, 23],
       [23, 22],
       [22, 21],
       [21, 20],
       [20, 19],
       [19, 18],
       [18, 17],
       [17, 16],
       [16, 15],
       [15, 14],
       [14, 13],
       [13, 12],
       [12, 11],
       [11, 10],
       [10,  9],
       [ 9,  8],
       [ 8,  7],
       [ 7,  6],
       [ 6,  5],
       [ 5,  4],
       [ 4,  3],
       [ 3,  2],
       [ 2,  1],
       [ 1,  0]]), edge_spaci...l_holes=False, dilate=False),
                    infer_track_interior=True,
                    initial_conditions_type=UniformInitialConditions(),
                    sorted_spikes_algorithm='spiking_likelihood_kde',
                    sorted_spikes_algorithm_params={'block_size': None,
                                                    'position_std': 6.0,
                                                    'use_diffusion': False},
                    transition_type=RandomWalk(environment_name='', movement_var=0.25, movement_mean=0.0, use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
decoder.fit(position_df.linear_position,
            spikes)
INFO:replay_trajectory_classification.decoder:Fitting initial conditions...
INFO:replay_trajectory_classification.decoder:Fitting state transition...
INFO:replay_trajectory_classification.decoder:Fitting place fields...
SortedSpikesDecoder(environment=Environment(environment_name='', place_bin_size=0.5, track_graph=<networkx.classes.graph.Graph object at 0x7f90bf745970>, edge_order=array([[ 0, 23],
       [23, 22],
       [22, 21],
       [21, 20],
       [20, 19],
       [19, 18],
       [18, 17],
       [17, 16],
       [16, 15],
       [15, 14],
       [14, 13],
       [13, 12],
       [12, 11],
       [11, 10],
       [10,  9],
       [ 9,  8],
       [ 8,  7],
       [ 7,  6],
       [ 6,  5],
       [ 5,  4],
       [ 4,  3],
       [ 3,  2],
       [ 2,  1],
       [ 1,  0]]), edge_spaci...l_holes=False, dilate=False),
                    infer_track_interior=True,
                    initial_conditions_type=UniformInitialConditions(),
                    sorted_spikes_algorithm='spiking_likelihood_kde',
                    sorted_spikes_algorithm_params={'block_size': None,
                                                    'position_std': 6.0,
                                                    'use_diffusion': False},
                    transition_type=RandomWalk(environment_name='', movement_var=0.25, movement_mean=0.0, use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Again we can plot the place fields

fig, ax = plt.subplots(figsize=(10, 3))
(decoder.place_fields_ * sampling_frequency).plot(x="position", hue="neuron", add_legend=False, ax=ax)
ax.set_xlabel('Linear Position')
ax.set_ylabel('Firing Rate')
ax.set_xlim((0, position_df.linear_position.max()))
sns.despine()
../../_images/e42b7bfb2685130dc2655353c3eb65fe0c32bb32a8f26379ac8f63562e637cea.png

Here is the random walk transition matrix.

Notice that there is now a spot of color at the upper left and bottom right corner of the transition matrix that indicates the position can move between the first and last linear position (which makes sense because this is a circle).

plt.figure(figsize=(10, 10))
bin1, bin2 = np.meshgrid(decoder.environment.place_bin_edges_, decoder.environment.place_bin_edges_)
plt.pcolormesh(bin1, bin2, decoder.state_transition_.T, vmin=0.0, vmax=0.01)
plt.xlabel('Linear Position t')
plt.ylabel('Linear Position t-1');
../../_images/ea741519150c4ca62c8a3847ac152d607f87837f3d9d101ea0362205978a6f2b.png

We can predict the data from the first 50,000 time points as before:

time_ind = slice(0, 50_000)
results = decoder.predict(spikes[time_ind], time=time[time_ind])

results
INFO:replay_trajectory_classification.decoder:Estimating likelihood...
INFO:replay_trajectory_classification.decoder:Estimating causal posterior...
INFO:replay_trajectory_classification.decoder:Estimating acausal posterior...
<xarray.Dataset>
Dimensions:            (time: 50000, position: 384)
Coordinates:
  * time               (time) float64 0.0 0.001 0.002 0.003 ... 50.0 50.0 50.0
  * position           (position) float64 0.2447 0.7342 1.224 ... 187.2 187.7
Data variables:
    likelihood         (time, position) float64 0.9989 0.9988 ... 0.9987 0.9987
    causal_posterior   (time, position) float64 0.002606 0.002606 ... 4.562e-13
    acausal_posterior  (time, position) float64 0.04053 0.03893 ... 4.562e-13
Attributes:
    data_log_likelihood:  -561.173387422146

And we can see that the linear position and decode match

fig, axes = plt.subplots(2, 1, figsize=(15, 5), sharex=True, constrained_layout=True)

spike_time_ind, neuron_ind = np.nonzero(spikes[time_ind])
cmap = plt.get_cmap('tab20')
c = [cmap.colors[ind] for ind in neuron_ind]
axes[0].scatter(time[time_ind][spike_time_ind], neuron_ind, clip_on=False, s=1, c=c)
axes[0].set_ylabel('Cells')

results.acausal_posterior.plot(x="time", y="position", ax=axes[1],
                               robust=True, cmap="bone_r", vmin=0.0, vmax=0.05,
                               clip_on=False)
axes[1].scatter(time[time_ind], position_df.iloc[time_ind].linear_position, color="magenta", s=1, clip_on=False)
axes[1].set_ylabel('Linear Position')
axes[1].set_xlabel('Time')
sns.despine(offset=5)
../../_images/1ad2f6e110c0d397b4d36ad03d23e5d073df625dbdf5707d8e3b637d528a61df.png

Decoding in 2D#

environment = Environment(place_bin_size=2.5)
transition_type = RandomWalk(movement_var=0.25)

decoder = SortedSpikesDecoder(
    environment=environment,
    transition_type=transition_type,
    sorted_spikes_algorithm='spiking_likelihood_kde',
    sorted_spikes_algorithm_params={'block_size': None,
                                    'position_std': 3.0},
)

decoder.fit(position, spikes)
INFO:replay_trajectory_classification.decoder:Fitting initial conditions...
INFO:replay_trajectory_classification.decoder:Fitting state transition...
INFO:replay_trajectory_classification.decoder:Fitting place fields...
SortedSpikesDecoder(environment=Environment(environment_name='', place_bin_size=2.5, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False),
                    infer_track_interior=True,
                    initial_conditions_type=UniformInitialConditions(),
                    sorted_spikes_algorithm='spiking_likelihood_kde',
                    sorted_spikes_algorithm_params={'block_size': None,
                                                    'position_std': 3.0},
                    transition_type=RandomWalk(environment_name='', movement_var=0.25, movement_mean=0.0, use_diffusion=False))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
time_ind = slice(0, 50_000)
results = decoder.predict(spikes[time_ind], time=time[time_ind])

results
INFO:replay_trajectory_classification.decoder:Estimating likelihood...
INFO:replay_trajectory_classification.decoder:Estimating causal posterior...
INFO:replay_trajectory_classification.decoder:Estimating acausal posterior...
<xarray.Dataset>
Dimensions:            (time: 50000, x_position: 26, y_position: 26)
Coordinates:
  * time               (time) float64 0.0 0.001 0.002 0.003 ... 50.0 50.0 50.0
  * x_position         (x_position) float64 -31.25 -28.75 -26.25 ... 28.75 31.25
  * y_position         (y_position) float64 -31.25 -28.75 -26.25 ... 28.75 31.25
Data variables:
    likelihood         (time, x_position, y_position) float64 nan nan ... nan
    causal_posterior   (time, x_position, y_position) float64 nan nan ... nan
    acausal_posterior  (time, x_position, y_position) float64 nan nan ... nan
Attributes:
    data_log_likelihood:  -1822.0986324353144

Here is the sum of the probability of 2D decoded positions overlayed with the true position

results.acausal_posterior.sum('time').plot(x='x_position', y='y_position', cmap='bone_r')
plt.scatter(position[time_ind, 0], position[time_ind, 1], color='magenta', s=1, clip_on=False, label='True Position')
sns.despine()
../../_images/60f57f26138c21c65af0263dca226a722488460e461c5f5fcd4dde14e56d174b.png

Here is the most likely position over time compared with the true position

map_estimate = results.acausal_posterior.stack(position=['x_position', 'y_position'])
map_estimate = map_estimate.position[map_estimate.argmax('position')]
map_estimate = np.asarray(map_estimate.values.tolist())

plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')
ax.plot3D(results.time, map_estimate[:, 0], map_estimate[:, 1], 'black', label='Most likely decoded position')
ax.set_xlabel('Time')
ax.set_ylabel('X-Position')
ax.set_zlabel('Y-Position')

ax.plot3D(results.time, position[time_ind, 0], position[time_ind, 1], 'magenta', label='True position')
ax.set_xlabel('Time')
ax.set_ylabel('X-Position')
ax.set_zlabel('Y-Position')
plt.legend()
<matplotlib.legend.Legend at 0x7f90d1f2b070>
../../_images/1d31908689182673c3008d9c0f4083b77d589559f20bdebb9347f9acb36ab488.png