MuJoCo XLA (MJX)

raw JSON →
3.8.0 verified Fri May 01 auth: no python

MJX is a JAX-based physics engine that accelerates MuJoCo simulations on GPU/TPU. It provides differentiable physics for reinforcement learning and robotics. Current version: 3.8.0. Released ~quarterly alongside MuJoCo.

pip install mujoco-mjx
error ModuleNotFoundError: No module named 'mujoco_mjx'
cause mujoco-mjx is not installed or you imported the wrong name.
fix
Run pip install mujoco-mjx and import as import mujoco_mjx as mjx.
error ValueError: MJX requires JAX. Please install JAX with GPU support.
cause JAX is not installed or installed only CPU version.
fix
Install JAX: pip install -U 'jax[cuda12]' or appropriate variant.
error TypeError: step() missing 1 required positional argument: 'ctrl'
cause Old code calling step without control input; MJX v3.5+ requires ctrl.
fix
Pass a JAX array for controls: mjx.step(model, data, jnp.zeros(model.nu)).
breaking MJX step() expects a control array (ctrl) argument as of v3.5+; earlier versions used a different signature.
fix Pass ctrl as third argument: mjx.step(model, data, ctrl).
gotcha MJX uses JAX's single-precision (float32) by default; use double precision in MuJoCo Python simulation may cause mismatch.
fix Set jax.config.update('jax_enable_x64', True) if you need float64.
deprecated mjx.MjxModel.from_mjb() is deprecated; use mjx.put_model(mujoco.MjModel) instead.
fix Use mjx.put_model(m) to convert a MuJoCo MjModel to MJX.

Create a simple MuJoCo model, convert to MJX, and simulate with random control.

import mujoco
import mujoco_mjx as mjx
import jax.numpy as jnp

xml = """
<mujoco model="test">
  <worldbody>
    <geom name="floor" type="plane" size="1 1 0.1" />
    <body>
      <joint name="slide" type="slide" axis="1 0 0" />
      <geom name="box" type="box" size="0.2 0.2 0.2" pos="0 0 0.2" />
    </body>
  </worldbody>
</mujoco>"""
m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mjx_model = mjx.put_model(m)
mjx_data = mjx.put_data(m, d)
# Step simulation
for _ in range(100):
    ctrl = jnp.zeros(m.nu)
    mjx_data = mjx.step(mjx_model, mjx_data, ctrl)
print("Final position:", mjx_data.qpos)