Source code for replay_trajectory_classification.likelihoods.calcium_likelihood

"""Calculate a Gamma likelihood for calcium imaging activity traces.

References
----------
[1] Farhoodi, S., Plitt, M.H., Giocomo, L., and Eden, U.T. (2020). Estimating Fluctuations in Neural Representations of Uncertain Environments. 20. .

"""

from __future__ import annotations

import logging
from typing import Optional

import dask
import numpy as np
import pandas as pd
import xarray as xr
from dask.distributed import Client, get_client
from patsy import DesignInfo, DesignMatrix, build_design_matrices, dmatrix
from regularized_glm import penalized_IRLS
from statsmodels.api import families

from replay_trajectory_classification.environments import get_n_bins


[docs] def make_spline_design_matrix( position: np.ndarray, place_bin_edges: np.ndarray, knot_spacing: int = 10 ) -> DesignMatrix: """Creates a design matrix for regression with a position spline basis. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) place_bin_edges : np.ndarray, shape (n_bins, n_position_dims) Returns ------- design_matrix : patsy.DesignMatrix """ # TODO: Add history dependence inner_knots = [] for pos, edges in zip(position.T, place_bin_edges.T): n_points = get_n_bins(edges, bin_size=knot_spacing) knots = np.linspace(edges.min(), edges.max(), n_points)[1:-1] knots = knots[(knots > pos.min()) & (knots < pos.max())] inner_knots.append(knots) inner_knots = np.meshgrid(*inner_knots) n_position_dims = position.shape[1] data = {} formula = "1 + te(" for ind in range(n_position_dims): formula += f"cr(x{ind}, knots=inner_knots[{ind}])" formula += ", " data[f"x{ind}"] = position[:, ind] formula += 'constraints="center")' return dmatrix(formula, data)
[docs] def make_spline_predict_matrix( design_info: DesignInfo, place_bin_centers: np.ndarray ) -> DesignMatrix: """Make a design matrix for position bins""" predict_data = {} for ind in range(place_bin_centers.shape[1]): predict_data[f"x{ind}"] = place_bin_centers[:, ind] return build_design_matrices([design_info], predict_data)[0]
[docs] def get_activity_rate(design_matrix: DesignMatrix, results: tuple) -> np.ndarray: """Predicts the calcium activity trace given fitted model coefficents.""" rate = design_matrix @ results.coefficients rate[rate < 0.1] = 0.1 return rate
@dask.delayed def fit_glm( response: np.ndarray, design_matrix: np.ndarray, penalty: Optional[float] = None, tolerance: float = 1e-5, ) -> tuple: """Fits a L2-penalized GLM. Parameters ---------- response : np.ndarray, shape (n_time,) Calcium activity trace design_matrix : np.ndarray, shape (n_time, n_coefficients) penalty : None or float L2 penalty on regression. If None, penalty is smallest possible. tolerance : float Smallest difference between iterations to consider model fitting converged. Returns ------- results : tuple """ if penalty is not None: penalty = np.ones((design_matrix.shape[1],)) * penalty penalty[0] = 0.0 # don't penalize the intercept else: penalty = np.finfo(float).eps return penalized_IRLS( design_matrix, response.squeeze(), family=families.Gamma(families.links.identity()), penalty=penalty, tolerance=tolerance, )
[docs] def gamma_log_likelihood( calcium_activity: np.ndarray, place_field: np.ndarray, scale: float ) -> np.ndarray: """Probability of parameters given spiking at a particular time. Parameters ---------- calcium_activity : np.ndarray, shape (n_time,) place_field : np.ndaarray, shape (n_place_bins,) scale : float Returns ------- gamma_log_likelihood : np.ndarray, shape (n_time, n_place_bins) """ # return scipy.stats.gamma.logpdf(v * calcium_activity / mu, v) # return scipy.stats.gamma.logpdf(mu, v) # return (-scipy.special.loggamma(v) + # v * np.log(v * calcium_activity / mu) - # v * calcium_activity / mu - # np.log(calcium_activity)) gamma = families.Gamma(families.links.identity()) return gamma.loglike_obs( endog=calcium_activity[:, np.newaxis], mu=place_field[np.newaxis, :], scale=scale, )
[docs] def combined_likelihood( calcium_activity: np.ndarray, place_fields: np.ndarray, scales: np.ndarray ) -> np.ndarray: """Combines the likelihoods of all the cells. Parameters ---------- calcium_activity : np.ndarray, shape (n_time, n_neurons) Deconvolved activity rate estimated from the fluorescence level. place_fields : np.ndarray, shape (n_bins, n_neurons) scales : np.ndarray, shape (n_neurons,) """ n_time = calcium_activity.shape[0] n_bins = place_fields.shape[0] log_likelihood = np.zeros((n_time, n_bins)) for activity, place_field, scale in zip(calcium_activity.T, place_fields.T, scales): log_likelihood += gamma_log_likelihood(activity, place_field, scale) return log_likelihood
[docs] def estimate_calcium_likelihood( calcium_activity: np.ndarray, place_fields: np.ndarray, scales: np.ndarray, is_track_interior: Optional[np.ndarray] = None, ) -> np.ndarray: """Find the likelihood given a fitted place field model. Parameters ---------- calcium_activity : np.ndarray, shape (n_time, n_neurons) Deconvolved activity rate estimated from the fluorescence level. place_fields : np.ndarray, shape (n_bins, n_neurons) scales : np.ndarray, shape (n_neurons,) is_track_interior : None or np.ndarray, optional, shape (n_bins, n_position_dims) Returns ------- likelihood : np.ndarray, shape (n_time, n_bins) """ if is_track_interior is not None: is_track_interior = is_track_interior.ravel(order="F") else: n_bins = place_fields.shape[0] is_track_interior = np.ones((n_bins,), dtype=bool) log_likelihood = combined_likelihood(calcium_activity, place_fields, scales) mask = np.ones_like(is_track_interior, dtype=float) mask[~is_track_interior] = np.nan return log_likelihood * mask
[docs] def estimate_calcium_place_fields( position: np.ndarray, calcium_activity: np.ndarray, place_bin_centers: np.ndarray, place_bin_edges: np.ndarray, penalty: float = 1e-1, knot_spacing: int = 10, ) -> xr.DataArray: """Gives the conditional intensity of the neurons' spiking with respect to position. Parameters ---------- position : np.ndarray, shape (n_time, n_position_dims) calcium_activity : np.ndarray, shape (n_time, n_neurons) Deconvolved activity rate estimated from the fluorescence level. place_bin_centers : np.ndarray, shape (n_bins, n_position_dims) place_bin_edges : np.ndarray, shape (n_bins + 1, n_position_dims) penalty : float, optional knot_spacing : int, optional Returns ------- conditional_intensity :xr.DataArray, shape (n_bins, n_neurons) """ if np.any(np.ptp(place_bin_edges, axis=0) <= knot_spacing): logging.warning("Range of position is smaller than knot spacing.") design_matrix = make_spline_design_matrix(position, place_bin_edges, knot_spacing) design_info = design_matrix.design_info try: client = get_client() except ValueError: client = Client() design_matrix = client.scatter(np.asarray(design_matrix), broadcast=True) results = [ fit_glm(activity, design_matrix, penalty) for activity in calcium_activity.T ] results = dask.compute(*results) predict_matrix = make_spline_predict_matrix(design_info, place_bin_centers) place_fields = np.stack( [get_activity_rate(predict_matrix, result) for result in results], axis=1 ) scales = np.asarray([result.scale for result in results]) DIMS = ["position", "neuron"] if position.shape[1] == 1: names = ["position"] coords = {"position": place_bin_centers.squeeze()} elif position.shape[1] == 2: names = ["x_position", "y_position"] coords = { "position": pd.MultiIndex.from_arrays( place_bin_centers.T.tolist(), names=names ) } return xr.DataArray(data=place_fields, coords=coords, dims=DIMS), scales