Source code for replay_trajectory_classification.likelihoods.multiunit_likelihood_gpu

"""Estimates a marked point process likelihood where the marks are
 features of the spike waveform using GPUs. Features are float32."""
from __future__ import annotations
from typing import Optional, Union

import numpy as np
from tqdm.autonotebook import tqdm

from replay_trajectory_classification.core import atleast_2d
from replay_trajectory_classification.likelihoods.diffusion import (
    diffuse_each_bin,
    estimate_diffusion_position_density,
    estimate_diffusion_position_distance,
)

try:
    import cupy as cp

    @cp.fuse
    def gaussian_pdf(x: cp.ndarray, mean: cp.ndarray, sigma: cp.ndarray) -> cp.ndarray:
        """Compute the value of a Gaussian probability density function at x with
        given mean and sigma."""
        return cp.exp(-0.5 * ((x - mean) / sigma) ** 2) / (sigma * cp.sqrt(2.0 * cp.pi))

    def estimate_position_distance(
        place_bin_centers: cp.ndarray,
        positions: cp.ndarray,
        position_std: Union[float, cp.ndarray],
    ) -> cp.ndarray:
        """Estimates the Euclidean distance between positions and position bins.

        Parameters
        ----------
        place_bin_centers : cp.ndarray, shape (n_position_bins, n_position_dims)
        positions : cp.ndarray, shape (n_time, n_position_dims)
        position_std : cp.ndarray, shape (n_position_dims,)

        Returns
        -------
        position_distance : cp.ndarray, shape (n_time, n_position_bins)

        """
        n_time, n_position_dims = positions.shape
        n_position_bins = place_bin_centers.shape[0]

        if isinstance(position_std, (int, float)):
            position_std = [position_std] * n_position_dims

        position_distance = cp.ones((n_time, n_position_bins), dtype=cp.float32)

        for position_ind, std in enumerate(position_std):
            position_distance *= gaussian_pdf(
                cp.expand_dims(place_bin_centers[:, position_ind], axis=0),
                cp.expand_dims(positions[:, position_ind], axis=1),
                std,
            )

        return position_distance

    def estimate_position_density(
        place_bin_centers: cp.ndarray,
        positions: cp.ndarray,
        position_std: Union[float, cp.ndarray],
        block_size: int = 100,
    ) -> cp.ndarray:
        """Estimates a kernel density estimate over position bins using
        Euclidean distances.

        Parameters
        ----------
        place_bin_centers : cp.ndarray, shape (n_position_bins, n_position_dims)
        positions : cp.ndarray, shape (n_time, n_position_dims)
        position_std : float or cp.ndarray, shape (n_position_dims,)

        Returns
        -------
        position_density : cp.ndarray, shape (n_position_bins,)

        """
        n_time = positions.shape[0]
        n_position_bins = place_bin_centers.shape[0]

        if block_size is None:
            block_size = n_time

        position_density = cp.empty((n_position_bins,))
        for start_ind in range(0, n_position_bins, block_size):
            block_inds = slice(start_ind, start_ind + block_size)
            position_density[block_inds] = cp.mean(
                estimate_position_distance(
                    place_bin_centers[block_inds], positions, position_std
                ),
                axis=0,
            )
        return position_density

    def estimate_log_intensity(
        density: cp.ndarray, occupancy: cp.ndarray, mean_rate: float
    ) -> cp.ndarray:
        """Calculates intensity in log space."""
        return cp.log(mean_rate) + cp.log(density) - cp.log(occupancy)

    def estimate_intensity(
        density: cp.ndarray, occupancy: cp.ndarray, mean_rate: float
    ) -> cp.ndarray:
        """Calculates intensity.

        Parameters
        ----------
        density : cp.ndarray, shape (n_bins,)
        occupancy : cp.ndarray, shape (n_bins,)
        mean_rate : float

        Returns
        -------
        intensity : cp.ndarray, shape (n_bins,)

        """
        return cp.exp(estimate_log_intensity(density, occupancy, mean_rate))

    def estimate_log_joint_mark_intensity(
        decoding_marks: cp.ndarray,
        encoding_marks: cp.ndarray,
        mark_std: Union[float, cp.ndarray],
        occupancy: cp.ndarray,
        mean_rate: float,
        place_bin_centers: Optional[cp.ndarray] = None,
        encoding_positions: Optional[cp.ndarray] = None,
        position_std: Union[float, cp.ndarray, None] = None,
        max_mark_diff: int = 6000,
        set_diag_zero: bool = False,
        position_distance: Optional[cp.ndarray] = None,
    ) -> np.ndarray:
        """Finds the joint intensity of the marks and positions in log space.

        Parameters
        ----------
        decoding_marks : np.ndarray, shape (n_decoding_spikes, n_features)
        encoding_marks : np.ndarray, shape (n_encoding_spikes, n_features)
        mark_std : float or np.ndarray, shape (n_features,)
        occupancy : np.ndarray, shape (n_position_bins,)
        mean_rate : float
        place_bin_centers : None or np.ndarray, shape (n_position_bins, n_position_dims)
            If None, position distance must be not None
        encoding_positions : None or np.ndarray, shape (n_decoding_spikes, n_position_dims)
            If None, position distance must be not None
        position_std : None or float or array_like, shape (n_position_dims,)
            If None, position distance must be not None
        max_mark_diff : int
            Maximum distance between integer marks.
        set_diag_zero : bool
        position_distance : np.ndarray, shape (n_encoding_spikes, n_position_bins)
            Precalculated distance between position and position bins.

        Returns
        -------
        log_joint_mark_intensity : np.ndarray, shape (n_decoding_spikes, n_position_bins)

        """
        n_encoding_spikes, n_marks = encoding_marks.shape

        n_decoding_spikes = decoding_marks.shape[0]
        mark_distance = cp.ones(
            (n_decoding_spikes, n_encoding_spikes), dtype=cp.float32
        )

        for mark_ind in range(n_marks):
            mark_distance *= gaussian_pdf(
                cp.expand_dims(decoding_marks[:, mark_ind], axis=1),
                cp.expand_dims(encoding_marks[:, mark_ind], axis=0),
                mark_std[mark_ind],
            )

        if set_diag_zero:
            diag_ind = (cp.arange(n_decoding_spikes), cp.arange(n_decoding_spikes))
            mark_distance[diag_ind] = 0.0

        if position_distance is None:
            position_distance = estimate_position_distance(
                place_bin_centers, encoding_positions, position_std
            ).astype(cp.float32)

        return cp.asnumpy(
            estimate_log_intensity(
                mark_distance @ position_distance / n_encoding_spikes,
                occupancy,
                mean_rate,
            )
        )

    def fit_multiunit_likelihood_gpu(
        position: np.ndarray,
        multiunits: np.ndarray,
        place_bin_centers: np.ndarray,
        mark_std: Union[float, np.ndarray],
        position_std: Union[float, np.ndarray],
        is_track_boundary: Optional[np.ndarray] = None,
        is_track_interior: Optional[np.ndarray] = None,
        edges: Optional[list[np.ndarray]] = None,
        block_size: int = 100,
        use_diffusion: bool = False,
        **kwargs,
    ) -> dict[
        cp.ndarray,
        cp.ndarray,
        cp.ndarray,
        cp.ndarray,
        cp.ndarray,
        np.ndarray,
        np.ndarray,
        int,
        np.ndarray,
        list[np.ndarray],
    ]:
        """Fits the clusterless place field model.

        Parameters
        ----------
        position : np.ndarray, shape (n_time, n_position_dims)
        multiunits : np.ndarray, shape (n_time, n_marks, n_electrodes)
        place_bin_centers : np.ndarray, shape (n_bins, n_position_dims)
        mark_std : float or np.ndarray, shape (n_marks,)
            Amount of smoothing for the mark features.  Standard deviation of kernel.
        position_std : float or np.ndarray, shape (n_position_dims,)
            Amount of smoothing for position.  Standard deviation of kernel.
        is_track_boundary : None or np.ndarray, shape (n_bins,)
        is_track_interior : None or np.ndarray, shape (n_bins,)
        edges : None or list of np.ndarray
        block_size : int
            Size of data to process in chunks
        use_diffusion : bool
            Use diffusion to respect the track geometry.

        Returns
        -------
        encoding_model : dict

        """
        if is_track_interior is None:
            is_track_interior = np.ones((place_bin_centers.shape[0],), dtype=bool)

        position = atleast_2d(position)
        place_bin_centers = atleast_2d(place_bin_centers)
        interior_place_bin_centers = cp.asarray(
            place_bin_centers[is_track_interior.ravel(order="F")], dtype=cp.float32
        )
        gpu_is_track_interior = cp.asarray(is_track_interior.ravel(order="F"))

        not_nan_position = np.all(~np.isnan(position), axis=1)

        if use_diffusion & (position.shape[1] > 1):
            n_total_bins = np.prod(is_track_interior.shape)
            bin_diffusion_distances = diffuse_each_bin(
                is_track_interior,
                is_track_boundary,
                dx=edges[0][1] - edges[0][0],
                dy=edges[1][1] - edges[1][0],
                std=position_std,
            ).reshape((n_total_bins, -1), order="F")
        else:
            bin_diffusion_distances = None

        if use_diffusion & (position.shape[1] > 1):
            occupancy = cp.asarray(
                estimate_diffusion_position_density(
                    position[not_nan_position],
                    edges,
                    bin_distances=bin_diffusion_distances,
                ),
                dtype=cp.float32,
            )
        else:
            occupancy = cp.zeros((place_bin_centers.shape[0],), dtype=cp.float32)
            occupancy[gpu_is_track_interior] = estimate_position_density(
                interior_place_bin_centers,
                cp.asarray(position[not_nan_position], dtype=cp.float32),
                position_std,
                block_size=block_size,
            )

        mean_rates = []
        summed_ground_process_intensity = cp.zeros(
            (place_bin_centers.shape[0],), dtype=cp.float32
        )
        encoding_marks = []
        encoding_positions = []

        n_marks = multiunits.shape[1]
        if isinstance(mark_std, (int, float)):
            mark_std = np.asarray([mark_std] * n_marks)
        else:
            mark_std = np.asarray(mark_std)

        for multiunit in np.moveaxis(multiunits, -1, 0):
            # ground process intensity
            is_spike = np.any(~np.isnan(multiunit), axis=1)
            mean_rates.append(is_spike.mean())

            if is_spike.sum() > 0:
                if use_diffusion & (position.shape[1] > 1):
                    marginal_density = cp.asarray(
                        estimate_diffusion_position_density(
                            position[is_spike & not_nan_position],
                            edges,
                            bin_distances=bin_diffusion_distances,
                        ),
                        dtype=cp.float32,
                    )
                else:
                    marginal_density = cp.zeros(
                        (place_bin_centers.shape[0],), dtype=cp.float32
                    )
                    marginal_density[gpu_is_track_interior] = estimate_position_density(
                        interior_place_bin_centers,
                        cp.asarray(
                            position[is_spike & not_nan_position], dtype=cp.float32
                        ),
                        position_std,
                        block_size=block_size,
                    )

            summed_ground_process_intensity += estimate_intensity(
                marginal_density, occupancy, mean_rates[-1]
            )

            is_mark_features = np.any(~np.isnan(multiunit), axis=0)
            encoding_marks.append(
                cp.asarray(
                    multiunit[np.ix_(is_spike & not_nan_position, is_mark_features)],
                    dtype=cp.float32,
                )
            )
            encoding_positions.append(position[is_spike & not_nan_position])

        summed_ground_process_intensity = cp.asnumpy(
            summed_ground_process_intensity
        ) + np.spacing(1)

        return {
            "encoding_marks": encoding_marks,
            "encoding_positions": encoding_positions,
            "summed_ground_process_intensity": summed_ground_process_intensity,
            "occupancy": occupancy,
            "mean_rates": mean_rates,
            "mark_std": mark_std,
            "position_std": position_std,
            "block_size": block_size,
            "bin_diffusion_distances": bin_diffusion_distances,
            "use_diffusion": use_diffusion,
            "edges": edges,
            **kwargs,
        }

    def estimate_multiunit_likelihood_gpu(
        multiunits: np.ndarray,
        encoding_marks: cp.ndarray,
        mark_std: np.ndarray,
        place_bin_centers: np.ndarray,
        encoding_positions: cp.ndarray,
        position_std: np.ndarray,
        occupancy: cp.ndarray,
        mean_rates: cp.ndarray,
        summed_ground_process_intensity: np.ndarray,
        bin_diffusion_distances: np.ndarray,
        edges: list[np.ndarray],
        max_mark_diff: int = 6000,
        set_diag_zero: bool = False,
        is_track_interior: Optional[np.ndarray] = None,
        time_bin_size: int = 1,
        block_size: int = 100,
        ignore_no_spike: bool = False,
        disable_progress_bar: bool = False,
        use_diffusion: bool = False,
    ) -> np.ndarray:
        """Estimates the likelihood of position bins given multiunit marks.

        Parameters
        ----------
        multiunits : np.ndarray, shape (n_decoding_time, n_marks, n_electrodes)
        encoding_marks : cp.ndarray, shape (n_encoding_spikes, n_marks, n_electrodes)
        mark_std : list, shape (n_marks,)
            Amount of smoothing for mark features
        place_bin_centers : cp.ndarray, shape (n_bins, n_position_dims)
        encoding_positions : cp.ndarray, shape (n_encoding_spikes, n_position_dims)
        position_std : float or array_like, shape (n_position_dims,)
            Amount of smoothing for position
        occupancy : cp.ndarray, (n_bins,)
        mean_rates : list, len (n_electrodes,)
        summed_ground_process_intensity : np.ndarray, shape (n_bins,)
        bin_diffusion_distances : np.ndarray, shape (n_bins, n_bins)
        edges : list of np.ndarray
        max_mark_diff : int
            Maximum difference between mark features
        set_diag_zero : bool
            Remove influence of the same mark in encoding and decoding.
        is_track_interior : None or np.ndarray, shape (n_bins_x, n_bins_y)
        time_bin_size : float
            Size of time steps
        block_size : int
            Size of data to process in chunks
        ignore_no_spike : bool
            Set contribution of no spikes to zero
        disable_progress_bar : bool
            If False, a progress bar will be displayed.
        use_diffusion : bool
            Respect track geometry by using diffusion distances

        Returns
        -------
        log_likelihood : (n_time, n_bins)

        """

        if is_track_interior is None:
            is_track_interior = np.ones((place_bin_centers.shape[0],), dtype=bool)
        else:
            is_track_interior = is_track_interior.ravel(order="F")

        n_time = multiunits.shape[0]
        if ignore_no_spike:
            log_likelihood = (
                -time_bin_size
                * summed_ground_process_intensity
                * np.zeros((n_time, 1), dtype=np.float32)
            )
        else:
            log_likelihood = (
                -time_bin_size
                * summed_ground_process_intensity
                * np.ones((n_time, 1), dtype=np.float32)
            )

        multiunits = np.moveaxis(multiunits, -1, 0)
        n_position_bins = is_track_interior.sum()
        interior_place_bin_centers = cp.asarray(
            place_bin_centers[is_track_interior], dtype=cp.float32
        )
        gpu_is_track_interior = cp.asarray(is_track_interior)
        interior_occupancy = occupancy[gpu_is_track_interior]

        for multiunit, enc_marks, enc_pos, mean_rate in zip(
            tqdm(multiunits, desc="n_electrodes", disable=disable_progress_bar),
            encoding_marks,
            encoding_positions,
            mean_rates,
        ):
            is_spike = np.any(~np.isnan(multiunit), axis=1)
            is_mark_features = np.any(~np.isnan(multiunit), axis=0)
            decoding_marks = cp.asarray(
                multiunit[np.ix_(is_spike, is_mark_features)], dtype=cp.float32
            )
            n_decoding_marks = decoding_marks.shape[0]
            log_joint_mark_intensity = np.zeros(
                (n_decoding_marks, n_position_bins), dtype=np.float32
            )

            if block_size is None:
                block_size = n_decoding_marks

            if use_diffusion & (place_bin_centers.shape[1] > 1):
                position_distance = cp.asarray(
                    estimate_diffusion_position_distance(
                        enc_pos,
                        edges,
                        bin_distances=bin_diffusion_distances,
                    )[:, is_track_interior],
                    dtype=cp.float32,
                )
            else:
                position_distance = estimate_position_distance(
                    interior_place_bin_centers,
                    cp.asarray(enc_pos, dtype=cp.float32),
                    position_std,
                ).astype(cp.float32)

            for start_ind in range(0, n_decoding_marks, block_size):
                block_inds = slice(start_ind, start_ind + block_size)
                log_joint_mark_intensity[
                    block_inds
                ] = estimate_log_joint_mark_intensity(
                    decoding_marks[block_inds],
                    enc_marks,
                    mark_std[is_mark_features],
                    interior_occupancy,
                    mean_rate,
                    max_mark_diff=max_mark_diff,
                    set_diag_zero=set_diag_zero,
                    position_distance=position_distance,
                )
            log_likelihood[np.ix_(is_spike, is_track_interior)] += np.nan_to_num(
                log_joint_mark_intensity
            )

            mempool = cp.get_default_memory_pool()
            mempool.free_all_blocks()

        log_likelihood[:, ~is_track_interior] = np.nan

        return log_likelihood

except ImportError:

[docs] def estimate_multiunit_likelihood_gpu(*args, **kwargs): print("Cupy is not installed or no GPU detected...")
[docs] def fit_multiunit_likelihood_gpu(*args, **kwargs): print("Cupy is not installed or no GPU detected...")