Source code for benchmark

from __future__ import annotations

# ruff: noqa: E402

import argparse
from contextlib import contextmanager
import math
import os
from pathlib import Path
import sys
import time
from typing import Any

import warp as wp

ROOT = Path(__file__).resolve().parents[0]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from sim.collision.pipeline import SapCollisionPipeline
from sim.loader.control import load_sap_control_sequence
from sim.loader.scene import load_sap_scene, load_sap_scene_config
from sim.resources.collision_model import sap_collision_state_from_state
from sim.solver_sap import SolverSAP


DEFAULT_SCENE = ROOT / "assets" / "yaml" / "unitree_g1_usd.yaml"


def _resolve_path(path: str | Path) -> Path:
    resolved = Path(path).expanduser()
    if not resolved.is_absolute():
        resolved = Path.cwd() / resolved
    return resolved.resolve()


@contextmanager
def _temporary_env(values: dict[str, str]) -> Any:
    previous = {name: os.environ.get(name) for name in values}
    os.environ.update(values)
    try:
        yield
    finally:
        for name, value in previous.items():
            if value is None:
                os.environ.pop(name, None)
            else:
                os.environ[name] = value


def _scene_simulation_config(scene_path: Path) -> dict[str, Any]:
    config = load_sap_scene_config(scene_path)
    simulation = config.get("simulation", {}) or {}
    return simulation if isinstance(simulation, dict) else {}


def _scene_default_shape_tau(scene_path: Path) -> float | None:
    config = load_sap_scene_config(scene_path)
    builder = config.get("builder", {}) or {}
    if not isinstance(builder, dict):
        return None
    defaults = builder.get("defaults", {}) or {}
    if not isinstance(defaults, dict):
        return None
    shape = defaults.get("shape", {}) or {}
    if not isinstance(shape, dict) or "tau" not in shape or shape["tau"] is None:
        return None
    return float(shape["tau"])


def _effective_dt(args: argparse.Namespace, simulation: dict[str, Any]) -> float:
    value = args.dt if args.dt is not None else simulation.get("dt", 0.003)
    dt = float(value)
    if dt <= 0.0:
        raise ValueError("--dt must be positive.")
    return dt


def _frame_count(args: argparse.Namespace, dt: float) -> int:
    if args.frames is not None:
        return max(int(args.frames), 1)
    return max(int(math.ceil(max(float(args.duration), 0.0) / dt)), 1)


def _scene_max_rigid_contact_per_env(scene_path: Path, simulation: dict[str, Any]) -> int:
    if "max_rigid_contact" not in simulation:
        raise ValueError(f"Scene file {scene_path} must define simulation.max_rigid_contact.")
    return max(int(simulation["max_rigid_contact"]), 1)


def _total_rigid_contact_capacity(max_rigid_contact_per_env: int, num_worlds: int) -> int:
    return max(int(max_rigid_contact_per_env), 1) * max(int(num_worlds), 1)


def _scene_num_worlds(args: argparse.Namespace, scene_path: Path, simulation: dict[str, Any]) -> int:
    if args.num_worlds is not None:
        return max(int(args.num_worlds), 1)
    if "num_worlds" not in simulation:
        raise ValueError(f"Scene file {scene_path} must define simulation.num_worlds.")
    return max(int(simulation["num_worlds"]), 1)


def _scene_solver_kwargs(scene_path: Path, simulation: dict[str, Any]) -> dict[str, Any]:
    raw_solver = simulation.get("solver", {}) or {}
    if not isinstance(raw_solver, dict):
        raise ValueError(f"Scene file {scene_path} simulation.solver must be a mapping.")
    return dict(raw_solver)


def _step_native(
    *,
    solver: SolverSAP,
    collision_pipeline: SapCollisionPipeline,
    state_0,
    state_1,
    control,
    contacts,
    dt: float,
):
    state_0.clear_forces()
    collision_pipeline.collide(sap_collision_state_from_state(state_0), contacts)
    solver.step(state_0, state_1, control, contacts, dt)
    return state_1, state_0


def _cuda_graph_supported(device) -> bool:
    return (
        bool(getattr(device, "is_cuda", False))
        and hasattr(wp, "ScopedCapture")
        and hasattr(wp, "capture_launch")
    )


