Keras and OpenAI Gym implementation of the Deep Q-learning algorithm to play Atari games.
This is an implementation in Keras and OpenAI Gym of the Deep Q-Learning algorithm (often referred to as Deep Q-Network, or DQN) by Mnih et al. on the well known Atari games.
Rather than a pre-packaged tool to simply see the agent playing the game, this is a model that needs to be trained and fine tuned by hand and has more of an educational value. This code tries to replicate the experimental setup described in the original DeepMind paper.
A similar project on Deep Q-Learning applied to videogames can be found on this repo of mine.
Make sure to cite the paper by Mnih et al. if you use this code for your research:
@article{mnih2015human,
title={Human-level control through deep reinforcement learning},
author={Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Rusu, Andrei A and Veness, Joel and Bellemare, Marc G and Graves, Alex and Riedmiller, Martin and Fidjeland, Andreas K and Ostrovski, Georg and others},
journal={Nature},
volume={518},
number={7540},
pages={529--533},
year={2015},
publisher={Nature Research}
}
and remember to give me a shoutout by linking to my Github profile, blog, or Twitter.
A big thank you to Carlo D'Eramo (@carloderamo), who helped me develop this code and spent a good amount of time debugging in the beginning.
To run the script you'll need the following dependencies:
which should all be available through Pip.
No additional setup is needed, so simply clone the repo:
git clone https://gitlab.com/danielegrattarola/deep-q-atari.git
cd deep-q-atari
A default training session can be run by typing:
python atari.py -t
which will train the model with the same parameters as described in
this Nature article,
on the MsPacmanDeterministic-v4
environment.
By running:
python atari.py -h
you'll see the options list. The possible options are:
-t, train
: train the agent;
-l, --load
: load the neural network weights from the given path;
-v, --video
: show video output;
-d, --debug
: run in debug mode (no output files);
--eval
: evaluate the agent;
-e, --environment
: name of the OpenAI Gym environment to use
(default: MsPacman-v0);
--minibatch-size
: number of sample to train the DQN at each update;
--replay-memory-size
: number of samples stored in the replay memory;
--target-network-update-freq
: frequency (number of frames) with which
the target DQN is updated;
--avg-val-computation-freq
: frequency (number of DQN updates) with which the
average reward and Q value are computed;
--discount-factor
: discount factor for the environment;
--update-freq
: frequency (number of steps) with which to train the DQN;
--learning-rate
: learning rate for the DQN;
--epsilon
: initial exploration rate for the agent;
--min-epsilon
: final exploration rate for the agent;
--epsilon-decrease
: rate at which to linearly decrease epsilon;
--replay-start-size
: minimum number of transitions (with fully random policy)
to store in the replay memory before starting training;
--initial-random-actions
: number of random actions to be performed by the
agent at the beginning of each episode;
--dropout
: dropout rate for the DQN;
--max-episodes
: maximum number of episodes that the agent can experience
before quitting;
--max-episode-length
: maximum number of steps in an episode;
--max-frames-number
: maximum number of frames for a run;
--test-freq
: frequency (number of episodes) with which to test the agent's
performance;
--validation-frames
: number of frames to test the model like in table 3 of the
paper
--test-states
: number of states on which to compute the average Q value;
The possible environments on which the agent can be trained are all the environments in the Atari gym package. A typical usage of this script on an headless server (e.g. EC2 instance) would look like this:
python atari.py -t -e BreakoutDeterministic-v4
If you want to see the actual game being played by the agent, simply add the -v
flag to the above command (note that this will obviously slow the collection of
samples.
You'll find some csv files in the output folder of the run (output/runYYYMMDD-hhmmss
)
which will contain raw data for the analysis of the agent's performance.
More specifically, the following files will be produced as output:
fit
method of Keras;--test-states
--avg-val-computation-freq
flag);ctrl+c
). You can pass any of these files as argument with the --load