Deep Q Learning

#

This is a Deep Q Learning implementation with: Double Q Network Dueling Network * Prioritized Replay

It is based on OpenAI Baselines implementation. I have taken some inspiration from Berkley Deep RL Course too.

It’s hardcoded for Atari Breakout, and tested with TensorFlow 1.7

There are two imports:

I terminate episodes after one life is lost. So the episode reward is total reward for a single life.

If someone reading this has any questions or comments please find me on Twitter, @vpj.

#

Imports

import time
from collections import deque

import io
import numpy as np
import random
import tensorflow as tf
from matplotlib import pyplot
from pathlib import Path, PurePath
from typing import Dict, Union

from util import Orthogonal, huber_loss, PiecewiseSchedule
from worker import Worker
#

I was using a computer with two GPUs and I wanted TensorFlow to use only one of them.

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#

Neural Network Model for $Q$ Values

Dueling Network ⚔️

We are using a dueling network to calculate Q-values. Intuition behind dueling network architure is that in most states the action doesn’t matter, and in some states the action is significant. Dueling network allows this to be represented very well.

So we create two networks for $V$ and $A$ and get $Q$ from them. We share the initial layers of the $V$ and $A$ networks.

$\epsilon$-greedy Sampling

When sampling actions we use a $\epsilon$-greedy strategy, where we take a greedy action with probabiliy $1 - \epsilon$ and take a random action with probability $\epsilon$. We refer to $\epsilon$ as exploration.

class Model(object):
#
#

Initialize

We need scope because we need multiple copies of variables for target network and training network.

    def __init__(self, *, scope: str, reuse: bool, batch_size: int,
                 scaled_images: tf.Tensor = None):
#
#

If scaled input is provided we use that, otherwise we process the observation from the game

        if scaled_images is None:
#

observations input (B, 84, 84, 4)

            self.obs = tf.placeholder(shape=(batch_size, 84, 84, 4),
                                      name="obs",
                                      dtype=tf.uint8)
            obs_float = tf.to_float(self.obs, name="obs_float")
#

scale image values to [0, 1] from [0, 255]

            self.scaled_images = tf.cast(obs_float, tf.float32) / 255.
        else:
            self.scaled_images = scaled_images
#

exploration, $\epsilon$, the probability of making a random action

        self.exploration_fraction = tf.placeholder(shape=[],
                                                   name="epsilon",
                                                   dtype=tf.float32)

        with tf.variable_scope(scope, reuse=reuse):
#

flattened output of the convolution network

            with tf.variable_scope("convolution_network"):
                self.h = Model._cnn(self.scaled_images)
#

$A(s,a)$

            with tf.variable_scope("action_value"):
                self.action_score = Model._create_action_score(self.h, 4)
#

$V(s)$

            with tf.variable_scope("state_value"):
                self.state_score = Model._create_state_score(self.h)
#

all trainable variables in this scope. I previously didn’t indicate the scope and it took all trainable variables.

            self.params = tf.trainable_variables(scope=scope)
#

$Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)\Big)$

        action_score_mean = tf.reduce_mean(self.action_score, axis=1)
        action_score_centered = self.action_score - tf.expand_dims(action_score_mean, axis=1)
        self.q = self.state_score + action_score_centered
#

greedy action

        greedy_action = tf.argmax(self.q, axis=1)
#

random action

        random_action = tf.random_uniform([batch_size], minval=0, maxval=4, dtype=tf.int64)
#

choose random action with probability $\epsilon$

        random_uniform = tf.random_uniform([batch_size],
                                           minval=0,
                                           maxval=1,
                                           dtype=tf.float32)
        is_choose_random = random_uniform < self.exploration_fraction
#

$\epsilon$-greedy action

        self.action = tf.where(is_choose_random, random_action, greedy_action)
#

Convolutional Neural Network

    @staticmethod
    def _cnn(scaled_images: tf.Tensor):
#
#

