Proximal Policy Optimization - PPO in PyTorch

This is a minimalistic implementation of Proximal Policy Optimization - PPO clipped version for Atari Breakout game on OpenAI Gym. This has less than 250 lines of code. It runs the game environments on multiple processes to sample efficiently. Advantages are calculated using Generalized Advantage Estimation.

The code for this tutorial is available at Github labml/rl_samples. And the web version of the tutorial is available on my blog.

If someone reading this has any questions or comments please find me on Twitter, @vpj.

18import multiprocessing
19import multiprocessing.connection
20from typing import Dict, List
21
22import cv2
23import gym
24import numpy as np
25import torch
26from labml import monit, tracker, logger, experiment
27from torch import nn
28from torch import optim
29from torch.distributions import Categorical
30from torch.nn import functional as F
31
32if torch.cuda.is_available():
33    device = torch.device("cuda:1")
34else:
35    device = torch.device("cpu")

Game environment

This is a wrapper for OpenAI gym game environment. We do a few things here:

  1. Apply the same action on four frames and get the last frame
  2. Convert observation frames to gray and scale it to (84, 84)
  3. Stack four frames of the last four actions
  4. Add episode information (total reward for the entire episode) for monitoring
  5. Restrict an episode to a single life (game has 5 lives, we reset after every single life)

Observation format

Observation is tensor of size (4, 84, 84). It is four frames (images of the game screen) stacked on first axis. i.e, each channel is a frame.

39class Game:
57    def __init__(self, seed: int):

create environment

59        self.env = gym.make('BreakoutNoFrameskip-v4')
60        self.env.seed(seed)

tensor for a stack of 4 frames

63        self.obs_4 = np.zeros((4, 84, 84))

keep track of the episode rewards

66        self.rewards = []

and number of lives left

68        self.lives = 0

Step

Executes action for 4 time steps and returns a tuple of (observation, reward, done, episode_info).

  • observation: stacked 4 frames (this frame and frames for last 3 actions)
  • reward: total reward while the action was executed
  • done: whether the episode finished (a life lost)
  • episode_info: episode information if completed
70    def step(self, action):
81        reward = 0.
82        done = None

run for 4 steps

86        for i in range(4):

execute the action in the OpenAI Gym environment

88            obs, r, done, info = self.env.step(action)
89
90            reward += r

get number of lives left

93            lives = self.env.unwrapped.ale.lives()

reset if a life is lost

95            if lives < self.lives:
96                done = True
97                break

Transform the last observation to (84, 84)

100        obs = self._process_obs(obs)

maintain rewards for each step

103        self.rewards.append(reward)
104
105        if done:

if finished, set episode information if episode is over, and reset

107            episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
108            self.reset()
109        else:
110            episode_info = None

get the max of last two frames obs = self.obs_2_max.max(axis=0)

113

push it to the stack of 4 frames

115            self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
116            self.obs_4[-1] = obs
117
118        return self.obs_4, reward, done, episode_info

Reset environment

Clean up episode info and 4 frame stack

120    def reset(self):

reset OpenAI Gym environment

127        obs = self.env.reset()

reset caches

130        obs = self._process_obs(obs)
131        for i in range(4):
132            self.obs_4[i] = obs
133        self.rewards = []
134
135        self.lives = self.env.unwrapped.ale.lives()
136
137        return self.obs_4

Process game frames

Convert game frames to gray and rescale to 84x84

139    @staticmethod
140    def _process_obs(obs):
145        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
146        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
147        return obs

Worker Process

Each worker process runs this method

150def worker_process(remote: multiprocessing.connection.Connection, seed: int):

create game

158    game = Game(seed)

wait for instructions from the connection and execute them

161    while True:
162        cmd, data = remote.recv()
163        if cmd == "step":
164            remote.send(game.step(data))
165        elif cmd == "reset":
166            remote.send(game.reset())
167        elif cmd == "close":
168            remote.close()
169            break
170        else:
171            raise NotImplementedError

Creates a new worker and runs it in a separate process.

174class Worker:
179    def __init__(self, seed):
180        self.child, parent = multiprocessing.Pipe()
181        self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
182        self.process.start()

Model

185class Model(nn.Module):
190    def __init__(self):
191        super().__init__()

The first convolution layer takes a 84x84 frame and produces a 20x20 frame

195        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)

The second convolution layer takes a 20x20 frame and produces a 9x9 frame

199        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)

The third convolution layer takes a 9x9 frame and produces a 7x7 frame

203        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

208        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)

A fully connected layer to get logits for $\pi$

211        self.pi_logits = nn.Linear(in_features=512, out_features=4)

A fully connected layer to get value function

