world-models-ppo

PyTorch World Model implementation with PPO.

Stars
9

world-models-ppo

World Model implementation with PPO in PyTorch. This repository builds on world-models for the VAE and MDN-RNN implementations and firedup for the PPO optimization of the Controller network. Check the firedup setup file for requirements.

First save a number of the CarRacing-v0 Gym environment rollouts used for the train and test sets in the data_dir folder:

python env/carracing.py --data_dir './env/data' ---n_fold_train 20 ---n_fold_test 1

Then train the Variational Autoencoder (VAE) using the stored rollouts:

from vae.train import run
run(data_dir='./env/data', vae_dir='./vae/model', epochs=5)

Using the pretrained VAE, we train the Recurrent Mixture Density Network (MDN-RNN) model to predict the future latent state:

from mdnrnn.train import run
run(data_dir='./env/data', vae_dir='./vae/model', mdnrnn_dir='./mdnrnn/model', epochs=5)

We can finally train the Controller network which steers the car with PPO:

from rl.algos.ppo.ppo import run
run(exp_name='carracing_ppo', epochs=100)
Related Projects