import torch
import numpy as np
import pickle
from typing import Self
[docs]
class FixedPointCollection:
"""
A class for storing fixed points and associated data.
"""
""" List of class attributes that represent data corresponding to fixed
points. All of these refer to Numpy arrays with axis 0 as the batch
dimension. Thus, each is concatenatable using np.concatenate(..., axis=0).
"""
_data_attrs_fp = [
"xstar",
"x_init",
"inputs",
"F_xstar",
"qstar",
"dq",
"n_iters",
]
""" List of class attributes that apply to all fixed points
(i.e., these are not indexed per fixed point). """
_nonspecific_attrs = [
"dtype",
"dtype_complex",
"tol_unique",
"verbose",
]
def __init__(
self,
xstar: torch.Tensor, # Fixed-point specific data
x_init: torch.Tensor | None = None,
inputs: torch.Tensor | None = None,
F_xstar: torch.Tensor | None = None,
qstar: torch.Tensor | None = None,
dq: torch.Tensor | None = None,
n_iters: torch.Tensor | None = None,
tol_unique: float = 1e-3,
dtype=torch.float32,
dtype_complex=torch.complex64,
verbose: bool = False,
):
"""
Initializes a FixedPoints object with all input arguments as class
properties.
Optional args:
xstar: [n x n_states] tensor with row xstar[i, :]
specifying an the fixed point identified from x_init[i, :].
x_init: [n x n_states] tensor with row x_init[i, :]
specifying the initial state from which xstar[i, :] was optimized.
inputs: [n x n_inputs] tensor with row inputs[i, :]
specifying the input to the RNN during the optimization of
xstar[i, :].
F_xstar: [n x n_states] tensor with F_xstar[i, :]
specifying RNN state after transitioning from the fixed point in
xstar[i, :].
qstar: [n,] tensor with qstar[i] containing the
optimized objective (1/2)(x-F(x))^T(x-F(x))
dq: [n,] tensor with dq[i] containing the absolute
difference in the objective function after (i.e., qstar[i]) vs
before the final gradient descent step of the optimization of
xstar[i, :]
n_iters: [n,] tensor with n_iters[i] as the number of
gradient descent iterations completed to yield xstar[i, :].
tol_unique: Positive scalar specifying the numerical precision
required to label two fixed points as being unique from one
another.
dtype: Data type for representing all of the object's data.
dtype_complex: Data type for representing complex values.
verbose: Bool indicating whether to print status updates.
"""
# These apply to all fixed points
# (one value each, rather than one value per fixed point).
self.tol_unique = tol_unique
self.dtype = dtype
self.dtype_complex = dtype_complex
self.verbose = verbose
self.n, self.n_states = xstar.shape
if inputs is not None:
self.n_inputs = inputs.shape[1]
else:
self.n_inputs = None
# Common FP attributes
self.xstar = xstar.clone()
self.x_init = x_init.clone() if x_init is not None else x_init
self.inputs = inputs.clone() if inputs is not None else inputs
self.F_xstar = F_xstar.clone() if F_xstar is not None else F_xstar
self.qstar = qstar.clone() if qstar is not None else qstar
self.dq = dq.clone() if dq is not None else dq
self.n_iters = n_iters.clone() if n_iters is not None else n_iters
self.assert_valid_shapes()
def __setitem__(self, index: int | list | slice, fps: Self):
"""Indexes into a subset of the fixed points and their associated data.
Usage:
fps_subset = fps[index]
Args:
index: a slice object for indexing into the FixedPoints data.
Returns:
A FixedPoints object containing a subset of the data from the
current FixedPoints object, as specified by index.
"""
assert isinstance(fps, type(self)), (
"fps must be a self object but was %s." % type(fps)
)
if isinstance(index, tuple):
assert len(index) <= 1 # only 1d slicing allowed
# Preserve shape of objects if index is an int
if isinstance(index, int):
index = slice(index, index + 1)
for attr_name in self._data_attrs_fp:
attr_val = getattr(self, attr_name)
if attr_val is not None:
attr_val[index] = getattr(fps, attr_name)
print(attr_name, attr_val is self.xstar)
def __getitem__(self, index: int | list | slice) -> Self:
"""Indexes into a subset of the fixed points and their associated data.
Usage:
fps_subset = fps[index]
Args:
index: a slice object for indexing into the FixedPoints data.
Returns:
A FixedPoints object containing a subset of the data from the
current FixedPoints object, as specified by index.
"""
if isinstance(index, tuple):
assert len(index) <= 1 # only 1d slicing allowed
# Preserve shape of objects if index is an int
if isinstance(index, int):
index = slice(index, index + 1)
kwargs = self._nonspecific_kwargs
for attr_name in self._data_attrs_fp:
attr_val = getattr(self, attr_name)
indexed_val = self._safe_index(attr_val, index)
kwargs[attr_name] = indexed_val.clone() if indexed_val is not None else None
return type(self)(**kwargs)
def __len__(self) -> int:
"""Returns the number of fixed points stored in the object."""
return self.n
def __contains__(self, fp: Self) -> bool:
"""Checks whether a specified fixed point is contained in the object.
Args:
fp: A FP object containing exactly one fixed point.
Returns:
bool indicating whether any fixed point matches fp.
"""
idx = self.find(fp)
return idx.numel() > 0
[docs]
def get_unique(self, use_F_xstar: bool = False) -> Self:
"""
Identifies unique fixed points. Among duplicates identified,
this keeps the one with smallest qstar.
Args:
use_F_xstar (bool): Whether to find unique on F_xstar instead of xstar
Returns:
A FixedPoints object containing only the unique fixed points and
their associated data. Uniqueness is determined down to tol_unique.
"""
assert self.xstar is not None, (
"Cannot find unique fixed points because self.xstar is None."
)
idx_keep = []
idx_checked = torch.zeros(size=(self.n,), dtype=torch.bool)
for idx in range(self.n):
if idx_checked[idx]:
# If this FP matched others, we've already determined which
# of those matching FPs to keep. Repeating would simply
# identify the same FP to keep.
continue
# Don't compare against FPs we've already checked
idx_check = torch.where(~idx_checked)[0]
fps_check = self[idx_check.tolist()] # only check against these FPs
idx_idx_check = fps_check.find(
self[idx], use_F_xstar=use_F_xstar
) # indexes into fps_check
idx_match = idx_check[idx_idx_check] # indexes into self
if len(idx_match) == 1:
# Only matches with itself
idx_keep.append(idx)
else:
qstars_match = self._safe_index(self.qstar, idx_match.tolist())
assert isinstance(qstars_match, torch.Tensor)
idx_candidate = idx_match[torch.argmin(qstars_match)]
idx_keep.append(int(idx_candidate))
idx_checked[idx_match] = True
return self[idx_keep]
[docs]
def find(self, fp: Self, use_F_xstar: bool = False) -> torch.Tensor:
"""Searches in the current FixedPoints object for matches to a
specified fixed point. Two fixed points are defined as matching
if the 2-norm of the difference between their concatenated (xstar,
inputs) is within tol_unique).
Args:
fp: A FixedPoints object containing exactly one fixed point.
Returns:
shape (n_matches,) tensor specifying indices into the current
FixedPoints object where matches to fp were found.
"""
# If not found or comparison is impossible (due to type or shape),
# follow convention of torch.where and return an empty tensor.
result = torch.empty(0, dtype=torch.int)
if isinstance(fp, FixedPointCollection):
# return if either object does not contain F_xstar
if use_F_xstar and (self.F_xstar is None or fp.F_xstar is None):
return result
if fp.n_states == self.n_states and fp.n_inputs == self.n_inputs:
if self.inputs is None:
self_data_nxd = self.F_xstar if use_F_xstar else self.xstar
arg_data_nxd = fp.F_xstar if use_F_xstar else fp.xstar
else:
self_state = self.F_xstar if use_F_xstar else self.xstar
other_state = fp.F_xstar if use_F_xstar else fp.xstar
self_data_nxd = torch.cat((self_state, self.inputs), dim=1)
assert isinstance(fp.inputs, torch.Tensor)
arg_data_nxd = torch.cat((other_state, fp.inputs), dim=1)
assert self_data_nxd is not None
norm_diffs_n = torch.linalg.norm(self_data_nxd - arg_data_nxd, axis=1)
result = torch.where(norm_diffs_n <= self.tol_unique)[0]
return result
[docs]
def update(self, new_fps: Self):
"""Combines the entries from another FixedPoints object into this
object.
Args:
new_fps: a FixedPoints object containing the entries to be
incorporated into this FixedPoints object.
Returns:
None
Raises:
AssertionError if the non-fixed-point specific attributes of
new_fps do not match those of this FixedPoints object.
AssertionError if any data attributes are found in one but not both
FixedPoints objects (especially relevant for decomposed Jacobians).
AssertionError if the updated object has inconsistent data shapes.
"""
self._assert_matching_nonspecific_attrs([new_fps])
for attr_name in self._data_attrs_fp:
this_has = hasattr(self, attr_name)
that_has = hasattr(new_fps, attr_name)
self_attr = getattr(self, attr_name)
other_attr = getattr(new_fps, attr_name)
assert this_has == that_has, (
"One but not both FixedPoints objects have %s. "
"FixedPoints.update does not currently support this "
"configuration." % attr_name
)
# Ensure both objects have the attribute
if this_has and that_has:
# Now check whether, one, both, or neither attrs are None
if self_attr is None and other_attr is None:
setattr(self, attr_name, None)
elif self_attr is None and other_attr is not None:
setattr(self, attr_name, other_attr)
elif self_attr is not None and other_attr is None:
setattr(self, attr_name, self_attr)
else:
cat_attr = torch.cat((self_attr, other_attr), dim=0)
setattr(self, attr_name, cat_attr)
self.n = self.n + new_fps.n
self.assert_valid_shapes()
[docs]
def save(self, save_path: str):
"""Saves all data contained in the FixedPoints object.
Args:
save_path: A string containing the path at which to save
(including directory, filename, and arbitrary extension).
Returns:
None.
"""
if self.verbose:
print("Saving FixedPoints object.")
self.assert_valid_shapes()
file = open(save_path, "wb")
file.write(pickle.dumps(self.__dict__))
file.close()
[docs]
def restore(self, restore_path: str):
"""Restores data from a previously saved FixedPoints object.
Args:
restore_path: A string containing the path at which to find a
previously saved FixedPoints object (including directory, filename,
and extension).
Returns:
None.
"""
if self.verbose:
print("Restoring FixedPoints object.")
file = open(restore_path, "rb")
restore_data = file.read()
file.close()
self.__dict__ = pickle.loads(restore_data)
self.assert_valid_shapes()
[docs]
def assert_valid_shapes(self):
"""Checks that all data attributes reflect the same number of fixed
points.
Raises:
AssertionError if any non-None data attribute does not have
.shape[0] as self.n.
"""
n = self.n
for attr_name in FixedPointCollection._data_attrs_fp:
data = getattr(self, attr_name)
if data is not None:
assert data.shape[0] == self.n, (
"Detected %d fixed points, but %s.shape is %s "
"(shape[0] should be %d" % (n, attr_name, str(data.shape), n)
)
[docs]
@staticmethod
def concatenate(fps_seq: list):
"""Join a sequence of FixedPoints objects.
Args:
fps_seq: sequence of FixedPoints objects. All FixedPoints objects
must have the following attributes in common:
n_states
n_inputs
Returns:
A FixedPoints objects containing the concatenated FixedPoints data.
"""
assert len(fps_seq) > 0, "Cannot concatenate empty list."
FixedPointCollection._assert_matching_nonspecific_attrs(fps_seq)
kwargs = {}
for attr_name in FixedPointCollection._nonspecific_attrs:
kwargs[attr_name] = getattr(fps_seq[0], attr_name)
for attr_name in FixedPointCollection._data_attrs_fp:
if all((hasattr(fps, attr_name) for fps in fps_seq)):
cat_list = [getattr(fps, attr_name) for fps in fps_seq]
if all([attr is None for attr in cat_list]):
cat_attr = None
elif any([attr is None for attr in cat_list]):
cat_attr = None
else:
cat_attr = torch.cat(cat_list)
kwargs[attr_name] = cat_attr
return FixedPointCollection(**kwargs)
@property
def is_single_fixed_point(self) -> bool:
return self.n == 1
@property
def kwargs(self) -> dict:
"""Returns dict of keyword arguments necessary for reinstantiating a
(shallow) copy of this FixedPoints object, i.e.,
fp_copy = FixedPoints(**fp.kwargs)
"""
kwargs = self._nonspecific_kwargs
for attr_name in self._data_attrs_fp:
kwargs[attr_name] = getattr(self, attr_name)
return kwargs
@staticmethod
def _assert_matching_nonspecific_attrs(fps_seq: list):
for attr_name in FixedPointCollection._nonspecific_attrs:
items = [getattr(fps, attr_name) for fps in fps_seq]
for item in items:
assert item == items[0], (
"Cannot concatenate FixedPoints because of mismatched %s "
"(%s is not %s)" % (attr_name, str(items[0]), str(item))
)
@staticmethod
def _safe_index(
x: torch.Tensor | None, idx: int | list | slice
) -> torch.Tensor | None:
"""Safe method for indexing into a tensor that might be None.
Args:
x: Either None or a tensor.
idx: Positive int or index-compatible argument for indexing into x.
"""
if x is None:
return None
else:
return x[idx]
@property
def _nonspecific_kwargs(self) -> dict:
# These are not specific to individual fixed points.
# Thus, simple copy, no indexing required
return {"dtype": self.dtype, "tol_unique": self.tol_unique}