Skip to content

MichalBortkiewicz/JaxGCRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

JaxGCRL

Installation | Quick Start | Environments | Baselines | Citation


Accelerating Goal-Conditioned RL Algorithms and Research

We provide blazingly fast goal-conditioned environments based on MJX and BRAX for quick experimentation with goal-conditioned self-supervised reinforcement learning.

  • Blazingly Fast Training - Train 10 million environment steps in 10 minutes on a single GPU, up to 22$\times$ faster than prior implementations.
  • Comprehensive Benchmarking - Includes 10+ diverse environments and multiple pre-implemented baselines for out-of-the-box evaluation.
  • Modular Implementation - Designed for clarity and scalability, allowing for easy modification of algorithms.

Installation πŸ“‚

The environment can be set up from the environment.yml file.

conda env create -f environment.yml

Quick Start πŸš€

To verify the installation, run a test experiment using the ./scripts/train.sh file:

chmod +x ./scripts/train.sh; ./scripts/train.sh

Note

If you haven't yet configured wandb, you may be prompted to log in.

Specific configs can be specified as in scripts/train.sh. The descriptions of the available flags are in utils.py:create_parser(). Common flags you may want to change include:

  • env=...: replace "ant" with any environment name. See utils.py:create_env() for names.
  • Removing --log_wandb: omits logging, if you don't want to use a wandb account.
  • --total_env_steps: shorter or longer runs.
  • --num_envs: based on how many environments your GPU memory allows.
  • --contrastive_loss_fn, --energy_fn, --h_dim, --n_hidden, etc.: algorithmic and architectural changes.

Environment Interaction

Environments can be controlled with the reset and step functions. These methods return a state object, which is a dataclass containing the following fields:

state.pipeline_state: current, internal state of the environment
state.obs: current observation
state.done: flag indicating if the agent reached the goal
state.metrics: agent performance metrics
state.info: additional info

The following code demonstrates how to interact with the environment:

import jax
from utils import create_env

key = jax.random.PRNGKey(0)

# Initialize the environment
env = create_env('ant')

# Use JIT compilation to make environment's reset and step functions execute faster
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

NUM_STEPS = 1000

# Reset the environment and obtain the initial state
state = jit_env_reset(key)

# Simulate the environment for a fixed number of steps
for _ in range(NUM_STEPS):
    # Generate a random action
    key, key_act = jax.random.split(key, 2)
    random_action = jax.random.uniform(key_act, shape=(8,), minval=-1, maxval=1)
    
    # Perform an environment step with the generated action
    state = jit_env_step(state, random_action)

Wandb support πŸ“ˆ

We strongly recommend using Wandb for tracking and visualizing results (Wandb support). Enable Wandb logging with the --log_wandb flag. The following flags are also available to organize experiments:

  • --project_name
  • --group_name
  • --exp_name

The --log_wandb flag logs metrics to Wandb. By default, metrics are logged to a CSV.

  1. Run example sweep:
wandb sweep --project exemplary_sweep ./scripts/sweep.yml
  1. Then run wandb agent with :
wandb agent <previous_command_output>

We also render videos of the learned policies as wandb artifacts.

Environments 🌎

We currently support a variety of continuous control environments:

  • Locomotion: Half-Cheetah, Ant, Humanoid
  • Locomotion + task: AntMaze, AntBall (AntSoccer), AntPush, HumanoidMaze
  • Simple arm: Reacher, Pusher, Pusher 2-object
  • Manipulation: Reach, Grasp, Push (easy/hard), Binpick (easy/hard)
Environment Env name Code
Reacher reacher link
Half Cheetah cheetah link
Pusher pusher_easy
pusher_hard
link
Ant ant link
Ant Maze ant_u_maze
ant_big_maze
ant_hardest_maze
link
Ant Soccer ant_ball link
Ant Push ant_push link
Humanoid humanoid link
Humanoid Maze humanoid_u_maze
humanoid_big_maze
humanoid_hardest_maze
link
Arm Reach arm_reach link
Arm Grasp arm_grasp link
Arm Push arm_push_easy
arm_push_hard
link
Arm Binpick arm_binpick_easy
arm_binpick_hard
link

To add new environments: add an XML to envs/assets, add a python environment file in envs, and register the environment name in utils.py.

Baselines πŸ€–

We currently support following algorithms:

Algorithm How to run Code
CRL python training.py ... link
SAC python training_sac.py ... link
SAC + HER python training_sac.py ... --use_her link
TD3 python training_td3.py ... link
TD3 + HER python training_td3.py ... --use_her link
PPO python training_ppo.py ... link

Code Structure πŸ“

The core structure of the codebase is as follows:


β”œβ”€β”€ src: Algorithm code (training, network, replay buffer, etc.)
β”‚   β”œβ”€β”€ train.py: Main file. Defines energy functions + losses, and networks. Collects trajectories, trains networks, runs evaluations.
β”‚   β”œβ”€β”€ replay_buffer.py: Contains replay buffer, including logic for state, action, and goal sampling for training.
β”‚   └── evaluator.py: Runs evaluation and collects metrics.
β”œβ”€β”€ envs: Environments (python files and XMLs)
β”‚   β”œβ”€β”€ ant.py, humanoid.py, ...: Most environments are here.
β”‚   β”œβ”€β”€ assets: Contains XMLs for environments.
β”‚   └── manipulation: Contains all manipulation environments.
β”œβ”€β”€ scripts/train.sh: Modify to choose environment and hyperparameters.
β”œβ”€β”€ utils.py: Logic for script argument processing, rendering, environment names, etc.
└── training.py: Interface file that processes script arguments, calls train.py, initializes wandb, etc.

The architecture can be adjusted in networks.py.

Contributing πŸ—οΈ

Help us build JaxGCRL into the best possible tool for the GCRL community. Reach out and start contributing or just add an Issue/PR!

  • Add Franka robot arm environments. [Done by SimpleGeometry]
  • Get around 70% success rate on Ant Big Maze task. [Done by RajGhugare19]
  • Add more complex versions of Ant Sokoban.
  • Integrate environments:
    • Overcooked
    • Hanabi
    • Rubik's cube
    • Sokoban

To run tests (make sure you have access to a GPU):

python -m pytest 

Citing JaxGCRL πŸ“œ

If you use JaxGCRL in your work, please cite us as follows:
@inproceedings{bortkiewicz2025accelerating,
    author    = {Bortkiewicz, Micha\l{} and Pa\l{}ucki, W\l{}adek and Myers, Vivek and
                 Dziarmaga, Tadeusz and Arczewski, Tomasz and Kuci\'{n}ski, \L{}ukasz and
                 Eysenbach, Benjamin},
    booktitle = {{International Conference} on {Learning Representations}},
    title     = {{Accelerating Goal-Conditioned RL Algorithms} and {Research}},
    url       = {https://arxiv.org/pdf/2408.11052},
    year      = {2025},
}

Questions ❓

If you have any questions, comments, or suggestions, please reach out to MichaΕ‚ Bortkiewicz (michalbortkiewicz8@gmail.com).

See Also πŸ™Œ

There are a number of other libraries which inspired this work, we encourage you to take a look!

JAX-native algorithms:

  • Mava: JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
  • Minimax: JAX implementations of autocurricula baselines for RL.
  • JaxIRL: JAX implementation of algorithms for inverse reinforcement learning.

JAX-native environments:

  • Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
  • Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.
  • XLand-MiniGrid: Meta-RL gridworld environments inspired by XLand and MiniGrid.
  • Craftax: (Crafter + NetHack) in JAX.
  • JaxMARL: Multi-agent RL in Jax.