three convolution layers

        h1 = tf.layers.conv2d(scaled_images,
                              name="conv1",
                              filters=32,
                              kernel_size=8,
                              kernel_initializer=Orthogonal(scale=np.sqrt(2)),
                              strides=4,
                              padding="valid",
                              activation=tf.nn.relu)

        h2 = tf.layers.conv2d(h1,
                              name="conv2",
                              filters=64,
                              kernel_size=4,
                              kernel_initializer=Orthogonal(scale=np.sqrt(2)),
                              strides=2,
                              padding="valid",
                              activation=tf.nn.relu)

        h3 = tf.layers.conv2d(h2,
                              name="conv3",
                              filters=64,
                              kernel_size=3,
                              kernel_initializer=Orthogonal(scale=np.sqrt(2)),
                              strides=1,
                              padding="valid",
                              activation=tf.nn.relu)
#

flatten the output of the convolution network

        nh = np.prod([v.value for v in h3.get_shape()[1:]])
        flat = tf.reshape(h3, [-1, nh])

        return flat
#

$A(s,a)$ head

    @staticmethod
    def _create_action_score(flat: tf.Tensor, n: int) -> tf.Tensor:
#

fully connected layer

        h = tf.layers.dense(flat, 256,
                            activation=tf.nn.relu,
                            kernel_initializer=Orthogonal(scale=np.sqrt(2)),
                            name="hidden")

        return tf.layers.dense(h, n,
                               activation=None,
                               kernel_initializer=Orthogonal(scale=0.01),
                               name="scores")
#

$V(s)$ head

    @staticmethod
    def _create_state_score(flat: tf.Tensor) -> tf.Tensor:
#

fully connected layer

        h = tf.layers.dense(flat, 256,
                            activation=tf.nn.relu,
                            kernel_initializer=Orthogonal(scale=np.sqrt(2)),
                            name="hidden")

        value = tf.layers.dense(h, 1,
                                activation=None,
                                kernel_initializer=Orthogonal(),
                                name="score")
        return value
#

Evaluate $Q(s, a)$ for all actions at state obs

    def evaluate(self, session: tf.Session, obs: np.ndarray) -> tf.Tensor:
#
        return session.run(self.q,
                           feed_dict={self.obs: obs})
#

Sample $\epsilon#-greedy action for obs

    def sample(self, session: tf.Session, obs: np.ndarray, exploration_fraction: float) -> tf.Tensor:
#
        return session.run(self.action,
                           feed_dict={self.obs: obs, self.exploration_fraction: exploration_fraction})
#

Trainer

We want to find optimal action-value function.

Target network 🎯

In order to improve stability we use experience replay that randomly sample from previous experience $U(D)$. We also use a Q network with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target. $\hl1{\theta_i^{-}}$ is updated periodically. This is according to the paper by DeepMind.

So the loss function is,

Double $Q$-Learning

The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, We use double Q-learning, where the $\operatorname{argmax}$ is taken from $\theta_i$ and the value is taken from $\theta_i^{-}$.

And the loss function becomes,

class Trainer(object):
#
#

Initialize

double_q_model has same weights as sample_model

    def __init__(self, gamma: float, model: Model, target_model: Model, double_q_model: Model):
#
#

learning rate

        self.learning_rate = tf.placeholder(dtype=tf.float32, shape=[], name="learning_rate")
#

model for $Q(s, a; \theta_i)$

        self.model = model
#

model for $Q(s, a; \theta_i^{-})$

        self.target_model = target_model
#

model for $Q(s, a; \theta_i)$. We need a copy of it because of the TensorFlow graph, but the parameters are the same

        self.double_q_model = double_q_model
#

we are treating observations as state $s$

        self.sampled_obs = self.model.obs
#

next state, $s’$

        self.sampled_next_obs = self.target_model.obs
#

sampled action $a$

        self.sampled_action = tf.placeholder(dtype=tf.int32, shape=[None],
                                             name="sampled_action")
#

sampled rewards $r$ sampled rewards

        self.sampled_reward = tf.placeholder(dtype=tf.float32, shape=[None],
                                             name="sampled_reward")
#

whether the game ended

        self.sampled_done = tf.placeholder(dtype=tf.float32, shape=[None],
                                           name="sampled_done")
#

weights of the samples

        self.sample_weights = tf.placeholder(dtype=tf.float32, shape=[None],
                                             name="sample_weights")
#

$Q(s, a; \theta_i)$

        self.q = tf.reduce_sum(self.model.q * tf.one_hot(self.sampled_action, 4),
                               axis=1)
#

$\mathop{\operatorname{argmax}}_{a’} Q(s’, a’; \theta_i)$

        best_next_action = tf.argmax(double_q_model.q,
                                     axis=1)
