from __future__ import annotations
import numpy as np
import warp as wp
from sim.contact_jacobian import SapContactJacobian, SapContactJacobianResult
from sim.contact_solve import (
SapContactSolve,
SapContactSolveResult,
normalize_sap_line_search_mode,
)
from sim.sap_helpers import (
_copy_public_body_force_to_sap_body_origin_kernel,
_copy_public_body_force_to_sap_body_origin_kernel_body_qd,
_copy_public_to_sap_force_kernel,
_copy_public_to_sap_force_kernel_f64,
_copy_public_to_sap_force_kernel_f64_from_joint_pose,
_copy_sap_to_public_velocity_f32_from_joint_pose,
_copy_sap_to_public_velocity_f64_from_joint_pose,
copy_public_to_sap_velocity_f32,
copy_public_to_sap_velocity_f64,
copy_public_to_sap_velocity_f64_from_joint_pose,
)
from sim.sap_kinematics import sap_eval_fk
from sim.sap_runtime import (
Model,
SAP_JOINT_BALL,
SAP_JOINT_D6,
SAP_JOINT_DISTANCE,
SAP_JOINT_FIXED,
SAP_JOINT_FREE,
SAP_JOINT_PRISMATIC,
SAP_JOINT_REVOLUTE,
SapContacts,
SapControl,
SapModel,
SapState,
sap_contacts_from_newton,
)
wp.config.enable_backward = False
@wp.kernel
def _copy_f64_to_f32(
src: wp.array(dtype=wp.float64),
dst: wp.array(dtype=wp.float32),
):
i = wp.tid()
dst[i] = wp.float32(src[i])
@wp.kernel
def _copy_f32_to_f64(
src: wp.array(dtype=wp.float32),
dst: wp.array(dtype=wp.float64),
):
i = wp.tid()
dst[i] = wp.float64(src[i])
def _normalize_precision_knob(value: str, *, option_name: str) -> str:
precision = str(value).strip().lower()
if precision == "f32":
precision = "fp32"
elif precision == "f64":
precision = "fp64"
if precision not in {"fp32", "fp64"}:
raise ValueError(f"{option_name} must be 'fp32'/'f32' or 'fp64'/'f64', got {value!r}.")
return precision
SAP_CONTACT_PRESET_VARIANTS = ("approx32", "approx64", "drake")
_CONTACT_PRESET_KWARGS = {
"approx32": {
"use_f64_boundary_pose": False,
"free_motion_solve_precision": "fp32",
"contact_solve_precision": "fp64",
"contact_linear_solve_precision": "fp32",
"sap_contact_weight_precision": "fp32",
"contact_weight_mode": "body_inertia",
"contact_point_mode": "witness_point",
"position_integration": "midpoint",
},
"approx64": {
"use_f64_boundary_pose": True,
"free_motion_solve_precision": "fp64",
"contact_solve_precision": "fp64",
"contact_linear_solve_precision": "fp64",
"sap_contact_weight_precision": "fp64",
"contact_weight_mode": "body_inertia",
"contact_point_mode": "witness_point",
"position_integration": "midpoint",
},
"drake": {
"use_f64_boundary_pose": True,
"free_motion_solve_precision": "fp64",
"contact_solve_precision": "fp64",
"contact_linear_solve_precision": "fp64",
"sap_contact_weight_precision": "fp64",
"contact_weight_mode": "diag_delassus",
"contact_point_mode": "contact_midpoint",
"position_integration": "sap_euler",
},
}
def _normalize_contact_preset_variant(value: str | None) -> str:
if value is None:
return "approx32"
variant = str(value).strip().lower().replace("-", "_")
if variant == "approx_32":
variant = "approx32"
elif variant == "approx_64":
variant = "approx64"
if variant not in _CONTACT_PRESET_KWARGS:
choices = ", ".join(SAP_CONTACT_PRESET_VARIANTS)
raise ValueError(f"contact_preset_variant must be one of {choices}, got {value!r}.")
return variant
def _contact_preset_kwargs(contact_preset_variant: str | None) -> dict[str, object]:
return dict(_CONTACT_PRESET_KWARGS[_normalize_contact_preset_variant(contact_preset_variant)])
def _infer_contact_tau_d_fallback(model: Model) -> float:
for name in ("sap_debug_shape_material_tau", "shape_material_tau"):
src = getattr(model, name, None)
if src is None:
continue
if isinstance(src, wp.array):
values = np.asarray(src.numpy(), dtype=np.float64).reshape(-1)
else:
values = np.asarray(src, dtype=np.float64).reshape(-1)
explicit = values[np.isfinite(values) & (values >= 0.0)]
if explicit.size > 0:
return float(explicit[0])
return 0.0
@wp.func
def _quat_rotate_vec3(q: wp.quat, v: wp.vec3) -> wp.vec3:
norm = wp.sqrt(q.x * q.x + q.y * q.y + q.z * q.z + q.w * q.w)
if norm <= wp.float32(0.0):
return v
inv_norm = wp.float32(1.0) / norm
x = q.x * inv_norm
y = q.y * inv_norm
z = q.z * inv_norm
w = q.w * inv_norm
qv = wp.vec3(x, y, z)
t = wp.float32(2.0) * wp.cross(qv, v)
return v + w * t + wp.cross(qv, t)
@wp.func
def _quat_sap_euler_xyzw(q: wp.quat, w: wp.vec3d, dt: wp.float64) -> wp.quat:
# SAP's quaternion floating mobilizer maps angular velocity with
# qdot = 0.5 * [0, w] * q. Runtime quaternions are stored as xyzw.
qx = wp.float64(q.x)
qy = wp.float64(q.y)
qz = wp.float64(q.z)
qw = wp.float64(q.w)
half_dt = wp.float64(0.5) * dt
x = qx + half_dt * (qw * w.x + w.y * qz - w.z * qy)
y = qy + half_dt * (qw * w.y + w.z * qx - w.x * qz)
z = qz + half_dt * (qw * w.z + w.x * qy - w.y * qx)
s = qw - half_dt * (w.x * qx + w.y * qy + w.z * qz)
norm = wp.sqrt(x * x + y * y + z * z + s * s)
if norm > wp.float64(0.0):
inv_norm = wp.float64(1.0) / norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
s = s * inv_norm
return wp.quat(wp.float32(x), wp.float32(y), wp.float32(z), wp.float32(s))
@wp.func
def _quat_sap_euler_xyzw_f32(q: wp.quat, w: wp.vec3, dt: wp.float32) -> wp.quat:
# SAP's quaternion floating mobilizer maps angular velocity with
# qdot = 0.5 * [0, w] * q. Runtime quaternions are stored as xyzw.
half_dt = wp.float32(0.5) * dt
x = q.x + half_dt * (q.w * w.x + w.y * q.z - w.z * q.y)
y = q.y + half_dt * (q.w * w.y + w.z * q.x - w.x * q.z)
z = q.z + half_dt * (q.w * w.z + w.x * q.y - w.y * q.x)
s = q.w - half_dt * (w.x * q.x + w.y * q.y + w.z * q.z)
norm = wp.sqrt(x * x + y * y + z * z + s * s)
if norm > wp.float32(0.0):
inv_norm = wp.float32(1.0) / norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
s = s * inv_norm
return wp.quat(x, y, z, s)
@wp.func
def _quat_sap_euler_xyzw_d(q: wp.quatd, w: wp.vec3d, dt: wp.float64) -> wp.quatd:
qx = wp.float64(q.x)
qy = wp.float64(q.y)
qz = wp.float64(q.z)
qw = wp.float64(q.w)
half_dt = wp.float64(0.5) * dt
x = qx + half_dt * (qw * w.x + w.y * qz - w.z * qy)
y = qy + half_dt * (qw * w.y + w.z * qx - w.x * qz)
z = qz + half_dt * (qw * w.z + w.x * qy - w.y * qx)
s = qw - half_dt * (w.x * qx + w.y * qy + w.z * qz)
norm = wp.sqrt(x * x + y * y + z * z + s * s)
if norm > wp.float64(0.0):
inv_norm = wp.float64(1.0) / norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
s = s * inv_norm
return wp.quatd(x, y, z, s)
@wp.func
def _quat_midpoint_xyzw(q: wp.quat, w_mid: wp.vec3d, dt: wp.float64) -> wp.quat:
angle = wp.length(w_mid) * dt
q_next = q
if angle > wp.float64(1.0e-12):
axis = wp.normalize(w_mid)
half_angle = wp.float64(0.5) * angle
s = wp.sin(half_angle)
c = wp.cos(half_angle)
r = wp.quatd(axis.x * s, axis.y * s, axis.z * s, c)
qd = wp.quatd(
wp.float64(q.x),
wp.float64(q.y),
wp.float64(q.z),
wp.float64(q.w),
)
out = r * qd
norm = wp.sqrt(out.x * out.x + out.y * out.y + out.z * out.z + out.w * out.w)
if norm > wp.float64(0.0):
inv_norm = wp.float64(1.0) / norm
out = wp.quatd(out.x * inv_norm, out.y * inv_norm, out.z * inv_norm, out.w * inv_norm)
q_next = wp.quat(wp.float32(out.x), wp.float32(out.y), wp.float32(out.z), wp.float32(out.w))
return q_next
@wp.func
def _quat_midpoint_xyzw_f32(q: wp.quat, w_mid: wp.vec3, dt: wp.float32) -> wp.quat:
angle = wp.length(w_mid) * dt
q_next = q
if angle > wp.float32(1.0e-12):
axis = wp.normalize(w_mid)
half_angle = wp.float32(0.5) * angle
s = wp.sin(half_angle)
c = wp.cos(half_angle)
r = wp.quat(axis.x * s, axis.y * s, axis.z * s, c)
out = r * q
norm = wp.sqrt(out.x * out.x + out.y * out.y + out.z * out.z + out.w * out.w)
if norm > wp.float32(0.0):
inv_norm = wp.float32(1.0) / norm
out = wp.quat(out.x * inv_norm, out.y * inv_norm, out.z * inv_norm, out.w * inv_norm)
q_next = out
return q_next
@wp.func
def _quat_midpoint_xyzw_d(q: wp.quatd, w_mid: wp.vec3d, dt: wp.float64) -> wp.quatd:
angle = wp.length(w_mid) * dt
q_next = q
if angle > wp.float64(1.0e-12):
axis = wp.normalize(w_mid)
half_angle = wp.float64(0.5) * angle
s = wp.sin(half_angle)
c = wp.cos(half_angle)
r = wp.quatd(axis.x * s, axis.y * s, axis.z * s, c)
q_next = r * q
norm = wp.sqrt(q_next.x * q_next.x + q_next.y * q_next.y + q_next.z * q_next.z + q_next.w * q_next.w)
if norm > wp.float64(0.0):
inv_norm = wp.float64(1.0) / norm
q_next = wp.quatd(
q_next.x * inv_norm,
q_next.y * inv_norm,
q_next.z * inv_norm,
q_next.w * inv_norm,
)
return q_next
@wp.kernel
def _integrate_generalized_positions_sap_euler(
joint_type: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_com: wp.array(dtype=wp.vec3d),
joint_q_in: wp.array(dtype=wp.float32),
v_sap: wp.array(dtype=wp.float64),
dt: float,
joint_q_out: wp.array(dtype=wp.float32),
joint_qd_out: wp.array(dtype=wp.float32),
):
joint = wp.tid()
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
axis_count = joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1]
jtype = joint_type[joint]
h = wp.float32(dt)
if jtype == SAP_JOINT_PRISMATIC or jtype == SAP_JOINT_REVOLUTE:
qd = wp.float32(v_sap[qd_start])
joint_q_out[q_start] = joint_q_in[q_start] + h * qd
joint_qd_out[qd_start] = qd
return
if jtype == SAP_JOINT_BALL:
q = wp.quat(
joint_q_in[q_start + 0],
joint_q_in[q_start + 1],
joint_q_in[q_start + 2],
joint_q_in[q_start + 3],
)
w = wp.vec3(
wp.float32(v_sap[qd_start + 0]),
wp.float32(v_sap[qd_start + 1]),
wp.float32(v_sap[qd_start + 2]),
)
q_next = _quat_sap_euler_xyzw_f32(q, w, h)
joint_q_out[q_start + 0] = q_next.x
joint_q_out[q_start + 1] = q_next.y
joint_q_out[q_start + 2] = q_next.z
joint_q_out[q_start + 3] = q_next.w
joint_qd_out[qd_start + 0] = wp.float32(w.x)
joint_qd_out[qd_start + 1] = wp.float32(w.y)
joint_qd_out[qd_start + 2] = wp.float32(w.z)
return
if jtype == SAP_JOINT_FREE or jtype == SAP_JOINT_DISTANCE:
q = wp.quat(
joint_q_in[q_start + 3],
joint_q_in[q_start + 4],
joint_q_in[q_start + 5],
joint_q_in[q_start + 6],
)
w = wp.vec3(
wp.float32(v_sap[qd_start + 0]),
wp.float32(v_sap[qd_start + 1]),
wp.float32(v_sap[qd_start + 2]),
)
v_origin = wp.vec3(
wp.float32(v_sap[qd_start + 3]),
wp.float32(v_sap[qd_start + 4]),
wp.float32(v_sap[qd_start + 5]),
)
q_next = _quat_sap_euler_xyzw_f32(q, w, h)
joint_q_out[q_start + 0] = joint_q_in[q_start + 0] + h * v_origin.x
joint_q_out[q_start + 1] = joint_q_in[q_start + 1] + h * v_origin.y
joint_q_out[q_start + 2] = joint_q_in[q_start + 2] + h * v_origin.z
joint_q_out[q_start + 3] = q_next.x
joint_q_out[q_start + 4] = q_next.y
joint_q_out[q_start + 5] = q_next.z
joint_q_out[q_start + 6] = q_next.w
child = joint_child[joint]
body_com_f = wp.vec3(
wp.float32(body_com[child].x),
wp.float32(body_com[child].y),
wp.float32(body_com[child].z),
)
r_com = _quat_rotate_vec3(q_next, body_com_f)
v_com = v_origin + wp.cross(w, r_com)
joint_qd_out[qd_start + 0] = v_com.x
joint_qd_out[qd_start + 1] = v_com.y
joint_qd_out[qd_start + 2] = v_com.z
joint_qd_out[qd_start + 3] = w.x
joint_qd_out[qd_start + 4] = w.y
joint_qd_out[qd_start + 5] = w.z
return
if jtype == SAP_JOINT_D6:
for axis in range(axis_count):
qd = wp.float32(v_sap[qd_start + axis])
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd
joint_qd_out[qd_start + axis] = qd
return
if jtype == SAP_JOINT_FIXED:
return
for axis in range(axis_count):
qd = wp.float32(v_sap[qd_start + axis])
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd
joint_qd_out[qd_start + axis] = qd
@wp.kernel
def _integrate_generalized_positions_midpoint(
joint_type: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_com: wp.array(dtype=wp.vec3d),
joint_q_in: wp.array(dtype=wp.float32),
v_prev_sap: wp.array(dtype=wp.float64),
v_new_sap: wp.array(dtype=wp.float64),
dt: float,
joint_q_out: wp.array(dtype=wp.float32),
joint_qd_out: wp.array(dtype=wp.float32),
):
joint = wp.tid()
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
axis_count = joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1]
jtype = joint_type[joint]
h = wp.float32(dt)
if jtype == SAP_JOINT_PRISMATIC or jtype == SAP_JOINT_REVOLUTE:
qd = wp.float32(v_new_sap[qd_start])
qd_mid = wp.float32(0.5) * (wp.float32(v_prev_sap[qd_start]) + qd)
joint_q_out[q_start] = joint_q_in[q_start] + h * qd_mid
joint_qd_out[qd_start] = qd
return
if jtype == SAP_JOINT_BALL:
q = wp.quat(
joint_q_in[q_start + 0],
joint_q_in[q_start + 1],
joint_q_in[q_start + 2],
joint_q_in[q_start + 3],
)
w_old = wp.vec3(
wp.float32(v_prev_sap[qd_start + 0]),
wp.float32(v_prev_sap[qd_start + 1]),
wp.float32(v_prev_sap[qd_start + 2]),
)
w_new = wp.vec3(
wp.float32(v_new_sap[qd_start + 0]),
wp.float32(v_new_sap[qd_start + 1]),
wp.float32(v_new_sap[qd_start + 2]),
)
q_next = _quat_midpoint_xyzw_f32(q, wp.float32(0.5) * (w_old + w_new), h)
joint_q_out[q_start + 0] = q_next.x
joint_q_out[q_start + 1] = q_next.y
joint_q_out[q_start + 2] = q_next.z
joint_q_out[q_start + 3] = q_next.w
joint_qd_out[qd_start + 0] = wp.float32(w_new.x)
joint_qd_out[qd_start + 1] = wp.float32(w_new.y)
joint_qd_out[qd_start + 2] = wp.float32(w_new.z)
return
if jtype == SAP_JOINT_FREE or jtype == SAP_JOINT_DISTANCE:
q = wp.quat(
joint_q_in[q_start + 3],
joint_q_in[q_start + 4],
joint_q_in[q_start + 5],
joint_q_in[q_start + 6],
)
w_old = wp.vec3(
wp.float32(v_prev_sap[qd_start + 0]),
wp.float32(v_prev_sap[qd_start + 1]),
wp.float32(v_prev_sap[qd_start + 2]),
)
v_origin_old = wp.vec3(
wp.float32(v_prev_sap[qd_start + 3]),
wp.float32(v_prev_sap[qd_start + 4]),
wp.float32(v_prev_sap[qd_start + 5]),
)
w_new = wp.vec3(
wp.float32(v_new_sap[qd_start + 0]),
wp.float32(v_new_sap[qd_start + 1]),
wp.float32(v_new_sap[qd_start + 2]),
)
v_origin_new = wp.vec3(
wp.float32(v_new_sap[qd_start + 3]),
wp.float32(v_new_sap[qd_start + 4]),
wp.float32(v_new_sap[qd_start + 5]),
)
q_next = _quat_midpoint_xyzw_f32(q, wp.float32(0.5) * (w_old + w_new), h)
v_origin_mid = wp.float32(0.5) * (v_origin_old + v_origin_new)
joint_q_out[q_start + 0] = joint_q_in[q_start + 0] + h * v_origin_mid.x
joint_q_out[q_start + 1] = joint_q_in[q_start + 1] + h * v_origin_mid.y
joint_q_out[q_start + 2] = joint_q_in[q_start + 2] + h * v_origin_mid.z
joint_q_out[q_start + 3] = q_next.x
joint_q_out[q_start + 4] = q_next.y
joint_q_out[q_start + 5] = q_next.z
joint_q_out[q_start + 6] = q_next.w
child = joint_child[joint]
body_com_f = wp.vec3(
wp.float32(body_com[child].x),
wp.float32(body_com[child].y),
wp.float32(body_com[child].z),
)
r_com = _quat_rotate_vec3(q_next, body_com_f)
v_com = v_origin_new + wp.cross(w_new, r_com)
joint_qd_out[qd_start + 0] = v_com.x
joint_qd_out[qd_start + 1] = v_com.y
joint_qd_out[qd_start + 2] = v_com.z
joint_qd_out[qd_start + 3] = w_new.x
joint_qd_out[qd_start + 4] = w_new.y
joint_qd_out[qd_start + 5] = w_new.z
return
if jtype == SAP_JOINT_D6:
for axis in range(axis_count):
qd = wp.float32(v_new_sap[qd_start + axis])
qd_mid = wp.float32(0.5) * (wp.float32(v_prev_sap[qd_start + axis]) + qd)
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd_mid
joint_qd_out[qd_start + axis] = qd
return
if jtype == SAP_JOINT_FIXED:
return
for axis in range(axis_count):
qd = wp.float32(v_new_sap[qd_start + axis])
qd_mid = wp.float32(0.5) * (wp.float32(v_prev_sap[qd_start + axis]) + qd)
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd_mid
joint_qd_out[qd_start + axis] = qd
@wp.kernel
def _integrate_generalized_positions_sap_euler_f64(
joint_type: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
joint_q_in: wp.array(dtype=wp.float64),
v_sap: wp.array(dtype=wp.float64),
dt: wp.float64,
joint_q_out: wp.array(dtype=wp.float64),
joint_qd_out: wp.array(dtype=wp.float64),
):
joint = wp.tid()
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
axis_count = joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1]
jtype = joint_type[joint]
h = wp.float64(dt)
if jtype == SAP_JOINT_PRISMATIC or jtype == SAP_JOINT_REVOLUTE:
qd = v_sap[qd_start]
joint_q_out[q_start] = joint_q_in[q_start] + h * qd
joint_qd_out[qd_start] = qd
return
if jtype == SAP_JOINT_BALL:
q = wp.quatd(
joint_q_in[q_start + 0],
joint_q_in[q_start + 1],
joint_q_in[q_start + 2],
joint_q_in[q_start + 3],
)
w = wp.vec3d(
v_sap[qd_start + 0],
v_sap[qd_start + 1],
v_sap[qd_start + 2],
)
q_next = _quat_sap_euler_xyzw_d(q, w, h)
joint_q_out[q_start + 0] = q_next.x
joint_q_out[q_start + 1] = q_next.y
joint_q_out[q_start + 2] = q_next.z
joint_q_out[q_start + 3] = q_next.w
joint_qd_out[qd_start + 0] = w.x
joint_qd_out[qd_start + 1] = w.y
joint_qd_out[qd_start + 2] = w.z
return
if jtype == SAP_JOINT_FREE or jtype == SAP_JOINT_DISTANCE:
q = wp.quatd(
joint_q_in[q_start + 3],
joint_q_in[q_start + 4],
joint_q_in[q_start + 5],
joint_q_in[q_start + 6],
)
w = wp.vec3d(
v_sap[qd_start + 0],
v_sap[qd_start + 1],
v_sap[qd_start + 2],
)
v_origin = wp.vec3d(
v_sap[qd_start + 3],
v_sap[qd_start + 4],
v_sap[qd_start + 5],
)
q_next = _quat_sap_euler_xyzw_d(q, w, h)
joint_q_out[q_start + 0] = joint_q_in[q_start + 0] + h * v_origin.x
joint_q_out[q_start + 1] = joint_q_in[q_start + 1] + h * v_origin.y
joint_q_out[q_start + 2] = joint_q_in[q_start + 2] + h * v_origin.z
joint_q_out[q_start + 3] = q_next.x
joint_q_out[q_start + 4] = q_next.y
joint_q_out[q_start + 5] = q_next.z
joint_q_out[q_start + 6] = q_next.w
joint_qd_out[qd_start + 0] = v_origin.x
joint_qd_out[qd_start + 1] = v_origin.y
joint_qd_out[qd_start + 2] = v_origin.z
joint_qd_out[qd_start + 3] = w.x
joint_qd_out[qd_start + 4] = w.y
joint_qd_out[qd_start + 5] = w.z
return
if jtype == SAP_JOINT_D6:
for axis in range(axis_count):
qd = v_sap[qd_start + axis]
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd
joint_qd_out[qd_start + axis] = qd
return
if jtype == SAP_JOINT_FIXED:
return
for axis in range(axis_count):
qd = v_sap[qd_start + axis]
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd
joint_qd_out[qd_start + axis] = qd
@wp.kernel
def _integrate_generalized_positions_midpoint_f64(
joint_type: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
joint_q_in: wp.array(dtype=wp.float64),
v_prev_sap: wp.array(dtype=wp.float64),
v_new_sap: wp.array(dtype=wp.float64),
dt: wp.float64,
joint_q_out: wp.array(dtype=wp.float64),
joint_qd_out: wp.array(dtype=wp.float64),
):
joint = wp.tid()
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
axis_count = joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1]
jtype = joint_type[joint]
h = wp.float64(dt)
if jtype == SAP_JOINT_PRISMATIC or jtype == SAP_JOINT_REVOLUTE:
qd = v_new_sap[qd_start]
qd_mid = wp.float64(0.5) * (v_prev_sap[qd_start] + qd)
joint_q_out[q_start] = joint_q_in[q_start] + h * qd_mid
joint_qd_out[qd_start] = qd
return
if jtype == SAP_JOINT_BALL:
q = wp.quatd(
joint_q_in[q_start + 0],
joint_q_in[q_start + 1],
joint_q_in[q_start + 2],
joint_q_in[q_start + 3],
)
w_old = wp.vec3d(
v_prev_sap[qd_start + 0],
v_prev_sap[qd_start + 1],
v_prev_sap[qd_start + 2],
)
w_new = wp.vec3d(
v_new_sap[qd_start + 0],
v_new_sap[qd_start + 1],
v_new_sap[qd_start + 2],
)
q_next = _quat_midpoint_xyzw_d(q, wp.float64(0.5) * (w_old + w_new), h)
joint_q_out[q_start + 0] = q_next.x
joint_q_out[q_start + 1] = q_next.y
joint_q_out[q_start + 2] = q_next.z
joint_q_out[q_start + 3] = q_next.w
joint_qd_out[qd_start + 0] = w_new.x
joint_qd_out[qd_start + 1] = w_new.y
joint_qd_out[qd_start + 2] = w_new.z
return
if jtype == SAP_JOINT_FREE or jtype == SAP_JOINT_DISTANCE:
q = wp.quatd(
joint_q_in[q_start + 3],
joint_q_in[q_start + 4],
joint_q_in[q_start + 5],
joint_q_in[q_start + 6],
)
w_old = wp.vec3d(
v_prev_sap[qd_start + 0],
v_prev_sap[qd_start + 1],
v_prev_sap[qd_start + 2],
)
v_origin_old = wp.vec3d(
v_prev_sap[qd_start + 3],
v_prev_sap[qd_start + 4],
v_prev_sap[qd_start + 5],
)
w_new = wp.vec3d(
v_new_sap[qd_start + 0],
v_new_sap[qd_start + 1],
v_new_sap[qd_start + 2],
)
v_origin_new = wp.vec3d(
v_new_sap[qd_start + 3],
v_new_sap[qd_start + 4],
v_new_sap[qd_start + 5],
)
q_next = _quat_midpoint_xyzw_d(q, wp.float64(0.5) * (w_old + w_new), h)
v_origin_mid = wp.float64(0.5) * (v_origin_old + v_origin_new)
joint_q_out[q_start + 0] = joint_q_in[q_start + 0] + h * v_origin_mid.x
joint_q_out[q_start + 1] = joint_q_in[q_start + 1] + h * v_origin_mid.y
joint_q_out[q_start + 2] = joint_q_in[q_start + 2] + h * v_origin_mid.z
joint_q_out[q_start + 3] = q_next.x
joint_q_out[q_start + 4] = q_next.y
joint_q_out[q_start + 5] = q_next.z
joint_q_out[q_start + 6] = q_next.w
joint_qd_out[qd_start + 0] = v_origin_new.x
joint_qd_out[qd_start + 1] = v_origin_new.y
joint_qd_out[qd_start + 2] = v_origin_new.z
joint_qd_out[qd_start + 3] = w_new.x
joint_qd_out[qd_start + 4] = w_new.y
joint_qd_out[qd_start + 5] = w_new.z
return
if jtype == SAP_JOINT_D6:
for axis in range(axis_count):
qd = v_new_sap[qd_start + axis]
qd_mid = wp.float64(0.5) * (v_prev_sap[qd_start + axis] + qd)
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd_mid
joint_qd_out[qd_start + axis] = qd
return
if jtype == SAP_JOINT_FIXED:
return
for axis in range(axis_count):
qd = v_new_sap[qd_start + axis]
qd_mid = wp.float64(0.5) * (v_prev_sap[qd_start + axis] + qd)
joint_q_out[q_start + axis] = joint_q_in[q_start + axis] + h * qd_mid
joint_qd_out[qd_start + axis] = qd
[docs]
class SolverSAP:
"""SAP-native solver pipeline.
This solver wires together the SAP runtime components:
1. `SapFreeMotion.compute()` computes SAP-order free motion.
2. `SapContactJacobian.compute()` computes contact Jacobians and
env-local dynamics matrices from SAP runtime data.
3. `SapContactSolve.solve()` solves the SAP stage2 problem in SAP-order
generalized velocities.
4. SAP-native integration writes `state_out`.
"""
[docs]
def __init__(
self,
model: Model,
*,
max_rigid_contact: int = 128,
fallback_mu: float | None = None,
fallback_stiffness: float = 1.0e10,
contact_beta: float = 1.0,
contact_sigma: float = 1.0e-3,
contact_tau_d: float | None = None,
block_size: int | None = None,
diag_shift: float = 0.0,
max_iterations: int = 100,
optimality_abs_tol: float = 1.0e-14,
optimality_rel_tol: float = 1.0e-6,
cost_abs_tol: float | None = None,
cost_rel_tol: float | None = None,
line_search_max_iterations: int = 40,
armijo_c: float = 1.0e-4,
rho: float = 0.8,
line_search_relative_slop: float | None = None,
line_search_variant: str = "monotone_decay",
contact_preset_variant: str | None = None,
contact_weight_mode: str | None = None,
contact_point_mode: str | None = None,
capture_contact_jacobian_snapshots: bool = False,
use_f64_boundary_pose: bool | None = None,
free_motion_solve_precision: str | None = None,
contact_solve_precision: str | None = None,
contact_linear_solve_precision: str | None = None,
sap_contact_weight_precision: str | None = None,
collect_iteration_stats: bool = False,
check_line_search_errors: bool = False,
graph_conditional: bool = True,
position_integration: str | None = None,
):
if not isinstance(model, SapModel):
raise TypeError("SolverSAP requires SapModel; convert frontend data before constructing the solver.")
if int(model.joint_count) <= 0 or int(model.joint_dof_count) <= 0:
raise ValueError("SolverSAP requires a model with articulated joint DOFs.")
preset_kwargs = _contact_preset_kwargs(contact_preset_variant)
self.contact_preset_variant = _normalize_contact_preset_variant(contact_preset_variant)
if use_f64_boundary_pose is None:
use_f64_boundary_pose = bool(preset_kwargs["use_f64_boundary_pose"])
if free_motion_solve_precision is None:
free_motion_solve_precision = str(preset_kwargs["free_motion_solve_precision"])
if contact_solve_precision is None:
contact_solve_precision = str(preset_kwargs["contact_solve_precision"])
if contact_linear_solve_precision is None:
contact_linear_solve_precision = str(preset_kwargs["contact_linear_solve_precision"])
if sap_contact_weight_precision is None:
sap_contact_weight_precision = str(preset_kwargs["sap_contact_weight_precision"])
if contact_weight_mode is None:
contact_weight_mode = str(preset_kwargs["contact_weight_mode"])
if contact_point_mode is None:
contact_point_mode = str(preset_kwargs["contact_point_mode"])
if position_integration is None:
position_integration = str(preset_kwargs["position_integration"])
self.model = model
self.max_rigid_contact = int(max_rigid_contact)
self.max_iterations = int(max_iterations)
self.optimality_abs_tol = float(optimality_abs_tol)
self.optimality_rel_tol = float(optimality_rel_tol)
self.armijo_c = float(armijo_c)
self.rho = float(rho)
self.line_search_relative_slop = line_search_relative_slop
self.line_search_variant = normalize_sap_line_search_mode(line_search_variant)
if self.line_search_variant == "armijo_decay" and not (0.0 < float(rho) < 1.0):
raise ValueError(f"rho must lie in (0, 1), got {rho!r}.")
if cost_abs_tol is None:
cost_abs_tol = 0.0 if self.line_search_variant == "monotone_decay" else 1.0e-30
if cost_rel_tol is None:
cost_rel_tol = 5.0e-3 if self.line_search_variant == "monotone_decay" else 1.0e-15
self.cost_abs_tol = float(cost_abs_tol)
self.cost_rel_tol = float(cost_rel_tol)
self.line_search_max_iterations = int(line_search_max_iterations)
if self.line_search_variant == "exact_root" and self.line_search_max_iterations == 40:
self.line_search_max_iterations = 100
self.capture_contact_jacobian_snapshots = bool(capture_contact_jacobian_snapshots)
self.use_f64_boundary_pose = bool(use_f64_boundary_pose)
self.free_motion_solve_precision = _normalize_precision_knob(
free_motion_solve_precision,
option_name="free_motion_solve_precision",
)
self.contact_solve_precision = _normalize_precision_knob(
contact_solve_precision,
option_name="contact_solve_precision",
)
self.contact_linear_solve_precision = _normalize_precision_knob(
contact_linear_solve_precision,
option_name="contact_linear_solve_precision",
)
self.sap_contact_weight_precision = _normalize_precision_knob(
sap_contact_weight_precision,
option_name="sap_contact_weight_precision",
)
self.collect_iteration_stats = bool(collect_iteration_stats)
self.check_line_search_errors = bool(check_line_search_errors)
if not bool(graph_conditional):
raise ValueError(
"SolverSAP only supports graph_conditional=True; "
"the Python contact-solve loop has been removed."
)
self.graph_conditional = True
self.position_integration = str(position_integration).strip().lower().replace("-", "_")
if self.position_integration not in {"sap_euler", "midpoint"}:
raise ValueError(
"position_integration must be 'sap_euler' or 'midpoint', "
f"got {position_integration!r}."
)
shape_fallback_tau_d = (
_infer_contact_tau_d_fallback(model) if contact_tau_d is None else float(contact_tau_d)
)
self.sap_model = model
self.contact_jacobian = SapContactJacobian(
self.sap_model,
max_rigid_contact=self.max_rigid_contact,
fallback_mu=fallback_mu,
fallback_stiffness=fallback_stiffness,
fallback_tau_d=shape_fallback_tau_d,
contact_weight_mode=contact_weight_mode,
contact_point_mode=contact_point_mode,
capture_local_snapshots=self.capture_contact_jacobian_snapshots,
use_f64_boundary_pose=self.use_f64_boundary_pose,
free_motion_solve_precision=self.free_motion_solve_precision,
sap_contact_weight_precision=self.sap_contact_weight_precision,
)
# `SapContactJacobian` owns and reuses the free-motion component so
# contact preparation and solver v* share the exact same buffers/basis.
self.free_motion = self.contact_jacobian.free_motion
self.contact_solve = SapContactSolve(
self.sap_model,
max_rigid_contact=self.max_rigid_contact,
contact_beta=contact_beta,
contact_sigma=contact_sigma,
contact_tau_d=shape_fallback_tau_d + shape_fallback_tau_d,
block_size=block_size,
diag_shift=diag_shift,
solve_precision=self.contact_solve_precision,
linear_solve_precision=self.contact_linear_solve_precision,
)
self._contact_solve_v_guess_active = wp.zeros(1, dtype=int, device=model.device)
self._zero_joint_f = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self._zero_control = SapControl(joint_f=self._zero_joint_f)
self._integrate_joint_qd_f32 = wp.zeros((model.joint_dof_count,), dtype=wp.float32, device=model.device)
self._integrate_v_f64 = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self._boundary_joint_qd_in_sap = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self._boundary_joint_f_sap = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self._boundary_joint_qd_out_sap_f32 = wp.zeros(
(model.joint_dof_count,),
dtype=wp.float32,
device=model.device,
)
self._boundary_body_f_sap = wp.zeros((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.dof_count = int(model.joint_dof_count)
self.num_envs = int(getattr(model, "world_count", 1))
self.dof_per_env = self.dof_count // max(self.num_envs, 1)
self.last_contacts: object | None = None
self.last_contact_jacobian_result: SapContactJacobianResult | None = None
self.last_contact_solve_result: SapContactSolveResult | None = None
self.reset_runtime_state()
@property
def device(self):
"""Return the Warp device used by the solver model and scratch buffers."""
return self.model.device
def get_max_contact_count(self) -> int:
"""Return the per-environment rigid-contact capacity used by the contact solve."""
return int(self.max_rigid_contact) * max(int(self.num_envs), 1)
def close(self) -> None:
"""Release solver-owned resources. The current implementation is a no-op placeholder for lifecycle
symmetry.
"""
pass
def reset_runtime_state(self, state: SapState | None = None) -> None:
"""Reset timestep counters and cached contact-solve guesses before starting a fresh rollout."""
self.sim_time = 0.0
self.frame_id = 0
self.last_contact_count = 0
self.last_truncated_contact_count = 0
self.last_solve_iterations = 0
self.last_line_search_iterations = 0
self.last_converged = True
self._has_contact_solve_v_guess = False
self._contact_solve_v_guess_active.zero_()
self.last_contacts = None
self.last_contact_jacobian_result = None
self.last_contact_solve_result = None
def _copy_public_joint_velocity_to_sap(self, state_in: SapState, dst: wp.array) -> None:
model = self.model
if state_in.joint_qd.dtype == wp.float64:
if state_in.joint_q.dtype == wp.float64:
wp.launch(
copy_public_to_sap_velocity_f64_from_joint_pose,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_in.joint_q,
state_in.joint_qd,
dst,
],
device=model.device,
)
return
if getattr(state_in.body_q, "dtype", None) == wp.transformd:
wp.launch(
copy_public_to_sap_velocity_f64,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_qd_start,
model.joint_dof_dim,
state_in.body_q,
self.free_motion.model_body_com,
state_in.joint_qd,
dst,
],
device=model.device,
)
return
raise TypeError("Public float64 joint velocities require float64 joint_q or transformd body_q.")
wp.launch(
copy_public_to_sap_velocity_f32,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_qd_start,
model.joint_dof_dim,
state_in.body_q,
model.body_com,
state_in.joint_qd,
dst,
],
device=model.device,
)
def _copy_public_joint_force_to_sap(self, state_in: SapState, control: SapControl, dst: wp.array) -> None:
model = self.model
if control.joint_f.dtype == wp.float64:
if state_in.joint_q.dtype == wp.float64:
wp.launch(
_copy_public_to_sap_force_kernel_f64_from_joint_pose,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_in.joint_q,
control.joint_f,
dst,
],
device=model.device,
)
return
if getattr(state_in.body_q, "dtype", None) == wp.transformd:
wp.launch(
_copy_public_to_sap_force_kernel_f64,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_qd_start,
model.joint_dof_dim,
state_in.body_q,
self.free_motion.model_body_com,
control.joint_f,
dst,
],
device=model.device,
)
return
raise TypeError("Public float64 joint forces require float64 joint_q or transformd body_q.")
wp.launch(
_copy_public_to_sap_force_kernel,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_qd_start,
model.joint_dof_dim,
state_in.body_q,
model.body_com,
control.joint_f,
dst,
],
device=model.device,
)
def _copy_public_body_force_to_sap(self, state_in: SapState, dst: wp.array) -> None:
model = self.model
body_f = getattr(state_in, "body_f", None)
if not int(model.body_count) or body_f is None:
dst.zero_()
return
if body_f.dtype != wp.spatial_vector:
raise TypeError("Public body_f conversion requires wp.spatial_vector input.")
if getattr(state_in.body_q, "dtype", None) == wp.transformd:
wp.launch(
_copy_public_body_force_to_sap_body_origin_kernel_body_qd,
dim=model.body_count,
inputs=[state_in.body_q, self.free_motion.model_body_com, body_f, dst],
device=model.device,
)
return
wp.launch(
_copy_public_body_force_to_sap_body_origin_kernel,
dim=model.body_count,
inputs=[state_in.body_q, model.body_com, body_f, dst],
device=model.device,
)
def _copy_sap_joint_velocity_to_public(
self,
state_out: SapState,
sap_v: wp.array,
) -> None:
model = self.model
if state_out.joint_qd.dtype == wp.float64:
if state_out.joint_q.dtype != wp.float64:
raise TypeError("Public float64 output joint_qd requires float64 joint_q.")
wp.launch(
_copy_sap_to_public_velocity_f64_from_joint_pose,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_out.joint_q,
sap_v,
state_out.joint_qd,
],
device=model.device,
)
return
wp.launch(
_copy_sap_to_public_velocity_f32_from_joint_pose,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_out.joint_q,
sap_v,
state_out.joint_qd,
],
device=model.device,
)
def _prepare_step_boundary(
self,
state_in: SapState,
state_out: SapState,
control: SapControl,
) -> tuple[SapState, SapState, SapControl, bool]:
state_order = getattr(state_in, "joint_qd_order", "sap")
force_order = getattr(control, "joint_f_order", "sap")
body_force_order = getattr(state_in, "body_f_order", "sap")
output_order = getattr(state_out, "joint_qd_order", "sap")
solve_state = state_in
solve_control = control
integrate_state_out = state_out
output_is_public = output_order == "public"
if state_order == "public":
self._copy_public_joint_velocity_to_sap(state_in, self._boundary_joint_qd_in_sap)
body_f = getattr(state_in, "body_f", None)
if body_force_order == "public":
self._copy_public_body_force_to_sap(state_in, self._boundary_body_f_sap)
body_f = self._boundary_body_f_sap
solve_state = SapState(
joint_q=state_in.joint_q,
joint_qd=self._boundary_joint_qd_in_sap,
body_q=state_in.body_q,
body_qd=getattr(state_in, "body_qd", None),
body_f=body_f,
joint_qd_order="sap",
body_f_order="sap",
requires_grad=bool(getattr(state_in, "requires_grad", False)),
)
elif state_order != "sap":
raise ValueError(f"Unsupported SapState joint_qd_order={state_order!r}.")
if force_order == "public":
self._copy_public_joint_force_to_sap(state_in, control, self._boundary_joint_f_sap)
solve_control = SapControl(
joint_f=self._boundary_joint_f_sap,
joint_target_pos=control.joint_target_pos,
joint_target_vel=control.joint_target_vel,
joint_act=control.joint_act,
joint_f_order="sap",
)
elif force_order != "sap":
raise ValueError(f"Unsupported SapControl joint_f_order={force_order!r}.")
if output_is_public:
if state_out.joint_q.dtype != wp.float32 or state_out.body_q.dtype != wp.transform:
raise TypeError("Public SolverSAP output currently requires float32 joint_q and body_q.")
integrate_state_out = SapState(
joint_q=state_out.joint_q,
joint_qd=self._boundary_joint_qd_out_sap_f32,
body_q=state_out.body_q,
body_qd=getattr(state_out, "body_qd", None),
body_f=None,
joint_qd_order="sap",
body_f_order="sap",
requires_grad=False,
)
elif output_order != "sap":
raise ValueError(f"Unsupported SapState output joint_qd_order={output_order!r}.")
return solve_state, integrate_state_out, solve_control, output_is_public
def integrate_particles(self, model: SapModel, state_in: SapState, state_out: SapState, dt: float) -> None:
"""Integrate particle state with semi-implicit Euler using the model gravity and timestep."""
return
def _integrate_state(
self,
state_in: SapState,
state_out: SapState,
solved_v_sap: wp.array,
dt: float,
) -> None:
model = self.model
if state_out.joint_q.dtype == wp.float64:
if self.position_integration == "midpoint":
wp.launch(
_integrate_generalized_positions_midpoint_f64,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
state_in.joint_q,
self.free_motion.joint_qd_sap_input,
solved_v_sap,
float(dt),
state_out.joint_q,
state_out.joint_qd,
],
device=model.device,
)
else:
wp.launch(
_integrate_generalized_positions_sap_euler_f64,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
state_in.joint_q,
solved_v_sap,
float(dt),
state_out.joint_q,
state_out.joint_qd,
],
device=model.device,
)
if state_out.joint_qd is not solved_v_sap:
wp.copy(dest=state_out.joint_qd, src=solved_v_sap)
if (
int(model.body_count) > 0
and getattr(state_in, "body_f", None) is not None
and getattr(state_out, "body_f", None) is not None
):
wp.copy(dest=state_out.body_f, src=state_in.body_f)
self.integrate_particles(model, state_in, state_out, dt)
return
joint_qd_out_f32 = (
state_out.joint_qd
if state_out.joint_qd.dtype == wp.float32
else self._integrate_joint_qd_f32
)
if self.position_integration == "midpoint":
wp.launch(
_integrate_generalized_positions_midpoint,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_in.joint_q,
self.free_motion.joint_qd_sap_input,
solved_v_sap,
float(dt),
state_out.joint_q,
joint_qd_out_f32,
],
device=model.device,
)
else:
wp.launch(
_integrate_generalized_positions_sap_euler,
dim=model.joint_count,
inputs=[
model.joint_type,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
self.free_motion.model_body_com,
state_in.joint_q,
solved_v_sap,
float(dt),
state_out.joint_q,
joint_qd_out_f32,
],
device=model.device,
)
if state_out.joint_qd is not solved_v_sap:
if state_out.joint_qd.dtype == wp.float64:
wp.copy(dest=state_out.joint_qd, src=solved_v_sap)
else:
wp.launch(
_copy_f64_to_f32,
dim=model.joint_dof_count,
inputs=[solved_v_sap, state_out.joint_qd],
device=model.device,
)
if (
int(model.body_count) > 0
and getattr(state_in, "body_f", None) is not None
and getattr(state_out, "body_f", None) is not None
):
wp.copy(dest=state_out.body_f, src=state_in.body_f)
[docs]
def step(
self,
state_in: SapState,
state_out: SapState,
control: SapControl | None,
contacts: SapContacts | None,
dt: float,
) -> SapState:
"""Advance one SAP timestep from state_in to state_out using control inputs, active contacts, and
the configured solver settings.
"""
if getattr(state_in, "requires_grad", False):
raise NotImplementedError("SolverSAP does not support grad states.")
if not isinstance(state_in, SapState):
raise TypeError("SolverSAP.step requires SapState input.")
if not isinstance(state_out, SapState):
raise TypeError("SolverSAP.step requires SapState output.")
if control is None:
control = self._zero_control
if not isinstance(control, SapControl):
raise TypeError("SolverSAP.step requires SapControl.")
if contacts is None:
contacts = SapContacts()
if not isinstance(contacts, SapContacts):
contacts = sap_contacts_from_newton(contacts)
self.last_contacts = contacts
solve_state, integrate_state_out, solve_control, output_is_public = self._prepare_step_boundary(
state_in,
state_out,
control,
)
contact_result = self.contact_jacobian.compute(
solve_state,
contacts,
control=solve_control,
dt=float(dt),
)
v0 = self.free_motion.joint_qd_sap_input
solve_result = self.contact_solve.solve(
contact_result,
solve_state,
solve_control,
float(dt),
self.free_motion.free_motion_joint_qd_sap,
v0=v0,
v_guess=self.contact_solve.v_flat,
v_guess_active=self._contact_solve_v_guess_active,
max_iterations=self.max_iterations,
optimality_abs_tol=self.optimality_abs_tol,
optimality_rel_tol=self.optimality_rel_tol,
cost_abs_tol=self.cost_abs_tol,
cost_rel_tol=self.cost_rel_tol,
line_search_max_iterations=self.line_search_max_iterations,
armijo_c=self.armijo_c,
rho=self.rho,
line_search_relative_slop=self.line_search_relative_slop,
line_search_variant=self.line_search_variant,
collect_iteration_stats=self.collect_iteration_stats,
check_line_search_errors=self.check_line_search_errors,
graph_conditional=self.graph_conditional,
)
v_integrate = solve_result.v_flat
if v_integrate.dtype == wp.float32:
wp.launch(
_copy_f32_to_f64,
dim=self.dof_count,
inputs=[v_integrate, self._integrate_v_f64],
device=self.sap_model.device,
)
v_integrate = self._integrate_v_f64
self._integrate_state(solve_state, integrate_state_out, v_integrate, float(dt))
if output_is_public:
self._copy_sap_joint_velocity_to_public(state_out, v_integrate)
sap_eval_fk(self.model, state_out.joint_q, state_out.joint_qd, state_out)
self.last_contact_jacobian_result = contact_result
self.last_contact_solve_result = solve_result
self.last_contact_count = int(contact_result.contact_count)
self.last_truncated_contact_count = int(contact_result.truncated_contact_count)
self.last_solve_iterations = int(solve_result.iterations)
self.last_line_search_iterations = int(solve_result.line_search_iterations)
self.last_converged = bool(solve_result.converged)
self._has_contact_solve_v_guess = True
self.sim_time += float(dt)
self.frame_id += 1
return state_out
def notify_model_changed(self, flags: int) -> None:
# The standalone components cache model-topology-dependent buffers. If
# topology changes, construct a new solver; scalar model data is read
# directly from model arrays during each step.
"""Refresh solver caches and work buffers after the underlying model arrays or topology have
changed.
"""
return
def update_contacts(self, contacts: SapContacts) -> None:
"""Update the solver contact buffer reference and resize dependent contact-stage data when needed."""
raise NotImplementedError("SolverSAP does not expose contact-force writeback yet.")
__all__ = [
"SolverSAP",
]