A Linear Gaussian State Space Model (LGSSM) for pytorch, also known as a Linear Dynamical System (LDS)
MIT License
A minimal Linear Gaussian State Space Model for pytorch.
Supports sampling and can compute log(p(x)) for a batch of observed sequences x using the Kalman filtering algorithm
from lgssm import LinearGaussianStateSpaceModel
model = LinearGaussianStateSpaceModel(prior_mean, prior_covariance, transition_matrix, transition_covariance, observation_matrix, observation_covariance)
x, z = model.sample(8) #(8, x_dim), (8, z_dim)
logpx, steps = model.log_prob(x.unsqueeze(1)) # Computes log(p(x)) for a batch of sequences x
# steps contain the prior and posterior filtering distributions for x and z at each time step
pip install git+https://github.com/rasmusbergpalm/pytorch-lgssm