Source code for rnntoolkit.flow_fields.flow_field_finder_base

import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from typing import Generic, TypeVar, Tuple

RNN = TypeVar("RNN", bound=nn.Module)


[docs] class FlowFieldFinderBase(Generic[RNN]): def __init__( self, rnn: RNN, fit_states: torch.Tensor, num_points: int, x_offset: int, y_offset: int, x_center: int = 0, y_center: int = 0, **kwargs, ): """ Flow field that gathers a flow field about a specified trajectory This class is meant to be inherited only by very specific users who are designing custom RNNs and would like a base structure to go off of Typical users training RNNs or mRNNs should only use the FlowFieldFinder or mFlowFieldFinder classes provided by the packages Args: rnn (RNN): RNN-like object num_points (int): number of points to use in grid, results in (num_points, num_points) x_offset (int): scale to offset grid about trajectory in x direction y_offset (int): scale to offset grid about trajectory in y direction x_center (int): x position to offset from using x_offset h_center (int): y position to offset from using y_offset """ self.rnn = rnn self.fit_states = fit_states self.num_points = num_points self.x_offset = x_offset self.y_offset = y_offset self.x_center = x_center self.y_center = y_center self.time_dim = 1 if self.rnn.batch_first else 0 self.dtype = next(self.rnn.parameters()).dtype # class objects self.reduce_obj = PCA(n_components=2) self._fit_traj(self.fit_states)
[docs] def find_nonlinear_flow(self, *args, **kwargs) -> list: """Compute 2D flow fields in a region subspace along a trajectory.""" raise NotImplementedError
[docs] def find_linear_flow(self, *args, **kwargs) -> list: """Compute linearized flow fields in a 2D subspace.""" raise NotImplementedError
def _nxd(self, x: torch.Tensor) -> torch.Tensor: """ Broadcast to nxd, even for a 1d tensor Args: x (Tensor): tensor to broadcast """ if x.dim() == 1: x = x.unsqueeze(0) x = torch.flatten(x, end_dim=-2) return x def _fit_traj(self, trajectory: torch.Tensor): """ Fit PCA object Args: trajectory (Tensor): states to reduce """ # Gather activity for specified region and cell type temp_act = torch.reshape(trajectory, (-1, trajectory.shape[-1])) # Do PCA on the specified region(s) self.reduce_obj.fit(temp_act) def _reduce_traj(self, trajectory: torch.Tensor) -> torch.Tensor: """ Fit PCA object and transform trajectory Args: trajectory (Tensor): states to reduce Returns: Tensor: reduced states """ # Gather activity for specified region and cell type temp_act = torch.reshape(trajectory, (-1, trajectory.shape[-1])) reduced_traj = self.reduce_obj.transform(temp_act) reduced_traj = torch.from_numpy(reduced_traj) return reduced_traj def _inverse_grid( self, lower_bound_x: float, upper_bound_x: float, lower_bound_y: float, upper_bound_y: float, expand_dims: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Obtain a low dimensional grid and its projection to higher dim state space Args: lower_bound_x (float): lower bound of grid in x direction upper_bound_x (float): upper bound of grid in x direction lower_bound_y (float): lower bound of grid in y direction upper_bound_y (float): upper bound of grid in y direction Returns: tuple: the low dimensional and projected low dimensional grid """ # Num points is along each axis, not in total x = torch.linspace(lower_bound_x, upper_bound_x, self.num_points) y = torch.linspace(lower_bound_y, upper_bound_y, self.num_points) # Gather 2D grid for flow fields xv, yv = torch.meshgrid(x, y, indexing="ij") xv, yv = xv.unsqueeze(-1), yv.unsqueeze(-1) # Convert the grid to a tensor and flatten for PCA low_dim_grid = torch.cat((xv, yv), dim=-1) low_dim_grid = torch.flatten(low_dim_grid, start_dim=0, end_dim=1) # Inverse PCA to input grid into network inverse_grid = self.reduce_obj.inverse_transform(low_dim_grid) inverse_grid = torch.from_numpy(inverse_grid).to(self.dtype) if expand_dims: low_dim_grid = torch.reshape( low_dim_grid, (self.num_points, self.num_points, 2) ) inverse_grid = torch.reshape( inverse_grid, (self.num_points, self.num_points, inverse_grid.shape[-1]) ) return low_dim_grid, inverse_grid def _compute_velocity( self, h_next: torch.Tensor, h_prev: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """compute velocity, or h_next - h_prev""" x_vel = h_next[..., 0] - h_prev[..., 0] y_vel = h_next[..., 1] - h_prev[..., 1] return x_vel, y_vel def _compute_speed(self, x_vel: torch.Tensor, y_vel: torch.Tensor) -> torch.Tensor: """compute magnitude of velocities""" speed = torch.sqrt(x_vel**2 + y_vel**2) return speed / speed.max() def _reshape_vals( self, x_vels: torch.Tensor, y_vels: torch.Tensor, grid: torch.Tensor, speeds: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Broadcast data to FlowField format Args: x_vels (Tensor): x velocities y_vels (Tensor): y velocities grid (Tensor): low dimensional grid coordinates speeds (Tensor): magnitude of x_vels and y_vels """ # Reshape to match FlowField object requirements x_vels = torch.reshape(x_vels, (self.num_points, self.num_points)) y_vels = torch.reshape(y_vels, (self.num_points, self.num_points)) grid = torch.reshape(grid, (self.num_points, self.num_points, 2)) speeds = torch.reshape(speeds, (self.num_points, self.num_points)) return x_vels, y_vels, grid, speeds def _set_bounds(self) -> Tuple[float, float, float, float]: """ Design grid bounds using a static center and offsetting """ lower_bound_x = self.x_center - self.x_offset upper_bound_x = self.x_center + self.x_offset lower_bound_y = self.y_center - self.y_offset upper_bound_y = self.y_center + self.y_offset return lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y def _set_tv_bounds(self, traj: torch.Tensor) -> Tuple[float, float, float, float]: """ Design grid bounds centered at the state traj using the offsets Args: traj (Tensor): state to center grid at """ lower_bound_x = torch.round(traj[0] - self.x_offset, decimals=1).item() upper_bound_x = torch.round(traj[0] + self.x_offset, decimals=1).item() lower_bound_y = torch.round(traj[1] - self.y_offset, decimals=1).item() upper_bound_y = torch.round(traj[1] + self.y_offset, decimals=1).item() return lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y