import torch
import torch.nn as nn
import numpy as np
import time
from copy import deepcopy
from .fp import FixedPointCollection
from rnntoolkit.fixed_points.fp_finder_base import FixedPointFinderBase
[docs]
class FixedPointFinder(FixedPointFinderBase):
_default_hps = {
"lr_init": 1e-4,
"lr_patience": 5,
"lr_factor": 0.95,
"lr_cooldown": 0,
"tol_q": 1e-12,
"tol_dq": 1e-20,
"max_iters": 5000,
"do_rerun_q_outliers": False,
"outlier_q_scale": 10.0,
"do_exclude_distance_outliers": True,
"outlier_distance_scale": 10.0,
"tol_unique": 1e-3,
"max_n_unique": np.inf,
"dtype": "float32",
"random_seed": 0,
"verbose": True,
"super_verbose": False,
"n_iters_per_print_update": 100,
}
[docs]
@classmethod
def default_hps(cls):
"""Returns a deep copy of the default hyperparameters dict.
The deep copy protects against external updates to the defaults, which
in turn protects against unintended interactions with the hashing done
by the Hyperparameters class.
Args:
None.
Returns:
dict of hyperparameters.
"""
return deepcopy(cls._default_hps)
def __init__(
self,
rnn: nn.RNN,
lr_init: float = _default_hps["lr_init"],
lr_patience: float = _default_hps["lr_patience"],
lr_factor: float = _default_hps["lr_factor"],
lr_cooldown: float = _default_hps["lr_cooldown"],
tol_q: float = _default_hps["tol_q"],
tol_dq: float = _default_hps["tol_dq"],
max_iters: int = _default_hps["max_iters"],
do_rerun_q_outliers: bool = _default_hps["do_rerun_q_outliers"],
outlier_q_scale: float = _default_hps["outlier_q_scale"],
do_exclude_distance_outliers: bool = _default_hps[
"do_exclude_distance_outliers"
],
outlier_distance_scale: float = _default_hps["outlier_distance_scale"],
tol_unique: float = _default_hps["tol_unique"],
max_n_unique: int = _default_hps["max_n_unique"],
dtype: str = _default_hps["dtype"],
random_seed: int = _default_hps["random_seed"],
verbose: bool = _default_hps["verbose"],
super_verbose: bool = _default_hps["super_verbose"],
n_iters_per_print_update: int = _default_hps["n_iters_per_print_update"],
):
super().__init__(rnn)
"""Creates a FixedPointFinder object.
Inherited from FixedPointFinderBase
Optimization terminates once every initialization satisfies one or
both of the following criteria:
1. q < tol_q
2. dq < tol_dq * learning_rate
Args:
rnn_cell: A Pytorch RNN
tol_q (optional): A positive scalar specifying the optimization
termination criteria on each q-value. Default: 1e-12.
tol_dq (optional): A positive scalar specifying the optimization
termination criteria on the improvement of each q-value (i.e.,
"dq") from one optimization iteration to the next.
max_iters (optional): A non-negative integer specifying the
maximum number of gradient descent iterations allowed.
do_rerun_q_outliers (optional): A bool indicating whether or not
to run additional optimization iterations on putative outlier
states
outlier_q_scale (optional): A positive float specifying the q
value for putative outlier fixed points, relative to the median q
value across all identified fixed points. Default: 10.
do_exclude_distance_outliers (optional): A bool indicating
whether or not to discard states that are far away from the set
of initial states
outlier_distance_scale (optional): A positive float specifying a
normalized distance cutoff used to exclude distance outliers
tol_unique (optional): A positive scalar specifying the numerical
precision required to label two fixed points as being unique from
one another.
max_n_unique (optional): A positive integer indicating the max
number of unique fixed points to keep.
dtype: string indicating the data type to use for all numerical ops
and objects. Default: 'float32'
random_seed: Seed for numpy random number generator. Default: 0.
verbose (optional): A bool indicating whether to print high-level
status updates. Default: True.
super_verbose (optional): A bool indicating whether or not to
print per-iteration updates during each optimization. Default:
False.
n_iters_per_print_update (optional): An int specifying how often
to print updates during the fixed point optimizations. Default:
100.
"""
self.dtype = dtype
self.device = next(rnn.parameters()).device
self.torch_dtype = getattr(torch, self.dtype)
# Make random sequences reproducible
self.random_seed = random_seed
self.rng = np.random.RandomState(random_seed)
# *********************************************************************
# Optimization hyperparameters ****************************************
# *********************************************************************
self.lr_init = lr_init
self.lr_patience = lr_patience
self.lr_factor = lr_factor
self.lr_cooldown = lr_cooldown
self.tol_q = tol_q
self.tol_dq = tol_dq
self.max_iters = max_iters
self.do_rerun_q_outliers = do_rerun_q_outliers
self.outlier_q_scale = outlier_q_scale
self.do_exclude_distance_outliers = do_exclude_distance_outliers
self.outlier_distance_scale = outlier_distance_scale
self.tol_unique = tol_unique
self.max_n_unique = max_n_unique
self.verbose = verbose
self.super_verbose = super_verbose
self.n_iters_per_print_update = n_iters_per_print_update
# *************************************************************************
# Primary exposed functions ***********************************************
# *************************************************************************
[docs]
def find_fixed_points(
self,
initial_states: torch.Tensor,
ext_inputs: torch.Tensor,
n_rounds_q_opt: int = 1,
) -> tuple[FixedPointCollection, FixedPointCollection]:
"""Finds RNN fixed points and the Jacobians at the fixed points.
Args:
initial_states: Tensor specifying the initial
states of the RNN, from which the optimization will search for
fixed points.
ext_inputs: external inputs to the RNN
n_rounds_q_opt: Number of rounds to run extra iterations on q
outliers
Returns:
unique_fps: A FixedPoints object containing the set of unique
fixed points after optimizing from all initial_states
all_fps: A FixedPoints object containing the likely redundant set
of fixed points (and associated metadata) resulting from ALL
initializations in initial_states
"""
all_fps = self._fp_optimization(
initial_states,
ext_inputs,
)
# Filter out duplicates after from the first optimization round
unique_fps = all_fps.get_unique()
self._print_if_verbose("\tIdentified %d unique fixed points." % unique_fps.n)
if self.do_exclude_distance_outliers:
unique_fps = self._exclude_distance_outliers(unique_fps, initial_states)
# Optionally run additional optimization iterations on identified
# fixed points with q values on the large side of the q-distribution.
if self.do_rerun_q_outliers:
unique_fps = self._run_additional_iterations_on_outliers(
unique_fps,
n_rounds=n_rounds_q_opt,
)
# Filter out duplicates after from the second optimization round
unique_fps = unique_fps.get_unique()
# Optionally subselect from the unique fixed points (e.g., for
# computational savings when not all are needed.)
if unique_fps.n > self.max_n_unique:
self._print_if_verbose(
"\tRandomly selecting %d unique "
"fixed points to keep." % self.max_n_unique
)
max_n_unique = int(self.max_n_unique)
idx_keep = list(self.rng.choice(unique_fps.n, max_n_unique, replace=False))
unique_fps = unique_fps[idx_keep]
self._print_if_verbose("\tFixed point finding complete.\n")
return unique_fps, all_fps
# *************************************************************************
# Helper functions ********************************************************
# *************************************************************************
def _exclude_distance_outliers(
self, fps: FixedPointCollection, initial_states: torch.Tensor
) -> FixedPointCollection:
"""Removes putative distance outliers from a set of fixed points.
See docstring for identify_distance_non_outliers(...).
"""
idx_keep = self.get_fp_non_distance_outliers(
fps, initial_states, self.outlier_distance_scale
)
return fps[idx_keep.tolist()]
def _run_additional_iterations_on_outliers(
self, fps: FixedPointCollection, n_rounds: int = 1
) -> FixedPointCollection:
"""Detects outlier states with respect to the q function and runs
additional optimization iterations on those states This should only be
used after calling either _run_joint_optimization or
_run_sequential_optimizations.
Args:
fps: A FixedPoints object containing (partially) optimized
fixed points and associated metadata.
stim_inp: additional stimulus to give network during optimization
W_rec: replaces self.mrnn.W_rec during forward pass
W_inp: replaces self.mrnn.W_inp during forward pass
Returns:
A FixedPoints object containing the further-optimized fixed points
and associated metadata.
"""
assert fps.qstar is not None
outlier_min_q = float(np.median(fps.qstar) * self.outlier_q_scale)
def perform_outlier_optimization(
fps: FixedPointCollection,
) -> FixedPointCollection:
idx_outliers = self.identify_q_outliers(fps, outlier_min_q)
outlier_fps = fps[idx_outliers.tolist()]
n_prev_iters = outlier_fps.n_iters
inputs = outlier_fps.inputs
initial_states = outlier_fps.xstar
self._print_if_verbose(
"\tPerforming another round of "
"joint optimization, "
"over outlier states only."
)
assert inputs is not None
assert n_prev_iters is not None
updated_outlier_fps = self._fp_optimization(initial_states, inputs)
assert updated_outlier_fps.n_iters is not None
updated_outlier_fps.n_iters += n_prev_iters
fps[idx_outliers.tolist()] = updated_outlier_fps
return fps
def outlier_update(fps: FixedPointCollection) -> torch.Tensor:
idx_outliers = self.identify_q_outliers(fps, outlier_min_q)
n_outliers = len(idx_outliers)
self._print_if_verbose(
"\n\tDetected %d putative outliers "
"(q>%.2e)." % (n_outliers, outlier_min_q)
)
return idx_outliers
idx_outliers = outlier_update(fps)
if len(idx_outliers) == 0:
return fps
for _ in range(n_rounds):
fps = perform_outlier_optimization(fps)
idx_outliers = outlier_update(fps)
if len(idx_outliers) == 0:
return fps
return fps
def _print_if_verbose(self, *args, **kwargs):
if self.verbose:
print(*args, **kwargs)
@classmethod
def _print_iter_update(
cls,
iter_count: int,
t_start: float,
q: torch.Tensor,
dq: torch.Tensor,
lr: float,
is_final: bool = False,
):
t = time.time()
t_elapsed = t - t_start
avg_iter_time = t_elapsed / iter_count
if is_final:
delimiter = "\n\t\t"
print("\t\t%d iters%s" % (iter_count, delimiter), end="")
else:
delimiter = ", "
print("\tIter: %d%s" % (iter_count, delimiter), end="")
if q.size == 1:
print("q = %.2e%sdq = %.2e%s" % (q, delimiter, dq, delimiter), end="")
else:
mean_q = torch.mean(q)
std_q = torch.std(q)
mean_dq = torch.mean(dq)
std_dq = torch.std(dq)
print(
"q = %.2e +/- %.2e%s"
"dq = %.2e +/- %.2e%s"
% (mean_q, std_q, delimiter, mean_dq, std_dq, delimiter),
end="",
)
print("learning rate = %.2e%s" % (lr, delimiter), end="")
print("avg iter time = %.2e sec" % avg_iter_time, end="")
def _fp_optimization(
self,
initial_states: torch.Tensor,
ext_inp: torch.Tensor,
) -> FixedPointCollection:
"""Finds multiple fixed points via a joint optimization over multiple
state vectors.
Args:
initial_states: Tensor specifying the initial
states of the RNN, from which the optimization will search for
fixed points.
ext_inp: Tensor specifying a set of constant
inputs into the RNN.
Returns:
fps: A FixedPoints object containing the optimized fixed points
and associated metadata.
"""
# Get batch and time dims
if self.batch_first:
TIME_DIM = 1
else:
TIME_DIM = 0
initial_states = self._broadcast_nxd(initial_states, tile_n=1)
# Get batch size of states
n = initial_states.shape[0]
# Broadcast external input to [n, 1, d]
ext_inp = self._broadcast_nxd(ext_inp, tile_n=n)
ext_inp = ext_inp.unsqueeze(TIME_DIM)
# assert the correct batch shapes
assert ext_inp.shape[0] == initial_states.shape[0]
self._print_if_verbose(
"\nSearching for fixed points from %d initial states.\n" % n
)
# Ensure that fixed point optimization does not alter RNN parameters.
print(
"\tFreezing model parameters so model is not affected by fixed point optimization."
)
for p in self.rnn.parameters():
p.requires_grad = False
ext_inp.requires_grad = False
initial_states.requires_grad = True
self._print_if_verbose("\tFinding fixed points via joint optimization.")
init_lr = self.lr_init
optimizer = torch.optim.Adam([initial_states], lr=self.lr_init)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.25, patience=10, cooldown=0, threshold=1e-10
)
iter_count = 1
iter_learning_rate = init_lr
t_start = time.time()
q_prev_b = torch.full((n,), float("nan"), device=self.device)
while True:
h = initial_states.clone()
# currently only works for 1 layer rnns
_, F_x_1xbxd = self.rnn(
ext_inp,
h.unsqueeze(0),
)
F_x_1xbxd = F_x_1xbxd.squeeze(0)
dx_bxd = h - F_x_1xbxd
q_b = 0.5 * torch.sum(torch.square(dx_bxd), dim=-1)
q_scalar = torch.mean(q_b)
dq_b = torch.abs(q_b - q_prev_b)
optimizer.zero_grad()
q_scalar.backward()
optimizer.step()
scheduler.step(metrics=q_scalar.detach())
iter_learning_rate = scheduler.state_dict()["_last_lr"][0]
ev_q_b = q_b.detach().cpu()
ev_dq_b = dq_b.detach().cpu()
if (
self.super_verbose
and np.mod(iter_count, self.n_iters_per_print_update) == 0
):
self._print_iter_update(
iter_count, t_start, ev_q_b, ev_dq_b, iter_learning_rate
)
if iter_count > 1 and torch.all(
torch.logical_or(
ev_dq_b < self.tol_dq * iter_learning_rate, ev_q_b < self.tol_q
)
):
"""Here dq is scaled by the learning rate. Otherwise very
small steps due to very small learning rates would spuriously
indicate convergence. This scaling is roughly equivalent to
measuring the gradient norm."""
self._print_if_verbose("\tOptimization complete to desired tolerance.")
break
if iter_count + 1 > self.max_iters:
self._print_if_verbose(
"\tMaximum iteration count reached. Terminating."
)
break
q_prev_b = q_b
iter_count += 1
if self.verbose:
self._print_iter_update(
iter_count, t_start, ev_q_b, ev_dq_b, iter_learning_rate, is_final=True
)
xstar = initial_states.detach().cpu()
F_xstar = F_x_1xbxd.detach().cpu()
print(F_xstar.shape)
# Indicate same n_iters for each initialization (i.e., joint optimization)
n_iters = torch.tile(torch.tensor([iter_count]), dims=(F_xstar.shape[0],))
inputs_bxd = ext_inp.squeeze(TIME_DIM)
fps = FixedPointCollection(
xstar=xstar,
x_init=initial_states,
inputs=inputs_bxd,
F_xstar=F_xstar,
qstar=ev_q_b,
dq=ev_dq_b,
n_iters=n_iters,
tol_unique=self.tol_unique,
dtype=self.torch_dtype,
)
return fps