Source code for rnntoolkit.flow_fields.flow_field_finder

import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from rnntoolkit.linear import Linearization
from rnntoolkit.flow_fields.flow_field import FlowField
from rnntoolkit.flow_fields.flow_field_finder_base import FlowFieldFinderBase


[docs] class FlowFieldFinder(FlowFieldFinderBase): def __init__( self, rnn: nn.RNN, fit_states: torch.Tensor, num_points: int, x_offset: int, y_offset: int, x_center: int = 0, y_center: int = 0, follow_traj: bool = False, ): super().__init__( rnn, fit_states, num_points, x_offset, y_offset, x_center, y_center ) """ Flow Field Finder that gathers a flow field about a specified trajectory Inherited from FlowFieldFinder base, which provides basic structure and utilities Args: rnn (nn.RNN): RNN object, more architectures coming soon num_points (int): number of points to use in grid, results in (num_points, num_points) x_offset (int): scale to offset grid about trajectory in x direction y_offset (int): scale to offset grid about trajectory in y direction x_center (int): x position to offset from using x_offset h_center (int): y position to offset from using y_offset follow_traj (bool): whether or not to center the grid around each state or use default grid locations """ self.follow_traj = follow_traj # Need to define a valid linearization object in each child of FlowFieldFinderBase self.linearization = Linearization(self.rnn)
[docs] def find_nonlinear_flow( self, states: torch.Tensor, inp: torch.Tensor, ) -> list: """Compute 2D flow fields at each given state Projects selected region activity onto a 2D PCA subspace, constructs a grid around the current point, and advances the system by one step to estimate the local flow (velocity vectors). Args: states (torch.Tensor): Hidden activations over time, can be batched or 1D inp (torch.Tensor): External input sequence, can be batched or 1D, but the total number of inputs (batch elements) must match that of states Returns: list: FlowField object per sampled state """ flow_field_list = [] # Reshape to nxd states, inp = self._nxd(states), self._nxd(inp) # assert states and input match shape assert states.shape[0] == inp.shape[0] n_states = states.shape[0] reduced_traj = self._reduce_traj(states) # Now going through trajectory for n in range(n_states): reduced_traj_n, inp_n = reduced_traj[n], inp[n] # If follow trajectory is true get grid centered around current t # This will make a different grid for each state (n grids) if self.follow_traj: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_tv_bounds(reduced_traj_n) ) else: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_bounds() ) low_dim_grid, inverse_grid = self._inverse_grid( lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y, ) # Repeat along the batch dimension to match the grid full_inp_batch = inp_n.repeat(low_dim_grid.shape[0], 1) with torch.no_grad(): # Current timestep input # Get activity for current timestep _, h = self.rnn( full_inp_batch.unsqueeze(self.time_dim), inverse_grid.unsqueeze(0), ) # Reduce h_next h_next = self._reduce_traj(h) # Compute velocity and speed x_vel, y_vel = self._compute_velocity(h_next, low_dim_grid) speed = self._compute_speed(x_vel, y_vel) # Reshape to match FlowField object requirements x_vel, y_vel, low_dim_grid, speed = self._reshape_vals( x_vel, y_vel, low_dim_grid, speed ) flow_field = FlowField(x_vel, y_vel, low_dim_grid, speed) flow_field_list.append(flow_field) return flow_field_list
[docs] def find_linear_flow( self, states: torch.Tensor, inp: torch.Tensor, delta_inp: torch.Tensor, ) -> list: """Compute linearized flow fields in a 2D subspace. Similar to :func:`flow_field`, but uses a local linear approximation (Jacobian) of the dynamics around points on the trajectory instead of a full forward step. Args: states (torch.Tensor): Hidden activations over time for selected regions, can be 1D or batched inp (torch.Tensor): External input sequence, can be batched or 1D, but the total number of inputs (batch elements) must match that of states delta_inp (torch.Tensor): External input sequence of input perturbations, can be batched or 1D, but the total number of inputs (batch elements) must match that of states, and the overall shape must match inp Returns: list: FlowField objects per sampled time. """ # reshape to nxd states, inp, delta_inp = self._nxd(states), self._nxd(inp), self._nxd(delta_inp) assert states.shape[0] == delta_inp.shape[0] assert delta_inp.shape == inp.shape n_states = states.shape[0] # Lists for x and y velocities flow_field_list = [] # Reduce the regional trajectories and return pca object reduced_traj = self._reduce_traj(states) for n in range(n_states): states_n = states[n] reduced_traj_n = reduced_traj[n] inp_n = inp[n] delta_inp_n = delta_inp[n] # If follow trajectory is true get grid centered around current t # This will make a different grid for each state (n grids) if self.follow_traj: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_tv_bounds(reduced_traj_n) ) else: lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y = ( self._set_bounds() ) # Inverse the grid to pass through RNN low_dim_grid, inverse_grid = self._inverse_grid( lower_bound_x, upper_bound_x, lower_bound_y, upper_bound_y, ) # Get a perturbation of the activity delta_h = inverse_grid - states_n with torch.no_grad(): # call forward method for linearization to get affine transformation h = self.linearization(inp_n, states_n, delta_inp_n, delta_h) # Put next h into a grid format h_next = self._reduce_traj(h) # Compute velocities between gathered trajectory of grid and original grid values x_vel, y_vel = self._compute_velocity(h_next, low_dim_grid) speed = self._compute_speed(x_vel, y_vel) x_vel, y_vel, low_dim_grid, speed = self._reshape_vals( x_vel, y_vel, low_dim_grid, speed ) # Reshape data back to grid flow_field = FlowField(x_vel, y_vel, low_dim_grid, speed) flow_field_list.append(flow_field) return flow_field_list