#

$Q\Big(s’, \mathop{\operatorname{argmax}}_{a’} Q(s’, a’; \theta_i); \theta_i^{-}\Big)$

        best_next_q = tf.reduce_sum(self.target_model.q * tf.one_hot(best_next_action, 4),
                                    axis=1)
#

mask out if game ended

        best_next_q_masked = (1. - self.sampled_done) * best_next_q
#

$r + \gamma Q\Big(s’, \mathop{\operatorname{argmax}}_{a’} Q(s’, a’; \theta_i); \theta_i^{-}\Big)$

        q_update = self.sampled_reward + gamma * best_next_q_masked
#

histograms for debugging

        tf.summary.histogram('q', self.q)
        tf.summary.histogram('q_update', q_update)
#

Temporal Difference $\delta$

        self.td_error = self.q - tf.stop_gradient(q_update)
#

take Huber loss instead of mean squared error

        error = huber_loss(self.td_error)
#

weighed error by priorities

        weighted_error = tf.reduce_mean(self.sample_weights * error)
#

apply clipped gradients

        adam = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        grads = adam.compute_gradients(weighted_error, var_list=self.model.params)
        for i, (grad, var) in enumerate(grads):
            if grad is not None:
                grads[i] = (tf.clip_by_norm(grad, 10), var)
        self.train_op = adam.apply_gradients(grads, name="apply_gradients")
#

update $\theta_i^{-}$ to $\theta_i$ periodically

        update_target_expr = []
        for var, var_target in zip(sorted(self.model.params, key=lambda v: v.name),
                                   sorted(self.target_model.params, key=lambda v: v.name)):
            update_target_expr.append(var_target.assign(var))
        self.update_target_op = tf.group(*update_target_expr)
#

histogram summaries

        self.summaries = tf.summary.merge_all()
#

Train model with samples

    def train(self, session: tf.Session, samples: Dict[str, np.ndarray], learning_rate: float):
#
        feed_dict = {self.sampled_obs: samples['obs'],
                     self.sampled_next_obs: samples['next_obs'],
                     self.sampled_action: samples['action'],
                     self.sampled_reward: samples['reward'],
                     self.sampled_done: samples['done'],
                     self.sample_weights: samples['weights'],
                     self.learning_rate: learning_rate}

        evals = [self.q,
                 self.td_error,
                 self.train_op]
#

return all results except train_op

        return session.run(evals, feed_dict=feed_dict)[:-1]
#

Update $\theta_i^{-}$

    def update_target(self, session: tf.Session):
#
        session.run(self.update_target_op, feed_dict={})
#

Generate summary for TensorBoard

    def summarize(self, session: tf.Session, samples):
#
        feed_dict = {self.sampled_obs: samples['obs'],
                     self.sampled_next_obs: samples['next_obs'],
                     self.sampled_action: samples['action'],
                     self.sampled_reward: samples['reward'],
                     self.sampled_done: samples['done'],
                     self.sample_weights: samples['weights']}

        return session.run(self.summaries, feed_dict=feed_dict)
#

Buffer for Prioritized Experience Replay

Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error.

We sample transition $i$ with probability, where $\alpha$ is a hyper-parameter that determines how much prioritization is used, with $\alpha = 0$ corresponding to uniform case.

We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where $\delta_i$ is the temporal difference for transition $i$.

We correct the bias introduced by prioritized replay by importance-sampling (IS) weights that fully compensates for when $\beta = 1$. We normalize weights by $1/\max_i w_i$ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase $\beta$ towards end of training.

Binary Segment Trees

We use binary segment trees to efficiently calculate $\sum_k^i p_k^\alpha$, the cumulative probability, which is needed to sample. We also use a binary segment tree to find $\min p_i^\alpha$, which is needed for $1/\max_i w_i$. We can also use a min-heap for this.

This is how a binary segment tree works for sum; it is similar for minimum. Let $x_i$ be the list of $N$ values we want to represent. Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row in the binary tree. That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$.

The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$ will have values of $x$. Every node keeps the sum of the two child nodes. So the root node keeps the sum of the entire array of values. The two children of the root node keep the sum of the first half of the array and the sum of the second half of the array, and so on.

