Lädt...

📚 Breaking down State-of-the-Art PPO Implementations in JAX


Nachrichtenbereich: 🔧 AI Nachrichten
🔗 Quelle: towardsdatascience.com

All the tricks and details you wish you knew about Proximal Policy Optimization

Photo by Lorenzo Herrera on Unsplash

Since its publication in a 2017 paper by OpenAI, Proximal Policy Optimization (PPO) is widely regarded as one of the state-of-the-art algorithms in Reinforcement Learning. Indeed, PPO has demonstrated remarkable performances across various tasks, from attaining superhuman performances in Dota 2 teams to solving a Rubik’s cube with a single robotic hand while maintaining three main advantages: simplicity, stability, and sample efficiency.

However, implementing RL algorithms from scratch is notoriously difficult and error-prone, given the numerous error sources and implementation details to be aware of.

In this article, we’ll focus on breaking down the clever tricks and programming concepts used in a popular implementation of PPO in JAX. Specifically, we’ll focus on the implementation featured in the PureJaxRL library, developed by Chris Lu.

Disclaimer: Rather than diving too deep into theory, this article covers the practical implementation details and (numerous) tricks used in popular versions of PPO. Should you require any reminders about PPO’s theory, please refer to the “references” section at the end of this article. Additionally, all the code (minus the added comments) is copied directly from PureJaxRL for pedagogical purposes.

GitHub - luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations

Actor-Critic Architectures

Proximal Policy Optimization is categorized within the policy gradient family of algorithms, a subset of which includes actor-critic methods. The designation ‘actor-critic’ reflects the dual components of the model:

  • The actor network creates a distribution over actions given the current state of the environment and returns an action sampled from this distribution. Here, the actor network comprises three dense layers separated by two activation layers (either ReLU or hyperbolic tangeant) and a final categorical layer applying the softmax function to the computed distribution.
  • The critic network estimates the value function of the current state, in other words, how good a particular action is at a given time. Its architecture is almost identical to the actor network, except for the final softmax layer. Indeed, the critic network doesn’t apply any activation function to the final dense layer outputs as it performs a regression task.
Actor-critic architecture, as defined in PureJaxRL (illustration made by the author)

Additionally, this implementation pays particular attention to weight initialization in dense layers. Indeed, all dense layers are initialized by orthogonal matrices with specific coefficients. This initialization strategy has been shown to preserve the gradient norms (i.e. scale) during forward passes and backpropagation, leading to smoother convergence and limiting the risks of vanishing or exploding gradients[1].

Orthogonal initialization is used in conjunction with specific scaling coefficients:

  • Square root of 2: Used for the first two dense layers of both networks, this factor aims to compensate for the variance reduction induced by ReLU activations (as inputs with negative values are set to 0). For the tanh activation, the Xavier initialization is a popular alternative[2].
  • 0.01: Used in the last dense layer of the actor network, this factor helps to minimize the initial differences in logit values before applying the softmax function. This will reduce the difference in action probabilities and thus encourage early exploration.
  • 1: As the critic network is performing a regression task, we do not scale the initial weights.
https://medium.com/media/6f187b7361f6450202f999af5aa17df0/href

Training Loop

The training loop is divided into 3 main blocks that share similar coding patterns, taking advantage of Jax’s functionalities:

  1. Trajectory collection: First, we’ll interact with the environment for a set number of steps and collect observations and rewards.
  2. Generalized Advantage Estimation (GAE): Then, we’ll approximate the expected return for each trajectory by computing the generalized advantage estimation.
  3. Update step: Finally, we’ll compute the gradient of the loss and update the network parameters via gradient descent.

Before going through each block in detail, here’s a quick reminder about the jax.lax.scan function that will show up multiple times throughout the code:

Jax.lax.scan

A common programming pattern in JAX consists of defining a function that acts on a single sample and using jax.lax.scan to iteratively apply it to elements of a sequence or an array, while carrying along some state.
For instance, we’ll apply it to the step function to step our environment N consecutive times while carrying the new state of the environment through each iteration.

In pure Python, we could proceed as follows:

trajectories = []

for step in range(n_steps):
action = actor_network(obs)
obs, state, reward, done, info = env.step(action, state)
trajectories.append(tuple(obs, state, reward, done, info))

However, we avoid writing such loops in JAX for performance reasons (as pure Python loops are incompatible with JIT compilation). The alternative is jax.lax.scan which is equivalent to:

def scan(f, init, xs, length=None):
"""Example provided in the JAX documentation."""
if xs is None:
xs = [None] * length

carry = init
ys = []
for x in xs:
# apply function f to current state
# and element x
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)

Using jax.lax.scan is more efficient than a Python loop because it allows the transformation to be optimized and executed as a single compiled operation rather than interpreting each loop iteration at runtime.

