import torch
import torch.nn as nn
from .fp import FixedPointCollection
from typing import Generic, TypeVar
RNN = TypeVar("RNN", bound=nn.Module)
[docs]
class FixedPointFinderBase(Generic[RNN]):
def __init__(
self,
rnn: RNN,
**kwargs,
):
"""Creates a FixedPointFinder object.
Base class that can be overwritten for different types of RNNs
This is meant for running optimization on RNN states to find fixed points
Overwrite the find_fixed_points method to find fixed points for your
specific rnn.
Base methods include sampling states, identifying outliers, and broadcasting
Args:
rnn_cell: A Pytorch RNN
"""
self.device = next(rnn.parameters()).device
self.rnn = rnn
self.batch_first = self.rnn.batch_first
# *************************************************************************
# Primary exposed functions ***********************************************
# *************************************************************************
[docs]
def sample_states(
self,
state_traj: torch.Tensor,
n_inits: int,
noise_scale: float = 0.0,
exclude_zero_tensors: bool = False,
) -> torch.Tensor:
"""Draws random samples from trajectories of the RNN state. Samples
can optionally be corrupted by independent and identically distributed
(IID) Gaussian noise. These samples are intended to be used as initial
states for fixed point optimizations.
Args:
state_traj: 1D or ND tensor containing
example trajectories of the RNN state.
n_inits: int specifying the number of sampled states to return.
noise_scale (optional): non-negative float specifying the standard
deviation of IID Gaussian noise samples added to the sampled
states.
exclude_zero_tensors (bool, optional): whether to exclude zeros
tensors that may be in state_traj
Returns:
initial_states: Sampled RNN states as a [n_inits x n_states] tensor
Raises:
ValueError if noise_scale is negative.
"""
if state_traj.dim() == 1:
state_traj = state_traj.unsqueeze(0)
# Get the batch shape of state trajectory, assumes -1 is state dim
flat_state_traj = torch.flatten(state_traj, end_dim=-2)
if exclude_zero_tensors:
non_zero_rows, _ = torch.nonzero(flat_state_traj, as_tuple=True)
non_zero_rows = torch.unique(non_zero_rows)
flat_state_traj = flat_state_traj[non_zero_rows]
rand_indices = torch.randint(high=flat_state_traj.shape[0], size=(n_inits,))
states = flat_state_traj[rand_indices]
# Add IID Gaussian noise to the sampled states
states = self._add_gaussian_noise(states, noise_scale)
assert not torch.any(torch.isnan(states)), (
"Detected NaNs in sampled states. Check state_traj and valid_bxt."
)
return states
[docs]
def find_fixed_points(self, *args, **kwargs):
"""Overwritten by subclass for network specific optimizations"""
raise NotImplementedError
# *************************************************************************
# Helper functions ********************************************************
# *************************************************************************
def _add_gaussian_noise(
self, data: torch.Tensor, noise_scale: float = 0.0
) -> torch.Tensor:
"""Adds IID Gaussian noise to Numpy data.
Args:
data: Tensor
noise_scale: (Optional) non-negative scalar indicating the
standard deviation of the Gaussian noise samples to be generated.
Default: 0.0.
Returns:
Tensor matching shape of data with noise added
Raises:
ValueError if noise_scale is negative.
"""
# Add IID Gaussian noise
if noise_scale == 0.0:
return data # no noise to add
else:
return data + noise_scale * torch.randn(*data.shape)
[docs]
@staticmethod
def identify_q_outliers(fps: FixedPointCollection, q_thresh: float) -> torch.Tensor:
"""Identify fixed points with optimized q values that exceed a
specified threshold.
Args:
fps: A FixedPoints object containing optimized fixed points and
associated metadata.
q_thresh: A scalar float indicating the threshold on fixed
points' q values.
Returns:
A tensor containing the indices into fps corresponding to
the fixed points with q values exceeding the threshold.
Usage:
idx = identify_q_outliers(fps, q_thresh)
outlier_fps = fps[idx]
"""
assert fps.qstar is not None
return torch.where(fps.qstar > q_thresh)[0]
[docs]
@staticmethod
def identify_q_non_outliers(
fps: FixedPointCollection, q_thresh: float
) -> torch.Tensor:
"""Identify fixed points with optimized q values that do not exceed a
specified threshold.
Args:
fps: A FixedPoints object containing optimized fixed points and
associated metadata.
q_thresh: A scalar float indicating the threshold on fixed points'
q values.
Returns:
A tensor containing the indices into fps corresponding to the
fixed points with q values that do not exceed the threshold.
Usage:
idx = identify_q_non_outliers(fps, q_thresh)
non_outlier_fps = fps[idx]
"""
assert fps.qstar is not None
return torch.where(fps.qstar <= q_thresh)[0]
[docs]
@staticmethod
def get_init_non_distance_outliers(
initial_states: torch.Tensor, dist_thresh: float
) -> torch.Tensor:
"""
get initial states that are far from centroid based on threshold.
Args:
initial_states (Tensor): initial states of fp optimization [n, state_dim]
dist_thresh (float): Threshold from initial states which is far.
Returns:
init_non_outlier_idx (Tensor): indices to initial_states tensor inside threshold
"""
# Centroid of initial_states, shape (n_states,)
centroid = torch.mean(initial_states, dim=0)
# Distance of each initial state from the centroid, shape (n,)
init_dists = torch.linalg.norm(initial_states - centroid, axis=1)
avg_init_dist = torch.mean(init_dists)
# Normalized distances of initial states to the centroid, shape: (n,)
scaled_init_dists = torch.true_divide(init_dists, avg_init_dist)
init_non_outlier_idx = torch.where(scaled_init_dists < dist_thresh)[0]
return init_non_outlier_idx
[docs]
@staticmethod
def get_fp_non_distance_outliers(
fps: FixedPointCollection, initial_states: torch.Tensor, dist_thresh: float
) -> torch.Tensor:
"""
get fixed points that are far from initial states based on threshold.
Args:
fps (FixedPointCollection): fps discovered [n, state_dim]
initial_states (Tensor): initial states of optimization [n, state_dim]
dist_thresh (float): threshold at which fixed points are considered far
Returns:
fsp_non_outlier_distance (Tensor): indices to fps object that are not far
"""
# Centroid of initial_states, shape (n_states,)
centroid = torch.mean(initial_states, dim=0)
# Distance of each initial state from the centroid, shape (n,)
init_dists = torch.linalg.norm(initial_states - centroid, axis=1)
avg_init_dist = torch.mean(init_dists)
# Distance of each FP from the initial_states centroid
fps_dists = torch.linalg.norm(fps.xstar - centroid, axis=1)
# Normalized
scaled_fps_dists = torch.true_divide(fps_dists, avg_init_dist)
fps_non_outlier_idx = torch.where(scaled_fps_dists < dist_thresh)[0]
return fps_non_outlier_idx
def _broadcast_nxd(self, data: torch.Tensor, tile_n: int = 1) -> torch.Tensor:
"""
Takes in a tensor of shape [..., d] and reshapes to nxd
tiles by tile_n if 1D
"""
# Broadcast to [n, d]
if data.dim() == 1:
# If only 1d, then tile
data = torch.tile(data, [tile_n, 1])
else:
# If > 1d, then flatten up to last dim
data = torch.flatten(data, end_dim=-2)
# Ensure proper device and dtype
data = data.to(self.device)
return data