Number of nodes in row $i$, This is equal to the sum of nodes in all rows above $i$. So we can use a single array $a$ to store the tree, where,

Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$. That is,

This way of maintaining binary trees is very easy to program. Note that we are indexing from 1.

class ReplayBuffer(object):
#
#

Initialize

    def __init__(self, capacity, alpha):
#

we use a power of 2 for capacity to make it easy to debug

        self.capacity = capacity
#

we refill the queue once it reaches capacity

        self.next_idx = 0
#

$\alpha$

        self.alpha = alpha
#

maintain segment binary trees to take sum and find minimum over a range

        self.priority_sum = [0 for _ in range(2 * self.capacity)]
        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]
#

current max priority, $p$, to be assigned to new transitions

        self.max_priority = 1.
#

arrays for buffer

        self.data = {
            'obs': np.zeros(shape=(capacity, 84, 84, 4), dtype=np.uint8),
            'action': np.zeros(shape=capacity, dtype=np.int32),
            'reward': np.zeros(shape=capacity, dtype=np.float32),
            'next_obs': np.zeros(shape=(capacity, 84, 84, 4), dtype=np.uint8),
            'done': np.zeros(shape=capacity, dtype=np.bool)
        }
#

size of the buffer

        self.size = 0
#

Add sample to queue

    def add(self, obs, action, reward, next_obs, done):
#
        idx = self.next_idx
#

store in the queue

        self.data['obs'][idx] = obs
        self.data['action'][idx] = action
        self.data['reward'][idx] = reward
        self.data['next_obs'][idx] = next_obs
        self.data['done'][idx] = done
#

increment head of the queue and calculate the size

        self.next_idx = (idx + 1) % self.capacity
        self.size = min(self.capacity, self.size + 1)
#

$p_i^\alpha$, new samples get max_priority

        priority_alpha = self.max_priority ** self.alpha
        self._set_priority_min(idx, priority_alpha)
        self._set_priority_sum(idx, priority_alpha)
#

Set priority in binary segment tree for minimum

    def _set_priority_min(self, idx, priority_alpha):
#
#

leaf of the binary tree

        idx += self.capacity
        self.priority_min[idx] = priority_alpha
#

update tree, by traversing along ancestors

        while idx >= 2:
            idx //= 2
            self.priority_min[idx] = min(self.priority_min[2 * idx],
                                         self.priority_min[2 * idx + 1])
#

Set priority in binary segment tree for sum

    def _set_priority_sum(self, idx, priority):
#
#

leaf of the binary tree

        idx += self.capacity
        self.priority_sum[idx] = priority
#

update tree, by traversing along ancestors

        while idx >= 2:
            idx //= 2
            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]
#

$\sum_k p_k^\alpha$

    def _sum(self):
#
        return self.priority_sum[1]
#

$\min_k p_k^\alpha$

    def _min(self):
#
        return self.priority_min[1]
#

Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$

    def find_prefix_sum_idx(self, prefix_sum):
#
#

start from the root

        idx = 1
        while idx < self.capacity:
#

if the sum of the left branch is higher than required sum

            if self.priority_sum[idx * 2] > prefix_sum:
#

go to left branch if the tree if the

                idx = 2 * idx
            else:
#

otherwise go to right branch and reduce the sum of left branch from required sum

                prefix_sum -= self.priority_sum[idx * 2]
                idx = 2 * idx + 1

        return idx - self.capacity
#

Sample from buffer

    def sample(self, batch_size, beta):
#
        samples = {
            'weights': np.zeros(shape=batch_size, dtype=np.float32),
            'indexes': np.zeros(shape=batch_size, dtype=np.int32)
        }
#

get samples

        for i in range(batch_size):
            p = random.random() * self._sum()
            idx = self.find_prefix_sum_idx(p)
            samples['indexes'][i] = idx
#

$\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$

        prob_min = self._min() / self._sum()
#

$\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$

        max_weight = (prob_min * self.size) ** (-beta)

        for i in range(batch_size):
            idx = samples['indexes'][i]
#

$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$

            prob = self.priority_sum[idx + self.capacity] / self._sum()
#

$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$

            weight = (prob * self.size) ** (-beta)
#

normalize by $\frac{1}{\max_i w_i}$, which also cancels off the $\frac{1}/{N}$ term

            samples['weights'][i] = weight / max_weight