We can see that the scan function takes multiple arguments:

  • f: A function that is applied at each step. It takes the current state and an element of xs (or a placeholder if xs is None) and returns the updated state and an output.
  • init: The initial state that f will use in its first invocation.
  • xs: A sequence of inputs that are iteratively processed by f. If xs is None, the function simulates a loop with length iterations using None as the input for each iteration.
  • length: Specifies the number of iterations if xs is None, ensuring that the function can still operate without explicit inputs.

Additionally, scan returns:

  • carry: The final state after all iterations.
  • ys: An array of outputs corresponding to each step’s application of f, stacked for easy analysis or further processing.

Finally, scan can be used in combination with vmap to scan a function over multiple dimensions in parallel. As we’ll see in the next section, this allows us to interact with several environments in parallel to collect trajectories rapidly.

Illustration of vmap, scan, and scan + vmap in the context of the step function (made by the author)

1. Trajectory Collection

As mentioned in the previous section, the trajectory collection block consists of a step function scanned across N iterations. This step function successively:

  • Selects an action using the actor network
  • Steps the environment
  • Stores transition data in a transition tuple
  • Stores the model parameters, the environment state, the current observation, and rng keys in a runner_state tuple
  • Returns runner_state and transition

Scanning this function returns the latest runner_state and traj_batch, an array of transition tuples. In practice, transitions are collected from multiple environments in parallel for efficiency as indicated by the use of jax.vmap(env.step, …)(for more details about vectorized environments and vmap, refer to my previous article).

https://medium.com/media/ee2875a6bdb941f399155c6c0904c4c0/href

2. Generalized Advantage Estimation

After collecting trajectories, we need to compute the advantage function, a crucial component of PPO’s loss function. The advantage function measures how much better a specific action is compared to the average action in a given state:

Where Gt is the return at time t and V(St) is the value of state s at time t.

As the return is generally unknown, we have to approximate the advantage function. A popular solution is generalized advantage estimation[3], defined as follows:

With γ the discount factor, λ a parameter that controls the trade-off between bias and variance in the estimate, and δt the temporal difference error at time t:

As we can see, the value of the GAE at time t depends on the GAE at future timesteps. Therefore, we compute it backward, starting from the end of a trajectory. For example, for a trajectory of 3 transitions, we would have:

Which is equivalent to the following recursive form:

Once again, we use jax.lax.scan on the trajectory batch (this time in reverse order) to iteratively compute the GAE.

https://medium.com/media/36dd1edacd3ecf53a1d203f46999f828/href

Note that the function returns advantages + traj_batch.value as a second output, which is equivalent to the return according to the first equation of this section.

3. Update step

The final block of the training loop defines the loss function, computes its gradient, and performs gradient descent on minibatches. Similarly to previous sections, the update step is an arrangement of several functions in a hierarchical order:

def _update_epoch(update_state, unused):
"""
Scans update_minibatch over shuffled and permuted
mini batches created from the trajectory batch.
"""

def _update_minbatch(train_state, batch_info):
"""
Wraps loss_fn and computes its gradient over the
trajectory batch before updating the network parameters.
"""
...

def _loss_fn(params, traj_batch, gae, targets):
"""
Defines the PPO loss and computes its value.
"""
...

Let’s break them down one by one, starting from the innermost function of the update step.

3.1 Loss function

This function aims to define and compute the PPO loss, originally defined as:

Where:

However, the PureJaxRL implementation features some tricks and differences compared to the original PPO paper[4]:

  • The paper defines the PPO loss in the context of gradient ascent whereas the implementation performs gradient descent. Therefore, the sign of each loss component is reversed.
  • The value function term is modified to include an additional clipped term. This could be seen as a way to make the value function updates more conservative (as for the clipped surrogate objective):
  • The GAE is standardized.

Here’s the complete loss function:

https://medium.com/media/46f2d043c070808a7da5d97342afe905/href

3.2 Update Minibatch

The update_minibatch function is essentially a wrapper around loss_fn used to compute its gradient over the trajectory batch and update the model parameters stored in train_state.

https://medium.com/media/a6798898fe3ef8800e8354098b03aaa8/href

3.3 Update Epoch

Finally, update_epoch wraps update_minibatch and applies it on minibatches. Once again, jax.lax.scan is used to apply the update function on all minibatches iteratively.

https://medium.com/media/725498766e43cfe26f21b1961bb49d01/href

Conclusion

From there, we can wrap all of the previous functions in an update_step function and use scan one last time for N steps to complete the training loop.

A global view of the training loop would look like this:

https://medium.com/media/8408eb84bc2b05ecd9c2ae8ebebc8c4e/href

We can now run a fully compiled training loop using jax.jit(train(rng)) or even train multiple agents in parallel using jax.vmap(train(rng)).

There we have it! We covered the essential building blocks of the PPO training loop as well as common programming patterns in JAX.

