{"id":24084,"library":"mujoco-mjx","title":"MuJoCo XLA (MJX)","description":"MJX is a JAX-based physics engine that accelerates MuJoCo simulations on GPU/TPU. It provides differentiable physics for reinforcement learning and robotics. Current version: 3.8.0. Released ~quarterly alongside MuJoCo.","status":"active","version":"3.8.0","language":"python","source_language":"en","source_url":"https://github.com/google-deepmind/mujoco/tree/main/mjx","tags":["physics","simulation","jax","reinforcement-learning","robotics","gpu-acceleration"],"install":[{"cmd":"pip install mujoco-mjx","lang":"bash","label":"Install from PyPI"}],"dependencies":[{"reason":"core dependency for JAX arrays and autograd","package":"jax","optional":false},{"reason":"required for loading/processing MJCF models that MJX executes","package":"mujoco","optional":false}],"imports":[{"note":"The top-level package is mujoco_mjx, not mjx.","wrong":"from mjx import MjxModel","symbol":"MjxModel","correct":"from mujoco_mjx import MjxModel"},{"note":"Same as above.","wrong":"from mjx import step","symbol":"step","correct":"from mujoco_mjx import step"}],"quickstart":{"code":"import mujoco\nimport mujoco_mjx as mjx\nimport jax.numpy as jnp\n\nxml = \"\"\"\n<mujoco model=\"test\">\n  <worldbody>\n    <geom name=\"floor\" type=\"plane\" size=\"1 1 0.1\" />\n    <body>\n      <joint name=\"slide\" type=\"slide\" axis=\"1 0 0\" />\n      <geom name=\"box\" type=\"box\" size=\"0.2 0.2 0.2\" pos=\"0 0 0.2\" />\n    </body>\n  </worldbody>\n</mujoco>\"\"\"\nm = mujoco.MjModel.from_xml_string(xml)\nd = mujoco.MjData(m)\nmjx_model = mjx.put_model(m)\nmjx_data = mjx.put_data(m, d)\n# Step simulation\nfor _ in range(100):\n    ctrl = jnp.zeros(m.nu)\n    mjx_data = mjx.step(mjx_model, mjx_data, ctrl)\nprint(\"Final position:\", mjx_data.qpos)","lang":"python","description":"Create a simple MuJoCo model, convert to MJX, and simulate with random control."},"warnings":[{"fix":"Pass ctrl as third argument: mjx.step(model, data, ctrl).","message":"MJX step() expects a control array (ctrl) argument as of v3.5+; earlier versions used a different signature.","severity":"breaking","affected_versions":"<3.5"},{"fix":"Set jax.config.update('jax_enable_x64', True) if you need float64.","message":"MJX uses JAX's single-precision (float32) by default; use double precision in MuJoCo Python simulation may cause mismatch.","severity":"gotcha","affected_versions":"all"},{"fix":"Use mjx.put_model(m) to convert a MuJoCo MjModel to MJX.","message":"mjx.MjxModel.from_mjb() is deprecated; use mjx.put_model(mujoco.MjModel) instead.","severity":"deprecated","affected_versions":">=3.6"}],"env_vars":null,"last_verified":"2026-05-01T00:00:00.000Z","next_check":"2026-07-30T00:00:00.000Z","problems":[{"fix":"Run pip install mujoco-mjx and import as import mujoco_mjx as mjx.","cause":"mujoco-mjx is not installed or you imported the wrong name.","error":"ModuleNotFoundError: No module named 'mujoco_mjx'"},{"fix":"Install JAX: pip install -U 'jax[cuda12]' or appropriate variant.","cause":"JAX is not installed or installed only CPU version.","error":"ValueError: MJX requires JAX. Please install JAX with GPU support."},{"fix":"Pass a JAX array for controls: mjx.step(model, data, jnp.zeros(model.nu)).","cause":"Old code calling step without control input; MJX v3.5+ requires ctrl.","error":"TypeError: step() missing 1 required positional argument: 'ctrl'"}],"ecosystem":"pypi","meta_description":null,"install_score":null,"install_tag":null,"quickstart_score":null,"quickstart_tag":null}