#

get samples data

        for k, v in self.data.items():
            samples[k] = v[samples['indexes']]

        return samples
#

Update priorities

    def update_priorities(self, indexes, priorities):
#
        for idx, priority in zip(indexes, priorities):
            self.max_priority = max(self.max_priority, priority)
#

$p_i^\alpha$

            priority_alpha = priority ** self.alpha
            self._set_priority_min(idx, priority_alpha)
            self._set_priority_sum(idx, priority_alpha)
#

Is the buffer full

We only start sampling afte the buffer is full.

    def is_full(self):
#
        return self.capacity == self.size
#

Main class

This class runs the training loop. It initializes TensorFlow, handles logging and monitoring, and runs workers as multiple processes.

class Main(object):
#
#

Initialize

    def __init__(self):
#
#

Configurations

#

$\gamma$

        self.GAMMA = 0.99
#

learning rate

        self.LEARNING_RATE = 1e-4
#

total number of time steps

        self.TOTAL_TIME_STEPS = int(40e6)
#

number of workers

        self.WORKERS = 8
#

steps sampled on each update

        self.SAMPLE_STEPS = 4
#

number of samples collected per update

        self.SAMPLES_PER_UPDATE = self.WORKERS * self.SAMPLE_STEPS
#

number of training iterations

        self.TRAIN_ITERS = max(1, self.SAMPLES_PER_UPDATE // 4)
#

number of updates

        self.UPDATES = self.TOTAL_TIME_STEPS // self.SAMPLES_PER_UPDATE
#

size of mini batch for training

        self.MINI_BATCH_SIZE = 32
#

exploration as a function of time step

        self.EXPLORATION = PiecewiseSchedule(
            [
                (0, 1.0),
                (1e6, 0.1),
                (self.TOTAL_TIME_STEPS / 2, 0.01)
            ], outside_value=0.01)
#

update target network every 10000 time steps

        self.UPDATE_TARGET_NETWORK = 10000 // (4 * self.TRAIN_ITERS)
#

size of the replay buffer

        self.REPLAY_BUFFER_SIZE = 2 ** 14
#

$\alpha$ for replay buffer

        self.PRIORITIZED_REPLAY_ALPHA = 0.6
#

$\beta$ for replay buffer as a function of time steps

        self.PRIORITIZED_REPLAY_BETA = PiecewiseSchedule(
            [
                (0, 0.4),
                (self.TOTAL_TIME_STEPS, 1)
            ], outside_value=1)
#

initialize TensorFlow session

        Main._init_tf_session()
#

create game

        self.workers = [Worker(47 + i) for i in range(self.WORKERS)]
#

replay buffer

        self.replay_buffer = ReplayBuffer(self.REPLAY_BUFFER_SIZE, self.PRIORITIZED_REPLAY_ALPHA)
#

episode information for monitoring

        self.episode_reward = [0 for _ in range(self.WORKERS)]
        self.episode_length = [0 for _ in range(self.WORKERS)]
        self.episode_info = deque(maxlen=100)
        self.best_episode = {
            'reward': 0,
            'obs': None
        }
#

model for sampling, $Q(s, a; \theta_i)$

        self.sample_model = Model(scope="q_function",
                                  reuse=False,
                                  batch_size=self.WORKERS)
#

model for target, $Q(s, a; \theta_i^{-})$

        self.target_model = Model(scope="target_q_function",
                                  reuse=False, batch_size=self.MINI_BATCH_SIZE)
#

model for training with same parameters, $Q(s, a; \theta_i)$

        self.train_model = Model(scope="q_function",
                                 reuse=True,
                                 batch_size=self.MINI_BATCH_SIZE)
#

model for double Q-learning with same parameters, $Q(s, a; \theta_i)$

        self.double_q_model = Model(scope="q_function",
                                    reuse=True,
                                    batch_size=self.MINI_BATCH_SIZE,
                                    scaled_images=self.target_model.scaled_images)
#

trainer

        self.trainer = Trainer(self.GAMMA,
                               self.train_model,
                               self.target_model,
                               self.double_q_model)
#

last observation for each worker

        self.obs = np.zeros((self.WORKERS, 84, 84, 4), dtype=np.uint8)
        for worker in self.workers:
            worker.child.send(("reset", None))
        for i, worker in enumerate(self.workers):
            self.obs[i] = worker.child.recv()
#

create TensorFlow session

        self.session: tf.Session = tf.get_default_session()
#

initialize TensorFlow variables

        init_op = tf.global_variables_initializer()
        self.session.run(init_op)
#

Sample data from sample_model

    def sample(self, exploration):
#
#

sample SAMPLE_STEPS

        for t in range(self.SAMPLE_STEPS):
#

sample actions

            actions = self.sample_model.sample(self.session, self.obs, exploration)
#

run sampled actions on each worker

            for w, worker in enumerate(self.workers):
                worker.child.send(("step", actions[w]))
#

collect information from each worker

            for w, worker in enumerate(self.workers):
#

get results after executing the actions

                next_obs, reward, done, info = worker.child.recv()
                next_obs = np.asarray(next_obs, dtype=np.uint8)
#

add transition to replay buffer

                self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
#

update episode information

                self.episode_length[w] += 1
                self.episode_reward[w] += reward
                if done:
                    if self.best_episode['reward'] < self.episode_reward[w]:
                        self.best_episode['reward'] = self.episode_reward[w]
                        self.best_episode['obs'] = self.obs[w]

                    self.episode_info.append({
                        "reward": self.episode_reward[w],
                        "length": self.episode_length[w]})
                    self.episode_reward[w] = 0
                    self.episode_length[w] = 0
#

update current observation

                self.obs[w, ...] = next_obs
#

Train the model

    def train(self, beta: float):
#
        td_errors_all = []
        q_all = []
        for _ in range(self.TRAIN_ITERS):
#

sample from priority replay buffer

            samples = self.replay_buffer.sample(self.MINI_BATCH_SIZE, beta)
#

train network

            q, td_errors = self.trainer.train(session=self.session,
                                              samples=samples,
                                              learning_rate=self.LEARNING_RATE)
            td_errors_all.append(td_errors)
            q_all.append(q)
#

$p_i = |\delta_i| + \epsilon$

            new_priorities = np.abs(td_errors) + 1e-6
#

update replay buffer

            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
#

return averages for monitoring

        return np.mean(q_all), np.std(q_all), np.mean(np.abs(td_errors_all))
#

Get histogram summaries

    def summarize(self, beta: float):
#
        samples = self.replay_buffer.sample(self.MINI_BATCH_SIZE, beta)
#

train network and get $\delta_i$

        return self.trainer.summarize(session=self.session,
                                      samples=samples)
#

Run training loop

    def run_training_loop(self):
#
#

load saved model

        self._load_model()
#

summary writer for TensorBoard

        writer = self._create_summary_writer()
        histogram_writer = self._create_summary_writer_histogram()
#

copy to target network initially

        self.trainer.update_target(self.session)

        for update in range(self.UPDATES):
            time_start = time.time()
            time_step = update * self.SAMPLES_PER_UPDATE
#

$\epsilon$, exploration fraction

            exploration = self.EXPLORATION(time_step)
#

$\beta$ for priority replay

            beta = self.PRIORITIZED_REPLAY_BETA(time_step)
#

sample with current policy

            self.sample(exploration)

            if self.replay_buffer.is_full():
#

train the model

                q, q_std, td_error = self.train(beta)
#

periodically update target network

                if update % self.UPDATE_TARGET_NETWORK == 0:
                    self.trainer.update_target(self.session)
            else:
                td_error = q = q_std = 0.

            time_end = time.time()
#

frame rate

            fps = int(self.SAMPLES_PER_UPDATE / (time_end - time_start))
#

log every 10 updates

            if update % 10 == 0:
#

mean of last 100 episodes

                reward_mean, length_mean, best_obs_frame = self._get_mean_episode_info()
#

write summary info to the writer, and log to the screen

                Main._write_summary(writer, best_obs_frame, time_step, fps,
                                    float(reward_mean), float(length_mean),
                                    float(q), float(q_std), float(td_error),
                                    exploration, beta)

                if self.replay_buffer.is_full():
#

write histogram summaries

                    histogram_summary = self.summarize(beta)
                    histogram_writer.add_summary(histogram_summary, global_step=time_step)
#

save model once in a while

            if self.replay_buffer.is_full() and update % 100000 == 0:
                self._save_model()
#

Initialize TensorFlow session

    @staticmethod
    def _init_tf_session():
#
#

let TensorFlow decide where to run operations; I think it chooses the GPU for everything if you have one

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=True)
#