To go further, I highly recommend reading the full training script in detail and running example notebooks on the PureJaxRL repository.

GitHub - luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations

Thank you very much for your support, until next time 👋

References:

Full training script, PureJaxRL, Chris Lu, 2023

[1] Explaining and illustrating orthogonal initialization for recurrent neural networks, Smerity, 2016

[2] Initializing neural networks, DeepLearning.ai

[3] Generalized Advantage Estimation in Reinforcement Learning, Siwei Causevic, Towards Data Science, 2023

[4] Proximal Policy Optimization Algorithms, Schulman et Al., OpenAI, 2017


Breaking down State-of-the-Art PPO Implementations in JAX was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

...

🕵️ Apache CXF 3.1.14/3.2.1 JAX-WS/JAX-RS Attachment denial of service


📈 38.57 Punkte
🕵️ Sicherheitslücken

🕵️ Apache CXF 3.1.14/3.2.1 JAX-WS/JAX-RS Attachment Denial of Service


📈 38.57 Punkte
🕵️ Sicherheitslücken

🔧 Implementing PPO for Cartpole-v1


📈 26.13 Punkte
🔧 Programmierung

📰 Understanding the Mathematics of PPO in Reinforcement Learning


📈 26.13 Punkte
🔧 AI Nachrichten

🔧 ORPO, DPO, and PPO: Optimizing Models for Human Preferences


📈 26.13 Punkte
🔧 Programmierung

📰 The Tournament of Reinforcement Learning: DDPG, SAC, PPO, I2A, Decision Transformer


📈 26.13 Punkte
🔧 AI Nachrichten

📰 Understand REINFORCE, Actor-Critic and PPO in one go


📈 26.13 Punkte
🔧 AI Nachrichten

📰 Proximal Policy Optimization (PPO): The Key to LLM Alignment


📈 26.13 Punkte
🔧 AI Nachrichten

📰 Rethinking the Role of PPO in RLHF


📈 26.13 Punkte
🔧 AI Nachrichten

📰 Proximal Policy Optimization (PPO) Explained


📈 26.13 Punkte
🔧 AI Nachrichten

🔧 Database Sharding: Breaking It Down Without Breaking Your DB


📈 25.4 Punkte
🔧 Programmierung

🐧 What is the state of fractional scaling? Are there currently any working fractional scaling implementations?


📈 24.58 Punkte
🐧 Linux Tipps

📰 Breaking down the state of authentication


📈 22.08 Punkte
📰 IT Security Nachrichten

📰 Breaking Down Joe Biden’s $10B Cybersecurity ‘Down Payment’


📈 21.33 Punkte
📰 IT Security Nachrichten

🎥 #TeamUSA and #breaking are breaking out! #TrendingNow


📈 19.65 Punkte
🎥 Video | Youtube

🔧 ** Breaking Code: Buenas Prácticas de Desarrollo de Software a través de Breaking Bad**⚗️


📈 19.65 Punkte
🔧 Programmierung

📰 Record-breaking number of record-breaking DDoS attacks confirmed


📈 19.65 Punkte
📰 IT Security Nachrichten

🔧 Managing Form State in React: Separate State Variables vs. Single State Object


📈 19.52 Punkte
🔧 Programmierung

🔧 Mastering State Management in React: App State vs Component State Explained


📈 19.52 Punkte
🔧 Programmierung

🎥 JAX on the Web with TensorFlow.js


📈 19.29 Punkte
🎥 Künstliche Intelligenz Videos

🎥 Get started with JAX using NNX 🦾🧠


📈 19.29 Punkte
🎥 Video | Youtube

📰 AI Model Training with JAX


📈 19.29 Punkte
🔧 AI Nachrichten

🕵️ jax guestbook 3.1/3.31 jax_guestbook.php page cross site scripting


📈 19.29 Punkte
🕵️ Sicherheitslücken

📰 Jax: Der Ohrwurm »Victoria's Secret« lässt Sie »Layla« endlich vergessen


📈 19.29 Punkte
📰 IT Nachrichten

🎥 Introduction to JAX with Pallas


📈 19.29 Punkte
🎥 Video | Youtube

🕵️ Jax Guestbook 3.1/3.3.1 information disclosure [CVE-2005-4880]


📈 19.29 Punkte
🕵️ Sicherheitslücken

🔧 W-JAX 2021: Frühbucher-Aktion bis Donnerstag, 24. Juni


📈 19.29 Punkte
🔧 Programmierung

🎥 How to use custom JAX kernels with Pallas


📈 19.29 Punkte
🎥 Video | Youtube

📰 Google DeepMind Releases Penzai: A JAX Library for Building, Editing, and Visualizing Neural Networks


📈 19.29 Punkte
🔧 AI Nachrichten

🕵️ Jax Calendar v1.0 - v1.34 XSS


📈 19.29 Punkte
🕵️ Sicherheitslücken

matomo