def _capture_native_step_graph(
    *,
    solver: SolverSAP,
    collision_pipeline: SapCollisionPipeline,
    state_0,
    state_1,
    control,
    contacts,
    dt: float,
    device,
):
    if not _cuda_graph_supported(device):
        return None

    sim_time = solver.sim_time
    frame_id = solver.frame_id
    has_contact_solve_v_guess = solver._has_contact_solve_v_guess
    try:
        with wp.ScopedCapture(device=device) as capture:
            next_state, prev_state = _step_native(
                solver=solver,
                collision_pipeline=collision_pipeline,
                state_0=state_0,
                state_1=state_1,
                control=control,
                contacts=contacts,
                dt=dt,
            )
            prev_state.assign(next_state)
    except Exception:
        return None
    finally:
        solver.sim_time = sim_time
        solver.frame_id = frame_id
        solver._has_contact_solve_v_guess = has_contact_solve_v_guess

    return capture.graph


[docs] def run_native(args: argparse.Namespace) -> None: """Run the native SAP benchmark loop from parsed command-line arguments. It loads a scene, builds the collision pipeline and solver, optionally captures a CUDA graph, and prints timing statistics. """ scene_path = _resolve_path(args.scene) simulation = _scene_simulation_config(scene_path) dt = _effective_dt(args, simulation) frames = _frame_count(args, dt) max_rigid_contact_per_env = _scene_max_rigid_contact_per_env(scene_path, simulation) num_worlds = _scene_num_worlds(args, scene_path, simulation) rigid_contact_capacity = _total_rigid_contact_capacity(max_rigid_contact_per_env, num_worlds) device = wp.get_device(args.device) loaded = load_sap_scene( scene_path, device=device, rigid_contact_max=rigid_contact_capacity, strict=True, num_worlds=num_worlds, ) model = loaded.sap_model if int(model.joint_count) <= 0 or int(model.joint_dof_count) <= 0: raise ValueError("SolverSAP requires a scene with at least one joint DOF.") solver = SolverSAP( model, max_rigid_contact=max_rigid_contact_per_env, contact_tau_d=_scene_default_shape_tau(scene_path), **_scene_solver_kwargs(scene_path, simulation), ) collision_pipeline = SapCollisionPipeline(loaded.collision_model, rigid_contact_max=rigid_contact_capacity) contacts = collision_pipeline.contacts() state_0 = loaded.sap_state state_1 = model.state() control = loaded.sap_control control_sequence = load_sap_control_sequence(scene_path, model, state_0, control, device=device) graph = None graph = _capture_native_step_graph( solver=solver, collision_pipeline=collision_pipeline, state_0=state_0, state_1=state_1, control=control, contacts=contacts, dt=dt, device=device, ) wp.synchronize_device(device) t0 = time.time() if graph is not None: for frame_index in range(frames): if control_sequence is not None: control_sequence.apply(control, frame_index, dt) wp.capture_launch(graph) solver.sim_time += dt solver.frame_id += 1 print("frame", solver.frame_id, "sim_time", solver.sim_time) else: for frame_index in range(frames): if control_sequence is not None: control_sequence.apply(control, frame_index, dt) state_0, state_1 = _step_native( solver=solver, collision_pipeline=collision_pipeline, state_0=state_0, state_1=state_1, control=control, contacts=contacts, dt=dt, ) solver.sim_time += dt solver.frame_id += 1 print("frame", solver.frame_id, "sim_time", solver.sim_time) wp.synchronize_device(device) t1 = time.time() fps = frames / (t1 - t0) if t1 > t0 else float("inf") realtime_ratio = (frames * dt) / (t1 - t0) if t1 > t0 else float("inf") print( f"scene={scene_path}: device={device} dt={dt:.6f} " f"frames={frames} num_worlds={num_worlds} " f"max_rigid_contact_per_env={max_rigid_contact_per_env} " f"rigid_contact_capacity={rigid_contact_capacity} " f"cuda_graph={graph is not None}", f"elapsed={t1 - t0:.3f}s fps={fps:.1f}", f"realtime_ratio={realtime_ratio:.3f}x", flush=True, )
[docs] def build_parser() -> argparse.ArgumentParser: """Build the command-line parser for the benchmark entry point.""" parser = argparse.ArgumentParser(description="Benchmark.") parser.add_argument("--scene", type=str, default=str(DEFAULT_SCENE), help="YAML scene file.") parser.add_argument("--duration", type=float, default=2.0, help="Simulation duration in seconds.") parser.add_argument("--frames", type=int, default=None, help="Number of simulation frames. Overrides --duration.") parser.add_argument("--dt", type=float, default=None, help="Simulation timestep. Defaults to simulation.dt.") parser.add_argument("--num-worlds", type=int, default=None, help="Number of replicated worlds.") parser.add_argument("--device", type=str, default=None, help="Warp device, for example cuda:0 or cpu.") return parser
def main() -> None: """Parse command-line options and dispatch the requested benchmark mode.""" args = build_parser().parse_args() run_native(args) if __name__ == "__main__": main()