grow GPU memory as needed

        config.gpu_options.allow_growth = True

        tf.Session(config=config).__enter__()
#

set random seeds, but it doesn’t seem to produce identical results.

One explanation is that there would be floating point errors that get accumulated. But that is not possible, because, as far as I know, floating point calculations are deterministic even if they could be unpredictable (in small scale). However, there may be certain hardware optimizations that cause them to be random.

        np.random.seed(7)
        tf.set_random_seed(7)
#

Get average episode reward and episode length

    def _get_mean_episode_info(self):
#
        return (np.mean([info["reward"] for info in self.episode_info]),
                np.mean([info["length"] for info in self.episode_info]),
                self.best_episode['obs'])
#

Create summary writer

I used TensorBoard for monitoring. I made copies of programs when I was making changes, and logged them to different directories so that I can later see how each version worked.

    def _create_summary_writer(self) -> tf.summary.FileWriter:
#
        log_dir = str(PurePath("log/", Path(__file__).stem))
        if tf.gfile.Exists(log_dir):
            tf.gfile.DeleteRecursively(log_dir)

        return tf.summary.FileWriter(log_dir, self.session.graph)
#

Create summary writer for histograms

    def _create_summary_writer_histogram(self) -> tf.summary.FileWriter:
#
        log_dir = str(PurePath("log/", "histograms"))
        if tf.gfile.Exists(log_dir):
            tf.gfile.DeleteRecursively(log_dir)

        return tf.summary.FileWriter(log_dir, self.session.graph)
