Source code for replay_trajectory_classification.likelihoods.spiking_likelihood_glm

"""Estimates a Poisson likelihood using place fields estimated with a GLM
with a spline basis"""

from __future__ import annotations

import logging
from typing import Optional

import dask
import numpy as np
import pandas as pd
import scipy.stats
import xarray as xr
from dask.distributed import Client, get_client
from patsy import DesignInfo, DesignMatrix, build_design_matrices, dmatrix
from regularized_glm import penalized_IRLS
from statsmodels.api import families

from replay_trajectory_classification.environments import get_n_bins


[docs] def make_spline_design_matrix( position: np.ndarray, place_bin_edges: np.ndarray, knot_spacing: float = 10.0 ) -> DesignMatrix: """Creates a design matrix for regression with a position spline basis. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) place_bin_edges : np.ndarray, shape (n_bins, n_position_dims) Returns ------- design_matrix : patsy.DesignMatrix """ inner_knots = [] for pos, edges in zip(position.T, place_bin_edges.T): n_points = get_n_bins(edges, bin_size=knot_spacing) knots = np.linspace(edges.min(), edges.max(), n_points)[1:-1] knots = knots[(knots > pos.min()) & (knots < pos.max())] inner_knots.append(knots) inner_knots = np.meshgrid(*inner_knots) n_position_dims = position.shape[1] data = {} formula = "1 + te(" for ind in range(n_position_dims): formula += f"cr(x{ind}, knots=inner_knots[{ind}])" formula += ", " data[f"x{ind}"] = position[:, ind] formula += 'constraints="center")' return dmatrix(formula, data)
[docs] def make_spline_predict_matrix( design_info: DesignInfo, place_bin_centers: np.ndarray ) -> DesignMatrix: """Make a design matrix for position bins""" predict_data = {} for ind in range(place_bin_centers.shape[1]): predict_data[f"x{ind}"] = place_bin_centers[:, ind] return build_design_matrices([design_info], predict_data)[0]
[docs] def get_firing_rate( design_matrix: DesignMatrix, results: tuple, sampling_frequency: int = 1 ): """Predicts the firing rate given fitted model coefficents.""" if np.any(np.isnan(results.coefficients)): n_time = design_matrix.shape[0] rate = np.zeros((n_time,)) else: rate = np.exp(design_matrix @ results.coefficients) * sampling_frequency return rate
@dask.delayed def fit_glm( response: np.ndarray, design_matrix: np.ndarray, penalty: Optional[float] = None, tolerance: float = 1e-5, ) -> tuple: """Fits a L2-penalized GLM. Parameters ---------- response : np.ndarray, shape (n_time,) Calcium activity trace design_matrix : np.ndarray, shape (n_time, n_coefficients) penalty : None or float L2 penalty on regression. If None, penalty is smallest possible. tolerance : float Smallest difference between iterations to consider model fitting converged. Returns ------- results : tuple """ if penalty is not None: penalty = np.ones((design_matrix.shape[1],)) * penalty penalty[0] = 0.0 # don't penalize the intercept else: penalty = np.finfo(float).eps return penalized_IRLS( design_matrix, response.squeeze(), family=families.Poisson(), penalty=penalty, tolerance=tolerance, )
[docs] def poisson_log_likelihood( spikes: np.ndarray, conditional_intensity: np.ndarray ) -> np.ndarray: """Probability of parameters given spiking at a particular time. Parameters ---------- spikes : np.ndarray, shape (n_time,) Indicator of spike or no spike at current time. conditional_intensity : np.ndarray, shape (n_place_bins,) Instantaneous probability of observing a spike Returns ------- poisson_log_likelihood : np.ndarray, shape (n_time, n_place_bins) """ # Logarithm of the absolute value of the gamma function is always 0 when # spikes are 0 or 1 return scipy.stats.poisson.logpmf( spikes[:, np.newaxis], conditional_intensity[np.newaxis, :] + np.spacing(1) )
[docs] def combined_likelihood( spikes: np.ndarray, conditional_intensity: np.ndarray ) -> np.ndarray: """Combines the likelihoods of all the cells. Parameters ---------- spikes : np.ndarray, shape (n_time, n_neurons) conditional_intensity : np.ndarray, shape (n_bins, n_neurons) """ n_time = spikes.shape[0] n_bins = conditional_intensity.shape[0] log_likelihood = np.zeros((n_time, n_bins)) for is_spike, ci in zip(spikes.T, conditional_intensity.T): log_likelihood += poisson_log_likelihood(is_spike, ci) return log_likelihood
[docs] def estimate_spiking_likelihood( spikes: np.ndarray, conditional_intensity: np.ndarray, is_track_interior: Optional[np.ndarray] = None, ) -> np.ndarray: """ Parameters ---------- spikes : np.ndarray, shape (n_time, n_neurons) conditional_intensity : np.ndarray, shape (n_bins, n_neurons) is_track_interior : None or np.ndarray, optional, shape (n_x_position_bins, n_y_position_bins) Returns ------- likelihood : np.ndarray, shape (n_time, n_bins) """ if is_track_interior is not None: is_track_interior = is_track_interior.ravel(order="F") else: n_bins = conditional_intensity.shape[0] is_track_interior = np.ones((n_bins,), dtype=bool) log_likelihood = combined_likelihood(spikes, conditional_intensity) mask = np.ones_like(is_track_interior, dtype=float) mask[~is_track_interior] = np.nan return log_likelihood * mask
[docs] def estimate_place_fields( position: np.ndarray, spikes: np.ndarray, place_bin_centers: np.ndarray, place_bin_edges: np.ndarray, edges: Optional[np.ndarray] = None, is_track_boundary: Optional[np.ndarray] = None, is_track_interior: Optional[np.ndarray] = None, penalty: float = 1e-1, knot_spacing: int = 10, ) -> xr.DataArray: """Gives the conditional intensity of the neurons' spiking with respect to position. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) spikes : np.ndarray, shape (n_time, n_neurons) place_bin_centers : np.ndarray, shape (n_bins, n_position_dims) place_bin_edges : np.ndarray, shape (n_bins + 1, n_position_dims) is_track_boundary : None or np.ndarray is_track_interior : None or np.ndarray penalty : None or float, optional L2 penalty on regression. If None, penalty is smallest possible. knot_spacing : int, optional Spacing of position knots. Controls how smooth the firing rate is. Returns ------- conditional_intensity : xr.DataArray, shape (n_bins, n_neurons) """ if np.any(np.ptp(place_bin_edges, axis=0) <= knot_spacing): logging.warning("Range of position is smaller than knot spacing.") design_matrix = make_spline_design_matrix(position, place_bin_edges, knot_spacing) design_info = design_matrix.design_info try: client = get_client() except ValueError: client = Client() design_matrix = client.scatter(np.asarray(design_matrix), broadcast=True) results = [fit_glm(is_spike, design_matrix, penalty) for is_spike in spikes.T] results = dask.compute(*results) predict_matrix = make_spline_predict_matrix(design_info, place_bin_centers) place_fields = np.stack( [get_firing_rate(predict_matrix, result) for result in results], axis=1 ) DIMS = ["position", "neuron"] if position.shape[1] == 1: names = ["position"] coords = {"position": place_bin_centers.squeeze()} elif position.shape[1] == 2: names = ["x_position", "y_position"] coords = { "position": pd.MultiIndex.from_arrays( place_bin_centers.T.tolist(), names=names ) } return xr.DataArray(data=place_fields, coords=coords, dims=DIMS)