fp¶
- class rnntoolkit.fixed_points.fp.FixedPointCollection(xstar: Tensor, x_init: Tensor | None = None, inputs: Tensor | None = None, F_xstar: Tensor | None = None, qstar: Tensor | None = None, dq: Tensor | None = None, n_iters: Tensor | None = None, tol_unique: float = 0.001, dtype=torch.float32, dtype_complex=torch.complex64, verbose: bool = False)[source]¶
Bases:
objectA class for storing fixed points and associated data.
- assert_valid_shapes()[source]¶
Checks that all data attributes reflect the same number of fixed points.
- Raises:
AssertionError if any non-None data attribute does not have .shape[0] as self.n.
- static concatenate(fps_seq: list)[source]¶
Join a sequence of FixedPoints objects.
- Args:
fps_seq: sequence of FixedPoints objects. All FixedPoints objects must have the following attributes in common:
n_states n_inputs
- Returns:
A FixedPoints objects containing the concatenated FixedPoints data.
- find(fp: Self, use_F_xstar: bool = False) Tensor[source]¶
Searches in the current FixedPoints object for matches to a specified fixed point. Two fixed points are defined as matching if the 2-norm of the difference between their concatenated (xstar, inputs) is within tol_unique).
- Args:
fp: A FixedPoints object containing exactly one fixed point.
- Returns:
shape (n_matches,) tensor specifying indices into the current FixedPoints object where matches to fp were found.
- get_unique(use_F_xstar: bool = False) Self[source]¶
Identifies unique fixed points. Among duplicates identified, this keeps the one with smallest qstar.
- Args:
use_F_xstar (bool): Whether to find unique on F_xstar instead of xstar
- Returns:
A FixedPoints object containing only the unique fixed points and their associated data. Uniqueness is determined down to tol_unique.
- property is_single_fixed_point: bool¶
- property kwargs: dict¶
Returns dict of keyword arguments necessary for reinstantiating a (shallow) copy of this FixedPoints object, i.e.,
fp_copy = FixedPoints(**fp.kwargs)
- restore(restore_path: str)[source]¶
Restores data from a previously saved FixedPoints object.
- Args:
restore_path: A string containing the path at which to find a previously saved FixedPoints object (including directory, filename, and extension).
- Returns:
None.
- save(save_path: str)[source]¶
Saves all data contained in the FixedPoints object.
- Args:
save_path: A string containing the path at which to save (including directory, filename, and arbitrary extension).
- Returns:
None.
- transform(U: Tensor, offset: float = 0.0) Self[source]¶
Apply an affine transformation to the state-space representation. This may be helpful for plotting fixed points in a given linear subspace (e.g., PCA or an RNN readout space).
- Args:
U: shape (n_states, k) numpy array projection matrix.
offset (optional): shape (k,) numpy translation vector. Default: 0.
- Returns:
A FixedPoints object.
- update(new_fps: Self)[source]¶
Combines the entries from another FixedPoints object into this object.
- Args:
new_fps: a FixedPoints object containing the entries to be incorporated into this FixedPoints object.
- Returns:
None
- Raises:
AssertionError if the non-fixed-point specific attributes of new_fps do not match those of this FixedPoints object.
AssertionError if any data attributes are found in one but not both FixedPoints objects (especially relevant for decomposed Jacobians).
AssertionError if the updated object has inconsistent data shapes.
fp_finder¶
- class rnntoolkit.fixed_points.fp_finder.FixedPointFinder(rnn: RNN, lr_init: float = 0.0001, lr_patience: float = 5, lr_factor: float = 0.95, lr_cooldown: float = 0, tol_q: float = 1e-12, tol_dq: float = 1e-20, max_iters: int = 5000, do_rerun_q_outliers: bool = False, outlier_q_scale: float = 10.0, do_exclude_distance_outliers: bool = True, outlier_distance_scale: float = 10.0, tol_unique: float = 0.001, max_n_unique: int = inf, dtype: str = 'float32', random_seed: int = 0, verbose: bool = True, super_verbose: bool = False, n_iters_per_print_update: int = 100)[source]¶
Bases:
FixedPointFinderBase- classmethod default_hps()[source]¶
Returns a deep copy of the default hyperparameters dict.
The deep copy protects against external updates to the defaults, which in turn protects against unintended interactions with the hashing done by the Hyperparameters class.
- Args:
None.
- Returns:
dict of hyperparameters.
- find_fixed_points(initial_states: Tensor, ext_inputs: Tensor, n_rounds_q_opt: int = 1) tuple[FixedPointCollection, FixedPointCollection][source]¶
Finds RNN fixed points and the Jacobians at the fixed points.
- Args:
- initial_states: Tensor specifying the initial
states of the RNN, from which the optimization will search for fixed points.
ext_inputs: external inputs to the RNN n_rounds_q_opt: Number of rounds to run extra iterations on q outliers
- Returns:
- unique_fps: A FixedPoints object containing the set of unique
fixed points after optimizing from all initial_states
- all_fps: A FixedPoints object containing the likely redundant set
of fixed points (and associated metadata) resulting from ALL initializations in initial_states
fp_finder_base¶
- class rnntoolkit.fixed_points.fp_finder_base.FixedPointFinderBase(rnn: RNN, **kwargs)[source]¶
Bases:
Generic[RNN]- find_fixed_points(*args, **kwargs)[source]¶
Overwritten by subclass for network specific optimizations
- static get_fp_non_distance_outliers(fps: FixedPointCollection, initial_states: Tensor, dist_thresh: float) Tensor[source]¶
get fixed points that are far from initial states based on threshold.
- Args:
fps (FixedPointCollection): fps discovered [n, state_dim] initial_states (Tensor): initial states of optimization [n, state_dim] dist_thresh (float): threshold at which fixed points are considered far
- Returns:
fsp_non_outlier_distance (Tensor): indices to fps object that are not far
- static get_init_non_distance_outliers(initial_states: Tensor, dist_thresh: float) Tensor[source]¶
get initial states that are far from centroid based on threshold.
- Args:
initial_states (Tensor): initial states of fp optimization [n, state_dim] dist_thresh (float): Threshold from initial states which is far.
- Returns:
init_non_outlier_idx (Tensor): indices to initial_states tensor inside threshold
- static identify_q_non_outliers(fps: FixedPointCollection, q_thresh: float) Tensor[source]¶
Identify fixed points with optimized q values that do not exceed a specified threshold.
- Args:
- fps: A FixedPoints object containing optimized fixed points and
associated metadata.
- q_thresh: A scalar float indicating the threshold on fixed points’
q values.
- Returns:
A tensor containing the indices into fps corresponding to the fixed points with q values that do not exceed the threshold.
- Usage:
idx = identify_q_non_outliers(fps, q_thresh) non_outlier_fps = fps[idx]
- static identify_q_outliers(fps: FixedPointCollection, q_thresh: float) Tensor[source]¶
Identify fixed points with optimized q values that exceed a specified threshold.
- Args:
- fps: A FixedPoints object containing optimized fixed points and
associated metadata.
- q_thresh: A scalar float indicating the threshold on fixed
points’ q values.
- Returns:
A tensor containing the indices into fps corresponding to the fixed points with q values exceeding the threshold.
- Usage:
idx = identify_q_outliers(fps, q_thresh) outlier_fps = fps[idx]
- sample_states(state_traj: Tensor, n_inits: int, noise_scale: float = 0.0, exclude_zero_tensors: bool = False) Tensor[source]¶
Draws random samples from trajectories of the RNN state. Samples can optionally be corrupted by independent and identically distributed (IID) Gaussian noise. These samples are intended to be used as initial states for fixed point optimizations.
- Args:
- state_traj: 1D or ND tensor containing
example trajectories of the RNN state.
n_inits: int specifying the number of sampled states to return. noise_scale (optional): non-negative float specifying the standard
deviation of IID Gaussian noise samples added to the sampled states.
- exclude_zero_tensors (bool, optional): whether to exclude zeros
tensors that may be in state_traj
- Returns:
initial_states: Sampled RNN states as a [n_inits x n_states] tensor
- Raises:
ValueError if noise_scale is negative.