#

Get checkpoint path

Different paths based on source code file name

    @staticmethod
    def _get_checkpoint_path() -> (str, str):
#
        checkpoint_path = PurePath("checkpoints/", Path(__file__).stem)
        model_file = checkpoint_path / 'model'
        return str(checkpoint_path), str(model_file)
#

Write summary

    @staticmethod
    def _write_summary(writer: tf.summary.Summary,
                       best_obs_frame: Union[np.ndarray, None],
                       time_step: int,
                       fps: int,
                       reward_mean: float,
                       length_mean: float,
                       q: float,
                       q_std: float,
                       td_error: float,
                       exploration: float,
                       beta: float):
#
        print("{:4} {:3} {:.2f} {:.3f}".format(time_step, fps, reward_mean, length_mean))

        summary = tf.Summary()
#

add an image

        if best_obs_frame is not None:
            sample_observation = best_obs_frame
            observation_png = io.BytesIO()
            pyplot.imsave(observation_png, sample_observation, format='png', cmap='gray')

            observation_png = tf.Summary.Image(encoded_image_string=observation_png.getvalue(),
                                               height=84,
                                               width=84)
            summary.value.add(tag="observation", image=observation_png)
#

add scalars

        summary.value.add(tag="fps", simple_value=fps)
        summary.value.add(tag='q', simple_value=q)
        summary.value.add(tag='q_std', simple_value=q_std)
        summary.value.add(tag='td_error', simple_value=td_error)
        summary.value.add(tag="reward_mean", simple_value=reward_mean)
        summary.value.add(tag="length_mean", simple_value=length_mean)
        summary.value.add(tag="exploration", simple_value=exploration)
        summary.value.add(tag="beta", simple_value=beta)
#

write to file

        writer.add_summary(summary, global_step=time_step)
#

Destroy

Stop the workers

    def destroy(self):
#
        for worker in self.workers:
            worker.child.send(("close", None))
#

Load model

    def _load_model(self):
#
        checkpoint_path, model_file = Main._get_checkpoint_path()
        if tf.train.latest_checkpoint(checkpoint_path) is not None:
            saver = tf.train.Saver()
            saver.restore(self.session, model_file)
            print("Loaded model")
#

Save model

    def _save_model(self):
#
        checkpoint_path, model_file = Main._get_checkpoint_path()
        os.makedirs(checkpoint_path, exist_ok=True)
        saver = tf.train.Saver()
        saver.save(self.session, model_file)
        print("Saved model")
#

Run it

if __name__ == "__main__":
    m = Main()
    m.run_training_loop()
    m.destroy()