MuJoCo XLA (MJX)
raw JSON → 3.8.0 verified Fri May 01 auth: no python
MJX is a JAX-based physics engine that accelerates MuJoCo simulations on GPU/TPU. It provides differentiable physics for reinforcement learning and robotics. Current version: 3.8.0. Released ~quarterly alongside MuJoCo.
pip install mujoco-mjx Common errors
error ModuleNotFoundError: No module named 'mujoco_mjx' ↓
cause mujoco-mjx is not installed or you imported the wrong name.
fix
Run pip install mujoco-mjx and import as import mujoco_mjx as mjx.
error ValueError: MJX requires JAX. Please install JAX with GPU support. ↓
cause JAX is not installed or installed only CPU version.
fix
Install JAX: pip install -U 'jax[cuda12]' or appropriate variant.
error TypeError: step() missing 1 required positional argument: 'ctrl' ↓
cause Old code calling step without control input; MJX v3.5+ requires ctrl.
fix
Pass a JAX array for controls: mjx.step(model, data, jnp.zeros(model.nu)).
Warnings
breaking MJX step() expects a control array (ctrl) argument as of v3.5+; earlier versions used a different signature. ↓
fix Pass ctrl as third argument: mjx.step(model, data, ctrl).
gotcha MJX uses JAX's single-precision (float32) by default; use double precision in MuJoCo Python simulation may cause mismatch. ↓
fix Set jax.config.update('jax_enable_x64', True) if you need float64.
deprecated mjx.MjxModel.from_mjb() is deprecated; use mjx.put_model(mujoco.MjModel) instead. ↓
fix Use mjx.put_model(m) to convert a MuJoCo MjModel to MJX.
Imports
- MjxModel wrong
from mjx import MjxModelcorrectfrom mujoco_mjx import MjxModel - step wrong
from mjx import stepcorrectfrom mujoco_mjx import step
Quickstart
import mujoco
import mujoco_mjx as mjx
import jax.numpy as jnp
xml = """
<mujoco model="test">
<worldbody>
<geom name="floor" type="plane" size="1 1 0.1" />
<body>
<joint name="slide" type="slide" axis="1 0 0" />
<geom name="box" type="box" size="0.2 0.2 0.2" pos="0 0 0.2" />
</body>
</worldbody>
</mujoco>"""
m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mjx_model = mjx.put_model(m)
mjx_data = mjx.put_data(m, d)
# Step simulation
for _ in range(100):
ctrl = jnp.zeros(m.nu)
mjx_data = mjx.step(mjx_model, mjx_data, ctrl)
print("Final position:", mjx_data.qpos)