214        self.value = nn.Linear(in_features=512, out_features=1)
216    def forward(self, obs: torch.Tensor):
217        h = F.relu(self.conv1(obs))
218        h = F.relu(self.conv2(h))
219        h = F.relu(self.conv3(h))
220        h = h.reshape((-1, 7 * 7 * 64))
221
222        h = F.relu(self.lin(h))
223
224        pi = Categorical(logits=self.pi_logits(h))
225        value = self.value(h).reshape(-1)
226
227        return pi, value
230def obs_to_torch(obs: np.ndarray) -> torch.Tensor:

scale to [0, 1]

232    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
235class Main:
236    def __init__(self):

Configurations

238

$\gamma$ and $\lambda$ for advantage calculation

240        self.gamma = 0.99
241        self.lamda = 0.95

number of updates

244        self.updates = 10000

number of epochs to train the model with sampled data

246        self.epochs = 4

number of worker processes

248        self.n_workers = 8

number of steps to run on each process for a single update

250        self.worker_steps = 128

number of mini batches

252        self.n_mini_batch = 4

total number of samples for a single update

254        self.batch_size = self.n_workers * self.worker_steps

size of a mini batch

256        self.mini_batch_size = self.batch_size // self.n_mini_batch
257        assert (self.batch_size % self.n_mini_batch == 0)

Initialize

260

create workers

262        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

initialize tensors for observations

265        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
266        for worker in self.workers:
267            worker.child.send(("reset", None))
268        for i, worker in enumerate(self.workers):
269            self.obs[i] = worker.child.recv()

model for sampling

272        self.model = Model().to(device)

optimizer

275        self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)

Sample data with current policy

277    def sample(self) -> (Dict[str, np.ndarray], List):
279        rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
280        actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
281        done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
282        obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
283        log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
284        values = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)

sample worker_steps from each worker

288        for t in range(self.worker_steps):
289            with torch.no_grad():

self.obs keeps track of the last observation from each worker, which is the input for the model to sample the next action

292                obs[:, t] = self.obs

sample actions from $\pi_{\theta_{OLD}}$ for each worker; this returns arrays of size n_workers

295                pi, v = self.model(obs_to_torch(self.obs))
296                values[:, t] = v.cpu().numpy()
297                a = pi.sample()
298                actions[:, t] = a.cpu().numpy()
299                log_pis[:, t] = pi.log_prob(a).cpu().numpy()

run sampled actions on each worker

302            for w, worker in enumerate(self.workers):
303                worker.child.send(("step", actions[w, t]))
304
305            for w, worker in enumerate(self.workers):

get results after executing the actions

307                self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()

collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works. We also add a game frame to it for monitoring.

313                if info:
314                    tracker.add('reward', info['reward'])
315                    tracker.add('length', info['length'])

calculate advantages

318        advantages = self._calc_advantages(done, rewards, values)
319        samples = {
320            'obs': obs,
321            'actions': actions,
322            'values': values,
323            'log_pis': log_pis,
324            'advantages': advantages
325        }

samples are currently in [workers, time] table, we should flatten it

329        samples_flat = {}
330        for k, v in samples.items():
331            v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
332            if k == 'obs':
333                samples_flat[k] = obs_to_torch(v)
334            else:
335                samples_flat[k] = torch.tensor(v, device=device)
336
337        return samples_flat

Calculate advantages

$\hat{A_t^{(1)}}$ is high bias, low variance whilst $\hat{A_t^{(\infty)}}$ is unbiased, high variance.

We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance. This is called Generalized Advantage Estimation. We set $w_k = \lambda^{k-1}$, this gives clean calculation for $\hat{A_t}$

