from __future__ import annotations
from dataclasses import dataclass
from functools import cache
import numpy as np
import warp as wp
wp.config.enable_backward = False
from sim.blocked_cholesky import BlockCholeskySolverBatched
from sim.sap_helpers import (
_copy_f32_to_f64,
_copy_f64,
_dense_cholesky_f64,
_dense_subs_f64,
_sap_calc_across_mobilizer_transform,
_sap_calc_mobilizer_velocity_and_bias_body_world,
_sap_compose_acceleration,
_sap_compose_acceleration_f32,
_sap_compose_child_body_transform,
_sap_compose_velocity,
_sap_compose_velocity_f32,
_sap_dynamic_bias_force_body_origin_world,
_sap_gravity_force_body_origin_world,
_sap_revolute_identity_child_kinematics,
_sap_revolute_identity_child_kinematics_f32_core,
_sap_shift_force,
_sap_shift_velocity,
_sap_shift_velocity_f32,
_sap_spatial,
_sap_spatial_inertia_body_origin_world,
_project_tau_no_drives,
_spatiald_from_spatialf,
_spatialf_from_spatiald,
_transformd_compose,
_transformd_from_transformf,
_transformd_identity,
_transformf_compose,
_transformf_from_transformd,
_vec3d_zero,
_vec3f_from_vec3d,
)
from sim.sap_runtime import (
Control,
Model,
SAP_JOINT_BALL,
SAP_JOINT_D6,
SAP_JOINT_DISTANCE,
SAP_JOINT_FIXED,
SAP_JOINT_FREE,
SAP_JOINT_PRISMATIC,
SAP_JOINT_REVOLUTE,
State,
)
_GEMM_COL_BLOCK = wp.constant(4)
_JTP_GEMM_TILE_M = 8
_JTP_GEMM_TILE_N = 8
_JTP_GEMM_TILE_K = 32
[docs]
@dataclass(frozen=True)
class SapFreeMotionResult:
"""Views into the mutable output buffers owned by `SapFreeMotion`."""
v_star: wp.array
vdot: wp.array
dynamics_matrix: wp.array | None
@wp.kernel
def _assemble_sap_free_motion_outputs_kernel(
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
sap_v0: wp.array(dtype=wp.float64),
sap_vdot_solve: wp.array(dtype=wp.float64),
dt: wp.float64,
sap_v_star: wp.array(dtype=wp.float64),
sap_vdot: wp.array(dtype=wp.float64),
):
joint = wp.tid()
dof_start = joint_qd_start[joint]
axis_count = joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1]
for axis in range(axis_count):
vdot = sap_vdot_solve[dof_start + axis]
sap_vdot[dof_start + axis] = vdot
sap_v_star[dof_start + axis] = sap_v0[dof_start + axis] + dt * vdot
@wp.kernel
def _copy_spatial_vector_to_spatial_vectord(
src: wp.array(dtype=wp.spatial_vector),
dst: wp.array(dtype=wp.spatial_vectord),
):
i = wp.tid()
dst[i] = _spatiald_from_spatialf(src[i])
@wp.kernel
def _eval_rigid_tau_no_drives(
articulation_start: wp.array(dtype=int),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
joint_f: wp.array(dtype=wp.float64),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_q: wp.array(dtype=wp.transformd),
body_fb_s: wp.array(dtype=wp.spatial_vectord),
body_f_ext: wp.array(dtype=wp.spatial_vectord),
body_ft_s: wp.array(dtype=wp.spatial_vectord),
tau: wp.array(dtype=wp.float64),
):
art = wp.tid()
start = articulation_start[art]
end = articulation_start[art + 1]
count = end - start
for offset in range(count):
joint = end - offset - 1
jtype = joint_type[joint]
parent = joint_parent[joint]
child = joint_child[joint]
dof_start = joint_qd_start[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
f_b_s = body_fb_s[child]
f_t_s = body_ft_s[child]
f_ext = body_f_ext[child]
f_s = f_b_s + f_t_s + f_ext
_project_tau_no_drives(
jtype,
joint_S_s,
joint_f,
dof_start,
lin_axis_count,
ang_axis_count,
f_s,
tau,
)
if parent >= 0:
p_child = wp.transform_get_translation(body_q[child])
p_parent = wp.transform_get_translation(body_q[parent])
wp.atomic_add(body_ft_s, parent, _sap_shift_force(f_s, p_parent - p_child))
@cache
def _make_eval_rigid_tau_no_drives_tiled(tile_size: int):
@wp.kernel(enable_backward=False)
def _eval_rigid_tau_no_drives_tiled(
articulation_level_joint_index: wp.array(dtype=wp.int32),
max_articulation_level_count: int,
max_articulation_level_width: int,
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_dof_dim: wp.array(dtype=int, ndim=2),
joint_f: wp.array(dtype=wp.float64),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_q: wp.array(dtype=wp.transformd),
body_fb_s: wp.array(dtype=wp.spatial_vectord),
body_f_ext: wp.array(dtype=wp.spatial_vectord),
body_ft_s: wp.array(dtype=wp.spatial_vectord),
tau: wp.array(dtype=wp.float64),
):
art, tid = wp.tid()
for reverse_level in range(max_articulation_level_count):
level_index = max_articulation_level_count - reverse_level - 1
stride = wp.block_dim()
slot = tid
while slot < max_articulation_level_width:
flat_index = (
(art * max_articulation_level_count + level_index) * max_articulation_level_width + slot
)
joint = articulation_level_joint_index[flat_index]
if joint >= 0:
jtype = joint_type[joint]
parent = joint_parent[joint]
child = joint_child[joint]
dof_start = joint_qd_start[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
f_b_s = body_fb_s[child]
f_t_s = body_ft_s[child]
f_ext = body_f_ext[child]
f_s = f_b_s + f_t_s + f_ext
_project_tau_no_drives(
jtype,
joint_S_s,
joint_f,
dof_start,
lin_axis_count,
ang_axis_count,
f_s,
tau,
)
if parent >= 0:
p_child = wp.transform_get_translation(body_q[child])
p_parent = wp.transform_get_translation(body_q[parent])
wp.atomic_add(body_ft_s, parent, _sap_shift_force(f_s, p_parent - p_child))
slot = slot + stride
if tile_size > 1:
sync_values = wp.tile_zeros((tile_size,), dtype=wp.int32, storage="shared")
sync_values[tid] = wp.int32(0)
_ = wp.tile_sum(sync_values)
return _eval_rigid_tau_no_drives_tiled
@wp.kernel
def _eval_rigid_id_root_level_kernel(
articulation_level_joint_index: wp.array(dtype=wp.int32),
max_articulation_level_count: int,
max_articulation_level_width: int,
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_q: wp.array(dtype=wp.float64),
joint_qd: wp.array(dtype=wp.float64),
joint_axis: wp.array(dtype=wp.vec3d),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_inertia: wp.array(dtype=wp.mat33d),
body_mass: wp.array(dtype=wp.float64),
body_com: wp.array(dtype=wp.vec3d),
body_q: wp.array(dtype=wp.transformd),
joint_X_p: wp.array(dtype=wp.transformd),
joint_X_c: wp.array(dtype=wp.transformd),
joint_X_c_identity: wp.array(dtype=wp.int32),
body_world: wp.array(dtype=wp.int32),
gravity: wp.array(dtype=wp.vec3d),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
body_v_s: wp.array(dtype=wp.spatial_vectord),
body_f_s: wp.array(dtype=wp.spatial_vectord),
body_a_s: wp.array(dtype=wp.spatial_vectord),
):
art = wp.tid()
flat_index = art * max_articulation_level_count * max_articulation_level_width
joint = articulation_level_joint_index[flat_index]
if joint < 0:
return
child = joint_child[joint]
jtype = joint_type[joint]
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
X_wpj = joint_X_p[joint]
X_wc = _transformd_identity()
v_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
A_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
if jtype == SAP_JOINT_REVOLUTE and joint_X_c_identity[joint] != 0:
X_wc, v_PB_W, A_PB_W = _sap_revolute_identity_child_kinematics(
joint_axis,
X_wpj,
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
else:
X_j = _sap_calc_across_mobilizer_transform(
jtype,
joint_axis,
qd_start,
lin_axis_count,
ang_axis_count,
joint_q,
q_start,
)
X_wc = _sap_compose_child_body_transform(X_wpj, X_j, joint_X_c[joint], joint_X_c_identity[joint])
v_PB_W, A_PB_W = _sap_calc_mobilizer_velocity_and_bias_body_world(
jtype,
joint_axis,
lin_axis_count,
ang_axis_count,
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
body_q[child] = X_wc
v_s = v_PB_W
a_s = A_PB_W
body_v_s[child] = v_s
body_a_s[child] = a_s
@wp.kernel
def _eval_rigid_id_level_parallel_1d_nonroot(
articulation_level_joint_index: wp.array(dtype=wp.int32),
max_articulation_level_count: int,
max_articulation_level_width: int,
level_index: int,
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_q: wp.array(dtype=wp.float64),
joint_qd: wp.array(dtype=wp.float64),
joint_axis: wp.array(dtype=wp.vec3d),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_inertia: wp.array(dtype=wp.mat33d),
body_mass: wp.array(dtype=wp.float64),
body_com: wp.array(dtype=wp.vec3d),
body_q: wp.array(dtype=wp.transformd),
joint_X_p: wp.array(dtype=wp.transformd),
joint_X_c: wp.array(dtype=wp.transformd),
joint_X_c_identity: wp.array(dtype=wp.int32),
body_world: wp.array(dtype=wp.int32),
gravity: wp.array(dtype=wp.vec3d),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
body_v_s: wp.array(dtype=wp.spatial_vectord),
body_f_s: wp.array(dtype=wp.spatial_vectord),
body_a_s: wp.array(dtype=wp.spatial_vectord),
):
art, slot = wp.tid()
flat_index = (
(art * max_articulation_level_count + level_index) * max_articulation_level_width + slot
)
joint = articulation_level_joint_index[flat_index]
if joint < 0:
return
parent = joint_parent[joint]
child = joint_child[joint]
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
jtype = joint_type[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
X_wpj = _transformd_compose(body_q[parent], joint_X_p[joint])
X_wc = _transformd_identity()
v_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
A_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
if jtype == SAP_JOINT_REVOLUTE and joint_X_c_identity[joint] != 0:
X_wc, v_PB_W, A_PB_W = _sap_revolute_identity_child_kinematics(
joint_axis,
X_wpj,
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
else:
X_j = _sap_calc_across_mobilizer_transform(
jtype,
joint_axis,
qd_start,
lin_axis_count,
ang_axis_count,
joint_q,
q_start,
)
X_wc = _sap_compose_child_body_transform(X_wpj, X_j, joint_X_c[joint], joint_X_c_identity[joint])
v_PB_W, A_PB_W = _sap_calc_mobilizer_velocity_and_bias_body_world(
jtype,
joint_axis,
lin_axis_count,
ang_axis_count,
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
body_q[child] = X_wc
v_parent_s = body_v_s[parent]
a_parent_s = body_a_s[parent]
p_parent = wp.transform_get_translation(body_q[parent])
p_child = wp.transform_get_translation(X_wc)
p_PB_W = p_child - p_parent
v_s = _sap_compose_velocity(v_parent_s, p_PB_W, v_PB_W)
a_s = _sap_compose_acceleration(a_parent_s, v_parent_s, p_PB_W, v_PB_W, A_PB_W)
body_v_s[child] = v_s
body_a_s[child] = a_s
@cache
def _make_eval_rigid_id_tiled_articulations(tile_size: int):
@wp.kernel(enable_backward=False)
def _eval_rigid_id_tiled_articulations(
articulation_level_joint_index: wp.array(dtype=wp.int32),
max_articulation_level_count: int,
max_articulation_level_width: int,
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_q: wp.array(dtype=wp.float64),
joint_qd: wp.array(dtype=wp.float64),
joint_axis: wp.array(dtype=wp.vec3d),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_inertia: wp.array(dtype=wp.mat33d),
body_mass: wp.array(dtype=wp.float64),
body_com: wp.array(dtype=wp.vec3d),
body_q: wp.array(dtype=wp.transformd),
joint_X_p: wp.array(dtype=wp.transformd),
joint_X_c: wp.array(dtype=wp.transformd),
joint_X_c_identity: wp.array(dtype=wp.int32),
body_world: wp.array(dtype=wp.int32),
gravity: wp.array(dtype=wp.vec3d),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
body_v_s: wp.array(dtype=wp.spatial_vectord),
body_f_s: wp.array(dtype=wp.spatial_vectord),
body_a_s: wp.array(dtype=wp.spatial_vectord),
):
art, tid = wp.tid()
for level_index in range(max_articulation_level_count):
stride = wp.block_dim()
slot = tid
while slot < max_articulation_level_width:
flat_index = (
(art * max_articulation_level_count + level_index) * max_articulation_level_width + slot
)
joint = articulation_level_joint_index[flat_index]
if joint >= 0:
parent = joint_parent[joint]
child = joint_child[joint]
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
jtype = joint_type[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
X_wpj = joint_X_p[joint]
if parent >= 0:
X_wpj = _transformd_compose(body_q[parent], joint_X_p[joint])
X_wc = _transformd_identity()
v_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
A_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
if jtype == SAP_JOINT_REVOLUTE and joint_X_c_identity[joint] != 0:
X_wc, v_PB_W, A_PB_W = _sap_revolute_identity_child_kinematics(
joint_axis,
X_wpj,
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
else:
X_j = _sap_calc_across_mobilizer_transform(
jtype,
joint_axis,
qd_start,
lin_axis_count,
ang_axis_count,
joint_q,
q_start,
)
X_wc = _sap_compose_child_body_transform(
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
)
v_PB_W, A_PB_W = _sap_calc_mobilizer_velocity_and_bias_body_world(
jtype,
joint_axis,
lin_axis_count,
ang_axis_count,
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
body_q[child] = X_wc
v_s = v_PB_W
a_s = A_PB_W
if parent >= 0:
v_parent_s = body_v_s[parent]
a_parent_s = body_a_s[parent]
p_parent = wp.transform_get_translation(body_q[parent])
p_child = wp.transform_get_translation(X_wc)
p_PB_W = p_child - p_parent
v_s = _sap_compose_velocity(v_parent_s, p_PB_W, v_PB_W)
a_s = _sap_compose_acceleration(a_parent_s, v_parent_s, p_PB_W, v_PB_W, A_PB_W)
body_v_s[child] = v_s
body_a_s[child] = a_s
slot = slot + stride
sync_values = wp.tile_zeros((tile_size,), dtype=wp.int32, storage="shared")
sync_values[tid] = wp.int32(0)
_ = wp.tile_sum(sync_values)
return _eval_rigid_id_tiled_articulations
@cache
def _make_eval_rigid_id_tiled_articulations_f32_revolute(tile_size: int):
@wp.kernel(enable_backward=False)
def _eval_rigid_id_tiled_articulations_f32_revolute(
articulation_level_joint_index: wp.array(dtype=wp.int32),
max_articulation_level_count: int,
max_articulation_level_width: int,
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_q: wp.array(dtype=wp.float64),
joint_qd: wp.array(dtype=wp.float64),
joint_axis: wp.array(dtype=wp.vec3d),
joint_dof_dim: wp.array(dtype=int, ndim=2),
body_inertia: wp.array(dtype=wp.mat33d),
body_mass: wp.array(dtype=wp.float64),
body_com: wp.array(dtype=wp.vec3d),
body_q: wp.array(dtype=wp.transformd),
joint_X_p: wp.array(dtype=wp.transformd),
joint_X_c: wp.array(dtype=wp.transformd),
joint_X_c_identity: wp.array(dtype=wp.int32),
body_world: wp.array(dtype=wp.int32),
gravity: wp.array(dtype=wp.vec3d),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
body_v_s: wp.array(dtype=wp.spatial_vectord),
body_f_s: wp.array(dtype=wp.spatial_vectord),
body_a_s: wp.array(dtype=wp.spatial_vectord),
body_q_f: wp.array(dtype=wp.transform),
body_v_s_f: wp.array(dtype=wp.spatial_vector),
body_a_s_f: wp.array(dtype=wp.spatial_vector),
):
art, tid = wp.tid()
for level_index in range(max_articulation_level_count):
stride = wp.block_dim()
slot = tid
while slot < max_articulation_level_width:
flat_index = (
(art * max_articulation_level_count + level_index) * max_articulation_level_width + slot
)
joint = articulation_level_joint_index[flat_index]
if joint >= 0:
parent = joint_parent[joint]
child = joint_child[joint]
q_start = joint_q_start[joint]
qd_start = joint_qd_start[joint]
jtype = joint_type[joint]
lin_axis_count = joint_dof_dim[joint, 0]
ang_axis_count = joint_dof_dim[joint, 1]
if jtype == SAP_JOINT_REVOLUTE and joint_X_c_identity[joint] != 0:
X_wpj_f = _transformf_from_transformd(joint_X_p[joint])
if parent >= 0:
X_wpj_f = _transformf_compose(body_q_f[parent], X_wpj_f)
X_wc_f, v_PB_f, A_PB_f = _sap_revolute_identity_child_kinematics_f32_core(
joint_axis,
X_wpj_f,
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
v_s_f = v_PB_f
a_s_f = A_PB_f
if parent >= 0:
v_parent_f = body_v_s_f[parent]
a_parent_f = body_a_s_f[parent]
p_parent_f = wp.transform_get_translation(body_q_f[parent])
p_child_f = wp.transform_get_translation(X_wc_f)
p_PB_W_f = p_child_f - p_parent_f
v_s_f = _sap_compose_velocity_f32(v_parent_f, p_PB_W_f, v_PB_f)
a_s_f = _sap_compose_acceleration_f32(a_parent_f, v_parent_f, p_PB_W_f, v_PB_f, A_PB_f)
body_q_f[child] = X_wc_f
body_v_s_f[child] = v_s_f
body_a_s_f[child] = a_s_f
body_q[child] = _transformd_from_transformf(X_wc_f)
body_v_s[child] = _spatiald_from_spatialf(v_s_f)
body_a_s[child] = _spatiald_from_spatialf(a_s_f)
else:
X_wpj = joint_X_p[joint]
if parent >= 0:
X_wpj = _transformd_compose(body_q[parent], joint_X_p[joint])
X_wc = _transformd_identity()
v_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
A_PB_W = _sap_spatial(_vec3d_zero(), _vec3d_zero())
X_j = _sap_calc_across_mobilizer_transform(
jtype,
joint_axis,
qd_start,
lin_axis_count,
ang_axis_count,
joint_q,
q_start,
)
X_wc = _sap_compose_child_body_transform(
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
)
v_PB_W, A_PB_W = _sap_calc_mobilizer_velocity_and_bias_body_world(
jtype,
joint_axis,
lin_axis_count,
ang_axis_count,
X_wpj,
X_j,
joint_X_c[joint],
joint_X_c_identity[joint],
joint_q,
q_start,
joint_qd,
qd_start,
joint_S_s,
)
body_q[child] = X_wc
v_s = v_PB_W
a_s = A_PB_W
if parent >= 0:
v_parent_s = body_v_s[parent]
a_parent_s = body_a_s[parent]
p_parent = wp.transform_get_translation(body_q[parent])
p_child = wp.transform_get_translation(X_wc)
p_PB_W = p_child - p_parent
v_s = _sap_compose_velocity(v_parent_s, p_PB_W, v_PB_W)
a_s = _sap_compose_acceleration(a_parent_s, v_parent_s, p_PB_W, v_PB_W, A_PB_W)
body_v_s[child] = v_s
body_a_s[child] = a_s
body_q_f[child] = _transformf_from_transformd(X_wc)
body_v_s_f[child] = _spatialf_from_spatiald(v_s)
body_a_s_f[child] = _spatialf_from_spatiald(a_s)
slot = slot + stride
sync_values = wp.tile_zeros((tile_size,), dtype=wp.int32, storage="shared")
sync_values[tid] = wp.int32(0)
_ = wp.tile_sum(sync_values)
return _eval_rigid_id_tiled_articulations_f32_revolute
@wp.kernel
def _eval_rigid_body_dynamics_parallel(
joint_child: wp.array(dtype=int),
body_inertia: wp.array(dtype=wp.mat33d),
body_mass: wp.array(dtype=wp.float64),
body_com: wp.array(dtype=wp.vec3d),
body_q: wp.array(dtype=wp.transformd),
body_world: wp.array(dtype=wp.int32),
gravity: wp.array(dtype=wp.vec3d),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
body_v_s: wp.array(dtype=wp.spatial_vectord),
body_f_s: wp.array(dtype=wp.spatial_vectord),
body_a_s: wp.array(dtype=wp.spatial_vectord),
):
joint = wp.tid()
child = joint_child[joint]
if child < 0:
return
X_wc = body_q[child]
v_s = body_v_s[child]
a_s = body_a_s[child]
mass = body_mass[child]
body_inertia_d = body_inertia[child]
body_com_d = body_com[child]
world_idx = body_world[child]
world_g = gravity[wp.max(world_idx, 0)]
I_s = _sap_spatial_inertia_body_origin_world(X_wc, body_inertia_d, mass, body_com_d)
f_b_s = I_s * a_s + _sap_dynamic_bias_force_body_origin_world(
X_wc,
body_inertia_d,
mass,
body_com_d,
v_s,
)
f_g_s = _sap_gravity_force_body_origin_world(X_wc, mass, body_com_d, world_g)
body_f_s[child] = f_b_s - f_g_s
body_I_s[child] = I_s
@wp.kernel
def _eval_rigid_jacobian_parallel(
articulation_start: wp.array(dtype=int),
articulation_J_start: wp.array(dtype=int),
joint_ancestor: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_q: wp.array(dtype=wp.transformd),
max_articulation_joint_count: int,
J: wp.array(dtype=wp.float64),
):
art, joint_slot, axis = wp.tid()
joint_start = articulation_start[art]
joint_end = articulation_start[art + 1]
joint_count = joint_end - joint_start
if joint_slot >= joint_count:
return
articulation_dof_start = joint_qd_start[joint_start]
articulation_dof_end = joint_qd_start[joint_end]
articulation_dof_count = articulation_dof_end - articulation_dof_start
J_offset = articulation_J_start[art]
row = joint_slot * 6 + axis
joint = joint_start + joint_slot
row_body = joint_child[joint]
p_row = wp.transform_get_translation(body_q[row_body])
while joint != -1:
col_body = joint_child[joint]
p_col = wp.transform_get_translation(body_q[col_body])
p_col_row = p_row - p_col
joint_dof_start = joint_qd_start[joint]
joint_dof_end = joint_qd_start[joint + 1]
joint_dof_count = joint_dof_end - joint_dof_start
for dof in range(joint_dof_count):
col = (joint_dof_start - articulation_dof_start) + dof
S_Bc = joint_S_s[joint_dof_start + dof]
S_Brow = _sap_shift_velocity(S_Bc, p_col_row)
J[J_offset + row * articulation_dof_count + col] = wp.float64(S_Brow[axis])
joint = joint_ancestor[joint]
@wp.kernel
def _eval_rigid_jacobian_parallel_f32_math(
articulation_start: wp.array(dtype=int),
articulation_J_start: wp.array(dtype=int),
joint_ancestor: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_S_s: wp.array(dtype=wp.spatial_vectord),
body_q: wp.array(dtype=wp.transformd),
max_articulation_joint_count: int,
J: wp.array(dtype=wp.float64),
):
art, joint_slot, axis = wp.tid()
joint_start = articulation_start[art]
joint_end = articulation_start[art + 1]
joint_count = joint_end - joint_start
if joint_slot >= joint_count:
return
articulation_dof_start = joint_qd_start[joint_start]
articulation_dof_end = joint_qd_start[joint_end]
articulation_dof_count = articulation_dof_end - articulation_dof_start
J_offset = articulation_J_start[art]
row = joint_slot * 6 + axis
joint = joint_start + joint_slot
row_body = joint_child[joint]
p_row = _vec3f_from_vec3d(wp.transform_get_translation(body_q[row_body]))
while joint != -1:
col_body = joint_child[joint]
p_col = _vec3f_from_vec3d(wp.transform_get_translation(body_q[col_body]))
p_col_row = p_row - p_col
joint_dof_start = joint_qd_start[joint]
joint_dof_end = joint_qd_start[joint + 1]
joint_dof_count = joint_dof_end - joint_dof_start
for dof in range(joint_dof_count):
col = (joint_dof_start - articulation_dof_start) + dof
S_Bc = _spatialf_from_spatiald(joint_S_s[joint_dof_start + dof])
S_Brow = _sap_shift_velocity_f32(S_Bc, p_col_row)
J[J_offset + row * articulation_dof_count + col] = wp.float64(S_Brow[axis])
joint = joint_ancestor[joint]
@wp.kernel
def _eval_rigid_mass_parallel(
articulation_start: wp.array(dtype=int),
articulation_M_start: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
body_I_s: wp.array(dtype=wp.spatial_matrixd),
max_articulation_joint_count: int,
M: wp.array(dtype=wp.float64),
):
art, joint_slot, element = wp.tid()
joint_start = articulation_start[art]
joint_end = articulation_start[art + 1]
joint_count = joint_end - joint_start
if joint_slot >= joint_count:
return
row = element // 6
col = element - row * 6
stride = joint_count * 6
joint = joint_start + joint_slot
body = joint_child[joint]
if body < 0:
return
I = body_I_s[body]
M_offset = articulation_M_start[art]
M[M_offset + (joint_slot * 6 + row) * stride + joint_slot * 6 + col] = wp.float64(I[row, col])
@wp.kernel
def _eval_mass_times_jacobian_batched(
articulation_m_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_m_start: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
m: wp.array(dtype=wp.float64),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
):
art, row, col_block = wp.tid()
m_rows = articulation_m_rows[art]
j_cols = articulation_j_cols[art]
col = col_block * _GEMM_COL_BLOCK
if row >= m_rows or col >= j_cols:
return
m_start = articulation_m_start[art]
j_start = articulation_j_start[art]
acc0 = wp.float64(0.0)
acc1 = wp.float64(0.0)
acc2 = wp.float64(0.0)
acc3 = wp.float64(0.0)
block_row = (row // 6) * 6
m_row_start = m_start + row * m_rows + block_row
j_block_start = j_start + block_row * j_cols + col
for k_local in range(6):
mval = m[m_row_start + k_local]
base = j_block_start + k_local * j_cols
j0 = j[base]
j1 = wp.float64(0.0)
j2 = wp.float64(0.0)
j3 = wp.float64(0.0)
if col + 1 < j_cols:
j1 = j[base + 1]
if col + 2 < j_cols:
j2 = j[base + 2]
if col + 3 < j_cols:
j3 = j[base + 3]
acc0 = acc0 + mval * j0
acc1 = acc1 + mval * j1
acc2 = acc2 + mval * j2
acc3 = acc3 + mval * j3
out = j_start + row * j_cols + col
p[out] = acc0
if col + 1 < j_cols:
p[out + 1] = acc1
if col + 2 < j_cols:
p[out + 2] = acc2
if col + 3 < j_cols:
p[out + 3] = acc3
@wp.kernel
def _eval_mass_times_jacobian_batched_f32_math(
articulation_m_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_m_start: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
m: wp.array(dtype=wp.float64),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
):
art, row, col_block = wp.tid()
m_rows = articulation_m_rows[art]
j_cols = articulation_j_cols[art]
col = col_block * _GEMM_COL_BLOCK
if row >= m_rows or col >= j_cols:
return
m_start = articulation_m_start[art]
j_start = articulation_j_start[art]
acc0 = wp.float32(0.0)
acc1 = wp.float32(0.0)
acc2 = wp.float32(0.0)
acc3 = wp.float32(0.0)
block_row = (row // 6) * 6
m_row_start = m_start + row * m_rows + block_row
j_block_start = j_start + block_row * j_cols + col
for k_local in range(6):
mval = wp.float32(m[m_row_start + k_local])
base = j_block_start + k_local * j_cols
j0 = wp.float32(j[base])
j1 = wp.float32(0.0)
j2 = wp.float32(0.0)
j3 = wp.float32(0.0)
if col + 1 < j_cols:
j1 = wp.float32(j[base + 1])
if col + 2 < j_cols:
j2 = wp.float32(j[base + 2])
if col + 3 < j_cols:
j3 = wp.float32(j[base + 3])
acc0 = acc0 + mval * j0
acc1 = acc1 + mval * j1
acc2 = acc2 + mval * j2
acc3 = acc3 + mval * j3
out = j_start + row * j_cols + col
p[out] = wp.float64(acc0)
if col + 1 < j_cols:
p[out + 1] = wp.float64(acc1)
if col + 2 < j_cols:
p[out + 2] = wp.float64(acc2)
if col + 3 < j_cols:
p[out + 3] = wp.float64(acc3)
@wp.kernel
def _eval_jacobian_transpose_times_p_batched(
articulation_j_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
h: wp.array(dtype=wp.float64),
):
art, row, col_block = wp.tid()
j_rows = articulation_j_rows[art]
j_cols = articulation_j_cols[art]
col = col_block * _GEMM_COL_BLOCK
if row >= j_cols or col >= j_cols:
return
if col > row:
return
j_start = articulation_j_start[art]
h_start = articulation_h_start[art]
acc0 = wp.float64(0.0)
acc1 = wp.float64(0.0)
acc2 = wp.float64(0.0)
acc3 = wp.float64(0.0)
for k in range(j_rows):
jval = j[j_start + k * j_cols + row]
base = j_start + k * j_cols + col
p0 = p[base]
p1 = wp.float64(0.0)
p2 = wp.float64(0.0)
p3 = wp.float64(0.0)
if col + 1 < j_cols and col + 1 <= row:
p1 = p[base + 1]
if col + 2 < j_cols and col + 2 <= row:
p2 = p[base + 2]
if col + 3 < j_cols and col + 3 <= row:
p3 = p[base + 3]
acc0 = acc0 + jval * p0
acc1 = acc1 + jval * p1
acc2 = acc2 + jval * p2
acc3 = acc3 + jval * p3
out = h_start + row * j_cols + col
h[out] = acc0
if col != row:
h[h_start + col * j_cols + row] = acc0
if col + 1 < j_cols and col + 1 <= row:
c = col + 1
h[out + 1] = acc1
if c != row:
h[h_start + c * j_cols + row] = acc1
if col + 2 < j_cols and col + 2 <= row:
c = col + 2
h[out + 2] = acc2
if c != row:
h[h_start + c * j_cols + row] = acc2
if col + 3 < j_cols and col + 3 <= row:
c = col + 3
h[out + 3] = acc3
if c != row:
h[h_start + c * j_cols + row] = acc3
@cache
def _make_eval_jacobian_transpose_times_p_tiled(tile_size: int):
@wp.kernel(enable_backward=False)
def _eval_jacobian_transpose_times_p_tiled(
articulation_j_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
h: wp.array(dtype=wp.float64),
):
art, upper, tid = wp.tid()
j_rows = articulation_j_rows[art]
j_cols = articulation_j_cols[art]
upper_count = (j_cols * (j_cols + 1)) // 2
if upper >= upper_count:
return
row = wp.int32(0)
rem = wp.int32(upper)
row_count = wp.int32(j_cols)
while rem >= row_count:
rem = rem - row_count
row = row + 1
row_count = row_count - 1
col = row + rem
j_start = articulation_j_start[art]
acc = wp.float64(0.0)
k = tid
stride = wp.block_dim()
while k < j_rows:
base = j_start + k * j_cols
acc = acc + j[base + row] * p[base + col]
k = k + stride
values = wp.tile_zeros((tile_size,), dtype=wp.float64, storage="shared")
values[tid] = acc
total = wp.tile_sum(values)
h_start = articulation_h_start[art]
wp.tile_store(h, total, offset=h_start + row * j_cols + col)
if col != row:
wp.tile_store(h, total, offset=h_start + col * j_cols + row)
return _eval_jacobian_transpose_times_p_tiled
@cache
def _make_eval_jacobian_transpose_times_p_tiled_f32_math(tile_size: int):
@wp.kernel(enable_backward=False)
def _eval_jacobian_transpose_times_p_tiled_f32_math(
articulation_j_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
h: wp.array(dtype=wp.float64),
):
art, upper, tid = wp.tid()
j_rows = articulation_j_rows[art]
j_cols = articulation_j_cols[art]
upper_count = (j_cols * (j_cols + 1)) // 2
if upper >= upper_count:
return
row = wp.int32(0)
rem = wp.int32(upper)
row_count = wp.int32(j_cols)
while rem >= row_count:
rem = rem - row_count
row = row + 1
row_count = row_count - 1
col = row + rem
j_start = articulation_j_start[art]
acc = wp.float32(0.0)
k = tid
stride = wp.block_dim()
while k < j_rows:
base = j_start + k * j_cols
acc = acc + wp.float32(j[base + row]) * wp.float32(p[base + col])
k = k + stride
values = wp.tile_zeros((tile_size,), dtype=wp.float64, storage="shared")
values[tid] = wp.float64(acc)
total = wp.tile_sum(values)
h_start = articulation_h_start[art]
wp.tile_store(h, total, offset=h_start + row * j_cols + col)
if col != row:
wp.tile_store(h, total, offset=h_start + col * j_cols + row)
return _eval_jacobian_transpose_times_p_tiled_f32_math
@cache
def _make_eval_jacobian_transpose_times_p_gemm_tile(tile_m: int, tile_n: int, tile_k: int):
@wp.kernel(enable_backward=False, module="unique")
def _eval_jacobian_transpose_times_p_gemm_tile(
articulation_j_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
h: wp.array(dtype=wp.float64),
):
art, row_tile, col_tile, tid = wp.tid()
j_rows = articulation_j_rows[art]
j_cols = articulation_j_cols[art]
row0 = row_tile * tile_m
col0 = col_tile * tile_n
if row0 >= j_cols or col0 >= j_cols or col0 > row0 + tile_m - 1:
return
j_start = articulation_j_start[art]
acc = wp.tile_zeros((tile_m, tile_n), dtype=wp.float64)
k0 = wp.int32(0)
while k0 < j_rows:
j_tile = wp.tile_zeros((tile_k, tile_m), dtype=wp.float64, storage="shared")
p_tile = wp.tile_zeros((tile_k, tile_n), dtype=wp.float64, storage="shared")
linear = tid
while linear < tile_k * tile_m:
kk = linear // tile_m
rr = linear - kk * tile_m
k = k0 + kk
row = row0 + rr
value = wp.float64(0.0)
if k < j_rows and row < j_cols:
value = j[j_start + k * j_cols + row]
j_tile[kk, rr] = value
linear = linear + wp.block_dim()
linear = tid
while linear < tile_k * tile_n:
kk = linear // tile_n
cc = linear - kk * tile_n
k = k0 + kk
col = col0 + cc
value = wp.float64(0.0)
if k < j_rows and col < j_cols:
value = p[j_start + k * j_cols + col]
p_tile[kk, cc] = value
linear = linear + wp.block_dim()
wp.tile_matmul(wp.tile_transpose(j_tile), p_tile, acc)
k0 = k0 + tile_k
h_start = articulation_h_start[art]
linear = tid
while linear < tile_m * tile_n:
rr = linear // tile_n
cc = linear - rr * tile_n
row = row0 + rr
col = col0 + cc
if row < j_cols and col < j_cols and col <= row:
value = acc[rr, cc]
h[h_start + row * j_cols + col] = value
if col != row:
h[h_start + col * j_cols + row] = value
linear = linear + wp.block_dim()
return _eval_jacobian_transpose_times_p_gemm_tile
@cache
def _make_eval_jacobian_transpose_times_p_gemm_tile_f32_math(tile_m: int, tile_n: int, tile_k: int):
@wp.kernel(enable_backward=False, module="unique")
def _eval_jacobian_transpose_times_p_gemm_tile_f32_math(
articulation_j_rows: wp.array(dtype=wp.int32),
articulation_j_cols: wp.array(dtype=wp.int32),
articulation_j_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
j: wp.array(dtype=wp.float64),
p: wp.array(dtype=wp.float64),
h: wp.array(dtype=wp.float64),
):
art, row_tile, col_tile, tid = wp.tid()
j_rows = articulation_j_rows[art]
j_cols = articulation_j_cols[art]
row0 = row_tile * tile_m
col0 = col_tile * tile_n
if row0 >= j_cols or col0 >= j_cols or col0 > row0 + tile_m - 1:
return
j_start = articulation_j_start[art]
acc = wp.tile_zeros((tile_m, tile_n), dtype=wp.float32)
k0 = wp.int32(0)
while k0 < j_rows:
j_tile = wp.tile_zeros((tile_k, tile_m), dtype=wp.float32, storage="shared")
p_tile = wp.tile_zeros((tile_k, tile_n), dtype=wp.float32, storage="shared")
linear = tid
while linear < tile_k * tile_m:
kk = linear // tile_m
rr = linear - kk * tile_m
k = k0 + kk
row = row0 + rr
value = wp.float32(0.0)
if k < j_rows and row < j_cols:
value = wp.float32(j[j_start + k * j_cols + row])
j_tile[kk, rr] = value
linear = linear + wp.block_dim()
linear = tid
while linear < tile_k * tile_n:
kk = linear // tile_n
cc = linear - kk * tile_n
k = k0 + kk
col = col0 + cc
value = wp.float32(0.0)
if k < j_rows and col < j_cols:
value = wp.float32(p[j_start + k * j_cols + col])
p_tile[kk, cc] = value
linear = linear + wp.block_dim()
wp.tile_matmul(wp.tile_transpose(j_tile), p_tile, acc)
k0 = k0 + tile_k
h_start = articulation_h_start[art]
linear = tid
while linear < tile_m * tile_n:
rr = linear // tile_n
cc = linear - rr * tile_n
row = row0 + rr
col = col0 + cc
if row < j_cols and col < j_cols and col <= row:
out_value = wp.float64(acc[rr, cc])
h[h_start + row * j_cols + col] = out_value
if col != row:
h[h_start + col * j_cols + row] = out_value
linear = linear + wp.block_dim()
return _eval_jacobian_transpose_times_p_gemm_tile_f32_math
@wp.kernel
def _integrate_joint_velocity_kernel(
joint_qd: wp.array(dtype=float),
joint_qdd: wp.array(dtype=float),
dt: float,
joint_qd_out: wp.array(dtype=float),
):
i = wp.tid()
joint_qd_out[i] = joint_qd[i] + wp.float32(dt) * joint_qdd[i]
@wp.kernel
def _assemble_global_sap_dynamics_matrix_kernel(
h_flat: wp.array(dtype=wp.float64),
joint_armature: wp.array(dtype=wp.float64),
dof_articulation_index: wp.array(dtype=wp.int32),
dof_articulation_local_index: wp.array(dtype=wp.int32),
articulation_dof_start: wp.array(dtype=wp.int32),
articulation_h_start: wp.array(dtype=wp.int32),
articulation_h_rows: wp.array(dtype=wp.int32),
dynamics_matrix: wp.array(dtype=wp.float64, ndim=2),
):
sap_i, sap_j = wp.tid()
art_i = dof_articulation_index[sap_i]
art_j = dof_articulation_index[sap_j]
acc = wp.float64(0.0)
if art_i >= 0 and art_i == art_j:
rows = articulation_h_rows[art_i]
h_start = articulation_h_start[art_i]
local_i = dof_articulation_local_index[sap_i]
local_j = dof_articulation_local_index[sap_j]
if local_i >= 0 and local_i < rows and local_j >= 0 and local_j < rows:
acc = h_flat[h_start + local_i * rows + local_j]
if sap_i == sap_j:
acc = acc + joint_armature[sap_i]
dynamics_matrix[sap_i, sap_j] = acc
@wp.kernel
def _eval_dense_cholesky_batched_f64(
A_starts: wp.array(dtype=int),
A_dim: wp.array(dtype=int),
R_starts: wp.array(dtype=int),
A: wp.array(dtype=wp.float64),
R: wp.array(dtype=wp.float64),
L: wp.array(dtype=wp.float64),
):
batch = wp.tid()
_dense_cholesky_f64(A_dim[batch], A, R, A_starts[batch], R_starts[batch], L)
@wp.kernel
def _eval_dense_solve_batched_f64(
L_start: wp.array(dtype=int),
L_dim: wp.array(dtype=int),
b_start: wp.array(dtype=int),
A: wp.array(dtype=wp.float64),
L: wp.array(dtype=wp.float64),
b: wp.array(dtype=wp.float64),
x: wp.array(dtype=wp.float64),
tmp: wp.array(dtype=wp.float64),
):
batch = wp.tid()
_dense_subs_f64(L_dim[batch], L_start[batch], b_start[batch], L, b, x)
@wp.kernel
def _pack_articulation_h_to_padded_batched_f64(
H_start: wp.array(dtype=int),
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
H: wp.array(dtype=wp.float64),
armature: wp.array(dtype=wp.float64),
max_rows: int,
out: wp.array(dtype=wp.float64, ndim=3),
):
art, i, j = wp.tid()
rows = H_rows[art]
value = wp.float64(0.0)
if i < max_rows and j < max_rows:
if i < rows and j < rows:
value = H[H_start[art] + i * rows + j]
if i == j:
value = value + armature[dof_start[art] + i]
elif i == j:
value = wp.float64(1.0)
out[art, i, j] = value
@wp.kernel
def _pack_articulation_tau_to_padded_batched_f64(
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
tau: wp.array(dtype=wp.float64),
max_rows: int,
out: wp.array(dtype=wp.float64, ndim=3),
):
art, i = wp.tid()
rows = H_rows[art]
value = wp.float64(0.0)
if i < max_rows and i < rows:
value = tau[dof_start[art] + i]
out[art, i, 0] = value
@wp.kernel
def _unpack_articulation_solution_from_padded_batched_f64(
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
max_rows: int,
x: wp.array(dtype=wp.float64, ndim=3),
out: wp.array(dtype=wp.float64),
):
art, i = wp.tid()
rows = H_rows[art]
if i < max_rows and i < rows:
out[dof_start[art] + i] = x[art, i, 0]
@wp.kernel
def _pack_articulation_h_to_padded_batched_f32(
H_start: wp.array(dtype=int),
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
H: wp.array(dtype=wp.float64),
armature: wp.array(dtype=wp.float64),
max_rows: int,
out: wp.array(dtype=wp.float32, ndim=3),
):
art, i, j = wp.tid()
rows = H_rows[art]
value = wp.float32(0.0)
if i < max_rows and j < max_rows:
if i < rows and j < rows:
value = wp.float32(H[H_start[art] + i * rows + j])
if i == j:
value = value + wp.float32(armature[dof_start[art] + i])
elif i == j:
value = wp.float32(1.0)
out[art, i, j] = value
@wp.kernel
def _pack_articulation_tau_to_padded_batched_f32(
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
tau: wp.array(dtype=wp.float64),
max_rows: int,
out: wp.array(dtype=wp.float32, ndim=3),
):
art, i = wp.tid()
rows = H_rows[art]
value = wp.float32(0.0)
if i < max_rows and i < rows:
value = wp.float32(tau[dof_start[art] + i])
out[art, i, 0] = value
@wp.kernel
def _unpack_articulation_solution_from_padded_batched_f32(
H_rows: wp.array(dtype=int),
dof_start: wp.array(dtype=int),
max_rows: int,
x: wp.array(dtype=wp.float32, ndim=3),
out: wp.array(dtype=wp.float64),
):
art, i = wp.tid()
rows = H_rows[art]
if i < max_rows and i < rows:
out[dof_start[art] + i] = wp.float64(x[art, i, 0])
[docs]
class SapFreeMotion:
"""Standalone SAP-style free-motion calculation for articulated models.
Inputs and outputs use SAP's floating mobilizer convention for
FREE/DISTANCE joints: `[w, v_body_origin]`. All other joint DOFs keep the
model's declared order. Frontend-specific convention conversion belongs in
the caller before constructing `SapState` / `SapControl`.
"""
[docs]
def __init__(
self,
model: Model,
*,
allocate_dynamics_matrix: bool = False,
use_f64_boundary_pose: bool = True,
linear_solve_precision: str = "fp64",
):
if not isinstance(model, Model):
raise TypeError("SapFreeMotion requires SapModel; convert in the frontend adapter before construction.")
if int(model.joint_count) <= 0 or int(model.joint_dof_count) <= 0:
raise ValueError("SapFreeMotion requires a model with articulated joint DOFs.")
self.model = model
self.use_f64_boundary_pose = bool(use_f64_boundary_pose)
self.linear_solve_precision = self._normalize_linear_solve_precision(linear_solve_precision)
self._compute_articulation_indices(model)
self.rigid_tile_size = max(32, int(self.max_articulation_level_width))
self._rigid_id_tiled = _make_eval_rigid_id_tiled_articulations(
int(self.rigid_tile_size)
)
self._rigid_id_tiled_f32_revolute = _make_eval_rigid_id_tiled_articulations_f32_revolute(
int(self.rigid_tile_size)
)
self._rigid_tau_tiled = _make_eval_rigid_tau_no_drives_tiled(
int(self.rigid_tile_size)
)
self._jtp_gemm_tile = _make_eval_jacobian_transpose_times_p_gemm_tile(
_JTP_GEMM_TILE_M,
_JTP_GEMM_TILE_N,
_JTP_GEMM_TILE_K,
)
self._jtp_gemm_tile_f32_math = _make_eval_jacobian_transpose_times_p_gemm_tile_f32_math(
_JTP_GEMM_TILE_M,
_JTP_GEMM_TILE_N,
_JTP_GEMM_TILE_K,
)
self._build_dof_maps(model)
self._allocate_buffers(model, allocate_dynamics_matrix=allocate_dynamics_matrix)
@property
def device(self):
"""Return the Warp device that owns the free-motion work buffers."""
return self.model.device
@staticmethod
def _normalize_linear_solve_precision(value: str) -> str:
precision = str(value).strip().lower()
if precision == "f32":
precision = "fp32"
elif precision == "f64":
precision = "fp64"
if precision not in {"fp32", "fp64"}:
raise ValueError(
"linear_solve_precision must be 'fp32'/'f32' or 'fp64'/'f64', "
f"got {value!r}."
)
return precision
def _compute_articulation_indices(self, model: Model) -> None:
self.max_articulation_level_count = 0
self.max_articulation_level_width = 0
self.max_articulation_joint_count = 0
self.max_articulation_m_rows = 0
self.max_articulation_j_cols = 0
self.J_size = 0
self.M_size = 0
self.H_size = 0
articulation_J_start = []
articulation_M_start = []
articulation_H_start = []
articulation_M_rows = []
articulation_H_rows = []
articulation_J_rows = []
articulation_J_cols = []
articulation_dof_start = []
articulation_level_lists = []
articulation_start = model.articulation_start.numpy()
joint_parent = model.joint_parent.numpy()
joint_qd_start = model.joint_qd_start.numpy()
for art in range(int(model.articulation_count)):
first_joint = int(articulation_start[art])
last_joint = int(articulation_start[art + 1])
joint_count = last_joint - first_joint
first_dof = int(joint_qd_start[first_joint])
last_dof = int(joint_qd_start[last_joint])
dof_count = last_dof - first_dof
articulation_J_start.append(self.J_size)
articulation_M_start.append(self.M_size)
articulation_H_start.append(self.H_size)
articulation_M_rows.append(joint_count * 6)
articulation_H_rows.append(dof_count)
articulation_J_rows.append(joint_count * 6)
articulation_J_cols.append(dof_count)
articulation_dof_start.append(first_dof)
self.max_articulation_joint_count = max(self.max_articulation_joint_count, joint_count)
self.max_articulation_m_rows = max(self.max_articulation_m_rows, joint_count * 6)
self.max_articulation_j_cols = max(self.max_articulation_j_cols, dof_count)
local_children = [[] for _ in range(joint_count)]
local_depth = [-1] * joint_count
queue = []
for joint in range(first_joint, last_joint):
parent_joint = int(joint_parent[joint])
local_joint = joint - first_joint
if parent_joint < first_joint or parent_joint >= last_joint:
local_depth[local_joint] = 0
queue.append(local_joint)
else:
local_children[parent_joint - first_joint].append(local_joint)
q_head = 0
while q_head < len(queue):
local_joint = queue[q_head]
q_head += 1
for child_local in local_children[local_joint]:
local_depth[child_local] = local_depth[local_joint] + 1
queue.append(child_local)
level_count = max(local_depth) + 1 if local_depth else 0
level_lists = [[] for _ in range(level_count)]
for local_joint, depth in enumerate(local_depth):
level_lists[depth].append(first_joint + local_joint)
articulation_level_lists.append(level_lists)
self.max_articulation_level_count = max(self.max_articulation_level_count, level_count)
for joints_at_level in level_lists:
self.max_articulation_level_width = max(
self.max_articulation_level_width,
len(joints_at_level),
)
self.J_size += 6 * joint_count * dof_count
self.M_size += 6 * joint_count * 6 * joint_count
self.H_size += dof_count * dof_count
self.max_articulation_level_count = max(self.max_articulation_level_count, 1)
self.max_articulation_level_width = max(self.max_articulation_level_width, 1)
self.max_articulation_joint_count = max(self.max_articulation_joint_count, 1)
self.max_articulation_m_rows = max(self.max_articulation_m_rows, 1)
self.max_articulation_j_cols = max(self.max_articulation_j_cols, 1)
self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
level_index = np.full(
int(model.articulation_count) * self.max_articulation_level_count * self.max_articulation_level_width,
-1,
dtype=np.int32,
)
for art, level_lists in enumerate(articulation_level_lists):
for level, joints_at_level in enumerate(level_lists):
base = (art * self.max_articulation_level_count + level) * self.max_articulation_level_width
level_index[base : base + len(joints_at_level)] = joints_at_level
self.articulation_level_joint_index = wp.array(level_index, dtype=wp.int32, device=model.device)
def _build_dof_maps(self, model: Model) -> None:
dof_count = int(model.joint_dof_count)
joint_for_dof = np.full(dof_count, -1, dtype=np.int32)
axis_for_dof = np.full(dof_count, -1, dtype=np.int32)
dof_articulation = np.full(dof_count, -1, dtype=np.int32)
dof_articulation_local = np.full(dof_count, -1, dtype=np.int32)
joint_type = model.joint_type.numpy()
joint_qd_start = model.joint_qd_start.numpy()
joint_dof_dim = model.joint_dof_dim.numpy()
articulation_start = model.articulation_start.numpy()
for art in range(int(model.articulation_count)):
first_joint = int(articulation_start[art])
last_joint = int(articulation_start[art + 1])
first_dof = int(joint_qd_start[first_joint])
last_dof = int(joint_qd_start[last_joint])
for dof in range(first_dof, last_dof):
dof_articulation[dof] = art
dof_articulation_local[dof] = dof - first_dof
for joint in range(int(model.joint_count)):
start = int(joint_qd_start[joint])
axis_count = int(joint_dof_dim[joint, 0] + joint_dof_dim[joint, 1])
for axis in range(axis_count):
dof = start + axis
joint_for_dof[dof] = joint
axis_for_dof[dof] = axis
if int(joint_type[joint]) in (int(SAP_JOINT_FREE), int(SAP_JOINT_DISTANCE)):
if axis_count != 6:
raise ValueError("FREE/DISTANCE joints must have 6 velocity DOFs.")
self.dof_joint_index = wp.array(joint_for_dof, dtype=wp.int32, device=model.device)
self.dof_axis_index = wp.array(axis_for_dof, dtype=wp.int32, device=model.device)
self.dof_articulation_index = wp.array(dof_articulation, dtype=wp.int32, device=model.device)
self.dof_articulation_local_index = wp.array(dof_articulation_local, dtype=wp.int32, device=model.device)
def _allocate_buffers(self, model: Model, *, allocate_dynamics_matrix: bool) -> None:
self.model_body_mass = self._make_model_array_f64(model, "sap_debug_body_mass", "body_mass", wp.float64)
self.model_body_inertia = self._make_model_array_f64(model, "sap_debug_body_inertia", "body_inertia", wp.mat33d)
self.model_body_com = self._make_model_array_f64(model, "sap_debug_body_com", "body_com", wp.vec3d)
self.model_joint_axis = self._make_model_array_f64(model, "sap_debug_joint_axis", "joint_axis", wp.vec3d)
self.model_joint_X_p = self._make_model_array_f64(model, "sap_debug_joint_X_p", "joint_X_p", wp.transformd)
self.model_joint_X_c = self._make_model_array_f64(model, "sap_debug_joint_X_c", "joint_X_c", wp.transformd)
self.model_joint_X_c_identity = self._make_transform_identity_flags(
model,
"sap_debug_joint_X_c",
"joint_X_c",
)
self.model_gravity = self._make_model_array_f64(model, "sap_debug_gravity", "gravity", wp.vec3d)
self.model_joint_armature = self._make_model_array_f64(
model,
"sap_debug_joint_armature",
"joint_armature",
wp.float64,
)
self.M = wp.zeros((self.M_size,), dtype=wp.float64, device=model.device)
self.J = wp.zeros((self.J_size,), dtype=wp.float64, device=model.device)
self.P = wp.zeros_like(self.J)
self.H = wp.zeros((self.H_size,), dtype=wp.float64, device=model.device)
self.L = wp.zeros_like(self.H)
self.block_solver = None
self.block_chol_a = None
self.block_chol_rhs = None
self.block_chol_x = None
block_dtype = wp.float32 if self.linear_solve_precision == "fp32" else wp.float64
self._block_chol_uses_f32 = block_dtype == wp.float32
self.block_solver = BlockCholeskySolverBatched(
max_num_equations=int(self.max_articulation_j_cols),
batch_size=int(model.articulation_count),
block_size=32,
device=model.device,
dtype=block_dtype,
)
padded = int(self.block_solver.max_num_equations)
self.block_chol_a = wp.zeros(
(int(model.articulation_count), padded, padded),
dtype=block_dtype,
device=model.device,
)
self.block_chol_rhs = wp.zeros(
(int(model.articulation_count), padded, 1),
dtype=block_dtype,
device=model.device,
)
self.block_chol_x = wp.zeros_like(self.block_chol_rhs)
self.joint_qdd_sap_solve = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.joint_tau = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.joint_solve_tmp = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.joint_S_s = wp.empty((model.joint_dof_count,), dtype=wp.spatial_vectord, device=model.device)
self.joint_q_input = wp.zeros((model.joint_coord_count,), dtype=wp.float64, device=model.device)
self.joint_qd_sap_input = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.joint_f_sap_input = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.free_motion_joint_qd_sap = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.free_motion_joint_qdd_sap = wp.zeros((model.joint_dof_count,), dtype=wp.float64, device=model.device)
self.body_q = wp.empty((model.body_count,), dtype=wp.transformd, device=model.device)
self.body_q_f = wp.empty((model.body_count,), dtype=wp.transform, device=model.device)
self.body_I_s = wp.empty((model.body_count,), dtype=wp.spatial_matrixd, device=model.device)
self.body_v_s = wp.empty((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.body_v_s_f = wp.empty((model.body_count,), dtype=wp.spatial_vector, device=model.device)
self.body_a_s = wp.empty((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.body_a_s_f = wp.empty((model.body_count,), dtype=wp.spatial_vector, device=model.device)
self.body_f_s = wp.zeros((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.body_ft_s = wp.zeros((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.body_f_ext_s = wp.zeros((model.body_count,), dtype=wp.spatial_vectord, device=model.device)
self.dynamics_matrix_sap = None
if allocate_dynamics_matrix:
self._ensure_dynamics_matrix_allocated()
self._result = SapFreeMotionResult(
v_star=self.free_motion_joint_qd_sap,
vdot=self.free_motion_joint_qdd_sap,
dynamics_matrix=self.dynamics_matrix_sap,
)
@staticmethod
def _make_model_array_f64(model: Model, exact_name: str, fallback_name: str, dtype) -> wp.array:
src = getattr(model, exact_name, None)
if src is None:
src = getattr(model, fallback_name)
if isinstance(src, wp.array):
src_np = src.numpy()
else:
src_np = np.asarray(src)
return wp.array(np.asarray(src_np, dtype=np.float64), dtype=dtype, device=model.device)
@staticmethod
def _make_transform_identity_flags(model: Model, exact_name: str, fallback_name: str) -> wp.array:
src = getattr(model, exact_name, None)
if src is None:
src = getattr(model, fallback_name)
if isinstance(src, wp.array):
src_np = src.numpy()
else:
src_np = np.asarray(src)
transforms = np.asarray(src_np, dtype=np.float64).reshape((-1, 7))
identity = (
(np.linalg.norm(transforms[:, 0:3], axis=1) <= 1.0e-12)
& (np.linalg.norm(transforms[:, 3:6], axis=1) <= 1.0e-12)
& (np.abs(transforms[:, 6] - 1.0) <= 1.0e-12)
)
return wp.array(identity.astype(np.int32), dtype=wp.int32, device=model.device)
def _ensure_dynamics_matrix_allocated(self) -> None:
if self.dynamics_matrix_sap is not None:
return
model = self.model
self.dynamics_matrix_sap = wp.zeros(
(model.joint_dof_count, model.joint_dof_count),
dtype=wp.float64,
device=model.device,
)
self._result = SapFreeMotionResult(
v_star=self.free_motion_joint_qd_sap,
vdot=self.free_motion_joint_qdd_sap,
dynamics_matrix=self.dynamics_matrix_sap,
)
def _prepare_joint_q_input(self, state_in: State) -> None:
if state_in.joint_q.dtype == wp.float64:
kernel = _copy_f64
else:
kernel = _copy_f32_to_f64
wp.launch(
kernel,
dim=self.model.joint_coord_count,
inputs=[state_in.joint_q, self.joint_q_input],
device=self.model.device,
)
def _launch_rigid_id(self, state_in: State) -> None:
model = self.model
self._prepare_joint_q_input(state_in)
use_f32_rigid_id = self.linear_solve_precision == "fp32"
rigid_id_inputs = [
self.articulation_level_joint_index,
int(self.max_articulation_level_count),
int(self.max_articulation_level_width),
model.joint_type,
model.joint_parent,
model.joint_child,
model.joint_q_start,
model.joint_qd_start,
self.joint_q_input,
self.joint_qd_sap_input,
self.model_joint_axis,
model.joint_dof_dim,
self.model_body_inertia,
self.model_body_mass,
self.model_body_com,
self.body_q,
self.model_joint_X_p,
self.model_joint_X_c,
self.model_joint_X_c_identity,
model.body_world,
self.model_gravity,
self.joint_S_s,
self.body_I_s,
self.body_v_s,
self.body_f_s,
self.body_a_s,
]
if use_f32_rigid_id:
rigid_id_inputs.extend(
[
self.body_q_f,
self.body_v_s_f,
self.body_a_s_f,
]
)
wp.launch_tiled(
self._rigid_id_tiled_f32_revolute if use_f32_rigid_id else self._rigid_id_tiled,
dim=model.articulation_count,
block_dim=int(self.rigid_tile_size),
inputs=rigid_id_inputs,
device=model.device,
)
self._launch_rigid_body_dynamics()
def _launch_rigid_body_dynamics(self) -> None:
model = self.model
wp.launch(
_eval_rigid_body_dynamics_parallel,
dim=model.joint_count,
inputs=[
model.joint_child,
self.model_body_inertia,
self.model_body_mass,
self.model_body_com,
self.body_q,
model.body_world,
self.model_gravity,
self.body_I_s,
self.body_v_s,
self.body_f_s,
self.body_a_s,
],
device=model.device,
)
def _assemble_articulation_matrices(self) -> None:
model = self.model
self.J.zero_()
self.M.zero_()
self.H.zero_()
use_f32_math = self.linear_solve_precision == "fp32"
wp.launch(
_eval_rigid_jacobian_parallel_f32_math if use_f32_math else _eval_rigid_jacobian_parallel,
dim=(model.articulation_count, self.max_articulation_joint_count, 6),
inputs=[
model.articulation_start,
self.articulation_J_start,
model.joint_ancestor,
model.joint_child,
model.joint_qd_start,
self.joint_S_s,
self.body_q,
int(self.max_articulation_joint_count),
],
outputs=[self.J],
device=model.device,
)
wp.launch(
_eval_rigid_mass_parallel,
dim=(model.articulation_count, self.max_articulation_joint_count, 36),
inputs=[
model.articulation_start,
self.articulation_M_start,
model.joint_child,
self.body_I_s,
int(self.max_articulation_joint_count),
],
outputs=[self.M],
device=model.device,
)
wp.launch(
_eval_mass_times_jacobian_batched_f32_math if use_f32_math else _eval_mass_times_jacobian_batched,
dim=(
model.articulation_count,
self.max_articulation_m_rows,
(self.max_articulation_j_cols + 3) // 4,
),
inputs=[
self.articulation_M_rows,
self.articulation_J_cols,
self.articulation_M_start,
self.articulation_J_start,
self.M,
self.J,
],
outputs=[self.P],
device=model.device,
)
wp.launch_tiled(
self._jtp_gemm_tile_f32_math if use_f32_math else self._jtp_gemm_tile,
dim=(
model.articulation_count,
(self.max_articulation_j_cols + _JTP_GEMM_TILE_M - 1) // _JTP_GEMM_TILE_M,
(self.max_articulation_j_cols + _JTP_GEMM_TILE_N - 1) // _JTP_GEMM_TILE_N,
),
block_dim=128,
inputs=[
self.articulation_J_rows,
self.articulation_J_cols,
self.articulation_J_start,
self.articulation_H_start,
self.J,
self.P,
],
outputs=[self.H],
device=model.device,
)
def _factor_dynamics_matrix(self) -> None:
model = self.model
wp.launch(
_eval_dense_cholesky_batched_f64,
dim=model.articulation_count,
inputs=[
self.articulation_H_start,
self.articulation_H_rows,
self.articulation_dof_start,
self.H,
self.model_joint_armature,
],
outputs=[self.L],
device=model.device,
)
def _can_use_block_dynamics_solve(self) -> bool:
return self.block_solver is not None
def _solve_articulation_accelerations_blocked(self) -> None:
model = self.model
assert self.block_solver is not None
assert self.block_chol_a is not None
assert self.block_chol_rhs is not None
assert self.block_chol_x is not None
max_rows = int(self.max_articulation_j_cols)
pack_h_kernel = (
_pack_articulation_h_to_padded_batched_f32
if self._block_chol_uses_f32
else _pack_articulation_h_to_padded_batched_f64
)
pack_tau_kernel = (
_pack_articulation_tau_to_padded_batched_f32
if self._block_chol_uses_f32
else _pack_articulation_tau_to_padded_batched_f64
)
unpack_kernel = (
_unpack_articulation_solution_from_padded_batched_f32
if self._block_chol_uses_f32
else _unpack_articulation_solution_from_padded_batched_f64
)
wp.launch(
pack_h_kernel,
dim=(model.articulation_count, max_rows, max_rows),
inputs=[
self.articulation_H_start,
self.articulation_H_rows,
self.articulation_dof_start,
self.H,
self.model_joint_armature,
max_rows,
self.block_chol_a,
],
device=model.device,
)
wp.launch(
pack_tau_kernel,
dim=(model.articulation_count, max_rows),
inputs=[
self.articulation_H_rows,
self.articulation_dof_start,
self.joint_tau,
max_rows,
self.block_chol_rhs,
],
device=model.device,
)
self.block_solver.factorize(self.block_chol_a, max_rows)
self.block_solver.solve(self.block_chol_rhs, self.block_chol_x)
wp.launch(
unpack_kernel,
dim=(model.articulation_count, max_rows),
inputs=[
self.articulation_H_rows,
self.articulation_dof_start,
max_rows,
self.block_chol_x,
self.joint_qdd_sap_solve,
],
device=model.device,
)
def _assemble_sap_dynamics_matrix(self) -> None:
model = self.model
self._ensure_dynamics_matrix_allocated()
wp.launch(
_assemble_global_sap_dynamics_matrix_kernel,
dim=(model.joint_dof_count, model.joint_dof_count),
inputs=[
self.H,
self.model_joint_armature,
self.dof_articulation_index,
self.dof_articulation_local_index,
self.articulation_dof_start,
self.articulation_H_start,
self.articulation_H_rows,
self.dynamics_matrix_sap,
],
device=model.device,
)
def _prepare_sap_boundary(self, state_in: State, control: Control) -> None:
"""Copy SAP-native state/control views into free-motion work buffers."""
model = self.model
if state_in.joint_qd is not self.joint_qd_sap_input:
kernel = _copy_f64 if state_in.joint_qd.dtype == wp.float64 else _copy_f32_to_f64
wp.launch(
kernel,
dim=model.joint_dof_count,
inputs=[state_in.joint_qd, self.joint_qd_sap_input],
device=model.device,
)
if control.joint_f is not self.joint_f_sap_input:
kernel = _copy_f64 if control.joint_f.dtype == wp.float64 else _copy_f32_to_f64
wp.launch(
kernel,
dim=model.joint_dof_count,
inputs=[control.joint_f, self.joint_f_sap_input],
device=model.device,
)
body_f = getattr(state_in, "body_f", None)
if not int(model.body_count) or body_f is None:
self.body_f_ext_s.zero_()
elif body_f is self.body_f_ext_s:
return
elif body_f.dtype == wp.spatial_vectord:
wp.copy(dest=self.body_f_ext_s, src=body_f)
elif body_f.dtype == wp.spatial_vector:
wp.launch(
_copy_spatial_vector_to_spatial_vectord,
dim=model.body_count,
inputs=[body_f, self.body_f_ext_s],
device=model.device,
)
else:
raise TypeError(
"SapState.body_f must be SAP body-origin forces with dtype "
"wp.spatial_vectord or wp.spatial_vector."
)
def _compute_sap_core(
self,
state_in: State,
dt: float,
*,
assemble_dynamics_matrix: bool,
) -> SapFreeMotionResult:
"""Run the SAP free-motion solve after boundary conversion.
Inputs at this layer are the SAP-order `joint_qd_sap_input` and
`joint_f_sap_input`. Spatial buffers follow SAP's body-origin,
world-expressed convention.
"""
model = self.model
self.body_f_s.zero_()
self._launch_rigid_id(state_in)
self.body_ft_s.zero_()
wp.launch_tiled(
self._rigid_tau_tiled,
dim=model.articulation_count,
block_dim=int(self.rigid_tile_size),
inputs=[
self.articulation_level_joint_index,
int(self.max_articulation_level_count),
int(self.max_articulation_level_width),
model.joint_type,
model.joint_parent,
model.joint_child,
model.joint_qd_start,
model.joint_dof_dim,
self.joint_f_sap_input,
self.joint_S_s,
self.body_q,
self.body_f_s,
self.body_f_ext_s,
self.body_ft_s,
self.joint_tau,
],
device=model.device,
)
self._assemble_articulation_matrices()
self.joint_qdd_sap_solve.zero_()
if self._can_use_block_dynamics_solve():
self._solve_articulation_accelerations_blocked()
else:
self._factor_dynamics_matrix()
wp.launch(
_eval_dense_solve_batched_f64,
dim=model.articulation_count,
inputs=[
self.articulation_H_start,
self.articulation_H_rows,
self.articulation_dof_start,
self.H,
self.L,
self.joint_tau,
self.joint_qdd_sap_solve,
],
outputs=[self.joint_solve_tmp],
device=model.device,
)
wp.launch(
_assemble_sap_free_motion_outputs_kernel,
dim=model.joint_count,
inputs=[
model.joint_qd_start,
model.joint_dof_dim,
self.joint_qd_sap_input,
self.joint_qdd_sap_solve,
float(dt),
self.free_motion_joint_qd_sap,
self.free_motion_joint_qdd_sap,
],
device=model.device,
)
if assemble_dynamics_matrix:
self._assemble_sap_dynamics_matrix()
return self._result
[docs]
def compute(
self,
state_in: State,
control: Control | None,
dt: float,
*,
assemble_dynamics_matrix: bool = False,
) -> SapFreeMotionResult:
"""Compute SAP free motion from SAP-native state/control.
The returned buffers satisfy `v_star = v0 + dt * vdot0` in SAP order.
Set `assemble_dynamics_matrix=True` only for parity/debug code that
needs the global dense `A = M + R` in SAP velocity order. Runtime
contact/solver paths should use env-local matrix assembly instead.
"""
if getattr(state_in, "requires_grad", False):
raise NotImplementedError("SapFreeMotion does not support grad states.")
if not isinstance(state_in, State):
raise TypeError("SapFreeMotion.compute requires SapState; convert before entering SAP components.")
if control is None or not isinstance(control, Control):
raise TypeError("SapFreeMotion.compute requires SapControl; convert before entering SAP components.")
self._prepare_sap_boundary(state_in, control)
return self._compute_sap_core(
state_in,
dt,
assemble_dynamics_matrix=assemble_dynamics_matrix,
)
def compute_articulation_free_motion(
self,
state_in: State,
control: Control | None,
dt: float,
*,
assemble_dynamics_matrix: bool = False,
):
"""Compatibility alias returning `(v_star, vdot, A_or_None)` in SAP order."""
result = self.compute(
state_in,
control,
dt,
assemble_dynamics_matrix=assemble_dynamics_matrix,
)
return result.v_star, result.vdot, result.dynamics_matrix