339    def _calc_advantages(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:

advantages table

372        advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
373        last_advantage = 0

$V(s_{t+1})$

376        _, last_value = self.model(obs_to_torch(self.obs))
377        last_value = last_value.cpu().data.numpy()
378
379        for t in reversed(range(self.worker_steps)):

mask if episode completed after step $t$

381            mask = 1.0 - done[:, t]
382            last_value = last_value * mask
383            last_advantage = last_advantage * mask

$\delta_t$

385            delta = rewards[:, t] + self.gamma * last_value - values[:, t]

$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$

388            last_advantage = delta + self.gamma * self.lamda * last_advantage

note that we are collecting in reverse order. My initial code was appending to a list and I forgot to reverse it later. It took me around 4 to 5 hours to find the bug. The performance of the model was improving slightly during initial runs, probably because the samples are similar.

397            advantages[:, t] = last_advantage
398
399            last_value = values[:, t]
400
401        return advantages

Train the model based on samples

403    def train(self, samples: Dict[str, torch.Tensor], learning_rate: float, clip_range: float):

It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it.

413        for _ in range(self.epochs):

shuffle for each epoch

415            indexes = torch.randperm(self.batch_size)

for each mini batch

418            for start in range(0, self.batch_size, self.mini_batch_size):

get mini batch

420                end = start + self.mini_batch_size
421                mini_batch_indexes = indexes[start: end]
422                mini_batch = {}
423                for k, v in samples.items():
424                    mini_batch[k] = v[mini_batch_indexes]

train

427                loss = self._calc_loss(clip_range=clip_range,
428                                       samples=mini_batch)

compute gradients

431                for pg in self.optimizer.param_groups:
432                    pg['lr'] = learning_rate
433                self.optimizer.zero_grad()
434                loss.backward()
435                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
436                self.optimizer.step()

Normalize advantage function

438    @staticmethod
439    def _normalize(adv: torch.Tensor):
441        return (adv - adv.mean()) / (adv.std() + 1e-8)

PPO Loss

We want to maximize policy reward where $r$ is the reward, $\pi$ is the policy, $\tau$ is a trajectory sampled from policy, and $\gamma$ is the discount factor between $[0, 1]$.

So,

Define discounted-future state distribution,

Then,

Importance sampling $a$ from $\pi_{\theta_{OLD}}$,

Then we assume $d^\pi_\theta(s)$ and $d^\pi_{\theta_{OLD}}(s)$ are similar. The error we introduce to $J(\pi_\theta) - J(\pi_{\theta_{OLD}})$ by this assumtion is bound by the KL divergence between $\pi_\theta$ and $\pi_{\theta_{OLD}}$. Constrained Policy Optimization shows the proof of this. I haven’t read it.

443    def _calc_loss(self, samples: Dict[str, torch.Tensor], clip_range: float) -> torch.Tensor:

$R_t$ returns sampled from $\pi_{\theta_{OLD}}$

543        sampled_return = samples['values'] + samples['advantages']

$\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$, where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$. Refer to sampling function in Main class below for the calculation of $\hat{A}_t$.

549        sampled_normalized_advantage = self._normalize(samples['advantages'])

Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; we are treating observations as state

553        pi, value = self.model(samples['obs'])

Policy

556

$-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$

558        log_pi = pi.log_prob(samples['actions'])

ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$; this is different from rewards $r_t$.

562        ratio = torch.exp(log_pi - samples['log_pis'])

The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull $\pi_\theta$ towards $\pi_{\theta_{OLD}}$ if the ratio is not between $1 - \epsilon$ and $1 + \epsilon$. This keeps the KL divergence between $\pi_\theta$ and $\pi_{\theta_{OLD}}$ constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn’t recover because we are sampling from a bad policy.

Using the normalized advantage $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$ introduces a bias to the policy gradient estimator, but it reduces variance a lot.

589        clipped_ratio = ratio.clamp(min=1.0 - clip_range,
590                                    max=1.0 + clip_range)
591        policy_reward = torch.min(ratio * sampled_normalized_advantage,
592                                  clipped_ratio * sampled_normalized_advantage)
593        policy_reward = policy_reward.mean()

Entropy Bonus

596

$\mathcal{L}^{EB}(\theta) = \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$

599        entropy_bonus = pi.entropy()
600        entropy_bonus = entropy_bonus.mean()

Value

603

Clipping makes sure the value function $V_\theta$ doesn’t deviate significantly from $V_{\theta_{OLD}}$.

617        clipped_value = samples['values'] + (value - samples['values']).clamp(min=-clip_range,
618                                                                              max=clip_range)
619        vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
620        vf_loss = 0.5 * vf_loss.mean()

$\mathcal{L}^{CLIP+VF+EB} (\theta) = \mathcal{L}^{CLIP} (\theta) - c_1 \mathcal{L}^{VF} (\theta) + c_2 \mathcal{L}^{EB}(\theta)$

625

we want to maximize $\mathcal{L}^{CLIP+VF+EB}(\theta)$ so we take the negative of it as the loss

628        loss = -(policy_reward - 0.5 * vf_loss + 0.01 * entropy_bonus)

for monitoring

631        approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
632        clip_fraction = (abs((ratio - 1.0)) > clip_range).to(torch.float).mean()
633
634        tracker.add({'policy_reward': policy_reward,
635                     'vf_loss': vf_loss,
636                     'entropy_bonus': entropy_bonus,
637                     'kl_div': approx_kl_divergence,
638                     'clip_fraction': clip_fraction})
639
640        return loss

Run training loop

642    def run_training_loop(self):

last 100 episode information

648        tracker.set_queue('reward', 100, True)
649        tracker.set_queue('length', 100, True)
650
651        for update in monit.loop(self.updates):
652            progress = update / self.updates

decreasing learning_rate and clip_range $\epsilon$

655            learning_rate = 2.5e-4 * (1 - progress)
656            clip_range = 0.1 * (1 - progress)

sample with current policy

659            samples = self.sample()

train the model

662            self.train(samples, learning_rate, clip_range)

write summary info to the writer, and log to the screen

665            tracker.save()
666            if (update + 1) % 1_000 == 0:
667                logger.log()

Destroy

Stop the workers

669    def destroy(self):
674        for worker in self.workers:
675            worker.child.send(("close", None))

Run it

679if __name__ == "__main__":
680    experiment.create()
681    m = Main()
682    experiment.start()
683    m.run_training_loop()
684    m.destroy()
685
686