Source code for sim.contact_solve
from dataclasses import dataclass
from functools import cache
from types import SimpleNamespace
import numpy as np
import warp as wp
from sim.blocked_cholesky import BlockCholeskySolverBatched
from sim.contact_jacobian import SapContactJacobianResult
from sim.sap_helpers import (
_clamp_antiderivative_f32,
_clamp_derivative_f32,
_clamp_scalar_f32,
_clamp_antiderivative_f64,
_clamp_derivative_f64,
_clamp_scalar_f64,
_compute_effective_pd_gains_sap,
_compute_effective_pd_gains_sap_f32,
_contact_projection_cost_from_vc_sap,
_contact_projection_cost_from_vc_sap_f32,
_contact_projection_cost_from_velocity_sap,
_contact_projection_cost_from_velocity_sap_f32,
_sap_armijo_ok,
_sap_armijo_ok_f32,
_m33d,
_m33f,
_v3d,
_v3f,
_zero_m33d,
_zero_m33f,
)
from sim.sap_runtime import (
Control,
Model,
SAP_JOINT_D6,
SAP_JOINT_DISTANCE,
SAP_JOINT_FREE,
SAP_JOINT_PRISMATIC,
SAP_JOINT_REVOLUTE,
SAP_JOINT_TARGET_NONE,
State,
)
wp.config.enable_backward = False
_PI = 3.141592653589793
_CONTACT_SOFT_NORM_TOL = 1.0e-7
_SAP_PD_BETA = 0.1
_SAP_LIMIT_BETA = 0.1
_SAP_LIMIT_STIFFNESS = 1.0e12
_SAP_LIMIT_WINDOW_FACTOR = 2.0
_SAP_EXACT_LINE_SEARCH_ALPHA_MAX = 1.5
_SAP_EXACT_LINE_SEARCH_MAX_ITERATIONS = 100
_SAP_EXACT_LINE_SEARCH_F_TOLERANCE = 1.0e-8
_CONTACT_MODE_NONE = 0
_CONTACT_MODE_STICTION = 1
_CONTACT_MODE_SLIDING = 2
_CONTACT_MODE_FRICTIONLESS = 3
_CONTACT_HESSIAN_GEMM_TILE_M = 8
_CONTACT_HESSIAN_GEMM_TILE_N = 8
_CONTACT_HESSIAN_GEMM_TILE_K = 32
def normalize_sap_line_search_mode(value: str) -> str:
"""Normalize a line-search variant string to the canonical value accepted by the SAP contact solve."""
mode = str(value).strip().lower().replace("-", "_")
if mode == "monotone_decay":
return "monotone_decay"
if mode == "armijo_decay":
return "armijo_decay"
if mode == "exact_root":
return "exact_root"
raise ValueError(
"line_search_variant must be 'monotone_decay', 'armijo_decay', or 'exact_root', "
f"got {value!r}."
)
@wp.kernel
def _copy_f32_to_f64(src: wp.array(dtype=wp.float32), dst: wp.array(dtype=wp.float64)):
i = wp.tid()
dst[i] = wp.float64(src[i])
@wp.kernel
def _copy_f64(src: wp.array(dtype=wp.float64), dst: wp.array(dtype=wp.float64)):
i = wp.tid()
dst[i] = src[i]
@wp.kernel
def _copy_f32(src: wp.array(dtype=wp.float32), dst: wp.array(dtype=wp.float32)):
i = wp.tid()
dst[i] = src[i]
@wp.kernel
def _copy_f64_to_f32(src: wp.array(dtype=wp.float64), dst: wp.array(dtype=wp.float32)):
i = wp.tid()
dst[i] = wp.float32(src[i])
@wp.kernel
def _copy_flat_f64_to_env_f32_batched(
src: wp.array(dtype=wp.float64),
dof_per_env: int,
dst: wp.array(dtype=wp.float32, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = wp.float32(src[env * dof_per_env + i])
@wp.kernel
def _copy_flat_f32_to_env_f64_batched(
src: wp.array(dtype=wp.float32),
dof_per_env: int,
dst: wp.array(dtype=wp.float64, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = wp.float64(src[env * dof_per_env + i])
@wp.kernel
def _copy_env_f64_to_env_f32_batched(
src: wp.array(dtype=wp.float64, ndim=2),
dof_per_env: int,
dst: wp.array(dtype=wp.float32, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = wp.float32(src[env, i])
@wp.kernel
def _copy_env_f32_to_env_f64_batched(
src: wp.array(dtype=wp.float32, ndim=2),
dof_per_env: int,
dst: wp.array(dtype=wp.float64, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = wp.float64(src[env, i])
@wp.kernel
def _copy_2d_f64_to_f32(src: wp.array(dtype=wp.float64, ndim=2), dst: wp.array(dtype=wp.float32, ndim=2)):
i, j = wp.tid()
dst[i, j] = wp.float32(src[i, j])
@wp.kernel
def _copy_3d_f64_to_f32(src: wp.array(dtype=wp.float64, ndim=3), dst: wp.array(dtype=wp.float32, ndim=3)):
i, j, k = wp.tid()
dst[i, j, k] = wp.float32(src[i, j, k])
@wp.kernel
def _copy_4d_f64_to_f32(src: wp.array(dtype=wp.float64, ndim=4), dst: wp.array(dtype=wp.float32, ndim=4)):
i, j, k, l = wp.tid()
dst[i, j, k, l] = wp.float32(src[i, j, k, l])
@wp.kernel
def _copy_vec3d_2d_to_vec3(src: wp.array(dtype=wp.vec3d, ndim=2), dst: wp.array(dtype=wp.vec3, ndim=2)):
i, j = wp.tid()
v = src[i, j]
dst[i, j] = wp.vec3(wp.float32(v.x), wp.float32(v.y), wp.float32(v.z))
@wp.kernel
def _copy_mat33d_2d_to_mat33(src: wp.array(dtype=wp.mat33d, ndim=2), dst: wp.array(dtype=wp.mat33, ndim=2)):
i, j = wp.tid()
m = src[i, j]
dst[i, j] = wp.mat33(
wp.float32(m[0, 0]), wp.float32(m[0, 1]), wp.float32(m[0, 2]),
wp.float32(m[1, 0]), wp.float32(m[1, 1]), wp.float32(m[1, 2]),
wp.float32(m[2, 0]), wp.float32(m[2, 1]), wp.float32(m[2, 2]),
)
[docs]
@dataclass(frozen=True)
class SapContactSolveResult:
"""Views into buffers owned by `SapContactSolve`."""
v_env: wp.array
v_flat: wp.array
cost: wp.array
previous_cost: wp.array
grad: wp.array
hessian: wp.array
constraint_impulse: wp.array
dynamics_impulse: wp.array
contact_gamma: wp.array
contact_g: wp.array
contact_vc: wp.array
contact_y: wp.array
contact_rt: wp.array
contact_rn: wp.array
contact_cost: wp.array
contact_mode: wp.array
pd_active: wp.array
pd_y: wp.array
pd_gamma: wp.array
pd_hdiag: wp.array
pd_cost: wp.array
pd_kp_eff: wp.array
pd_kd_eff: wp.array
limit_lower_active: wp.array
limit_upper_active: wp.array
limit_lower_gamma: wp.array
limit_upper_gamma: wp.array
limit_grad: wp.array
limit_hdiag: wp.array
limit_cost: wp.array
first_dv: wp.array
alpha: wp.array
newton_iterations_env: wp.array
line_search_iterations_env: wp.array
newton_active: wp.array
converged_env: wp.array
optimality_reached_env: wp.array
cost_reached_env: wp.array
iterations: int
line_search_iterations: int
converged: bool
@cache
def _make_contact_solve_kernel_table(scalar):
if scalar == wp.float32:
vec3 = wp.vec3
mat33 = wp.mat33
v3 = _v3f
m33 = _m33f
zero_m33 = _zero_m33f
clamp_scalar = _clamp_scalar_f32
clamp_derivative = _clamp_derivative_f32
clamp_antiderivative = _clamp_antiderivative_f32
compute_effective_pd_gains = _compute_effective_pd_gains_sap_f32
contact_projection_cost_from_vc = _contact_projection_cost_from_vc_sap_f32
contact_projection_cost_from_velocity = _contact_projection_cost_from_velocity_sap_f32
sap_armijo_ok = _sap_armijo_ok_f32
elif scalar == wp.float64:
vec3 = wp.vec3d
mat33 = wp.mat33d
v3 = _v3d
m33 = _m33d
zero_m33 = _zero_m33d
clamp_scalar = _clamp_scalar_f64
clamp_derivative = _clamp_derivative_f64
clamp_antiderivative = _clamp_antiderivative_f64
compute_effective_pd_gains = _compute_effective_pd_gains_sap
contact_projection_cost_from_vc = _contact_projection_cost_from_vc_sap
contact_projection_cost_from_velocity = _contact_projection_cost_from_velocity_sap
sap_armijo_ok = _sap_armijo_ok
else:
raise ValueError(f"Unsupported contact solve dtype {scalar!r}.")
@wp.kernel(module="unique")
def _copy_flat_to_env_batched(
src: wp.array(dtype=scalar),
dof_per_env: int,
dst: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = src[env * dof_per_env + i]
@wp.kernel(module="unique")
def _copy_env_to_flat_batched(
src: wp.array(dtype=scalar, ndim=2),
dof_per_env: int,
dst: wp.array(dtype=scalar),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env * dof_per_env + i] = src[env, i]
@wp.kernel(module="unique")
def _copy_env_to_env_batched(
src: wp.array(dtype=scalar, ndim=2),
dof_per_env: int,
dst: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
dst[env, i] = src[env, i]
@wp.kernel(module="unique")
def _copy_solve_velocity_inputs_flat_batched(
v_star_src: wp.array(dtype=scalar),
v0_src: wp.array(dtype=scalar),
dof_per_env: int,
v_star_dst: wp.array(dtype=scalar, ndim=2),
v0_dst: wp.array(dtype=scalar, ndim=2),
v_dst: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
flat = env * dof_per_env + i
v_star_dst[env, i] = v_star_src[flat]
v0 = v0_src[flat]
v0_dst[env, i] = v0
v_dst[env, i] = v0
@wp.kernel(module="unique")
def _copy_solve_velocity_inputs_flat_batched_with_guess_flag(
v_star_src: wp.array(dtype=scalar),
v0_src: wp.array(dtype=scalar),
v_guess_src: wp.array(dtype=scalar),
use_v_guess: wp.array(dtype=int),
dof_per_env: int,
v_star_dst: wp.array(dtype=scalar, ndim=2),
v0_dst: wp.array(dtype=scalar, ndim=2),
v_dst: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env:
flat = env * dof_per_env + i
v_star_dst[env, i] = v_star_src[flat]
v0 = v0_src[flat]
v0_dst[env, i] = v0
v = v0
if use_v_guess[0] != 0:
v = v_guess_src[flat]
v_dst[env, i] = v
v_guess_src[flat] = v
@wp.kernel(module="unique")
def _initialize_and_mark_unconstrained_free_envs_batched(
dof_per_env: int,
contact_count: wp.array(dtype=int),
pd_active: wp.array(dtype=int, ndim=2),
limit_lower_active: wp.array(dtype=int, ndim=2),
limit_upper_active: wp.array(dtype=int, ndim=2),
participating_dof: wp.array(dtype=int, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
v_flat: wp.array(dtype=scalar),
first_dv: wp.array(dtype=scalar, ndim=2),
newton_iterations_env: wp.array(dtype=int),
ls_iterations_total: wp.array(dtype=int),
alpha: wp.array(dtype=scalar),
previous_cost: wp.array(dtype=scalar),
converged_env: wp.array(dtype=int),
optimality_reached_env: wp.array(dtype=int),
cost_reached_env: wp.array(dtype=int),
stage2_active_env: wp.array(dtype=int),
newton_active: wp.array(dtype=int),
stage2_active_count: wp.array(dtype=int),
):
env = wp.tid()
newton_iterations_env[env] = 0
ls_iterations_total[env] = 0
alpha[env] = scalar(1.0)
previous_cost[env] = scalar(0.0)
optimality_reached_env[env] = 0
cost_reached_env[env] = 0
unconstrained = contact_count[env] == 0
for i in range(dof_per_env):
if pd_active[env, i] == 1 or limit_lower_active[env, i] == 1 or limit_upper_active[env, i] == 1:
unconstrained = False
if unconstrained:
converged_env[env] = 1
stage2_active_env[env] = 0
newton_active[env] = 0
for i in range(dof_per_env):
first_dv[env, i] = v_star[env, i] - v[env, i]
v[env, i] = v_star[env, i]
v_flat[env * dof_per_env + i] = v_star[env, i]
else:
converged_env[env] = 0
stage2_active_env[env] = 1
newton_active[env] = 1
for i in range(dof_per_env):
if participating_dof[env, i] == 0:
first_dv[env, i] = v_star[env, i] - v[env, i]
v[env, i] = v_star[env, i]
v_flat[env * dof_per_env + i] = v_star[env, i]
wp.atomic_add(stage2_active_count, 0, 1)
@wp.kernel(module="unique")
def _initialize_newton_loop_state(
newton_loop_iteration: wp.array(dtype=int),
newton_max_reached: wp.array(dtype=int),
):
newton_loop_iteration[0] = 0
newton_max_reached[0] = 0
@wp.kernel(module="unique")
def _extract_a_diag_data_batched(
dof_per_env: int,
a_mat: wp.array(dtype=scalar, ndim=3),
a_inv_diag: wp.array(dtype=scalar, ndim=2),
d_scale: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i >= dof_per_env:
return
diag = a_mat[env, i, i]
if diag < scalar(1.0e-12) or not wp.isfinite(diag):
diag = scalar(1.0e-12)
a_inv_diag[env, i] = scalar(1.0) / diag
d_scale[env, i] = scalar(1.0) / wp.sqrt(diag)
@wp.kernel(module="unique")
def _clear_participating_dofs_batched(
participating_dof: wp.array(dtype=int, ndim=2),
):
env, i = wp.tid()
participating_dof[env, i] = 0
@wp.kernel(module="unique")
def _mark_contact_participating_dofs_batched(
dof_per_env: int,
max_contacts: int,
contact_env_count: wp.array(dtype=int),
contact_env_body0: wp.array(dtype=int, ndim=2),
contact_env_body1: wp.array(dtype=int, ndim=2),
contact_env_jacobian: wp.array(dtype=scalar, ndim=4),
body_dof_start: wp.array(dtype=int),
body_dof_count: wp.array(dtype=int),
participating_dof: wp.array(dtype=int, ndim=2),
):
env, c, i = wp.tid()
if i >= dof_per_env:
return
count = contact_env_count[env]
if count > max_contacts:
count = max_contacts
if c >= count:
return
b0 = contact_env_body0[env, c]
if b0 >= 0:
start0 = body_dof_start[b0]
count0 = body_dof_count[b0]
if start0 >= 0 and count0 > 0:
if i >= start0 and i < start0 + count0:
participating_dof[env, i] = 1
b1 = contact_env_body1[env, c]
if b1 >= 0:
start1 = body_dof_start[b1]
count1 = body_dof_count[b1]
if start1 >= 0 and count1 > 0:
if i >= start1 and i < start1 + count1:
participating_dof[env, i] = 1
if (
contact_env_jacobian[env, c, 0, i] != scalar(0.0)
or contact_env_jacobian[env, c, 1, i] != scalar(0.0)
or contact_env_jacobian[env, c, 2, i] != scalar(0.0)
):
participating_dof[env, i] = 1
@wp.kernel(module="unique")
def _mark_model_participating_dofs_batched(
dof_per_env: int,
pd_active: wp.array(dtype=int, ndim=2),
limit_lower_active: wp.array(dtype=int, ndim=2),
limit_upper_active: wp.array(dtype=int, ndim=2),
participating_dof: wp.array(dtype=int, ndim=2),
):
env, i = wp.tid()
if i >= dof_per_env:
return
if pd_active[env, i] == 1 or limit_lower_active[env, i] == 1 or limit_upper_active[env, i] == 1:
participating_dof[env, i] = 1
@wp.kernel(module="unique")
def _build_pd_terms_sap_batched(
enabled: int,
dof_per_env: int,
dof_coord_index: wp.array(dtype=int),
dof_target_index: wp.array(dtype=int),
joint_target_mode: wp.array(dtype=int),
joint_target_ke: wp.array(dtype=float),
joint_target_kd: wp.array(dtype=float),
joint_effort_limit: wp.array(dtype=float),
joint_q: wp.array(dtype=scalar),
joint_target_pos: wp.array(dtype=float),
joint_target_vel: wp.array(dtype=float),
joint_act: wp.array(dtype=float),
a_inv_diag: wp.array(dtype=scalar, ndim=2),
dt: scalar,
mode_none: int,
pd_active: wp.array(dtype=int, ndim=2),
pd_a: wp.array(dtype=scalar, ndim=2),
pd_gain: wp.array(dtype=scalar, ndim=2),
pd_limit: wp.array(dtype=scalar, ndim=2),
pd_kp_eff: wp.array(dtype=scalar, ndim=2),
pd_kd_eff: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i >= dof_per_env:
return
dof = env * dof_per_env + i
coord = dof_coord_index[dof]
target_dof = dof_target_index[dof]
pd_active[env, i] = 0
pd_a[env, i] = scalar(0.0)
pd_gain[env, i] = scalar(0.0)
pd_limit[env, i] = scalar(0.0)
pd_kp_eff[env, i] = scalar(0.0)
pd_kd_eff[env, i] = scalar(0.0)
if target_dof < 0:
target_dof = dof
if enabled == 0 or joint_target_mode[target_dof] == mode_none:
return
kp = scalar(joint_target_ke[target_dof])
kd = scalar(joint_target_kd[target_dof])
if kp <= scalar(0.0) and kd <= scalar(0.0):
return
if coord < 0 and kp > scalar(0.0):
return
eff = compute_effective_pd_gains(kp, kd, scalar(dt), a_inv_diag[env, i])
kp_eff = eff[0]
kd_eff = eff[1]
gain = scalar(dt) * kp_eff + kd_eff
if gain <= scalar(0.0):
return
q0 = scalar(0.0)
qd = scalar(0.0)
if coord >= 0:
q0 = scalar(joint_q[coord])
qd = scalar(joint_target_pos[target_dof])
vd = scalar(joint_target_vel[target_dof])
u0 = scalar(joint_act[target_dof])
pd_active[env, i] = 1
pd_a[env, i] = kp_eff * (qd - q0) + kd_eff * vd + u0
pd_gain[env, i] = gain
pd_limit[env, i] = scalar(joint_effort_limit[target_dof])
pd_kp_eff[env, i] = kp_eff
pd_kd_eff[env, i] = kd_eff
@wp.kernel(module="unique")
def _eval_pd_terms_sap_batched(
add_pd: int,
active_env: wp.array(dtype=int),
dof_per_env: int,
pd_active: wp.array(dtype=int, ndim=2),
pd_a: wp.array(dtype=scalar, ndim=2),
pd_gain: wp.array(dtype=scalar, ndim=2),
pd_limit: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
dt: scalar,
pd_y: wp.array(dtype=scalar, ndim=2),
pd_gamma: wp.array(dtype=scalar, ndim=2),
pd_hdiag: wp.array(dtype=scalar, ndim=2),
pd_cost: wp.array(dtype=scalar, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, i = wp.tid()
if i >= dof_per_env or add_pd == 0 or active_env[env] == 0 or pd_active[env, i] == 0:
if i < dof_per_env:
pd_y[env, i] = scalar(0.0)
pd_gamma[env, i] = scalar(0.0)
pd_hdiag[env, i] = scalar(0.0)
pd_cost[env, i] = scalar(0.0)
return
gain = pd_gain[env, i]
y = pd_a[env, i] - gain * v[env, i]
gamma = scalar(dt) * clamp_scalar(y, pd_limit[env, i])
hdiag = scalar(dt) * gain * clamp_derivative(y, pd_limit[env, i])
cost = scalar(0.0)
if gain > scalar(0.0):
cost = (scalar(dt) / gain) * clamp_antiderivative(y, pd_limit[env, i])
pd_y[env, i] = y
pd_gamma[env, i] = gamma
pd_hdiag[env, i] = hdiag
pd_cost[env, i] = cost
wp.atomic_add(total_cost, env, cost)
@wp.kernel(module="unique")
def _build_limit_terms_sap_batched(
enabled: int,
dof_per_env: int,
dof_coord_index: wp.array(dtype=int),
limit_supported: wp.array(dtype=int),
joint_limit_lower: wp.array(dtype=scalar),
joint_limit_upper: wp.array(dtype=scalar),
joint_limit_ke: wp.array(dtype=scalar),
joint_limit_kd: wp.array(dtype=scalar),
joint_q: wp.array(dtype=scalar),
v0: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
a_inv_diag: wp.array(dtype=scalar, ndim=2),
dt: scalar,
lower_active: wp.array(dtype=int, ndim=2),
upper_active: wp.array(dtype=int, ndim=2),
lower_vhat: wp.array(dtype=scalar, ndim=2),
upper_vhat: wp.array(dtype=scalar, ndim=2),
lower_r: wp.array(dtype=scalar, ndim=2),
upper_r: wp.array(dtype=scalar, ndim=2),
lower_rinv: wp.array(dtype=scalar, ndim=2),
upper_rinv: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i >= dof_per_env:
return
dof = env * dof_per_env + i
coord = dof_coord_index[dof]
lower_active[env, i] = 0
upper_active[env, i] = 0
lower_vhat[env, i] = scalar(0.0)
upper_vhat[env, i] = scalar(0.0)
lower_r[env, i] = scalar(1.0)
upper_r[env, i] = scalar(1.0)
lower_rinv[env, i] = scalar(0.0)
upper_rinv[env, i] = scalar(0.0)
if enabled == 0 or coord < 0 or limit_supported[dof] == 0:
return
ke = scalar(joint_limit_ke[dof])
if ke <= scalar(0.0):
return
q0 = scalar(joint_q[coord])
lower = scalar(joint_limit_lower[dof])
upper = scalar(joint_limit_upper[dof])
speed_scale = wp.max(wp.abs(v0[env, i]), wp.abs(v_star[env, i]))
window = scalar(_SAP_LIMIT_WINDOW_FACTOR) * scalar(dt) * speed_scale
kd = scalar(joint_limit_kd[dof])
tau_d = scalar(0.0)
if kd > scalar(0.0):
tau_d = kd / ke
beta = scalar(_SAP_LIMIT_BETA)
beta_factor = beta * beta / (scalar(4.0) * scalar(_PI) * scalar(_PI))
r_nr = beta_factor * wp.max(a_inv_diag[env, i], scalar(1.0e-12))
r_soft = scalar(1.0) / (
scalar(dt) * ke * (scalar(dt) + tau_d)
)
r = wp.max(r_nr, r_soft)
rinv = scalar(1.0) / r
if wp.isfinite(lower):
g = q0 - lower
if g <= window:
lower_active[env, i] = 1
lower_vhat[env, i] = -g / (scalar(dt) + tau_d)
lower_r[env, i] = r
lower_rinv[env, i] = rinv
if wp.isfinite(upper):
g = upper - q0
if g <= window:
upper_active[env, i] = 1
upper_vhat[env, i] = -g / (scalar(dt) + tau_d)
upper_r[env, i] = r
upper_rinv[env, i] = rinv
@wp.kernel(module="unique")
def _eval_limit_terms_sap_batched(
add_limits: int,
active_env: wp.array(dtype=int),
dof_per_env: int,
lower_active: wp.array(dtype=int, ndim=2),
upper_active: wp.array(dtype=int, ndim=2),
lower_vhat: wp.array(dtype=scalar, ndim=2),
upper_vhat: wp.array(dtype=scalar, ndim=2),
lower_r: wp.array(dtype=scalar, ndim=2),
upper_r: wp.array(dtype=scalar, ndim=2),
lower_rinv: wp.array(dtype=scalar, ndim=2),
upper_rinv: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
lower_gamma_out: wp.array(dtype=scalar, ndim=2),
upper_gamma_out: wp.array(dtype=scalar, ndim=2),
limit_grad: wp.array(dtype=scalar, ndim=2),
limit_hdiag: wp.array(dtype=scalar, ndim=2),
limit_cost: wp.array(dtype=scalar, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, i = wp.tid()
if i >= dof_per_env:
return
grad = scalar(0.0)
hdiag = scalar(0.0)
cost = scalar(0.0)
lower_gamma = scalar(0.0)
upper_gamma = scalar(0.0)
if add_limits == 1 and active_env[env] == 1:
if lower_active[env, i] == 1:
gamma = lower_rinv[env, i] * (lower_vhat[env, i] - v[env, i])
if gamma > scalar(0.0):
lower_gamma = gamma
grad = grad + gamma
hdiag = hdiag + lower_rinv[env, i]
cost = cost + scalar(0.5) * lower_r[env, i] * gamma * gamma
if upper_active[env, i] == 1:
gamma = upper_rinv[env, i] * (upper_vhat[env, i] + v[env, i])
if gamma > scalar(0.0):
upper_gamma = gamma
grad = grad - gamma
hdiag = hdiag + upper_rinv[env, i]
cost = cost + scalar(0.5) * upper_r[env, i] * gamma * gamma
lower_gamma_out[env, i] = lower_gamma
upper_gamma_out[env, i] = upper_gamma
limit_grad[env, i] = grad
limit_hdiag[env, i] = hdiag
limit_cost[env, i] = cost
if active_env[env] == 1:
wp.atomic_add(total_cost, env, cost)
@wp.kernel(module="unique")
def _projection_eval_contact_sap_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_phi0: wp.array(dtype=scalar, ndim=2),
contact_w_eff: wp.array(dtype=scalar, ndim=2),
contact_mu: wp.array(dtype=scalar, ndim=2),
contact_k: wp.array(dtype=scalar, ndim=2),
contact_tau_d: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
beta: scalar,
sigma: scalar,
dt: scalar,
contact_gamma: wp.array(dtype=vec3, ndim=2),
contact_g: wp.array(dtype=mat33, ndim=2),
contact_vc: wp.array(dtype=vec3, ndim=2),
contact_y: wp.array(dtype=vec3, ndim=2),
contact_rt: wp.array(dtype=scalar, ndim=2),
contact_rn: wp.array(dtype=scalar, ndim=2),
contact_cost: wp.array(dtype=scalar, ndim=2),
contact_mode: wp.array(dtype=int, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, c = wp.tid()
zero = v3(scalar(0.0), scalar(0.0), scalar(0.0))
if c >= max_contacts or c >= contact_count[env]:
return
if active_env[env] == 0:
contact_gamma[env, c] = zero
contact_g[env, c] = zero_m33()
contact_vc[env, c] = zero
contact_y[env, c] = zero
contact_rt[env, c] = scalar(0.0)
contact_rn[env, c] = scalar(0.0)
contact_cost[env, c] = scalar(0.0)
contact_mode[env, c] = _CONTACT_MODE_NONE
return
vc = zero
for j in range(dof_per_env):
vj = v[env, j]
vc = vec3(
vc.x + contact_jac[env, c, 0, j] * vj,
vc.y + contact_jac[env, c, 1, j] * vj,
vc.z + contact_jac[env, c, 2, j] * vj,
)
wi = contact_w_eff[env, c]
if wi < scalar(1.0e-12) or not wp.isfinite(wi):
wi = scalar(1.0e-12)
beta64 = scalar(beta)
beta_factor = beta64 * beta64 / (scalar(4.0) * scalar(_PI) * scalar(_PI))
rn_hard = beta_factor * wi
k_c = contact_k[env, c]
if k_c <= scalar(0.0) or not wp.isfinite(k_c):
k_c = scalar(1.0)
rn_soft = scalar(1.0) / (
scalar(dt) * k_c * (scalar(dt) + wp.max(contact_tau_d[env, c], scalar(0.0)))
)
rn = wp.max(rn_hard, rn_soft)
rt = scalar(sigma) * wi
if rt < scalar(1.0e-30):
rt = scalar(1.0e-30)
if rn < scalar(1.0e-30):
rn = scalar(1.0e-30)
rt_inv = scalar(1.0) / rt
rn_inv = scalar(1.0) / rn
tau_c = wp.max(contact_tau_d[env, c], scalar(0.0))
vhat_n = -contact_phi0[env, c] / (scalar(dt) + tau_c)
y = vec3(-rt_inv * vc.x, -rt_inv * vc.y, rn_inv * (vhat_n - vc.z))
mu = contact_mu[env, c]
if mu < scalar(0.0) or not wp.isfinite(mu):
mu = scalar(0.0)
yr = wp.sqrt(
y.x * y.x
+ y.y * y.y
+ scalar(_CONTACT_SOFT_NORM_TOL) * scalar(_CONTACT_SOFT_NORM_TOL)
)
t_hat = v3(scalar(0.0), scalar(0.0), scalar(0.0))
if yr > scalar(0.0):
t_hat = v3(y.x / yr, y.y / yr, scalar(0.0))
gamma = zero
g_mat = zero_m33()
mode = _CONTACT_MODE_NONE
if mu <= scalar(1.0e-12):
if y.z > scalar(0.0):
gamma = v3(scalar(0.0), scalar(0.0), y.z)
g_mat = m33(
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
rn_inv,
)
mode = _CONTACT_MODE_FRICTIONLESS
else:
mu_tilde = mu * wp.sqrt(rt / rn)
mu_hat = mu * rt / rn
factor = scalar(1.0) / (scalar(1.0) + mu_tilde * mu_tilde)
if yr <= mu * y.z:
gamma = y
g_mat = m33(
rt_inv,
scalar(0.0),
scalar(0.0),
scalar(0.0),
rt_inv,
scalar(0.0),
scalar(0.0),
scalar(0.0),
rn_inv,
)
mode = _CONTACT_MODE_STICTION
elif (-mu_hat * yr < y.z) and (y.z < yr / mu):
gamma_n = (y.z + mu_hat * yr) * factor
gamma = v3(mu * gamma_n * t_hat.x, mu * gamma_n * t_hat.y, gamma_n)
p00 = t_hat.x * t_hat.x
p01 = t_hat.x * t_hat.y
p10 = t_hat.y * t_hat.x
p11 = t_hat.y * t_hat.y
pp00 = scalar(1.0) - p00
pp01 = -p01
pp10 = -p10
pp11 = scalar(1.0) - p11
gn_over_yr = scalar(0.0)
if yr > scalar(0.0):
gn_over_yr = gamma_n / yr
dgt_dyt00 = mu * (gn_over_yr * pp00 + mu_hat * factor * p00)
dgt_dyt01 = mu * (gn_over_yr * pp01 + mu_hat * factor * p01)
dgt_dyt10 = mu * (gn_over_yr * pp10 + mu_hat * factor * p10)
dgt_dyt11 = mu * (gn_over_yr * pp11 + mu_hat * factor * p11)
dgt_dyn0 = mu * factor * t_hat.x
dgt_dyn1 = mu * factor * t_hat.y
dgn_dyt0 = mu_hat * factor * t_hat.x
dgn_dyt1 = mu_hat * factor * t_hat.y
g_mat = m33(
dgt_dyt00 * rt_inv,
dgt_dyt01 * rt_inv,
dgt_dyn0 * rn_inv,
dgt_dyt10 * rt_inv,
dgt_dyt11 * rt_inv,
dgt_dyn1 * rn_inv,
dgn_dyt0 * rt_inv,
dgn_dyt1 * rt_inv,
factor * rn_inv,
)
mode = _CONTACT_MODE_SLIDING
cost = scalar(0.5) * (
rt * (gamma.x * gamma.x + gamma.y * gamma.y) + rn * gamma.z * gamma.z
)
contact_gamma[env, c] = gamma
contact_g[env, c] = g_mat
contact_vc[env, c] = vc
contact_y[env, c] = y
contact_rt[env, c] = rt
contact_rn[env, c] = rn
contact_cost[env, c] = cost
contact_mode[env, c] = mode
wp.atomic_add(total_cost, env, cost)
@wp.kernel(module="unique")
def _projection_cost_only_contact_sap_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_phi0: wp.array(dtype=scalar, ndim=2),
contact_w_eff: wp.array(dtype=scalar, ndim=2),
contact_mu: wp.array(dtype=scalar, ndim=2),
contact_k: wp.array(dtype=scalar, ndim=2),
contact_tau_d: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
beta: scalar,
sigma: scalar,
dt: scalar,
contact_cost: wp.array(dtype=scalar, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, c = wp.tid()
if c >= max_contacts or c >= contact_count[env]:
return
if active_env[env] == 0:
contact_cost[env, c] = scalar(0.0)
return
cost = contact_projection_cost_from_velocity(
env,
c,
dof_per_env,
contact_jac,
contact_phi0,
contact_w_eff,
contact_mu,
contact_k,
contact_tau_d,
v,
dv,
scalar(0.0),
beta,
sigma,
dt,
)
contact_cost[env, c] = cost
wp.atomic_add(total_cost, env, cost)
@wp.kernel(module="unique")
def _projection_eval_contact_gamma_sap_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_phi0: wp.array(dtype=scalar, ndim=2),
contact_w_eff: wp.array(dtype=scalar, ndim=2),
contact_mu: wp.array(dtype=scalar, ndim=2),
contact_k: wp.array(dtype=scalar, ndim=2),
contact_tau_d: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
beta: scalar,
sigma: scalar,
dt: scalar,
contact_gamma: wp.array(dtype=vec3, ndim=2),
contact_vc: wp.array(dtype=vec3, ndim=2),
contact_cost: wp.array(dtype=scalar, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, c = wp.tid()
zero = v3(scalar(0.0), scalar(0.0), scalar(0.0))
if c >= max_contacts or c >= contact_count[env]:
return
if active_env[env] == 0:
contact_gamma[env, c] = zero
contact_vc[env, c] = zero
contact_cost[env, c] = scalar(0.0)
return
vc = zero
for j in range(dof_per_env):
vj = v[env, j]
vc = vec3(
vc.x + contact_jac[env, c, 0, j] * vj,
vc.y + contact_jac[env, c, 1, j] * vj,
vc.z + contact_jac[env, c, 2, j] * vj,
)
wi = contact_w_eff[env, c]
if wi < scalar(1.0e-12) or not wp.isfinite(wi):
wi = scalar(1.0e-12)
beta_factor = beta * beta / (scalar(4.0) * scalar(_PI) * scalar(_PI))
rn_hard = beta_factor * wi
k_c = contact_k[env, c]
if k_c <= scalar(0.0) or not wp.isfinite(k_c):
k_c = scalar(1.0)
tau_c = wp.max(contact_tau_d[env, c], scalar(0.0))
rn_soft = scalar(1.0) / (dt * k_c * (dt + tau_c))
rn = wp.max(rn_hard, rn_soft)
rt = sigma * wi
if rt < scalar(1.0e-30):
rt = scalar(1.0e-30)
if rn < scalar(1.0e-30):
rn = scalar(1.0e-30)
rt_inv = scalar(1.0) / rt
rn_inv = scalar(1.0) / rn
vhat_n = -contact_phi0[env, c] / (dt + tau_c)
y = vec3(-rt_inv * vc.x, -rt_inv * vc.y, rn_inv * (vhat_n - vc.z))
mu = contact_mu[env, c]
if mu < scalar(0.0) or not wp.isfinite(mu):
mu = scalar(0.0)
yr = wp.sqrt(
y.x * y.x
+ y.y * y.y
+ scalar(_CONTACT_SOFT_NORM_TOL) * scalar(_CONTACT_SOFT_NORM_TOL)
)
t_hat = zero
if yr > scalar(0.0):
t_hat = v3(y.x / yr, y.y / yr, scalar(0.0))
gamma = zero
if mu <= scalar(1.0e-12):
if y.z > scalar(0.0):
gamma = v3(scalar(0.0), scalar(0.0), y.z)
else:
mu_tilde = mu * wp.sqrt(rt / rn)
mu_hat = mu * rt / rn
factor = scalar(1.0) / (scalar(1.0) + mu_tilde * mu_tilde)
if yr <= mu * y.z:
gamma = y
elif (-mu_hat * yr < y.z) and (y.z < yr / mu):
gamma_n = (y.z + mu_hat * yr) * factor
gamma = v3(mu * gamma_n * t_hat.x, mu * gamma_n * t_hat.y, gamma_n)
cost = scalar(0.5) * (
rt * (gamma.x * gamma.x + gamma.y * gamma.y) + rn * gamma.z * gamma.z
)
contact_gamma[env, c] = gamma
contact_vc[env, c] = vc
contact_cost[env, c] = cost
wp.atomic_add(total_cost, env, cost)
@wp.kernel(module="unique")
def _projection_eval_contact_hessian_sap_batched(
active_env: wp.array(dtype=int),
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_phi0: wp.array(dtype=scalar, ndim=2),
contact_w_eff: wp.array(dtype=scalar, ndim=2),
contact_mu: wp.array(dtype=scalar, ndim=2),
contact_k: wp.array(dtype=scalar, ndim=2),
contact_tau_d: wp.array(dtype=scalar, ndim=2),
contact_vc: wp.array(dtype=vec3, ndim=2),
beta: scalar,
sigma: scalar,
dt: scalar,
contact_g: wp.array(dtype=mat33, ndim=2),
contact_y: wp.array(dtype=vec3, ndim=2),
contact_rt: wp.array(dtype=scalar, ndim=2),
contact_rn: wp.array(dtype=scalar, ndim=2),
contact_mode: wp.array(dtype=int, ndim=2),
):
env, c = wp.tid()
zero = v3(scalar(0.0), scalar(0.0), scalar(0.0))
if c >= max_contacts or c >= contact_count[env]:
return
if active_env[env] == 0:
contact_g[env, c] = zero_m33()
contact_y[env, c] = zero
contact_rt[env, c] = scalar(0.0)
contact_rn[env, c] = scalar(0.0)
contact_mode[env, c] = _CONTACT_MODE_NONE
return
wi = contact_w_eff[env, c]
if wi < scalar(1.0e-12) or not wp.isfinite(wi):
wi = scalar(1.0e-12)
beta64 = scalar(beta)
beta_factor = beta64 * beta64 / (scalar(4.0) * scalar(_PI) * scalar(_PI))
rn_hard = beta_factor * wi
k_c = contact_k[env, c]
if k_c <= scalar(0.0) or not wp.isfinite(k_c):
k_c = scalar(1.0)
tau_c = wp.max(contact_tau_d[env, c], scalar(0.0))
rn_soft = scalar(1.0) / (scalar(dt) * k_c * (scalar(dt) + tau_c))
rn = wp.max(rn_hard, rn_soft)
rt = scalar(sigma) * wi
if rt < scalar(1.0e-30):
rt = scalar(1.0e-30)
if rn < scalar(1.0e-30):
rn = scalar(1.0e-30)
rt_inv = scalar(1.0) / rt
rn_inv = scalar(1.0) / rn
vhat_n = -contact_phi0[env, c] / (scalar(dt) + tau_c)
vc = contact_vc[env, c]
y = vec3(-rt_inv * vc.x, -rt_inv * vc.y, rn_inv * (vhat_n - vc.z))
mu = contact_mu[env, c]
if mu < scalar(0.0) or not wp.isfinite(mu):
mu = scalar(0.0)
yr = wp.sqrt(
y.x * y.x
+ y.y * y.y
+ scalar(_CONTACT_SOFT_NORM_TOL) * scalar(_CONTACT_SOFT_NORM_TOL)
)
t_hat = zero
if yr > scalar(0.0):
t_hat = v3(y.x / yr, y.y / yr, scalar(0.0))
g_mat = zero_m33()
mode = _CONTACT_MODE_NONE
if mu <= scalar(1.0e-12):
if y.z > scalar(0.0):
g_mat = m33(
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
scalar(0.0),
rn_inv,
)
mode = _CONTACT_MODE_FRICTIONLESS
else:
mu_tilde = mu * wp.sqrt(rt / rn)
mu_hat = mu * rt / rn
factor = scalar(1.0) / (scalar(1.0) + mu_tilde * mu_tilde)
if yr <= mu * y.z:
g_mat = m33(
rt_inv,
scalar(0.0),
scalar(0.0),
scalar(0.0),
rt_inv,
scalar(0.0),
scalar(0.0),
scalar(0.0),
rn_inv,
)
mode = _CONTACT_MODE_STICTION
elif (-mu_hat * yr < y.z) and (y.z < yr / mu):
gamma_n = (y.z + mu_hat * yr) * factor
p00 = t_hat.x * t_hat.x
p01 = t_hat.x * t_hat.y
p10 = t_hat.y * t_hat.x
p11 = t_hat.y * t_hat.y
pp00 = scalar(1.0) - p00
pp01 = -p01
pp10 = -p10
pp11 = scalar(1.0) - p11
gn_over_yr = scalar(0.0)
if yr > scalar(0.0):
gn_over_yr = gamma_n / yr
dgt_dyt00 = mu * (gn_over_yr * pp00 + mu_hat * factor * p00)
dgt_dyt01 = mu * (gn_over_yr * pp01 + mu_hat * factor * p01)
dgt_dyt10 = mu * (gn_over_yr * pp10 + mu_hat * factor * p10)
dgt_dyt11 = mu * (gn_over_yr * pp11 + mu_hat * factor * p11)
dgt_dyn0 = mu * factor * t_hat.x
dgt_dyn1 = mu * factor * t_hat.y
dgn_dyt0 = mu_hat * factor * t_hat.x
dgn_dyt1 = mu_hat * factor * t_hat.y
g_mat = m33(
dgt_dyt00 * rt_inv,
dgt_dyt01 * rt_inv,
dgt_dyn0 * rn_inv,
dgt_dyt10 * rt_inv,
dgt_dyt11 * rt_inv,
dgt_dyn1 * rn_inv,
dgn_dyt0 * rt_inv,
dgn_dyt1 * rt_inv,
factor * rn_inv,
)
mode = _CONTACT_MODE_SLIDING
contact_g[env, c] = g_mat
contact_y[env, c] = y
contact_rt[env, c] = rt
contact_rn[env, c] = rn
contact_mode[env, c] = mode
@wp.kernel(module="unique")
def _accumulate_pd_impulse_batched(
add_pd: int,
dof_per_env: int,
pd_active: wp.array(dtype=int, ndim=2),
pd_gamma: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if add_pd == 1 and i < dof_per_env and pd_active[env, i] == 1:
constraint_impulse[env, i] = constraint_impulse[env, i] + pd_gamma[env, i]
@wp.kernel(module="unique")
def _accumulate_limit_impulse_batched(
add_limits: int,
dof_per_env: int,
limit_grad: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if add_limits == 1 and i < dof_per_env:
constraint_impulse[env, i] = constraint_impulse[env, i] + limit_grad[env, i]
@cache
def _make_contact_impulse_single_tile_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _contact_impulse_single_tile(
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_gamma: wp.array(dtype=vec3, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
):
env, i, tid = wp.tid()
if i >= dof_per_env:
return
count = contact_count[env]
if count > max_contacts:
count = max_contacts
value = scalar(0.0)
stride = wp.block_dim()
c = tid
while c < count:
gamma = contact_gamma[env, c]
value = value + contact_jac[env, c, 0, i] * gamma.x
value = value + contact_jac[env, c, 1, i] * gamma.y
value = value + contact_jac[env, c, 2, i] * gamma.z
c = c + stride
values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
values[tid] = value
total = wp.tile_sum(values)[0]
if tid == 0:
constraint_impulse[env, i] = total
return _contact_impulse_single_tile
@cache
def _make_contact_hessian_single_tile_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _contact_hessian_single_tile(
dof_per_env: int,
upper_count: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_g: wp.array(dtype=mat33, ndim=2),
hess_contact: wp.array(dtype=scalar, ndim=3),
):
env, upper, tid = wp.tid()
if upper >= upper_count:
return
i = wp.int32(0)
rem = wp.int32(upper)
row_count = wp.int32(dof_per_env)
while rem >= row_count:
rem = rem - row_count
i = i + 1
row_count = row_count - 1
j = i + rem
count = contact_count[env]
if count > max_contacts:
count = max_contacts
value = scalar(0.0)
stride = wp.block_dim()
c = tid
while c < count:
g = contact_g[env, c]
for r in range(3):
ji = contact_jac[env, c, r, i]
for s in range(3):
value = value + ji * g[r, s] * contact_jac[env, c, s, j]
c = c + stride
values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
values[tid] = value
total = wp.tile_sum(values)[0]
if tid == 0:
hess_contact[env, i, j] = total
if j != i:
hess_contact[env, j, i] = total
return _contact_hessian_single_tile
@cache
def _make_pack_contact_hessian_gemm_inputs_kernel(tile_k: int, tile_dof: int):
@wp.kernel(enable_backward=False, module="unique")
def _pack_contact_hessian_gemm_inputs(
dof_per_env: int,
contact_capacity: int,
padded_contact_rows: int,
padded_dof: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_g: wp.array(dtype=mat33, ndim=2),
contact_j_flat: wp.array(dtype=scalar, ndim=3),
contact_gj_flat: wp.array(dtype=scalar, ndim=3),
):
env, k_tile, dof_tile, tid = wp.tid()
k0 = k_tile * tile_k
d0 = dof_tile * tile_dof
linear = tid
while linear < tile_k * tile_dof:
kk = linear // tile_dof
dd = linear - kk * tile_dof
k = k0 + kk
d = d0 + dd
j_value = scalar(0.0)
gj_value = scalar(0.0)
if k < padded_contact_rows and d < padded_dof:
c = k // 3
r = k - c * 3
count = contact_count[env]
if count > contact_capacity:
count = contact_capacity
if c < count and d < dof_per_env:
j_value = contact_jac[env, c, r, d]
g = contact_g[env, c]
gj_value = (
g[r, 0] * contact_jac[env, c, 0, d]
+ g[r, 1] * contact_jac[env, c, 1, d]
+ g[r, 2] * contact_jac[env, c, 2, d]
)
contact_j_flat[env, k, d] = j_value
contact_gj_flat[env, k, d] = gj_value
linear = linear + wp.block_dim()
return _pack_contact_hessian_gemm_inputs
@cache
def _make_contact_hessian_gemm_tile_kernel(tile_m: int, tile_n: int, tile_k: int):
@wp.kernel(enable_backward=False, module="unique")
def _contact_hessian_gemm_tile(
dof_per_env: int,
padded_contact_rows: int,
contact_j_flat: wp.array(dtype=scalar, ndim=3),
contact_gj_flat: wp.array(dtype=scalar, ndim=3),
hess_contact: wp.array(dtype=scalar, ndim=3),
):
env, row_tile, col_tile, tid = wp.tid()
row0 = row_tile * tile_m
col0 = col_tile * tile_n
if row0 >= dof_per_env or col0 >= dof_per_env or col0 > row0 + tile_m - 1:
return
j_env = contact_j_flat[env]
gj_env = contact_gj_flat[env]
acc = wp.tile_zeros((tile_m, tile_n), dtype=scalar)
k0 = wp.int32(0)
while k0 < padded_contact_rows:
j_tile = wp.tile_load(
j_env,
shape=(tile_k, tile_m),
offset=(k0, row0),
storage="shared",
)
gj_tile = wp.tile_load(
gj_env,
shape=(tile_k, tile_n),
offset=(k0, col0),
storage="shared",
)
wp.tile_matmul(wp.tile_transpose(j_tile), gj_tile, acc)
k0 = k0 + tile_k
linear = tid
while linear < tile_m * tile_n:
rr = linear // tile_n
cc = linear - rr * tile_n
i = row0 + rr
j = col0 + cc
if i < dof_per_env and j < dof_per_env and j <= i:
value = acc[rr, cc]
hess_contact[env, i, j] = value
if j != i:
hess_contact[env, j, i] = value
linear = linear + wp.block_dim()
return _contact_hessian_gemm_tile
@cache
def _make_assemble_grad_and_dynamics_impulse_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _assemble_grad_and_dynamics_impulse_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
a_mat: wp.array(dtype=scalar, ndim=3),
add_pd: int,
pd_active: wp.array(dtype=int, ndim=2),
pd_gamma: wp.array(dtype=scalar, ndim=2),
add_limits: int,
limit_grad: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dynamics_impulse: wp.array(dtype=scalar, ndim=2),
):
env, i, tid = wp.tid()
if i >= dof_per_env:
return
if active_env[env] == 0:
if tid == 0:
constraint_impulse[env, i] = scalar(0.0)
grad[env, i] = scalar(0.0)
dynamics_impulse[env, i] = scalar(0.0)
return
local_a_res = scalar(0.0)
local_a_v = scalar(0.0)
stride = wp.block_dim()
j = tid
while j < dof_per_env:
aij = a_mat[env, i, j]
local_a_res = local_a_res + aij * (v[env, j] - v_star[env, j])
local_a_v = local_a_v + aij * v[env, j]
j = j + stride
res_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
res_values[tid] = local_a_res
v_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
v_values[tid] = local_a_v
a_res = wp.tile_sum(res_values)[0]
a_v = wp.tile_sum(v_values)[0]
if tid == 0:
impulse = constraint_impulse[env, i]
if add_pd == 1 and pd_active[env, i] == 1:
impulse = impulse + pd_gamma[env, i]
if add_limits == 1:
impulse = impulse + limit_grad[env, i]
constraint_impulse[env, i] = impulse
grad[env, i] = a_res - impulse
dynamics_impulse[env, i] = a_v
return _assemble_grad_and_dynamics_impulse_tiled
@cache
def _make_assemble_model_terms_and_grad_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _assemble_model_terms_and_grad_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
a_mat: wp.array(dtype=scalar, ndim=3),
add_pd: int,
pd_active: wp.array(dtype=int, ndim=2),
pd_a: wp.array(dtype=scalar, ndim=2),
pd_gain: wp.array(dtype=scalar, ndim=2),
pd_limit: wp.array(dtype=scalar, ndim=2),
dt: scalar,
pd_y: wp.array(dtype=scalar, ndim=2),
pd_gamma: wp.array(dtype=scalar, ndim=2),
pd_hdiag: wp.array(dtype=scalar, ndim=2),
pd_cost: wp.array(dtype=scalar, ndim=2),
add_limits: int,
lower_active: wp.array(dtype=int, ndim=2),
upper_active: wp.array(dtype=int, ndim=2),
lower_vhat: wp.array(dtype=scalar, ndim=2),
upper_vhat: wp.array(dtype=scalar, ndim=2),
lower_r: wp.array(dtype=scalar, ndim=2),
upper_r: wp.array(dtype=scalar, ndim=2),
lower_rinv: wp.array(dtype=scalar, ndim=2),
upper_rinv: wp.array(dtype=scalar, ndim=2),
lower_gamma_out: wp.array(dtype=scalar, ndim=2),
upper_gamma_out: wp.array(dtype=scalar, ndim=2),
limit_grad: wp.array(dtype=scalar, ndim=2),
limit_hdiag: wp.array(dtype=scalar, ndim=2),
limit_cost: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dynamics_impulse: wp.array(dtype=scalar, ndim=2),
total_cost: wp.array(dtype=scalar),
):
env, i, tid = wp.tid()
if i >= dof_per_env:
return
if active_env[env] == 0:
if tid == 0:
pd_y[env, i] = scalar(0.0)
pd_gamma[env, i] = scalar(0.0)
pd_hdiag[env, i] = scalar(0.0)
pd_cost[env, i] = scalar(0.0)
lower_gamma_out[env, i] = scalar(0.0)
upper_gamma_out[env, i] = scalar(0.0)
limit_grad[env, i] = scalar(0.0)
limit_hdiag[env, i] = scalar(0.0)
limit_cost[env, i] = scalar(0.0)
constraint_impulse[env, i] = scalar(0.0)
grad[env, i] = scalar(0.0)
dynamics_impulse[env, i] = scalar(0.0)
return
local_a_res = scalar(0.0)
local_a_v = scalar(0.0)
stride = wp.block_dim()
j = tid
while j < dof_per_env:
aij = a_mat[env, i, j]
local_a_res = local_a_res + aij * (v[env, j] - v_star[env, j])
local_a_v = local_a_v + aij * v[env, j]
j = j + stride
res_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
res_values[tid] = local_a_res
a_res = wp.tile_sum(res_values)[0]
v_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
v_values[tid] = local_a_v
a_v = wp.tile_sum(v_values)[0]
if tid == 0:
vi = v[env, i]
model_cost = scalar(0.5) * (vi - v_star[env, i]) * a_res
pd_y_value = scalar(0.0)
pd_gamma_value = scalar(0.0)
pd_hdiag_value = scalar(0.0)
pd_cost_value = scalar(0.0)
if add_pd == 1 and pd_active[env, i] == 1:
gain = pd_gain[env, i]
pd_y_value = pd_a[env, i] - gain * vi
pd_gamma_value = scalar(dt) * clamp_scalar(pd_y_value, pd_limit[env, i])
pd_hdiag_value = scalar(dt) * gain * clamp_derivative(pd_y_value, pd_limit[env, i])
if gain > scalar(0.0):
pd_cost_value = (scalar(dt) / gain) * clamp_antiderivative(
pd_y_value, pd_limit[env, i]
)
model_cost = model_cost + pd_cost_value
limit_grad_value = scalar(0.0)
limit_hdiag_value = scalar(0.0)
limit_cost_value = scalar(0.0)
lower_gamma = scalar(0.0)
upper_gamma = scalar(0.0)
if add_limits == 1:
if lower_active[env, i] == 1:
limit_gamma = lower_rinv[env, i] * (lower_vhat[env, i] - vi)
if limit_gamma > scalar(0.0):
lower_gamma = limit_gamma
limit_grad_value = limit_grad_value + limit_gamma
limit_hdiag_value = limit_hdiag_value + lower_rinv[env, i]
limit_cost_value = (
limit_cost_value + scalar(0.5) * lower_r[env, i] * limit_gamma * limit_gamma
)
if upper_active[env, i] == 1:
limit_gamma = upper_rinv[env, i] * (upper_vhat[env, i] + vi)
if limit_gamma > scalar(0.0):
upper_gamma = limit_gamma
limit_grad_value = limit_grad_value - limit_gamma
limit_hdiag_value = limit_hdiag_value + upper_rinv[env, i]
limit_cost_value = (
limit_cost_value + scalar(0.5) * upper_r[env, i] * limit_gamma * limit_gamma
)
model_cost = model_cost + limit_cost_value
impulse = constraint_impulse[env, i] + pd_gamma_value + limit_grad_value
constraint_impulse[env, i] = impulse
grad[env, i] = a_res - impulse
dynamics_impulse[env, i] = a_v
pd_y[env, i] = pd_y_value
pd_gamma[env, i] = pd_gamma_value
pd_hdiag[env, i] = pd_hdiag_value
pd_cost[env, i] = pd_cost_value
lower_gamma_out[env, i] = lower_gamma
upper_gamma_out[env, i] = upper_gamma
limit_grad[env, i] = limit_grad_value
limit_hdiag[env, i] = limit_hdiag_value
limit_cost[env, i] = limit_cost_value
wp.atomic_add(total_cost, env, model_cost)
return _assemble_model_terms_and_grad_tiled
@wp.kernel(module="unique")
def _assemble_hessian_total_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
a_mat: wp.array(dtype=scalar, ndim=3),
hess_contact: wp.array(dtype=scalar, ndim=3),
has_pd: int,
has_limits: int,
pd_hdiag: wp.array(dtype=scalar, ndim=2),
limit_hdiag: wp.array(dtype=scalar, ndim=2),
hess: wp.array(dtype=scalar, ndim=3),
):
env, i, j = wp.tid()
if i >= dof_per_env or j >= dof_per_env:
return
if active_env[env] == 0:
hess[env, i, j] = scalar(0.0)
return
value = a_mat[env, i, j] + hess_contact[env, i, j]
if i == j:
if has_pd == 1:
value = value + pd_hdiag[env, i]
if has_limits == 1:
value = value + limit_hdiag[env, i]
hess[env, i, j] = value
@cache
def _make_compute_base_cost_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_base_cost_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
a_mat: wp.array(dtype=scalar, ndim=3),
out_cost: wp.array(dtype=scalar),
):
env, tid = wp.tid()
if active_env[env] == 0:
if tid == 0:
out_cost[env] = scalar(0.0)
return
local_cost = scalar(0.0)
stride = wp.block_dim()
entry = tid
entry_count = dof_per_env * dof_per_env
while entry < entry_count:
i = entry // dof_per_env
j = entry - i * dof_per_env
residual_i = v[env, i] - v_star[env, i]
residual_j = v[env, j] - v_star[env, j]
local_cost = local_cost + scalar(0.5) * residual_i * a_mat[env, i, j] * residual_j
entry = entry + stride
values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
values[tid] = local_cost
total = wp.tile_sum(values)[0]
if tid == 0:
out_cost[env] = total
return _compute_base_cost_tiled
@cache
def _make_compute_line_search_base_coeffs_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_line_search_base_coeffs_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
a_mat: wp.array(dtype=scalar, ndim=3),
out_base0: wp.array(dtype=scalar),
out_base_linear: wp.array(dtype=scalar),
out_base_quadratic: wp.array(dtype=scalar),
):
env, tid = wp.tid()
if active_env[env] == 0:
if tid == 0:
out_base0[env] = scalar(0.0)
out_base_linear[env] = scalar(0.0)
out_base_quadratic[env] = scalar(0.0)
return
local_base0 = scalar(0.0)
local_base_linear = scalar(0.0)
local_base_quadratic = scalar(0.0)
stride = wp.block_dim()
entry = tid
entry_count = dof_per_env * dof_per_env
while entry < entry_count:
i = entry // dof_per_env
j = entry - i * dof_per_env
residual_i = v[env, i] - v_star[env, i]
residual_j = v[env, j] - v_star[env, j]
dv_i = dv[env, i]
dv_j = dv[env, j]
aij = a_mat[env, i, j]
local_base0 = local_base0 + scalar(0.5) * residual_i * aij * residual_j
local_base_linear = local_base_linear + dv_i * aij * residual_j
local_base_quadratic = local_base_quadratic + dv_i * aij * dv_j
entry = entry + stride
base_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
base_values[tid] = local_base0
base0 = wp.tile_sum(base_values)[0]
linear_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
linear_values[tid] = local_base_linear
base_linear = wp.tile_sum(linear_values)[0]
quadratic_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
quadratic_values[tid] = local_base_quadratic
base_quadratic = wp.tile_sum(quadratic_values)[0]
if tid == 0:
out_base0[env] = base0
out_base_linear[env] = base_linear
out_base_quadratic[env] = base_quadratic
return _compute_line_search_base_coeffs_tiled
@cache
def _make_compute_line_search_contact_delta_velocity_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_line_search_contact_delta_velocity_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
dv: wp.array(dtype=scalar, ndim=2),
out_dvc: wp.array(dtype=vec3, ndim=2),
):
env, c, tid = wp.tid()
zero = v3(scalar(0.0), scalar(0.0), scalar(0.0))
if c >= max_contacts or c >= contact_count[env]:
return
if active_env[env] == 0:
if tid == 0:
out_dvc[env, c] = zero
return
local_dvcx = scalar(0.0)
local_dvcy = scalar(0.0)
local_dvcz = scalar(0.0)
stride = wp.block_dim()
j = tid
while j < dof_per_env:
dvj = dv[env, j]
local_dvcx = local_dvcx + contact_jac[env, c, 0, j] * dvj
local_dvcy = local_dvcy + contact_jac[env, c, 1, j] * dvj
local_dvcz = local_dvcz + contact_jac[env, c, 2, j] * dvj
j = j + stride
dvcx_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
dvcx_values[tid] = local_dvcx
dvcy_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
dvcy_values[tid] = local_dvcy
dvcz_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
dvcz_values[tid] = local_dvcz
dvcx = wp.tile_sum(dvcx_values)[0]
dvcy = wp.tile_sum(dvcy_values)[0]
dvcz = wp.tile_sum(dvcz_values)[0]
if tid == 0:
out_dvc[env, c] = v3(dvcx, dvcy, dvcz)
return _compute_line_search_contact_delta_velocity_tiled
@cache
def _make_compute_norm_terms_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_norm_terms_tiled(
active_env: wp.array(dtype=int),
dof_per_env: int,
participating_dof: wp.array(dtype=int, ndim=2),
d_scale: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dynamics_impulse: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
grad_norm2: wp.array(dtype=scalar),
p_norm2: wp.array(dtype=scalar),
jc_norm2: wp.array(dtype=scalar),
):
env, tid = wp.tid()
if active_env[env] == 0:
if tid == 0:
grad_norm2[env] = scalar(0.0)
p_norm2[env] = scalar(0.0)
jc_norm2[env] = scalar(0.0)
return
grad_acc = scalar(0.0)
p_acc = scalar(0.0)
jc_acc = scalar(0.0)
stride = wp.block_dim()
i = tid
while i < dof_per_env:
if participating_dof[env, i] == 1:
s = d_scale[env, i]
g = s * grad[env, i]
p = s * dynamics_impulse[env, i]
jc = s * constraint_impulse[env, i]
grad_acc = grad_acc + g * g
p_acc = p_acc + p * p
jc_acc = jc_acc + jc * jc
i = i + stride
grad_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
grad_values[tid] = grad_acc
p_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
p_values[tid] = p_acc
jc_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
jc_values[tid] = jc_acc
grad_total = wp.tile_sum(grad_values)[0]
p_total = wp.tile_sum(p_values)[0]
jc_total = wp.tile_sum(jc_values)[0]
if tid == 0:
grad_norm2[env] = grad_total
p_norm2[env] = p_total
jc_norm2[env] = jc_total
return _compute_norm_terms_tiled
@cache
def _make_compute_norm_terms_and_update_active_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_norm_terms_and_update_active_tiled(
active_env_input: wp.array(dtype=int),
dof_per_env: int,
participating_dof: wp.array(dtype=int, ndim=2),
d_scale: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dynamics_impulse: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
cost: wp.array(dtype=scalar),
previous_cost: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
iteration: int,
optimality_abs_tol: scalar,
optimality_rel_tol: scalar,
cost_abs_tol: scalar,
cost_rel_tol: scalar,
cost_min_alpha: scalar,
single_env_count: int,
active_env: wp.array(dtype=int),
converged_env: wp.array(dtype=int),
optimality_reached_env: wp.array(dtype=int),
cost_reached_env: wp.array(dtype=int),
newton_iterations_env: wp.array(dtype=int),
active_count: wp.array(dtype=int),
grad_norm2: wp.array(dtype=scalar),
p_norm2: wp.array(dtype=scalar),
jc_norm2: wp.array(dtype=scalar),
):
env, tid = wp.tid()
if active_env_input[env] == 0 or converged_env[env] == 1:
if tid == 0:
grad_norm2[env] = scalar(0.0)
p_norm2[env] = scalar(0.0)
jc_norm2[env] = scalar(0.0)
active_env[env] = 0
if single_env_count == 1:
active_count[0] = 0
return
grad_acc = scalar(0.0)
p_acc = scalar(0.0)
jc_acc = scalar(0.0)
stride = wp.block_dim()
i = tid
while i < dof_per_env:
if participating_dof[env, i] == 1:
s = d_scale[env, i]
g = s * grad[env, i]
p = s * dynamics_impulse[env, i]
jc = s * constraint_impulse[env, i]
grad_acc = grad_acc + g * g
p_acc = p_acc + p * p
jc_acc = jc_acc + jc * jc
i = i + stride
grad_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
grad_values[tid] = grad_acc
p_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
p_values[tid] = p_acc
jc_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
jc_values[tid] = jc_acc
grad_total = wp.tile_sum(grad_values)[0]
p_total = wp.tile_sum(p_values)[0]
jc_total = wp.tile_sum(jc_values)[0]
if tid == 0:
grad_norm2[env] = grad_total
p_norm2[env] = p_total
jc_norm2[env] = jc_total
grad_norm = wp.sqrt(wp.max(grad_total, scalar(0.0)))
p_norm = wp.sqrt(wp.max(p_total, scalar(0.0)))
jc_norm = wp.sqrt(wp.max(jc_total, scalar(0.0)))
opt_tol = scalar(optimality_abs_tol) + scalar(optimality_rel_tol) * wp.max(p_norm, jc_norm)
opt_reached = 0
cost_reached = 0
if grad_norm <= opt_tol:
opt_reached = 1
if iteration > 0 and alpha[env] > cost_min_alpha:
scale = scalar(0.5) * (wp.abs(cost[env]) + wp.abs(previous_cost[env]))
tol = scalar(cost_abs_tol) + scalar(cost_rel_tol) * scale
if wp.abs(cost[env] - previous_cost[env]) < tol:
cost_reached = 1
active = 1
if opt_reached == 1 or cost_reached == 1:
active = 0
active_env[env] = active
optimality_reached_env[env] = opt_reached
cost_reached_env[env] = cost_reached
converged_env[env] = 1 - active
if active == 0:
newton_iterations_env[env] = iteration
if single_env_count == 1:
active_count[0] = active
elif active == 1:
wp.atomic_add(active_count, 0, 1)
return _compute_norm_terms_and_update_active_tiled
@cache
def _make_compute_norm_terms_and_update_active_conditional_tiled_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _compute_norm_terms_and_update_active_conditional_tiled(
active_env_input: wp.array(dtype=int),
dof_per_env: int,
participating_dof: wp.array(dtype=int, ndim=2),
d_scale: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dynamics_impulse: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
cost: wp.array(dtype=scalar),
previous_cost: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
iteration_count: wp.array(dtype=int),
max_iterations: int,
optimality_abs_tol: scalar,
optimality_rel_tol: scalar,
cost_abs_tol: scalar,
cost_rel_tol: scalar,
cost_min_alpha: scalar,
single_env_count: int,
active_env: wp.array(dtype=int),
converged_env: wp.array(dtype=int),
optimality_reached_env: wp.array(dtype=int),
cost_reached_env: wp.array(dtype=int),
newton_iterations_env: wp.array(dtype=int),
active_count: wp.array(dtype=int),
max_reached: wp.array(dtype=int),
grad_norm2: wp.array(dtype=scalar),
p_norm2: wp.array(dtype=scalar),
jc_norm2: wp.array(dtype=scalar),
):
env, tid = wp.tid()
iteration = iteration_count[0]
if active_env_input[env] == 0 or converged_env[env] == 1:
if tid == 0:
grad_norm2[env] = scalar(0.0)
p_norm2[env] = scalar(0.0)
jc_norm2[env] = scalar(0.0)
active_env[env] = 0
if single_env_count == 1:
active_count[0] = 0
return
grad_acc = scalar(0.0)
p_acc = scalar(0.0)
jc_acc = scalar(0.0)
stride = wp.block_dim()
i = tid
while i < dof_per_env:
if participating_dof[env, i] == 1:
s = d_scale[env, i]
g = s * grad[env, i]
p = s * dynamics_impulse[env, i]
jc = s * constraint_impulse[env, i]
grad_acc = grad_acc + g * g
p_acc = p_acc + p * p
jc_acc = jc_acc + jc * jc
i = i + stride
grad_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
grad_values[tid] = grad_acc
p_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
p_values[tid] = p_acc
jc_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
jc_values[tid] = jc_acc
grad_total = wp.tile_sum(grad_values)[0]
p_total = wp.tile_sum(p_values)[0]
jc_total = wp.tile_sum(jc_values)[0]
if tid == 0:
grad_norm2[env] = grad_total
p_norm2[env] = p_total
jc_norm2[env] = jc_total
grad_norm = wp.sqrt(wp.max(grad_total, scalar(0.0)))
p_norm = wp.sqrt(wp.max(p_total, scalar(0.0)))
jc_norm = wp.sqrt(wp.max(jc_total, scalar(0.0)))
opt_tol = scalar(optimality_abs_tol) + scalar(optimality_rel_tol) * wp.max(p_norm, jc_norm)
opt_reached = 0
cost_reached = 0
if grad_norm <= opt_tol:
opt_reached = 1
if iteration > 0 and alpha[env] > cost_min_alpha:
scale = scalar(0.5) * (wp.abs(cost[env]) + wp.abs(previous_cost[env]))
tol = scalar(cost_abs_tol) + scalar(cost_rel_tol) * scale
if wp.abs(cost[env] - previous_cost[env]) < tol:
cost_reached = 1
optimality_reached_env[env] = opt_reached
cost_reached_env[env] = cost_reached
if opt_reached == 1 or cost_reached == 1:
active_env[env] = 0
converged_env[env] = 1
newton_iterations_env[env] = iteration
if single_env_count == 1:
active_count[0] = 0
return
newton_iterations_env[env] = iteration
if iteration >= max_iterations:
active_env[env] = 0
max_reached[0] = 1
if single_env_count == 1:
active_count[0] = 0
return
active_env[env] = 1
converged_env[env] = 0
if single_env_count == 1:
active_count[0] = 1
else:
wp.atomic_add(active_count, 0, 1)
return _compute_norm_terms_and_update_active_conditional_tiled
@wp.kernel(module="unique")
def _increment_scalar_i32(value: wp.array(dtype=int)):
value[0] = value[0] + 1
@wp.kernel(module="unique")
def _set_scalar_i32(value: wp.array(dtype=int), new_value: int):
value[0] = new_value
@cache
def _create_pack_dense_to_padded_batched_kernel(dtype):
@wp.kernel(module="unique")
def _pack_dense_to_padded_batched(
src: wp.array(dtype=scalar, ndim=3),
dst: wp.array(dtype=dtype, ndim=3),
active_size: int,
diag_shift: scalar,
):
env, i, j = wp.tid()
if i < active_size and j < active_size:
out = dtype(src[env, i, j])
if i == j:
out = out + dtype(diag_shift)
dst[env, i, j] = out
elif i == j:
dst[env, i, j] = dtype(1.0)
else:
dst[env, i, j] = dtype(0.0)
return _pack_dense_to_padded_batched
@cache
def _create_pack_grad_to_padded_rhs_batched_kernel(dtype):
@wp.kernel(module="unique")
def _pack_grad_to_padded_rhs_batched(
grad: wp.array(dtype=scalar, ndim=2),
rhs: wp.array(dtype=dtype, ndim=3),
active_size: int,
):
env, i = wp.tid()
if i < active_size:
rhs[env, i, 0] = -dtype(grad[env, i])
else:
rhs[env, i, 0] = dtype(0.0)
return _pack_grad_to_padded_rhs_batched
@cache
def _create_pack_dense_and_grad_to_padded_batched_kernel(dtype):
@wp.kernel(module="unique")
def _pack_dense_and_grad_to_padded_batched(
hessian: wp.array(dtype=scalar, ndim=3),
grad: wp.array(dtype=scalar, ndim=2),
chol_a: wp.array(dtype=dtype, ndim=3),
rhs: wp.array(dtype=dtype, ndim=3),
active_size: int,
diag_shift: scalar,
):
env, i, j = wp.tid()
if i < active_size and j < active_size:
out = dtype(hessian[env, i, j])
if i == j:
out = out + dtype(diag_shift)
chol_a[env, i, j] = out
elif i == j:
chol_a[env, i, j] = dtype(1.0)
else:
chol_a[env, i, j] = dtype(0.0)
if j == 0:
if i < active_size:
rhs[env, i, 0] = -dtype(grad[env, i])
else:
rhs[env, i, 0] = dtype(0.0)
return _pack_dense_and_grad_to_padded_batched
@cache
def _create_unpack_solution_batched_kernel(dtype):
@wp.kernel(module="unique")
def _unpack_solution_batched(
src: wp.array(dtype=dtype, ndim=3),
dst: wp.array(dtype=scalar, ndim=2),
active_size: int,
):
env, i = wp.tid()
if i < active_size:
dst[env, i] = scalar(src[env, i, 0])
return _unpack_solution_batched
@cache
def _create_unpack_solution_and_first_batched_kernel(dtype):
@wp.kernel(module="unique")
def _unpack_solution_and_first_batched(
active_env: wp.array(dtype=int),
loop_iteration: wp.array(dtype=int),
src: wp.array(dtype=dtype, ndim=3),
dst: wp.array(dtype=scalar, ndim=2),
first_dv: wp.array(dtype=scalar, ndim=2),
active_size: int,
):
env, i = wp.tid()
if i < active_size:
value = scalar(src[env, i, 0])
dst[env, i] = value
if active_env[env] == 1 and loop_iteration[0] == 0:
first_dv[env, i] = value
return _unpack_solution_and_first_batched
@wp.kernel(module="unique")
def _compute_search_direction_data_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
a_mat: wp.array(dtype=scalar, ndim=3),
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
dp: wp.array(dtype=scalar, ndim=2),
dell0: wp.array(dtype=scalar),
dell_a0: wp.array(dtype=scalar),
d2ell_a: wp.array(dtype=scalar),
):
env, i = wp.tid()
if active_env[env] == 0 or i >= dof_per_env:
return
acc = scalar(0.0)
for j in range(dof_per_env):
acc = acc + a_mat[env, i, j] * dv[env, j]
dp[env, i] = acc
wp.atomic_add(dell0, env, grad[env, i] * dv[env, i])
wp.atomic_add(dell_a0, env, acc * (v[env, i] - v_star[env, i]))
wp.atomic_add(d2ell_a, env, dv[env, i] * acc)
@wp.kernel(module="unique")
def _compute_search_direction_data_serial_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
a_mat: wp.array(dtype=scalar, ndim=3),
v: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
grad: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
dp: wp.array(dtype=scalar, ndim=2),
dell0: wp.array(dtype=scalar),
dell_a0: wp.array(dtype=scalar),
d2ell_a: wp.array(dtype=scalar),
):
env = wp.tid()
d0 = scalar(0.0)
da0 = scalar(0.0)
d2 = scalar(0.0)
if active_env[env] == 1:
for i in range(dof_per_env):
acc = scalar(0.0)
for j in range(dof_per_env):
acc = acc + a_mat[env, i, j] * dv[env, j]
dp[env, i] = acc
d0 = d0 + grad[env, i] * dv[env, i]
da0 = da0 + acc * (v[env, i] - v_star[env, i])
d2 = d2 + dv[env, i] * acc
dell0[env] = d0
dell_a0[env] = da0
d2ell_a[env] = d2
@wp.kernel(module="unique")
def _axpy_to_trial_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
x: wp.array(dtype=scalar, ndim=2),
direction: wp.array(dtype=scalar, ndim=2),
alpha: wp.array(dtype=scalar),
out: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i >= dof_per_env:
return
if active_env[env] == 1:
out[env, i] = x[env, i] + alpha[env] * direction[env, i]
else:
out[env, i] = x[env, i]
@wp.kernel(module="unique")
def _compute_line_derivative_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
v_trial: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
dp: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
derivative: wp.array(dtype=scalar),
):
env, i = wp.tid()
if active_env[env] == 0 or i >= dof_per_env:
return
value = dp[env, i] * (v_trial[env, i] - v_star[env, i])
value = value - dv[env, i] * constraint_impulse[env, i]
wp.atomic_add(derivative, env, value)
@wp.kernel(module="unique")
def _compute_line_derivative_serial_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
v_trial: wp.array(dtype=scalar, ndim=2),
v_star: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
dp: wp.array(dtype=scalar, ndim=2),
constraint_impulse: wp.array(dtype=scalar, ndim=2),
derivative: wp.array(dtype=scalar),
):
env = wp.tid()
out = scalar(0.0)
if active_env[env] == 1:
for i in range(dof_per_env):
value = dp[env, i] * (v_trial[env, i] - v_star[env, i])
value = value - dv[env, i] * constraint_impulse[env, i]
out = out + value
derivative[env] = out
@wp.kernel(module="unique")
def _compute_line_second_derivative_serial_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_jac: wp.array(dtype=scalar, ndim=4),
contact_g: wp.array(dtype=mat33, ndim=2),
has_pd: int,
has_limits: int,
pd_hdiag: wp.array(dtype=scalar, ndim=2),
limit_hdiag: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
d2ell_a: wp.array(dtype=scalar),
derivative2: wp.array(dtype=scalar),
):
env = wp.tid()
if active_env[env] == 0:
derivative2[env] = scalar(0.0)
return
out = d2ell_a[env]
count = contact_count[env]
if count > max_contacts:
count = max_contacts
c = int(0)
while c < count:
dvc = v3(scalar(0.0), scalar(0.0), scalar(0.0))
for j in range(dof_per_env):
dvj = dv[env, j]
dvc = vec3(
dvc.x + contact_jac[env, c, 0, j] * dvj,
dvc.y + contact_jac[env, c, 1, j] * dvj,
dvc.z + contact_jac[env, c, 2, j] * dvj,
)
g = contact_g[env, c]
gdvc = v3(
g[0, 0] * dvc.x + g[0, 1] * dvc.y + g[0, 2] * dvc.z,
g[1, 0] * dvc.x + g[1, 1] * dvc.y + g[1, 2] * dvc.z,
g[2, 0] * dvc.x + g[2, 1] * dvc.y + g[2, 2] * dvc.z,
)
out = out + dvc.x * gdvc.x + dvc.y * gdvc.y + dvc.z * gdvc.z
c = c + 1
for i in range(dof_per_env):
hdiag = scalar(0.0)
if has_pd == 1:
hdiag = hdiag + pd_hdiag[env, i]
if has_limits == 1:
hdiag = hdiag + limit_hdiag[env, i]
out = out + hdiag * dv[env, i] * dv[env, i]
derivative2[env] = out
@wp.kernel(module="unique")
def _replace_trial_cost_with_sap_line_search_cost_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_cost: wp.array(dtype=scalar, ndim=2),
has_pd: int,
has_limits: int,
pd_cost: wp.array(dtype=scalar, ndim=2),
limit_cost: wp.array(dtype=scalar, ndim=2),
momentum_cost0: wp.array(dtype=scalar),
dell_a0: wp.array(dtype=scalar),
d2ell_a: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
trial_cost: wp.array(dtype=scalar),
):
env = wp.tid()
if active_env[env] == 0:
return
regularizer_cost = scalar(0.0)
count = contact_count[env]
if count > max_contacts:
count = max_contacts
c = int(0)
while c < count:
regularizer_cost = regularizer_cost + contact_cost[env, c]
c = c + 1
for i in range(dof_per_env):
if has_pd == 1:
regularizer_cost = regularizer_cost + pd_cost[env, i]
if has_limits == 1:
regularizer_cost = regularizer_cost + limit_cost[env, i]
a = alpha[env]
momentum_cost = (
momentum_cost0[env]
+ a * dell_a0[env]
+ scalar(0.5) * a * a * d2ell_a[env]
)
trial_cost[env] = momentum_cost + regularizer_cost
@wp.kernel(module="unique")
def _init_unit_decay_line_search_state(
newton_active: wp.array(dtype=int),
current_cost: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
):
env = wp.tid()
alpha[env] = scalar(1.0)
ls_active[env] = 0
ls_accepted[env] = 0
ls_status[env] = 0
ls_iterations[env] = 0
accepted_cost[env] = current_cost[env]
if newton_active[env] == 1:
ls_active[env] = 1
@wp.kernel(module="unique")
def _update_unit_decay_line_search_state(
trial_cost: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
min_alpha: scalar,
cost_relax_r: scalar,
cost_relax_a: scalar,
decay: scalar,
):
env = wp.tid()
if ls_active[env] == 0:
return
ls_iterations[env] = ls_iterations[env] + 1
current = current_cost[env]
tol = cost_relax_a + cost_relax_r * wp.abs(current)
if trial_cost[env] <= current + tol:
ls_active[env] = 0
ls_accepted[env] = 1
accepted_cost[env] = trial_cost[env]
return
a = alpha[env] * decay
alpha[env] = a
if a < min_alpha:
ls_active[env] = 0
ls_status[env] = -2
return
wp.atomic_add(next_active_count, 0, 1)
@cache
def _make_unit_decay_line_search_fused_parallel_kernel(tile_size: int):
@wp.kernel(enable_backward=False, module="unique")
def _run_unit_decay_line_search_fused_parallel_batched(
newton_active: wp.array(dtype=int),
dof_per_env: int,
max_contacts: int,
contact_count: wp.array(dtype=int),
contact_vc0: wp.array(dtype=vec3, ndim=2),
contact_dvc: wp.array(dtype=vec3, ndim=2),
contact_phi0: wp.array(dtype=scalar, ndim=2),
contact_w_eff: wp.array(dtype=scalar, ndim=2),
contact_mu: wp.array(dtype=scalar, ndim=2),
contact_k: wp.array(dtype=scalar, ndim=2),
contact_tau_d: wp.array(dtype=scalar, ndim=2),
pd_active: wp.array(dtype=int, ndim=2),
pd_a: wp.array(dtype=scalar, ndim=2),
pd_gain: wp.array(dtype=scalar, ndim=2),
pd_limit: wp.array(dtype=scalar, ndim=2),
limit_lower_active: wp.array(dtype=int, ndim=2),
limit_upper_active: wp.array(dtype=int, ndim=2),
limit_lower_vhat: wp.array(dtype=scalar, ndim=2),
limit_upper_vhat: wp.array(dtype=scalar, ndim=2),
limit_lower_r: wp.array(dtype=scalar, ndim=2),
limit_upper_r: wp.array(dtype=scalar, ndim=2),
limit_lower_rinv: wp.array(dtype=scalar, ndim=2),
limit_upper_rinv: wp.array(dtype=scalar, ndim=2),
v: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
v_flat: wp.array(dtype=scalar),
line_base0: wp.array(dtype=scalar),
line_base_linear: wp.array(dtype=scalar),
line_base_quadratic: wp.array(dtype=scalar),
beta: scalar,
sigma: scalar,
dt: scalar,
has_contact_terms: int,
has_pd_terms: int,
has_limit_terms: int,
max_iterations: int,
decay: scalar,
min_alpha: scalar,
cost_relax_r: scalar,
cost_relax_a: scalar,
alpha_out: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
previous_cost: wp.array(dtype=scalar),
accepted_cost: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
ls_iterations_total: wp.array(dtype=int),
):
env, tid = wp.tid()
current = current_cost[env]
stride = wp.block_dim()
if newton_active[env] == 0:
i = tid
while i < dof_per_env:
v_flat[env * dof_per_env + i] = v[env, i]
i = i + stride
if tid == 0:
alpha_out[env] = scalar(1.0)
ls_active[env] = 0
ls_accepted[env] = 0
ls_status[env] = 0
ls_iterations[env] = 0
accepted_cost[env] = current
return
base0 = line_base0[env]
base_linear = line_base_linear[env]
base_quadratic = line_base_quadratic[env]
alpha = scalar(1.0)
trial_cost = current
accepted_flag = int(0)
status_value = int(0)
iterations_value = int(0)
for _ in range(max_iterations):
momentum_cost = base0 + alpha * base_linear + scalar(0.5) * alpha * alpha * base_quadratic
local_regularizer = scalar(0.0)
if has_contact_terms != 0:
count = contact_count[env]
if count > max_contacts:
count = max_contacts
c = tid
while c < count:
vc_base = contact_vc0[env, c]
dvc = contact_dvc[env, c]
vc = v3(
vc_base.x + alpha * dvc.x,
vc_base.y + alpha * dvc.y,
vc_base.z + alpha * dvc.z,
)
local_regularizer = local_regularizer + contact_projection_cost_from_vc(
env,
c,
vc,
contact_phi0,
contact_w_eff,
contact_mu,
contact_k,
contact_tau_d,
beta,
sigma,
dt,
)
c = c + stride
i = tid
while i < dof_per_env:
vi = v[env, i] + alpha * dv[env, i]
if has_pd_terms != 0 and pd_active[env, i] == 1:
gain = pd_gain[env, i]
if gain > scalar(0.0):
y = pd_a[env, i] - gain * vi
local_regularizer = (
local_regularizer
+ (dt / gain) * clamp_antiderivative(y, pd_limit[env, i])
)
if has_limit_terms != 0:
if limit_lower_active[env, i] == 1:
lower_gamma = limit_lower_rinv[env, i] * (limit_lower_vhat[env, i] - vi)
if lower_gamma > scalar(0.0):
local_regularizer = (
local_regularizer
+ scalar(0.5) * limit_lower_r[env, i] * lower_gamma * lower_gamma
)
if limit_upper_active[env, i] == 1:
upper_gamma = limit_upper_rinv[env, i] * (limit_upper_vhat[env, i] + vi)
if upper_gamma > scalar(0.0):
local_regularizer = (
local_regularizer
+ scalar(0.5) * limit_upper_r[env, i] * upper_gamma * upper_gamma
)
i = i + stride
regularizer_values = wp.tile_zeros((tile_size,), dtype=scalar, storage="shared")
regularizer_values[tid] = local_regularizer
regularizer_cost = wp.tile_sum(regularizer_values)[0]
trial_cost = momentum_cost + regularizer_cost
iterations_value = iterations_value + 1
tol = cost_relax_a + cost_relax_r * wp.abs(current)
if trial_cost <= current + tol:
accepted_flag = 1
break
alpha = alpha * decay
if alpha < min_alpha:
status_value = -2
break
i = tid
while i < dof_per_env:
vi = v[env, i]
if accepted_flag == 1:
vi = vi + alpha * dv[env, i]
v[env, i] = vi
v_flat[env * dof_per_env + i] = vi
i = i + stride
if tid == 0:
alpha_out[env] = alpha
ls_active[env] = 0
ls_accepted[env] = accepted_flag
ls_status[env] = status_value
ls_iterations[env] = iterations_value
if accepted_flag == 1:
accepted_cost[env] = trial_cost
previous_cost[env] = current
current_cost[env] = trial_cost
ls_iterations_total[env] = ls_iterations_total[env] + iterations_value
else:
accepted_cost[env] = current
previous_cost[env] = current
current_cost[env] = current
ls_iterations_total[env] = ls_iterations_total[env] + iterations_value
return _run_unit_decay_line_search_fused_parallel_batched
@wp.kernel(module="unique")
def _init_sap_backtracking_state(
newton_active: wp.array(dtype=int),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha_max: scalar,
relative_slop: scalar,
alpha: wp.array(dtype=scalar),
alpha_prev: wp.array(dtype=scalar),
ell_prev: wp.array(dtype=scalar),
ell_slop: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
):
env = wp.tid()
ls_active[env] = 0
ls_accepted[env] = 0
ls_status[env] = 0
ls_iterations[env] = 0
accepted_cost[env] = current_cost[env]
if newton_active[env] == 0:
return
amax = scalar(alpha_max)
alpha[env] = amax
alpha_prev[env] = amax
ell_prev[env] = current_cost[env]
scale = wp.max(scalar(1.0), wp.abs(current_cost[env]))
ell_slop[env] = (scalar(relative_slop) / scalar(10.0)) * scale
if dell0[env] >= scalar(0.0) or not wp.isfinite(dell0[env]):
ls_status[env] = -1
return
ls_active[env] = 1
@wp.kernel(module="unique")
def _accept_sap_alpha_max(
trial_cost: wp.array(dtype=scalar),
trial_derivative: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
relative_slop: scalar,
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
alpha: wp.array(dtype=scalar),
alpha_prev: wp.array(dtype=scalar),
ell_prev: wp.array(dtype=scalar),
ell_slop: wp.array(dtype=scalar),
accepted_cost: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
):
env = wp.tid()
if ls_active[env] == 0:
return
ell = trial_cost[env]
scale = wp.max(scalar(1.0), scalar(0.5) * (wp.abs(ell) + wp.abs(current_cost[env])))
slop = (scalar(relative_slop) / scalar(10.0)) * scale
ell_slop[env] = slop
if trial_derivative[env] < scalar(0.0) or trial_derivative[env] < slop:
ls_active[env] = 0
ls_accepted[env] = 1
accepted_cost[env] = ell
return
alpha_prev[env] = alpha[env]
ell_prev[env] = ell
wp.atomic_add(next_active_count, 0, 1)
@wp.kernel(module="unique")
def _scale_sap_backtracking_alpha(
ls_active: wp.array(dtype=int),
alpha: wp.array(dtype=scalar),
rho: scalar,
):
env = wp.tid()
if ls_active[env] == 1:
alpha[env] = alpha[env] * scalar(rho)
@wp.kernel(module="unique")
def _update_sap_backtracking_iteration(
trial_cost: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
alpha_prev: wp.array(dtype=scalar),
ell_prev: wp.array(dtype=scalar),
ell_slop: wp.array(dtype=scalar),
armijo_c: scalar,
iteration: int,
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
):
env = wp.tid()
if ls_active[env] == 0:
return
ell = trial_cost[env]
a = alpha[env]
a_prev = alpha_prev[env]
e_prev = ell_prev[env]
denom = a - a_prev
dell_approx = scalar(0.0)
if wp.abs(denom) > scalar(0.0):
dell_approx = (ell - e_prev) / denom
accept = 0
accept_prev = 0
if wp.abs(dell_approx) < ell_slop[env]:
accept = 1
elif ell > e_prev and sap_armijo_ok(a, ell, current_cost[env], dell0[env], scalar(armijo_c)):
if sap_armijo_ok(a_prev, e_prev, current_cost[env], dell0[env], scalar(armijo_c)):
accept = 1
accept_prev = 1
else:
accept = 1
if accept == 1:
ls_active[env] = 0
ls_accepted[env] = 1
ls_iterations[env] = iteration
if accept_prev == 1:
alpha[env] = a_prev
accepted_cost[env] = e_prev
else:
accepted_cost[env] = ell
return
alpha_prev[env] = a
ell_prev[env] = ell
wp.atomic_add(next_active_count, 0, 1)
@wp.kernel(module="unique")
def _update_sap_backtracking_iteration_conditional(
trial_cost: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
alpha_prev: wp.array(dtype=scalar),
ell_prev: wp.array(dtype=scalar),
ell_slop: wp.array(dtype=scalar),
armijo_c: scalar,
iteration_count: wp.array(dtype=int),
max_iterations: int,
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
):
env = wp.tid()
if ls_active[env] == 0:
return
iteration = iteration_count[0]
ell = trial_cost[env]
a = alpha[env]
a_prev = alpha_prev[env]
e_prev = ell_prev[env]
denom = a - a_prev
dell_approx = scalar(0.0)
if wp.abs(denom) > scalar(0.0):
dell_approx = (ell - e_prev) / denom
accept = 0
accept_prev = 0
if wp.abs(dell_approx) < ell_slop[env]:
accept = 1
elif ell > e_prev and sap_armijo_ok(a, ell, current_cost[env], dell0[env], scalar(armijo_c)):
if sap_armijo_ok(a_prev, e_prev, current_cost[env], dell0[env], scalar(armijo_c)):
accept = 1
accept_prev = 1
else:
accept = 1
if accept == 1:
ls_active[env] = 0
ls_accepted[env] = 1
ls_iterations[env] = iteration
if accept_prev == 1:
alpha[env] = a_prev
accepted_cost[env] = e_prev
else:
accepted_cost[env] = ell
return
if iteration >= max_iterations - 1:
ls_active[env] = 0
if sap_armijo_ok(alpha[env], trial_cost[env], current_cost[env], dell0[env], scalar(armijo_c)):
ls_accepted[env] = 1
ls_iterations[env] = max_iterations
accepted_cost[env] = trial_cost[env]
else:
ls_status[env] = -2
return
alpha_prev[env] = a
ell_prev[env] = ell
wp.atomic_add(next_active_count, 0, 1)
@wp.kernel(module="unique")
def _finalize_sap_backtracking(
trial_cost: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha: wp.array(dtype=scalar),
armijo_c: scalar,
max_iterations: int,
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
):
env = wp.tid()
if ls_active[env] == 0:
return
if sap_armijo_ok(alpha[env], trial_cost[env], current_cost[env], dell0[env], scalar(armijo_c)):
ls_active[env] = 0
ls_accepted[env] = 1
ls_iterations[env] = max_iterations
accepted_cost[env] = trial_cost[env]
else:
ls_status[env] = -2
@wp.kernel(module="unique")
def _copy_i32_batched(src: wp.array(dtype=int), dst: wp.array(dtype=int)):
env = wp.tid()
dst[env] = src[env]
@wp.kernel(module="unique")
def _init_sap_exact_alpha_max_state(
newton_active: wp.array(dtype=int),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha_max: scalar,
alpha: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
):
env = wp.tid()
alpha[env] = scalar(alpha_max)
ls_active[env] = 0
ls_accepted[env] = 0
ls_status[env] = 0
ls_iterations[env] = 0
accepted_cost[env] = current_cost[env]
if newton_active[env] == 0:
return
if dell0[env] >= scalar(0.0) or not wp.isfinite(dell0[env]):
ls_status[env] = -1
return
ls_active[env] = 1
@wp.kernel(module="unique")
def _init_sap_exact_root_state(
trial_cost: wp.array(dtype=scalar),
trial_derivative: wp.array(dtype=scalar),
trial_second_derivative: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
dell0: wp.array(dtype=scalar),
alpha_max: scalar,
cost_abs_tol: scalar,
cost_rel_tol: scalar,
root_tolerance: scalar,
alpha: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
accepted_cost: wp.array(dtype=scalar),
exact_scale: wp.array(dtype=scalar),
exact_x_lower: wp.array(dtype=scalar),
exact_x_upper: wp.array(dtype=scalar),
exact_f_lower: wp.array(dtype=scalar),
exact_f_upper: wp.array(dtype=scalar),
exact_root: wp.array(dtype=scalar),
exact_minus_dx: wp.array(dtype=scalar),
exact_minus_dx_previous: wp.array(dtype=scalar),
exact_x_tolerance: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
):
env = wp.tid()
if ls_active[env] == 0:
return
if trial_derivative[env] <= scalar(0.0):
ls_active[env] = 0
ls_accepted[env] = 1
alpha[env] = scalar(alpha_max)
accepted_cost[env] = trial_cost[env]
return
if -dell0[env] < scalar(cost_abs_tol) + scalar(cost_rel_tol) * trial_cost[env]:
ls_active[env] = 0
ls_accepted[env] = 1
alpha[env] = scalar(1.0)
return
scale = -dell0[env]
f_upper = trial_derivative[env] / scale
if not wp.isfinite(f_upper) or f_upper <= scalar(0.0):
ls_active[env] = 0
ls_status[env] = -3
return
alpha_guess = -dell0[env] / trial_second_derivative[env]
if alpha_guess > scalar(alpha_max):
alpha_guess = scalar(alpha_max)
if (
not wp.isfinite(alpha_guess)
or alpha_guess < scalar(0.0)
or alpha_guess > scalar(alpha_max)
or scalar(root_tolerance) * alpha_guess <= scalar(0.0)
):
ls_active[env] = 0
ls_status[env] = -4
return
exact_scale[env] = scale
exact_x_lower[env] = scalar(0.0)
exact_x_upper[env] = scalar(alpha_max)
exact_f_lower[env] = scalar(-1.0)
exact_f_upper[env] = f_upper
exact_root[env] = alpha_guess
exact_minus_dx[env] = -scalar(alpha_max)
exact_minus_dx_previous[env] = -scalar(alpha_max)
exact_x_tolerance[env] = scalar(root_tolerance) * alpha_guess
alpha[env] = alpha_guess
accepted_cost[env] = current_cost[env]
wp.atomic_add(next_active_count, 0, 1)
@wp.kernel(module="unique")
def _update_sap_exact_root_state(
trial_derivative: wp.array(dtype=scalar),
trial_second_derivative: wp.array(dtype=scalar),
root_tolerance: scalar,
iteration_count: wp.array(dtype=int),
max_iterations: int,
alpha: wp.array(dtype=scalar),
ls_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
ls_status: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
exact_scale: wp.array(dtype=scalar),
exact_x_lower: wp.array(dtype=scalar),
exact_x_upper: wp.array(dtype=scalar),
exact_f_lower: wp.array(dtype=scalar),
exact_f_upper: wp.array(dtype=scalar),
exact_root: wp.array(dtype=scalar),
exact_minus_dx: wp.array(dtype=scalar),
exact_minus_dx_previous: wp.array(dtype=scalar),
exact_x_tolerance: wp.array(dtype=scalar),
next_active_count: wp.array(dtype=int),
):
env = wp.tid()
if ls_active[env] == 0:
return
iteration = iteration_count[0]
scale = exact_scale[env]
f = trial_derivative[env] / scale
df = trial_second_derivative[env] / scale
root = exact_root[env]
f_upper = exact_f_upper[env]
lower_update = 0
if (f < scalar(0.0) and f_upper >= scalar(0.0)) or (
f >= scalar(0.0) and f_upper < scalar(0.0)
):
lower_update = 1
if lower_update == 1:
exact_x_lower[env] = root
exact_f_lower[env] = f
else:
exact_x_upper[env] = root
exact_f_upper[env] = f
if wp.abs(f) < scalar(root_tolerance):
ls_active[env] = 0
ls_accepted[env] = 1
ls_iterations[env] = iteration
alpha[env] = root
return
if iteration >= max_iterations:
ls_active[env] = 0
ls_status[env] = -5
return
newton_is_slow = scalar(0.0)
if scalar(2.0) * wp.abs(f) > wp.abs(exact_minus_dx_previous[env] * df):
newton_is_slow = scalar(1.0)
exact_minus_dx_previous[env] = exact_minus_dx[env]
bisect_minus_dx = scalar(0.5) * (exact_x_lower[env] - exact_x_upper[env])
bisect_root = exact_x_lower[env] - bisect_minus_dx
newton_minus_dx = scalar(0.0)
newton_root = root
if wp.abs(df) > scalar(0.0):
newton_minus_dx = f / df
newton_root = root - newton_minus_dx
use_bisect = 0
if newton_is_slow > scalar(0.0):
use_bisect = 1
if newton_root < exact_x_lower[env] or newton_root > exact_x_upper[env]:
use_bisect = 1
if not wp.isfinite(newton_root) or not wp.isfinite(newton_minus_dx):
use_bisect = 1
next_root = newton_root
next_minus_dx = newton_minus_dx
if use_bisect == 1:
next_root = bisect_root
next_minus_dx = bisect_minus_dx
if wp.abs(next_minus_dx) < exact_x_tolerance[env]:
ls_active[env] = 0
ls_accepted[env] = 1
ls_iterations[env] = iteration
alpha[env] = next_root
return
exact_root[env] = next_root
exact_minus_dx[env] = next_minus_dx
alpha[env] = next_root
wp.atomic_add(next_active_count, 0, 1)
@wp.kernel(module="unique")
def _store_exact_accepted_cost(
ls_accepted: wp.array(dtype=int),
trial_cost: wp.array(dtype=scalar),
accepted_cost: wp.array(dtype=scalar),
):
env = wp.tid()
if ls_accepted[env] == 1:
accepted_cost[env] = trial_cost[env]
@wp.kernel(module="unique")
def _commit_line_search_step_batched(
newton_active: wp.array(dtype=int),
ls_accepted: wp.array(dtype=int),
dof_per_env: int,
alpha: wp.array(dtype=scalar),
v: wp.array(dtype=scalar, ndim=2),
dv: wp.array(dtype=scalar, ndim=2),
v_flat: wp.array(dtype=scalar),
current_cost: wp.array(dtype=scalar),
previous_cost: wp.array(dtype=scalar),
accepted_cost: wp.array(dtype=scalar),
):
env, i = wp.tid()
if i >= dof_per_env:
return
if newton_active[env] == 1 and ls_accepted[env] == 1:
if i == 0:
previous_cost[env] = current_cost[env]
current_cost[env] = accepted_cost[env]
v[env, i] = v[env, i] + alpha[env] * dv[env, i]
v_flat[env * dof_per_env + i] = v[env, i]
@wp.kernel(module="unique")
def _accumulate_line_search_iterations_batched(
ls_accepted: wp.array(dtype=int),
ls_iterations: wp.array(dtype=int),
total_iterations: wp.array(dtype=int),
):
env = wp.tid()
if ls_accepted[env] == 1:
total_iterations[env] = total_iterations[env] + ls_iterations[env]
@wp.kernel(module="unique")
def _copy_first_dv_active_batched(
active_env: wp.array(dtype=int),
dof_per_env: int,
dv: wp.array(dtype=scalar, ndim=2),
first_dv: wp.array(dtype=scalar, ndim=2),
):
env, i = wp.tid()
if i < dof_per_env and active_env[env] == 1:
first_dv[env, i] = dv[env, i]
@wp.kernel(module="unique")
def _count_active_int(
active: wp.array(dtype=int),
count: wp.array(dtype=int),
):
env = wp.tid()
if active[env] == 1:
wp.atomic_add(count, 0, 1)
@wp.kernel(module="unique")
def _activate_unconverged_envs_batched(
converged_env: wp.array(dtype=int),
active_env: wp.array(dtype=int),
):
env = wp.tid()
active_env[env] = 1 - converged_env[env]
return SimpleNamespace(
_copy_flat_to_env_batched=_copy_flat_to_env_batched,
_copy_env_to_flat_batched=_copy_env_to_flat_batched,
_copy_env_to_env_batched=_copy_env_to_env_batched,
_copy_solve_velocity_inputs_flat_batched=_copy_solve_velocity_inputs_flat_batched,
_copy_solve_velocity_inputs_flat_batched_with_guess_flag=_copy_solve_velocity_inputs_flat_batched_with_guess_flag,
_initialize_and_mark_unconstrained_free_envs_batched=_initialize_and_mark_unconstrained_free_envs_batched,
_initialize_newton_loop_state=_initialize_newton_loop_state,
_extract_a_diag_data_batched=_extract_a_diag_data_batched,
_clear_participating_dofs_batched=_clear_participating_dofs_batched,
_mark_contact_participating_dofs_batched=_mark_contact_participating_dofs_batched,
_mark_model_participating_dofs_batched=_mark_model_participating_dofs_batched,
_build_pd_terms_sap_batched=_build_pd_terms_sap_batched,
_eval_pd_terms_sap_batched=_eval_pd_terms_sap_batched,
_build_limit_terms_sap_batched=_build_limit_terms_sap_batched,
_eval_limit_terms_sap_batched=_eval_limit_terms_sap_batched,
_projection_eval_contact_sap_batched=_projection_eval_contact_sap_batched,
_projection_cost_only_contact_sap_batched=_projection_cost_only_contact_sap_batched,
_projection_eval_contact_gamma_sap_batched=_projection_eval_contact_gamma_sap_batched,
_projection_eval_contact_hessian_sap_batched=_projection_eval_contact_hessian_sap_batched,
_accumulate_pd_impulse_batched=_accumulate_pd_impulse_batched,
_accumulate_limit_impulse_batched=_accumulate_limit_impulse_batched,
_make_contact_impulse_single_tile_kernel=_make_contact_impulse_single_tile_kernel,
_make_contact_hessian_single_tile_kernel=_make_contact_hessian_single_tile_kernel,
_make_pack_contact_hessian_gemm_inputs_kernel=_make_pack_contact_hessian_gemm_inputs_kernel,
_make_contact_hessian_gemm_tile_kernel=_make_contact_hessian_gemm_tile_kernel,
_make_assemble_grad_and_dynamics_impulse_tiled_kernel=_make_assemble_grad_and_dynamics_impulse_tiled_kernel,
_make_assemble_model_terms_and_grad_tiled_kernel=_make_assemble_model_terms_and_grad_tiled_kernel,
_assemble_hessian_total_batched=_assemble_hessian_total_batched,
_make_compute_base_cost_tiled_kernel=_make_compute_base_cost_tiled_kernel,
_make_compute_line_search_base_coeffs_tiled_kernel=_make_compute_line_search_base_coeffs_tiled_kernel,
_make_compute_line_search_contact_delta_velocity_tiled_kernel=_make_compute_line_search_contact_delta_velocity_tiled_kernel,
_make_compute_norm_terms_tiled_kernel=_make_compute_norm_terms_tiled_kernel,
_make_compute_norm_terms_and_update_active_tiled_kernel=_make_compute_norm_terms_and_update_active_tiled_kernel,
_make_compute_norm_terms_and_update_active_conditional_tiled_kernel=_make_compute_norm_terms_and_update_active_conditional_tiled_kernel,
_increment_scalar_i32=_increment_scalar_i32,
_set_scalar_i32=_set_scalar_i32,
_create_pack_dense_to_padded_batched_kernel=_create_pack_dense_to_padded_batched_kernel,
_create_pack_grad_to_padded_rhs_batched_kernel=_create_pack_grad_to_padded_rhs_batched_kernel,
_create_pack_dense_and_grad_to_padded_batched_kernel=_create_pack_dense_and_grad_to_padded_batched_kernel,
_create_unpack_solution_batched_kernel=_create_unpack_solution_batched_kernel,
_create_unpack_solution_and_first_batched_kernel=_create_unpack_solution_and_first_batched_kernel,
_compute_search_direction_data_batched=_compute_search_direction_data_batched,
_compute_search_direction_data_serial_batched=_compute_search_direction_data_serial_batched,
_axpy_to_trial_batched=_axpy_to_trial_batched,
_compute_line_derivative_batched=_compute_line_derivative_batched,
_compute_line_derivative_serial_batched=_compute_line_derivative_serial_batched,
_compute_line_second_derivative_serial_batched=_compute_line_second_derivative_serial_batched,
_replace_trial_cost_with_sap_line_search_cost_batched=_replace_trial_cost_with_sap_line_search_cost_batched,
_init_unit_decay_line_search_state=_init_unit_decay_line_search_state,
_update_unit_decay_line_search_state=_update_unit_decay_line_search_state,
_make_unit_decay_line_search_fused_parallel_kernel=_make_unit_decay_line_search_fused_parallel_kernel,
_init_sap_backtracking_state=_init_sap_backtracking_state,
_accept_sap_alpha_max=_accept_sap_alpha_max,
_scale_sap_backtracking_alpha=_scale_sap_backtracking_alpha,
_update_sap_backtracking_iteration=_update_sap_backtracking_iteration,
_update_sap_backtracking_iteration_conditional=_update_sap_backtracking_iteration_conditional,
_finalize_sap_backtracking=_finalize_sap_backtracking,
_copy_i32_batched=_copy_i32_batched,
_init_sap_exact_alpha_max_state=_init_sap_exact_alpha_max_state,
_init_sap_exact_root_state=_init_sap_exact_root_state,
_update_sap_exact_root_state=_update_sap_exact_root_state,
_store_exact_accepted_cost=_store_exact_accepted_cost,
_commit_line_search_step_batched=_commit_line_search_step_batched,
_accumulate_line_search_iterations_batched=_accumulate_line_search_iterations_batched,
_copy_first_dv_active_batched=_copy_first_dv_active_batched,
_count_active_int=_count_active_int,
_activate_unconverged_envs_batched=_activate_unconverged_envs_batched,
)
[docs]
class SapContactSolve:
"""SAP SAP stage2 contact solve in SAP-order generalized velocities.
This class consumes `SapContactJacobianResult` buffers. Runtime collision
detection stays outside this module and is adapted to `SapContacts` before
entering the solver components.
"""
[docs]
def __init__(
self,
model: Model,
*,
max_rigid_contact: int = 128,
contact_beta: float = 1.0,
contact_sigma: float = 1.0e-3,
contact_tau_d: float = 0.1,
block_size: int | None = None,
diag_shift: float = 0.0,
contact_assembly_tile_size: int = 256,
solve_precision: str = "fp64",
linear_solve_precision: str = "fp64",
):
if not isinstance(model, Model):
raise TypeError("SapContactSolve requires SapModel; convert in the frontend adapter before construction.")
if int(model.joint_dof_count) <= 0:
raise ValueError("SapContactSolve requires a model with positive joint_dof_count.")
self.model = model
self.device = model.device
self.dof_count = int(model.joint_dof_count)
self.body_count = int(model.body_count)
self.num_envs = int(getattr(model, "world_count", 1))
if (
self.num_envs <= 0
or self.dof_count % self.num_envs != 0
or self.body_count % self.num_envs != 0
):
raise ValueError(
"SapContactSolve requires contiguous equal-sized env dof/body blocks "
f"(num_envs={self.num_envs}, dof_count={self.dof_count}, body_count={self.body_count})."
)
self.dof_per_env = self.dof_count // self.num_envs
self.bodies_per_env = self.body_count // self.num_envs
self.max_rigid_contact = int(max_rigid_contact)
if self.max_rigid_contact <= 0:
self.max_rigid_contact = 1
self.contact_beta = float(contact_beta)
self.contact_sigma = float(contact_sigma)
self.contact_tau_d = float(contact_tau_d)
self.diag_shift = float(diag_shift)
self.solve_precision = self._normalize_solve_precision(solve_precision)
self.solve_dtype = wp.float32 if self.solve_precision == "fp32" else wp.float64
self.numpy_dtype = np.float32 if self.solve_precision == "fp32" else np.float64
self.vec3_dtype = wp.vec3 if self.solve_precision == "fp32" else wp.vec3d
self.mat33_dtype = wp.mat33 if self.solve_precision == "fp32" else wp.mat33d
self.k = _make_contact_solve_kernel_table(self.solve_dtype)
self.linear_solve_precision = self._normalize_linear_solve_precision(linear_solve_precision)
self.contact_assembly_tile_size = int(contact_assembly_tile_size)
if self.contact_assembly_tile_size <= 0:
self.contact_assembly_tile_size = 256
self.unit_line_search_tile_size = 128
self.unit_line_search_contact_vc_tile_size = 128
if block_size is None:
block_size = 32 if wp.get_device(self.device).is_cuda and self.dof_per_env > 32 else 16
self.block_size = int(block_size)
self.linear_solve_dtype = wp.float32 if self.linear_solve_precision == "fp32" else wp.float64
self._pack_dense_to_padded_batched = self.k._create_pack_dense_to_padded_batched_kernel(
self.linear_solve_dtype
)
self._pack_grad_to_padded_rhs_batched = self.k._create_pack_grad_to_padded_rhs_batched_kernel(
self.linear_solve_dtype
)
self._pack_dense_and_grad_to_padded_batched = self.k._create_pack_dense_and_grad_to_padded_batched_kernel(
self.linear_solve_dtype
)
self._unpack_solution_batched = self.k._create_unpack_solution_batched_kernel(self.linear_solve_dtype)
self._unpack_solution_and_first_batched = self.k._create_unpack_solution_and_first_batched_kernel(
self.linear_solve_dtype
)
self._base_cost_tiled = self.k._make_compute_base_cost_tiled_kernel(self.unit_line_search_tile_size)
self._unit_line_search_base_coeffs = self.k._make_compute_line_search_base_coeffs_tiled_kernel(
self.unit_line_search_tile_size
)
self._unit_line_search_contact_delta_velocity = self.k._make_compute_line_search_contact_delta_velocity_tiled_kernel(
self.unit_line_search_contact_vc_tile_size
)
self._grad_dynamics_impulse_tiled = self.k._make_assemble_grad_and_dynamics_impulse_tiled_kernel(
self.unit_line_search_tile_size
)
self._contact_hessian_gemm_tile = self.k._make_contact_hessian_gemm_tile_kernel(
_CONTACT_HESSIAN_GEMM_TILE_M,
_CONTACT_HESSIAN_GEMM_TILE_N,
_CONTACT_HESSIAN_GEMM_TILE_K,
)
self._pack_contact_hessian_gemm_inputs = self.k._make_pack_contact_hessian_gemm_inputs_kernel(
_CONTACT_HESSIAN_GEMM_TILE_K,
max(_CONTACT_HESSIAN_GEMM_TILE_M, _CONTACT_HESSIAN_GEMM_TILE_N),
)
self._model_terms_grad_tiled = self.k._make_assemble_model_terms_and_grad_tiled_kernel(
self.unit_line_search_tile_size
)
self._norm_terms_tiled = self.k._make_compute_norm_terms_tiled_kernel(self.unit_line_search_tile_size)
self._norm_terms_update_active_tiled = self.k._make_compute_norm_terms_and_update_active_tiled_kernel(
self.unit_line_search_tile_size
)
self._norm_terms_update_active_conditional_tiled = (
self.k._make_compute_norm_terms_and_update_active_conditional_tiled_kernel(
self.unit_line_search_tile_size
)
)
self._unit_line_search_fused_parallel = self.k._make_unit_decay_line_search_fused_parallel_kernel(
self.unit_line_search_tile_size
)
self.block_solver = BlockCholeskySolverBatched(
max_num_equations=self.dof_per_env,
batch_size=self.num_envs,
block_size=self.block_size,
device=self.device,
dtype=self.linear_solve_dtype,
)
self.padded_dof = self.block_solver.max_num_equations
self._mode_none = int(SAP_JOINT_TARGET_NONE)
self.dof_coord_index, self.dof_target_index, self.limit_supported = self._build_dof_maps(model)
self.body_dof_start, self.body_dof_count = self._build_body_dof_maps(model)
self._model_can_have_pd_terms = self._detect_model_pd_terms()
self._model_can_have_limit_terms = self._detect_model_limit_terms()
self._has_pd_terms = False
self._has_limit_terms = False
self.zero_control = wp.zeros((self.dof_count,), dtype=self.solve_dtype, device=self.device)
self.joint_q_input = wp.zeros((model.joint_coord_count,), dtype=self.solve_dtype, device=self.device)
self.joint_limit_lower_solve = wp.array(
np.asarray(model.joint_limit_lower.numpy(), dtype=self.numpy_dtype).reshape(-1),
dtype=self.solve_dtype,
device=self.device,
)
self.joint_limit_upper_solve = wp.array(
np.asarray(model.joint_limit_upper.numpy(), dtype=self.numpy_dtype).reshape(-1),
dtype=self.solve_dtype,
device=self.device,
)
self.joint_limit_ke_solve = wp.array(
np.asarray(model.joint_limit_ke.numpy(), dtype=self.numpy_dtype).reshape(-1),
dtype=self.solve_dtype,
device=self.device,
)
self.joint_limit_kd_solve = wp.array(
np.asarray(model.joint_limit_kd.numpy(), dtype=self.numpy_dtype).reshape(-1),
dtype=self.solve_dtype,
device=self.device,
)
shape_dof = (self.num_envs, self.dof_per_env)
shape_mat = (self.num_envs, self.dof_per_env, self.dof_per_env)
shape_contact = (self.num_envs, self.max_rigid_contact)
self.contact_hessian_gemm_padded_contact_rows = (
(self.max_rigid_contact * 3 + _CONTACT_HESSIAN_GEMM_TILE_K - 1)
// _CONTACT_HESSIAN_GEMM_TILE_K
* _CONTACT_HESSIAN_GEMM_TILE_K
)
contact_hessian_tile_dof = max(_CONTACT_HESSIAN_GEMM_TILE_M, _CONTACT_HESSIAN_GEMM_TILE_N)
self.contact_hessian_gemm_padded_dof = (
(self.dof_per_env + contact_hessian_tile_dof - 1)
// contact_hessian_tile_dof
* contact_hessian_tile_dof
)
shape_contact_gemm = (
self.num_envs,
self.contact_hessian_gemm_padded_contact_rows,
self.contact_hessian_gemm_padded_dof,
)
self.v_env = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.v_flat = wp.zeros((self.dof_count,), dtype=self.solve_dtype, device=self.device)
self.v_star_env = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.v0_env = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.v_trial = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.a_inv_diag = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.d_scale = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.participating_dof = wp.zeros(shape_dof, dtype=int, device=self.device)
self.cost = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.previous_cost = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.accepted_cost = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.grad = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.constraint_impulse = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.dynamics_impulse = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.hess_contact = wp.zeros(shape_mat, dtype=self.solve_dtype, device=self.device)
self.hessian = wp.zeros(shape_mat, dtype=self.solve_dtype, device=self.device)
self.contact_hessian_j_flat = wp.zeros(shape_contact_gemm, dtype=self.solve_dtype, device=self.device)
self.contact_hessian_gj_flat = wp.zeros(shape_contact_gemm, dtype=self.solve_dtype, device=self.device)
self.contact_assembly_upper_count = self.dof_per_env * (self.dof_per_env + 1) // 2
self.trial_constraint_impulse = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_pd_y = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_pd_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_pd_hdiag = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_pd_cost = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_limit_lower_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_limit_upper_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_limit_grad = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_limit_hdiag = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_limit_cost = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.trial_contact_gamma = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.trial_contact_g = wp.zeros(shape_contact, dtype=self.mat33_dtype, device=self.device)
self.trial_contact_vc = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.trial_contact_y = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.trial_contact_rt = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.trial_contact_rn = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.trial_contact_cost = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.trial_contact_mode = wp.zeros(shape_contact, dtype=int, device=self.device)
self.trial_cost = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.trial_derivative = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.trial_second_derivative = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.contact_gamma = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.contact_g = wp.zeros(shape_contact, dtype=self.mat33_dtype, device=self.device)
self.contact_vc = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.contact_y = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.contact_rt = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.contact_rn = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.contact_cost = wp.zeros(shape_contact, dtype=self.solve_dtype, device=self.device)
self.contact_mode = wp.zeros(shape_contact, dtype=int, device=self.device)
self.contact_tau_d_fallback = wp.full(shape_contact, self.contact_tau_d, dtype=self.solve_dtype, device=self.device)
self.pd_active = wp.zeros(shape_dof, dtype=int, device=self.device)
self.pd_a = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_gain = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_limit = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_kp_eff = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_kd_eff = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_y = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_hdiag = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.pd_cost = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_lower_active = wp.zeros(shape_dof, dtype=int, device=self.device)
self.limit_upper_active = wp.zeros(shape_dof, dtype=int, device=self.device)
self.limit_lower_vhat = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_upper_vhat = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_lower_r = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_upper_r = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_lower_rinv = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_upper_rinv = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_lower_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_upper_gamma = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_grad = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_hdiag = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.limit_cost = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.newton_active = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.converged_env = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.optimality_reached_env = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.cost_reached_env = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.stage2_active_env = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.stage2_active_count = wp.zeros((1,), dtype=int, device=self.device)
self.all_env_active = wp.ones((self.num_envs,), dtype=int, device=self.device)
self.active_count = wp.zeros((1,), dtype=int, device=self.device)
self.newton_loop_iteration = wp.zeros((1,), dtype=int, device=self.device)
self.newton_max_reached = wp.zeros((1,), dtype=int, device=self.device)
self.grad_norm2 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.p_norm2 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.jc_norm2 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.chol_a = wp.zeros(
(self.num_envs, self.padded_dof, self.padded_dof),
dtype=self.linear_solve_dtype,
device=self.device,
)
self.chol_rhs = wp.zeros(
(self.num_envs, self.padded_dof, 1),
dtype=self.linear_solve_dtype,
device=self.device,
)
self.chol_x = wp.zeros(
(self.num_envs, self.padded_dof, 1),
dtype=self.linear_solve_dtype,
device=self.device,
)
self.dv = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.first_dv = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.dp = wp.zeros(shape_dof, dtype=self.solve_dtype, device=self.device)
self.dell0 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.dell_a0 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.d2ell_a = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.line_momentum_cost = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.line_search_base0 = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.line_search_base_linear = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.line_search_base_quadratic = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.line_search_contact_dvc = wp.zeros(shape_contact, dtype=self.vec3_dtype, device=self.device)
self.alpha = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.newton_iterations_env = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.alpha_prev = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.ell_prev = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.ell_slop = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.ls_active = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.ls_accepted = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.ls_status = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.ls_iterations = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.ls_iterations_total = wp.zeros((self.num_envs,), dtype=int, device=self.device)
self.ls_active_count = wp.zeros((1,), dtype=int, device=self.device)
self.ls_loop_iteration = wp.zeros((1,), dtype=int, device=self.device)
self.exact_scale = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_x_lower = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_x_upper = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_f_lower = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_f_upper = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_root = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_minus_dx = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_minus_dx_previous = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self.exact_x_tolerance = wp.zeros((self.num_envs,), dtype=self.solve_dtype, device=self.device)
self._contact_result_f32 = None
if self.solve_dtype == wp.float32:
self._contact_env_phi0_f32 = wp.zeros(shape_contact, dtype=wp.float32, device=self.device)
self._contact_env_jacobian_f32 = wp.zeros(
(self.num_envs, self.max_rigid_contact, 3, self.dof_per_env),
dtype=wp.float32,
device=self.device,
)
self._contact_env_w_eff_f32 = wp.zeros(shape_contact, dtype=wp.float32, device=self.device)
self._contact_env_mu_f32 = wp.zeros(shape_contact, dtype=wp.float32, device=self.device)
self._contact_env_stiffness_f32 = wp.zeros(shape_contact, dtype=wp.float32, device=self.device)
self._contact_env_tau_d_f32 = wp.zeros(shape_contact, dtype=wp.float32, device=self.device)
self._dynamics_matrix_env_f32 = wp.zeros(shape_mat, dtype=wp.float32, device=self.device)
self._body_jacobian_local_f32 = wp.zeros(
(int(model.body_count), 6, self.dof_per_env),
dtype=wp.float32,
device=self.device,
)
self._contact_env_R_WC_f32 = wp.zeros(
(self.num_envs, self.max_rigid_contact, 3, 3),
dtype=wp.float32,
device=self.device,
)
self._contact_env_point_f32 = wp.zeros(shape_contact, dtype=wp.vec3, device=self.device)
self._contact_env_witness0_f32 = wp.zeros_like(self._contact_env_point_f32)
self._contact_env_witness1_f32 = wp.zeros_like(self._contact_env_point_f32)
self.last_iterations = 0
self.last_line_search_iterations = 0
@staticmethod
def _normalize_solve_precision(value: str) -> str:
precision = str(value).strip().lower()
if precision == "f32":
precision = "fp32"
elif precision == "f64":
precision = "fp64"
if precision not in {"fp32", "fp64"}:
raise ValueError(
"solve_precision must be 'fp32'/'f32' or 'fp64'/'f64', "
f"got {value!r}."
)
return precision
@staticmethod
def _normalize_linear_solve_precision(value: str) -> str:
precision = str(value).strip().lower()
if precision == "f32":
precision = "fp32"
elif precision == "f64":
precision = "fp64"
if precision not in {"fp32", "fp64"}:
raise ValueError(
"linear_solve_precision must be 'fp32'/'f32' or 'fp64'/'f64', "
f"got {value!r}."
)
return precision
def _build_dof_maps(self, model: Model) -> tuple[wp.array, wp.array, wp.array]:
dof_coord_index = np.full(self.dof_count, -1, dtype=np.int32)
dof_target_index = np.arange(self.dof_count, dtype=np.int32)
limit_supported = np.zeros(self.dof_count, dtype=np.int32)
joint_type = model.joint_type.numpy()
joint_q_start = model.joint_q_start.numpy()
joint_qd_start = model.joint_qd_start.numpy()
joint_dof_dim = model.joint_dof_dim.numpy()
free_types = {
int(SAP_JOINT_FREE),
int(SAP_JOINT_DISTANCE),
}
supported_types = {
int(SAP_JOINT_PRISMATIC),
int(SAP_JOINT_REVOLUTE),
int(SAP_JOINT_D6),
}
for joint_idx, jtype in enumerate(joint_type):
dof_start = int(joint_qd_start[joint_idx])
if int(jtype) in free_types:
for axis in range(3):
dof_target_index[dof_start + axis] = dof_start + axis + 3
dof_target_index[dof_start + axis + 3] = dof_start + axis
if int(jtype) not in supported_types:
continue
coord_start = int(joint_q_start[joint_idx])
axis_count = int(joint_dof_dim[joint_idx, 0] + joint_dof_dim[joint_idx, 1])
for axis in range(axis_count):
dof = dof_start + axis
if 0 <= dof < self.dof_count:
dof_coord_index[dof] = coord_start + axis
if axis_count == 1:
limit_supported[dof] = 1
return (
wp.array(dof_coord_index, dtype=int, device=self.device),
wp.array(dof_target_index, dtype=int, device=self.device),
wp.array(limit_supported, dtype=int, device=self.device),
)
def _build_body_dof_maps(self, model: Model) -> tuple[wp.array, wp.array]:
body_dof_start = np.full(self.body_count, -1, dtype=np.int32)
body_dof_count = np.zeros(self.body_count, dtype=np.int32)
articulation_start = np.asarray(model.articulation_start.numpy(), dtype=np.int32).reshape(-1)
joint_child = np.asarray(model.joint_child.numpy(), dtype=np.int32).reshape(-1)
joint_qd_start = np.asarray(model.joint_qd_start.numpy(), dtype=np.int32).reshape(-1)
joint_dof_dim = np.asarray(model.joint_dof_dim.numpy(), dtype=np.int32)
for art_idx in range(int(model.articulation_count)):
first_joint = int(articulation_start[art_idx])
last_joint = int(articulation_start[art_idx + 1])
starts: list[int] = []
dof_count = 0
for joint_idx in range(first_joint, last_joint):
axis_count = int(joint_dof_dim[joint_idx, 0] + joint_dof_dim[joint_idx, 1])
if axis_count <= 0:
continue
starts.append(int(joint_qd_start[joint_idx]))
dof_count += axis_count
if not starts or dof_count <= 0:
continue
dof_start = min(starts)
for joint_idx in range(first_joint, last_joint):
body = int(joint_child[joint_idx])
if body < 0 or body >= self.body_count:
continue
env = body // self.bodies_per_env
body_dof_start[body] = dof_start - env * self.dof_per_env
body_dof_count[body] = dof_count
return (
wp.array(body_dof_start, dtype=int, device=self.device),
wp.array(body_dof_count, dtype=int, device=self.device),
)
def _detect_model_pd_terms(self) -> bool:
target_mode = np.asarray(self.model.joint_target_mode.numpy(), dtype=np.int64).reshape(-1)
target_ke = np.asarray(self.model.joint_target_ke.numpy(), dtype=np.float64).reshape(-1)
target_kd = np.asarray(self.model.joint_target_kd.numpy(), dtype=np.float64).reshape(-1)
count = min(target_mode.size, target_ke.size, target_kd.size)
if count <= 0:
return False
active = (target_mode[:count] != self._mode_none) & (
(target_ke[:count] > 0.0) | (target_kd[:count] > 0.0)
)
return bool(np.any(active))
def _detect_model_limit_terms(self) -> bool:
limit_supported = np.asarray(self.limit_supported.numpy(), dtype=np.int64).reshape(-1)
lower = np.asarray(self.model.joint_limit_lower.numpy(), dtype=np.float64).reshape(-1)
upper = np.asarray(self.model.joint_limit_upper.numpy(), dtype=np.float64).reshape(-1)
ke = np.asarray(self.model.joint_limit_ke.numpy(), dtype=np.float64).reshape(-1)
count = min(limit_supported.size, lower.size, upper.size, ke.size)
if count <= 0:
return False
active = (
(limit_supported[:count] != 0)
& (ke[:count] > 0.0)
& (np.isfinite(lower[:count]) | np.isfinite(upper[:count]))
)
return bool(np.any(active))
def _as_control_array(self, arr) -> wp.array:
if arr is None:
return self.zero_control
return arr
def _prepare_joint_q_input(self, state: State) -> None:
if self.solve_dtype == wp.float64:
kernel = _copy_f64 if state.joint_q.dtype == wp.float64 else _copy_f32_to_f64
else:
kernel = _copy_f32 if state.joint_q.dtype == wp.float32 else _copy_f64_to_f32
wp.launch(
kernel,
dim=self.model.joint_coord_count,
inputs=[state.joint_q, self.joint_q_input],
device=self.device,
)
def _build_participating_dof_mask(self, contact_result: SapContactJacobianResult) -> None:
contact_capacity = self._contact_capacity(contact_result)
wp.launch(
self.k._clear_participating_dofs_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.participating_dof],
device=self.device,
)
wp.launch(
self.k._mark_contact_participating_dofs_batched,
dim=(self.num_envs, contact_capacity, self.dof_per_env),
inputs=[
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_body0,
contact_result.contact_env_body1,
contact_result.contact_env_jacobian,
self.body_dof_start,
self.body_dof_count,
self.participating_dof,
],
device=self.device,
)
wp.launch(
self.k._mark_model_participating_dofs_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
self.dof_per_env,
self.pd_active,
self.limit_lower_active,
self.limit_upper_active,
self.participating_dof,
],
device=self.device,
)
def _contact_capacity(self, contact_result: SapContactJacobianResult) -> int:
result_slots = int(contact_result.contact_env_jacobian.shape[1])
return max(1, min(self.max_rigid_contact, result_slots))
def _compute_base_cost(
self,
active_env: wp.array,
v: wp.array,
a_mat: wp.array,
out_cost: wp.array,
) -> None:
wp.launch_tiled(
self._base_cost_tiled,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
v,
self.v_star_env,
a_mat,
out_cost,
],
device=self.device,
)
def _assemble_grad_and_dynamics_impulse(
self,
active_env: wp.array,
v: wp.array,
a_mat: wp.array,
constraint_impulse: wp.array,
add_constraint_terms: bool = False,
) -> None:
add_pd = int(bool(add_constraint_terms) and bool(self._has_pd_terms))
add_limits = int(bool(add_constraint_terms) and bool(self._has_limit_terms))
wp.launch_tiled(
self._grad_dynamics_impulse_tiled,
dim=(self.num_envs, self.dof_per_env),
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
v,
self.v_star_env,
a_mat,
add_pd,
self.pd_active,
self.pd_gamma,
add_limits,
self.limit_grad,
constraint_impulse,
self.grad,
self.dynamics_impulse,
],
device=self.device,
)
def _assemble_model_terms_and_grad(
self,
active_env: wp.array,
v: wp.array,
a_mat: wp.array,
dt: float,
*,
pd_y: wp.array,
pd_gamma: wp.array,
pd_hdiag: wp.array,
pd_cost: wp.array,
lower_gamma: wp.array,
upper_gamma: wp.array,
limit_grad: wp.array,
limit_hdiag: wp.array,
limit_cost: wp.array,
constraint_impulse: wp.array,
cost: wp.array,
) -> None:
wp.launch_tiled(
self._model_terms_grad_tiled,
dim=(self.num_envs, self.dof_per_env),
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
v,
self.v_star_env,
a_mat,
int(self._has_pd_terms),
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
float(dt),
pd_y,
pd_gamma,
pd_hdiag,
pd_cost,
int(self._has_limit_terms),
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
lower_gamma,
upper_gamma,
limit_grad,
limit_hdiag,
limit_cost,
constraint_impulse,
self.grad,
self.dynamics_impulse,
cost,
],
device=self.device,
)
def _compute_norm_terms(self, active_env: wp.array) -> None:
wp.launch_tiled(
self._norm_terms_tiled,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
self.participating_dof,
self.d_scale,
self.grad,
self.dynamics_impulse,
self.constraint_impulse,
self.grad_norm2,
self.p_norm2,
self.jc_norm2,
],
device=self.device,
)
def _compute_norm_terms_and_update_active(
self,
active_env: wp.array,
iteration: int,
*,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
cost_min_alpha: float,
) -> None:
single_env_count = int(self.num_envs == 1)
if single_env_count == 0:
self.active_count.zero_()
wp.launch_tiled(
self._norm_terms_update_active_tiled,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
self.participating_dof,
self.d_scale,
self.grad,
self.dynamics_impulse,
self.constraint_impulse,
self.cost,
self.previous_cost,
self.alpha,
int(iteration),
float(optimality_abs_tol),
float(optimality_rel_tol),
float(cost_abs_tol),
float(cost_rel_tol),
float(cost_min_alpha),
single_env_count,
self.newton_active,
self.converged_env,
self.optimality_reached_env,
self.cost_reached_env,
self.newton_iterations_env,
self.active_count,
self.grad_norm2,
self.p_norm2,
self.jc_norm2,
],
device=self.device,
)
def _compute_norm_terms_and_update_active_conditional(
self,
active_env: wp.array,
*,
max_iterations: int,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
cost_min_alpha: float,
) -> None:
single_env_count = int(self.num_envs == 1)
if single_env_count == 0:
self.active_count.zero_()
wp.launch_tiled(
self._norm_terms_update_active_conditional_tiled,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
active_env,
self.dof_per_env,
self.participating_dof,
self.d_scale,
self.grad,
self.dynamics_impulse,
self.constraint_impulse,
self.cost,
self.previous_cost,
self.alpha,
self.newton_loop_iteration,
int(max_iterations),
float(optimality_abs_tol),
float(optimality_rel_tol),
float(cost_abs_tol),
float(cost_rel_tol),
float(cost_min_alpha),
single_env_count,
self.newton_active,
self.converged_env,
self.optimality_reached_env,
self.cost_reached_env,
self.newton_iterations_env,
self.active_count,
self.newton_max_reached,
self.grad_norm2,
self.p_norm2,
self.jc_norm2,
],
device=self.device,
)
def _solver_update_active(
self,
contact_result: SapContactJacobianResult,
active_env: wp.array,
dt: float,
*,
max_iterations: int,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
cost_min_alpha: float,
) -> None:
grad_ready = self._evaluate_problem(
contact_result,
self.v_env,
active_env,
dt,
include_hessian=False,
cost=self.cost,
constraint_impulse=self.constraint_impulse,
contact_gamma=self.contact_gamma,
contact_g=self.contact_g,
contact_vc=self.contact_vc,
contact_y=self.contact_y,
contact_rt=self.contact_rt,
contact_rn=self.contact_rn,
contact_cost=self.contact_cost,
contact_mode=self.contact_mode,
pd_y=self.pd_y,
pd_gamma=self.pd_gamma,
pd_hdiag=self.pd_hdiag,
pd_cost=self.pd_cost,
lower_gamma=self.limit_lower_gamma,
upper_gamma=self.limit_upper_gamma,
limit_grad=self.limit_grad,
limit_hdiag=self.limit_hdiag,
limit_cost=self.limit_cost,
defer_constraint_terms=True,
)
if not grad_ready:
self._assemble_grad_and_dynamics_impulse(
active_env,
self.v_env,
contact_result.dynamics_matrix_env,
self.constraint_impulse,
add_constraint_terms=True,
)
self._compute_norm_terms_and_update_active_conditional(
active_env,
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
cost_min_alpha=float(cost_min_alpha),
)
def _cast_contact_result(self, contact_result: SapContactJacobianResult) -> SapContactJacobianResult:
if self.solve_dtype == wp.float64 or contact_result.dynamics_matrix_env.dtype == wp.float32:
return contact_result
wp.launch(
_copy_3d_f64_to_f32,
dim=self._dynamics_matrix_env_f32.shape,
inputs=[contact_result.dynamics_matrix_env, self._dynamics_matrix_env_f32],
device=self.device,
)
wp.launch(
_copy_2d_f64_to_f32,
dim=self._contact_env_phi0_f32.shape,
inputs=[contact_result.contact_env_phi0, self._contact_env_phi0_f32],
device=self.device,
)
wp.launch(
_copy_4d_f64_to_f32,
dim=self._contact_env_jacobian_f32.shape,
inputs=[contact_result.contact_env_jacobian, self._contact_env_jacobian_f32],
device=self.device,
)
wp.launch(
_copy_2d_f64_to_f32,
dim=self._contact_env_w_eff_f32.shape,
inputs=[contact_result.contact_env_w_eff, self._contact_env_w_eff_f32],
device=self.device,
)
wp.launch(
_copy_2d_f64_to_f32,
dim=self._contact_env_mu_f32.shape,
inputs=[contact_result.contact_env_mu, self._contact_env_mu_f32],
device=self.device,
)
wp.launch(
_copy_2d_f64_to_f32,
dim=self._contact_env_stiffness_f32.shape,
inputs=[contact_result.contact_env_stiffness, self._contact_env_stiffness_f32],
device=self.device,
)
tau_src = contact_result.contact_env_tau_d
if tau_src is not None:
wp.launch(
_copy_2d_f64_to_f32,
dim=self._contact_env_tau_d_f32.shape,
inputs=[tau_src, self._contact_env_tau_d_f32],
device=self.device,
)
wp.launch(
_copy_3d_f64_to_f32,
dim=self._body_jacobian_local_f32.shape,
inputs=[contact_result.body_jacobian_local, self._body_jacobian_local_f32],
device=self.device,
)
wp.launch(
_copy_4d_f64_to_f32,
dim=self._contact_env_R_WC_f32.shape,
inputs=[contact_result.contact_env_R_WC, self._contact_env_R_WC_f32],
device=self.device,
)
wp.launch(
_copy_vec3d_2d_to_vec3,
dim=self._contact_env_point_f32.shape,
inputs=[contact_result.contact_env_point, self._contact_env_point_f32],
device=self.device,
)
wp.launch(
_copy_vec3d_2d_to_vec3,
dim=self._contact_env_witness0_f32.shape,
inputs=[contact_result.contact_env_witness0, self._contact_env_witness0_f32],
device=self.device,
)
wp.launch(
_copy_vec3d_2d_to_vec3,
dim=self._contact_env_witness1_f32.shape,
inputs=[contact_result.contact_env_witness1, self._contact_env_witness1_f32],
device=self.device,
)
self._contact_result_f32 = SapContactJacobianResult(
contact_count=contact_result.contact_count,
truncated_contact_count=contact_result.truncated_contact_count,
contact_env_count=contact_result.contact_env_count,
contact_env_phi0=self._contact_env_phi0_f32,
contact_env_jacobian=self._contact_env_jacobian_f32,
contact_env_w_eff=self._contact_env_w_eff_f32,
contact_env_mu=self._contact_env_mu_f32,
contact_env_stiffness=self._contact_env_stiffness_f32,
contact_env_tau_d=self._contact_env_tau_d_f32,
contact_env_R_WC=self._contact_env_R_WC_f32,
contact_env_point=self._contact_env_point_f32,
contact_env_witness0=self._contact_env_witness0_f32,
contact_env_witness1=self._contact_env_witness1_f32,
contact_env_body0=contact_result.contact_env_body0,
contact_env_body1=contact_result.contact_env_body1,
body_jacobian_local=self._body_jacobian_local_f32,
dynamics_matrix_env=self._dynamics_matrix_env_f32,
)
return self._contact_result_f32
def _load_velocity(self, src: wp.array, dst: wp.array) -> None:
if src.shape == (self.dof_count,):
if src.dtype == self.solve_dtype:
wp.launch(
self.k._copy_flat_to_env_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[src, self.dof_per_env, dst],
device=self.device,
)
elif self.solve_dtype == wp.float32:
wp.launch(
_copy_flat_f64_to_env_f32_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[src, self.dof_per_env, dst],
device=self.device,
)
else:
wp.launch(
_copy_flat_f32_to_env_f64_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[src, self.dof_per_env, dst],
device=self.device,
)
elif src.shape == (self.num_envs, self.dof_per_env):
if src.dtype == self.solve_dtype:
wp.copy(dst, src)
elif self.solve_dtype == wp.float32:
wp.launch(
_copy_env_f64_to_env_f32_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[src, self.dof_per_env, dst],
device=self.device,
)
else:
wp.launch(
_copy_env_f32_to_env_f64_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[src, self.dof_per_env, dst],
device=self.device,
)
else:
raise ValueError(
"velocity input must be flat SAP-order `(joint_dof_count,)` "
"or env-local `(num_envs, dof_per_env)`, got "
f"{src.shape!r}"
)
def prepare(
self,
contact_result: SapContactJacobianResult,
state: State,
control: Control | None,
dt: float,
v_star: wp.array,
*,
v0: wp.array | None = None,
v_guess: wp.array | None = None,
v_guess_active: wp.array | None = None,
) -> None:
"""Prepare contact-solve buffers for the current active contact set before iterative solve
evaluation.
"""
contact_result = self._cast_contact_result(contact_result)
if v0 is None:
v0 = v_guess if v_guess is not None else v_star
if not isinstance(state, State):
raise TypeError("SapContactSolve.prepare requires SapState.")
if control is None or not isinstance(control, Control):
raise TypeError("SapContactSolve.prepare requires SapControl.")
self._has_pd_terms = self._model_can_have_pd_terms
self._has_limit_terms = self._model_can_have_limit_terms
if (
v_guess_active is not None
and v_guess is not None
and v_star.shape == (self.dof_count,)
and v0.shape == (self.dof_count,)
and v_guess.shape == (self.dof_count,)
and v_star.dtype == self.solve_dtype
and v0.dtype == self.solve_dtype
and v_guess.dtype == self.solve_dtype
):
wp.launch(
self.k._copy_solve_velocity_inputs_flat_batched_with_guess_flag,
dim=(self.num_envs, self.dof_per_env),
inputs=[
v_star,
v0,
v_guess,
v_guess_active,
self.dof_per_env,
self.v_star_env,
self.v0_env,
self.v_env,
],
device=self.device,
)
elif (
v_guess is None
and v_star.shape == (self.dof_count,)
and v0.shape == (self.dof_count,)
and v_star.dtype == self.solve_dtype
and v0.dtype == self.solve_dtype
):
wp.launch(
self.k._copy_solve_velocity_inputs_flat_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
v_star,
v0,
self.dof_per_env,
self.v_star_env,
self.v0_env,
self.v_env,
],
device=self.device,
)
else:
self._load_velocity(v_star, self.v_star_env)
self._load_velocity(v0, self.v0_env)
self._load_velocity(v_guess if v_guess is not None else v0, self.v_env)
self._prepare_joint_q_input(state)
wp.launch(
self.k._extract_a_diag_data_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
self.dof_per_env,
contact_result.dynamics_matrix_env,
self.a_inv_diag,
self.d_scale,
],
device=self.device,
)
wp.launch(
self.k._build_pd_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_pd_terms),
self.dof_per_env,
self.dof_coord_index,
self.dof_target_index,
self.model.joint_target_mode,
self.model.joint_target_ke,
self.model.joint_target_kd,
self.model.joint_effort_limit,
self.joint_q_input,
self._as_control_array(getattr(control, "joint_target_pos", None)),
self._as_control_array(getattr(control, "joint_target_vel", None)),
self._as_control_array(getattr(control, "joint_act", None)),
self.a_inv_diag,
float(dt),
self._mode_none,
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
self.pd_kp_eff,
self.pd_kd_eff,
],
device=self.device,
)
wp.launch(
self.k._build_limit_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_limit_terms),
self.dof_per_env,
self.dof_coord_index,
self.limit_supported,
self.joint_limit_lower_solve,
self.joint_limit_upper_solve,
self.joint_limit_ke_solve,
self.joint_limit_kd_solve,
self.joint_q_input,
self.v0_env,
self.v_star_env,
self.a_inv_diag,
float(dt),
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
],
device=self.device,
)
self._build_participating_dof_mask(contact_result)
def _assemble_contact_impulse_from_terms(
self,
contact_result: SapContactJacobianResult,
*,
contact_capacity: int,
contact_gamma: wp.array,
constraint_impulse: wp.array,
) -> None:
tile_size = int(self.contact_assembly_tile_size)
wp.launch_tiled(
self.k._make_contact_impulse_single_tile_kernel(tile_size),
dim=(self.num_envs, self.dof_per_env),
block_dim=tile_size,
inputs=[
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_gamma,
constraint_impulse,
],
device=self.device,
)
def _assemble_contact_hessian_from_terms(
self,
contact_result: SapContactJacobianResult,
*,
contact_capacity: int,
contact_g: wp.array,
) -> None:
contact_hessian_tile_dof = max(_CONTACT_HESSIAN_GEMM_TILE_M, _CONTACT_HESSIAN_GEMM_TILE_N)
wp.launch_tiled(
self._pack_contact_hessian_gemm_inputs,
dim=(
self.num_envs,
self.contact_hessian_gemm_padded_contact_rows // _CONTACT_HESSIAN_GEMM_TILE_K,
self.contact_hessian_gemm_padded_dof // contact_hessian_tile_dof,
),
block_dim=128,
inputs=[
self.dof_per_env,
contact_capacity,
self.contact_hessian_gemm_padded_contact_rows,
self.contact_hessian_gemm_padded_dof,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_g,
self.contact_hessian_j_flat,
self.contact_hessian_gj_flat,
],
device=self.device,
)
wp.launch_tiled(
self._contact_hessian_gemm_tile,
dim=(
self.num_envs,
(self.dof_per_env + _CONTACT_HESSIAN_GEMM_TILE_M - 1) // _CONTACT_HESSIAN_GEMM_TILE_M,
(self.dof_per_env + _CONTACT_HESSIAN_GEMM_TILE_N - 1) // _CONTACT_HESSIAN_GEMM_TILE_N,
),
block_dim=128,
inputs=[
self.dof_per_env,
self.contact_hessian_gemm_padded_contact_rows,
self.contact_hessian_j_flat,
self.contact_hessian_gj_flat,
self.hess_contact,
],
device=self.device,
)
def _ensure_hessian_terms_for_active_envs(
self,
contact_result: SapContactJacobianResult,
active_env: wp.array,
dt: float,
) -> int:
contact_capacity = self._contact_capacity(contact_result)
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
wp.launch(
self.k._projection_eval_contact_hessian_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
active_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
self.contact_vc,
self.contact_beta,
self.contact_sigma,
float(dt),
self.contact_g,
self.contact_y,
self.contact_rt,
self.contact_rn,
self.contact_mode,
],
device=self.device,
)
return contact_capacity
def _evaluate_cost_terms(
self,
contact_result: SapContactJacobianResult,
v: wp.array,
active_env: wp.array,
dt: float,
*,
cost: wp.array,
contact_gamma: wp.array,
contact_g: wp.array,
contact_vc: wp.array,
contact_y: wp.array,
contact_rt: wp.array,
contact_rn: wp.array,
contact_cost: wp.array,
contact_mode: wp.array,
pd_y: wp.array,
pd_gamma: wp.array,
pd_hdiag: wp.array,
pd_cost: wp.array,
lower_gamma: wp.array,
upper_gamma: wp.array,
limit_grad: wp.array,
limit_hdiag: wp.array,
limit_cost: wp.array,
) -> int:
contact_capacity = self._contact_capacity(contact_result)
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
self._compute_base_cost(active_env, v, contact_result.dynamics_matrix_env, cost)
wp.launch(
self.k._projection_eval_contact_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
active_env,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
v,
self.contact_beta,
self.contact_sigma,
float(dt),
contact_gamma,
contact_g,
contact_vc,
contact_y,
contact_rt,
contact_rn,
contact_cost,
contact_mode,
cost,
],
device=self.device,
)
wp.launch(
self.k._eval_pd_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_pd_terms),
active_env,
self.dof_per_env,
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
v,
float(dt),
pd_y,
pd_gamma,
pd_hdiag,
pd_cost,
cost,
],
device=self.device,
)
wp.launch(
self.k._eval_limit_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_limit_terms),
active_env,
self.dof_per_env,
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
v,
lower_gamma,
upper_gamma,
limit_grad,
limit_hdiag,
limit_cost,
cost,
],
device=self.device,
)
return contact_capacity
def _evaluate_cost_terms_no_contact_hessian(
self,
contact_result: SapContactJacobianResult,
v: wp.array,
active_env: wp.array,
dt: float,
*,
cost: wp.array,
contact_gamma: wp.array,
contact_vc: wp.array,
contact_cost: wp.array,
pd_y: wp.array,
pd_gamma: wp.array,
pd_hdiag: wp.array,
pd_cost: wp.array,
lower_gamma: wp.array,
upper_gamma: wp.array,
limit_grad: wp.array,
limit_hdiag: wp.array,
limit_cost: wp.array,
) -> int:
contact_capacity = self._contact_capacity(contact_result)
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
self._compute_base_cost(active_env, v, contact_result.dynamics_matrix_env, cost)
wp.launch(
self.k._projection_eval_contact_gamma_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
active_env,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
v,
self.contact_beta,
self.contact_sigma,
float(dt),
contact_gamma,
contact_vc,
contact_cost,
cost,
],
device=self.device,
)
wp.launch(
self.k._eval_pd_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_pd_terms),
active_env,
self.dof_per_env,
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
v,
float(dt),
pd_y,
pd_gamma,
pd_hdiag,
pd_cost,
cost,
],
device=self.device,
)
wp.launch(
self.k._eval_limit_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_limit_terms),
active_env,
self.dof_per_env,
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
v,
lower_gamma,
upper_gamma,
limit_grad,
limit_hdiag,
limit_cost,
cost,
],
device=self.device,
)
return contact_capacity
def _evaluate_problem(
self,
contact_result: SapContactJacobianResult,
v: wp.array,
active_env: wp.array,
dt: float,
*,
include_hessian: bool,
cost: wp.array,
constraint_impulse: wp.array,
contact_gamma: wp.array,
contact_g: wp.array,
contact_vc: wp.array,
contact_y: wp.array,
contact_rt: wp.array,
contact_rn: wp.array,
contact_cost: wp.array,
contact_mode: wp.array,
pd_y: wp.array,
pd_gamma: wp.array,
pd_hdiag: wp.array,
pd_cost: wp.array,
lower_gamma: wp.array,
upper_gamma: wp.array,
limit_grad: wp.array,
limit_hdiag: wp.array,
limit_cost: wp.array,
defer_constraint_terms: bool = False,
) -> bool:
if bool(defer_constraint_terms):
contact_capacity = self._contact_capacity(contact_result)
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
cost.zero_()
if include_hessian:
wp.launch(
self.k._projection_eval_contact_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
active_env,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
v,
self.contact_beta,
self.contact_sigma,
float(dt),
contact_gamma,
contact_g,
contact_vc,
contact_y,
contact_rt,
contact_rn,
contact_cost,
contact_mode,
cost,
],
device=self.device,
)
else:
wp.launch(
self.k._projection_eval_contact_gamma_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
active_env,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
v,
self.contact_beta,
self.contact_sigma,
float(dt),
contact_gamma,
contact_vc,
contact_cost,
cost,
],
device=self.device,
)
self._assemble_contact_impulse_from_terms(
contact_result,
contact_capacity=contact_capacity,
contact_gamma=contact_gamma,
constraint_impulse=constraint_impulse,
)
self._assemble_model_terms_and_grad(
active_env,
v,
contact_result.dynamics_matrix_env,
dt,
pd_y=pd_y,
pd_gamma=pd_gamma,
pd_hdiag=pd_hdiag,
pd_cost=pd_cost,
lower_gamma=lower_gamma,
upper_gamma=upper_gamma,
limit_grad=limit_grad,
limit_hdiag=limit_hdiag,
limit_cost=limit_cost,
constraint_impulse=constraint_impulse,
cost=cost,
)
if include_hessian:
self._assemble_hessian_from_terms(
contact_result,
active_env=active_env,
contact_capacity=contact_capacity,
contact_g=contact_g,
pd_hdiag=pd_hdiag,
limit_hdiag=limit_hdiag,
)
return True
if include_hessian:
contact_capacity = self._evaluate_cost_terms(
contact_result,
v,
active_env,
dt,
cost=cost,
contact_gamma=contact_gamma,
contact_g=contact_g,
contact_vc=contact_vc,
contact_y=contact_y,
contact_rt=contact_rt,
contact_rn=contact_rn,
contact_cost=contact_cost,
contact_mode=contact_mode,
pd_y=pd_y,
pd_gamma=pd_gamma,
pd_hdiag=pd_hdiag,
pd_cost=pd_cost,
lower_gamma=lower_gamma,
upper_gamma=upper_gamma,
limit_grad=limit_grad,
limit_hdiag=limit_hdiag,
limit_cost=limit_cost,
)
else:
contact_capacity = self._evaluate_cost_terms_no_contact_hessian(
contact_result,
v,
active_env,
dt,
cost=cost,
contact_gamma=contact_gamma,
contact_vc=contact_vc,
contact_cost=contact_cost,
pd_y=pd_y,
pd_gamma=pd_gamma,
pd_hdiag=pd_hdiag,
pd_cost=pd_cost,
lower_gamma=lower_gamma,
upper_gamma=upper_gamma,
limit_grad=limit_grad,
limit_hdiag=limit_hdiag,
limit_cost=limit_cost,
)
self._assemble_contact_impulse_from_terms(
contact_result,
contact_capacity=contact_capacity,
contact_gamma=contact_gamma,
constraint_impulse=constraint_impulse,
)
wp.launch(
self.k._accumulate_pd_impulse_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_pd_terms and not bool(defer_constraint_terms)),
self.dof_per_env,
self.pd_active,
pd_gamma,
constraint_impulse,
],
device=self.device,
)
wp.launch(
self.k._accumulate_limit_impulse_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_limit_terms and not bool(defer_constraint_terms)),
self.dof_per_env,
limit_grad,
constraint_impulse,
],
device=self.device,
)
if include_hessian:
self._assemble_hessian_from_terms(
contact_result,
active_env=active_env,
contact_capacity=contact_capacity,
contact_g=contact_g,
pd_hdiag=pd_hdiag,
limit_hdiag=limit_hdiag,
)
return False
def _assemble_hessian_from_terms(
self,
contact_result: SapContactJacobianResult,
*,
active_env: wp.array | None = None,
contact_capacity: int | None = None,
contact_g: wp.array | None = None,
pd_hdiag: wp.array | None = None,
limit_hdiag: wp.array | None = None,
) -> None:
if active_env is None:
active_env = self.all_env_active
if contact_capacity is None:
contact_capacity = self._contact_capacity(contact_result)
if contact_g is None:
contact_g = self.contact_g
if pd_hdiag is None:
pd_hdiag = self.pd_hdiag
if limit_hdiag is None:
limit_hdiag = self.limit_hdiag
self._assemble_contact_hessian_from_terms(
contact_result,
contact_capacity=contact_capacity,
contact_g=contact_g,
)
wp.launch(
self.k._assemble_hessian_total_batched,
dim=(self.num_envs, self.dof_per_env, self.dof_per_env),
inputs=[
active_env,
self.dof_per_env,
contact_result.dynamics_matrix_env,
self.hess_contact,
int(self._has_pd_terms),
int(self._has_limit_terms),
pd_hdiag,
limit_hdiag,
self.hessian,
],
device=self.device,
)
def evaluate_itemwise(
self,
contact_result: SapContactJacobianResult,
state: State,
control: Control | None,
dt: float,
v_star: wp.array,
*,
v: wp.array | None = None,
v0: wp.array | None = None,
) -> SapContactSolveResult:
"""Evaluate itemwise SAP objective, gradient, and line-search quantities for diagnostics or solver
internals.
"""
self.prepare(contact_result, state, control, dt, v_star, v0=v0, v_guess=v)
grad_ready = self._evaluate_problem(
contact_result,
self.v_env,
self.all_env_active,
dt,
include_hessian=True,
cost=self.cost,
constraint_impulse=self.constraint_impulse,
contact_gamma=self.contact_gamma,
contact_g=self.contact_g,
contact_vc=self.contact_vc,
contact_y=self.contact_y,
contact_rt=self.contact_rt,
contact_rn=self.contact_rn,
contact_cost=self.contact_cost,
contact_mode=self.contact_mode,
pd_y=self.pd_y,
pd_gamma=self.pd_gamma,
pd_hdiag=self.pd_hdiag,
pd_cost=self.pd_cost,
lower_gamma=self.limit_lower_gamma,
upper_gamma=self.limit_upper_gamma,
limit_grad=self.limit_grad,
limit_hdiag=self.limit_hdiag,
limit_cost=self.limit_cost,
defer_constraint_terms=True,
)
if not grad_ready:
self._assemble_grad_and_dynamics_impulse(
self.all_env_active,
self.v_env,
contact_result.dynamics_matrix_env,
self.constraint_impulse,
add_constraint_terms=True,
)
wp.launch(
self.k._copy_env_to_flat_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.v_env, self.dof_per_env, self.v_flat],
device=self.device,
)
return self._make_result(0, 0, True)
def _solve_newton_direction(self, *, store_first_dv: bool = False) -> None:
wp.launch(
self._pack_dense_and_grad_to_padded_batched,
dim=(self.num_envs, self.padded_dof, self.padded_dof),
inputs=[
self.hessian,
self.grad,
self.chol_a,
self.chol_rhs,
self.dof_per_env,
self.diag_shift,
],
device=self.device,
)
self.block_solver.factorize_masked(self.chol_a, self.dof_per_env, self.newton_active)
self.block_solver.solve_masked(self.chol_rhs, self.chol_x, self.newton_active)
if store_first_dv:
wp.launch(
self._unpack_solution_and_first_batched,
dim=(self.num_envs, self.padded_dof),
inputs=[
self.newton_active,
self.newton_loop_iteration,
self.chol_x,
self.dv,
self.first_dv,
self.dof_per_env,
],
device=self.device,
)
return
wp.launch(
self._unpack_solution_batched,
dim=(self.num_envs, self.padded_dof),
inputs=[self.chol_x, self.dv, self.dof_per_env],
device=self.device,
)
def _run_sap_backtracking(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
armijo_c: float,
rho: float,
alpha_max: float,
max_iterations: int,
relative_slop: float,
check_errors: bool = True,
) -> None:
if int(max_iterations) <= 0:
return
self.dell0.zero_()
self.dell_a0.zero_()
self.d2ell_a.zero_()
wp.launch(
self.k._compute_search_direction_data_serial_batched,
dim=self.num_envs,
inputs=[
self.newton_active,
self.dof_per_env,
contact_result.dynamics_matrix_env,
self.v_env,
self.v_star_env,
self.grad,
self.dv,
self.dp,
self.dell0,
self.dell_a0,
self.d2ell_a,
],
device=self.device,
)
self._compute_base_cost(
self.newton_active,
self.v_env,
contact_result.dynamics_matrix_env,
self.line_momentum_cost,
)
wp.launch(
self.k._init_sap_backtracking_state,
dim=self.num_envs,
inputs=[
self.newton_active,
self.cost,
self.dell0,
float(alpha_max),
float(relative_slop),
self.alpha,
self.alpha_prev,
self.ell_prev,
self.ell_slop,
self.ls_active,
self.ls_accepted,
self.ls_status,
self.ls_iterations,
self.accepted_cost,
],
device=self.device,
)
wp.launch(
self.k._axpy_to_trial_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.ls_active, self.dof_per_env, self.v_env, self.dv, self.alpha, self.v_trial],
device=self.device,
)
self._evaluate_trial(contact_result, dt)
self._replace_trial_cost_with_sap_line_search_cost(contact_result)
self._compute_trial_derivative()
self.ls_active_count.zero_()
wp.launch(
self.k._accept_sap_alpha_max,
dim=self.num_envs,
inputs=[
self.trial_cost,
self.trial_derivative,
self.cost,
float(relative_slop),
self.ls_active,
self.ls_accepted,
self.alpha,
self.alpha_prev,
self.ell_prev,
self.ell_slop,
self.accepted_cost,
self.ls_active_count,
],
device=self.device,
)
if int(max_iterations) > 1:
self.ls_loop_iteration.fill_(1)
wp.capture_while(
self.ls_active_count,
while_body=self._run_sap_backtracking_body,
contact_result=contact_result,
dt=float(dt),
armijo_c=float(armijo_c),
rho=float(rho),
max_iterations=int(max_iterations),
)
if check_errors:
self._raise_line_search_errors_if_any(stage="backtracking")
wp.launch(
self.k._accumulate_line_search_iterations_batched,
dim=self.num_envs,
inputs=[self.ls_accepted, self.ls_iterations, self.ls_iterations_total],
device=self.device,
)
wp.launch(
self.k._commit_line_search_step_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
self.newton_active,
self.ls_accepted,
self.dof_per_env,
self.alpha,
self.v_env,
self.dv,
self.v_flat,
self.cost,
self.previous_cost,
self.accepted_cost,
],
device=self.device,
)
def _run_sap_backtracking_body(
self,
*,
contact_result: SapContactJacobianResult,
dt: float,
armijo_c: float,
rho: float,
max_iterations: int,
) -> None:
wp.launch(
self.k._scale_sap_backtracking_alpha,
dim=self.num_envs,
inputs=[self.ls_active, self.alpha, float(rho)],
device=self.device,
)
wp.launch(
self.k._axpy_to_trial_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.ls_active, self.dof_per_env, self.v_env, self.dv, self.alpha, self.v_trial],
device=self.device,
)
self._evaluate_trial(contact_result, dt)
self._replace_trial_cost_with_sap_line_search_cost(contact_result)
self.ls_active_count.zero_()
wp.launch(
self.k._update_sap_backtracking_iteration_conditional,
dim=self.num_envs,
inputs=[
self.trial_cost,
self.cost,
self.dell0,
self.alpha,
self.alpha_prev,
self.ell_prev,
self.ell_slop,
float(armijo_c),
self.ls_loop_iteration,
int(max_iterations),
self.ls_active,
self.ls_accepted,
self.ls_status,
self.ls_iterations,
self.accepted_cost,
self.ls_active_count,
],
device=self.device,
)
wp.launch(self.k._increment_scalar_i32, dim=1, inputs=[self.ls_loop_iteration], device=self.device)
def _run_unit_decay_line_search(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
max_iterations: int,
decay: float,
min_alpha: float,
cost_relax_r: float,
cost_relax_a: float,
check_errors: bool = True,
) -> None:
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
contact_capacity = self._contact_capacity(contact_result)
wp.launch_tiled(
self._unit_line_search_base_coeffs,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
self.newton_active,
self.dof_per_env,
self.v_env,
self.v_star_env,
self.dv,
contact_result.dynamics_matrix_env,
self.line_search_base0,
self.line_search_base_linear,
self.line_search_base_quadratic,
],
device=self.device,
)
wp.launch_tiled(
self._unit_line_search_contact_delta_velocity,
dim=(self.num_envs, contact_capacity),
block_dim=self.unit_line_search_contact_vc_tile_size,
inputs=[
self.newton_active,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
self.dv,
self.line_search_contact_dvc,
],
device=self.device,
)
wp.launch_tiled(
self._unit_line_search_fused_parallel,
dim=self.num_envs,
block_dim=self.unit_line_search_tile_size,
inputs=[
self.newton_active,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
self.contact_vc,
self.line_search_contact_dvc,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
self.v_env,
self.dv,
self.v_flat,
self.line_search_base0,
self.line_search_base_linear,
self.line_search_base_quadratic,
self.solve_dtype(self.contact_beta),
self.solve_dtype(self.contact_sigma),
self.solve_dtype(dt),
1,
int(self._has_pd_terms),
int(self._has_limit_terms),
int(max_iterations),
self.solve_dtype(decay),
self.solve_dtype(min_alpha),
self.solve_dtype(cost_relax_r),
self.solve_dtype(cost_relax_a),
self.alpha,
self.cost,
self.previous_cost,
self.accepted_cost,
self.ls_active,
self.ls_accepted,
self.ls_status,
self.ls_iterations,
self.ls_iterations_total,
],
device=self.device,
)
if check_errors:
self._raise_line_search_errors_if_any(stage="unit_device")
def _run_conditional_line_search(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
line_search_variant: str,
line_search_max_iterations: int,
armijo_c: float,
rho: float,
line_search_relative_slop: float,
cost_abs_tol: float,
cost_rel_tol: float,
check_errors: bool,
) -> None:
if line_search_variant == "monotone_decay":
self._run_unit_decay_line_search(
contact_result,
dt,
max_iterations=int(line_search_max_iterations),
decay=0.5,
min_alpha=1.0e-8,
cost_relax_r=1.0e-12,
cost_relax_a=1.0e-14,
check_errors=bool(check_errors),
)
return
if line_search_variant == "armijo_decay":
self._run_sap_backtracking(
contact_result,
dt,
armijo_c=float(armijo_c),
rho=float(rho),
alpha_max=1.0 / float(rho),
max_iterations=int(line_search_max_iterations),
relative_slop=float(line_search_relative_slop),
check_errors=bool(check_errors),
)
return
if line_search_variant == "exact_root":
self._run_sap_exact(
contact_result,
dt,
max_iterations=int(line_search_max_iterations),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
check_errors=bool(check_errors),
)
return
raise ValueError(f"Unsupported SAP line search variant {line_search_variant!r}.")
def _run_unit_conditional_newton_body(
self,
*,
contact_result: SapContactJacobianResult,
dt: float,
max_iterations: int,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
cost_min_alpha: float,
line_search_max_iterations: int,
line_search_variant: str,
armijo_c: float,
rho: float,
line_search_relative_slop: float,
check_line_search_errors: bool,
) -> None:
self._run_unit_conditional_newton_step(
contact_result=contact_result,
dt=float(dt),
line_search_max_iterations=int(line_search_max_iterations),
line_search_variant=line_search_variant,
armijo_c=float(armijo_c),
rho=float(rho),
line_search_relative_slop=float(line_search_relative_slop),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
check_line_search_errors=bool(check_line_search_errors),
)
self._solver_update_active(
contact_result,
self.stage2_active_env,
float(dt),
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
cost_min_alpha=float(cost_min_alpha),
)
def _run_unit_conditional_newton_step(
self,
*,
contact_result: SapContactJacobianResult,
dt: float,
line_search_max_iterations: int,
line_search_variant: str,
armijo_c: float,
rho: float,
line_search_relative_slop: float,
cost_abs_tol: float,
cost_rel_tol: float,
check_line_search_errors: bool,
) -> None:
contact_capacity = self._ensure_hessian_terms_for_active_envs(
contact_result,
self.newton_active,
dt,
)
self._assemble_hessian_from_terms(
contact_result,
active_env=self.newton_active,
contact_capacity=contact_capacity,
)
self._solve_newton_direction(store_first_dv=True)
self._run_conditional_line_search(
contact_result,
dt,
line_search_variant=line_search_variant,
line_search_max_iterations=int(line_search_max_iterations),
armijo_c=float(armijo_c),
rho=float(rho),
line_search_relative_slop=float(line_search_relative_slop),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
check_errors=bool(check_line_search_errors),
)
wp.launch(self.k._increment_scalar_i32, dim=1, inputs=[self.newton_loop_iteration], device=self.device)
def _run_unit_conditional_newton_loop(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
max_iterations: int,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
line_search_max_iterations: int,
line_search_variant: str,
armijo_c: float,
rho: float,
line_search_relative_slop: float,
cost_min_alpha: float,
collect_iteration_stats: bool,
check_line_search_errors: bool,
loop_counters_initialized: bool = False,
v_flat_seeded: bool = False,
) -> SapContactSolveResult:
if not bool(loop_counters_initialized):
wp.launch(
self.k._initialize_newton_loop_state,
dim=1,
inputs=[self.newton_loop_iteration, self.newton_max_reached],
device=self.device,
)
self._solver_update_active(
contact_result,
self.stage2_active_env,
float(dt),
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
cost_min_alpha=float(cost_min_alpha),
)
if int(max_iterations) > 0:
wp.capture_while(
self.active_count,
while_body=self._run_unit_conditional_newton_body,
contact_result=contact_result,
dt=float(dt),
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
cost_min_alpha=float(cost_min_alpha),
line_search_max_iterations=int(line_search_max_iterations),
line_search_variant=line_search_variant,
armijo_c=float(armijo_c),
rho=float(rho),
line_search_relative_slop=float(line_search_relative_slop),
check_line_search_errors=False,
)
if not bool(v_flat_seeded):
wp.launch(
self.k._copy_env_to_flat_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.v_env, self.dof_per_env, self.v_flat],
device=self.device,
)
self.last_iterations = -1
self.last_line_search_iterations = -1
converged = True
return self._make_result(self.last_iterations, self.last_line_search_iterations, converged)
def _run_unit_conditional_newton_loop_capture(
self,
*,
contact_result: SapContactJacobianResult,
dt: float,
max_iterations: int,
optimality_abs_tol: float,
optimality_rel_tol: float,
cost_abs_tol: float,
cost_rel_tol: float,
line_search_max_iterations: int,
line_search_variant: str,
armijo_c: float,
rho: float,
line_search_relative_slop: float,
cost_min_alpha: float,
collect_iteration_stats: bool,
check_line_search_errors: bool,
loop_counters_initialized: bool,
v_flat_seeded: bool,
) -> None:
self._run_unit_conditional_newton_loop(
contact_result,
float(dt),
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
line_search_max_iterations=int(line_search_max_iterations),
line_search_variant=line_search_variant,
armijo_c=float(armijo_c),
rho=float(rho),
line_search_relative_slop=float(line_search_relative_slop),
cost_min_alpha=float(cost_min_alpha),
collect_iteration_stats=bool(collect_iteration_stats),
check_line_search_errors=bool(check_line_search_errors),
loop_counters_initialized=bool(loop_counters_initialized),
v_flat_seeded=bool(v_flat_seeded),
)
def _run_sap_exact(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
max_iterations: int,
cost_abs_tol: float,
cost_rel_tol: float,
check_errors: bool = True,
) -> None:
if int(max_iterations) <= 0:
return
self.dell0.zero_()
self.dell_a0.zero_()
self.d2ell_a.zero_()
wp.launch(
self.k._compute_search_direction_data_serial_batched,
dim=self.num_envs,
inputs=[
self.newton_active,
self.dof_per_env,
contact_result.dynamics_matrix_env,
self.v_env,
self.v_star_env,
self.grad,
self.dv,
self.dp,
self.dell0,
self.dell_a0,
self.d2ell_a,
],
device=self.device,
)
self._compute_base_cost(
self.newton_active,
self.v_env,
contact_result.dynamics_matrix_env,
self.line_momentum_cost,
)
alpha_max = float(_SAP_EXACT_LINE_SEARCH_ALPHA_MAX)
wp.launch(
self.k._init_sap_exact_alpha_max_state,
dim=self.num_envs,
inputs=[
self.newton_active,
self.cost,
self.dell0,
float(alpha_max),
self.alpha,
self.ls_active,
self.ls_accepted,
self.ls_status,
self.ls_iterations,
self.accepted_cost,
],
device=self.device,
)
wp.launch(
self.k._axpy_to_trial_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.ls_active, self.dof_per_env, self.v_env, self.dv, self.alpha, self.v_trial],
device=self.device,
)
self._evaluate_trial(contact_result, dt, include_contact_hessian=True)
self._replace_trial_cost_with_sap_line_search_cost(contact_result)
self._compute_trial_derivative()
self._compute_trial_second_derivative(contact_result)
self.ls_active_count.zero_()
wp.launch(
self.k._init_sap_exact_root_state,
dim=self.num_envs,
inputs=[
self.trial_cost,
self.trial_derivative,
self.trial_second_derivative,
self.cost,
self.dell0,
float(alpha_max),
float(cost_abs_tol),
float(cost_rel_tol),
float(_SAP_EXACT_LINE_SEARCH_F_TOLERANCE),
self.alpha,
self.ls_active,
self.ls_accepted,
self.ls_status,
self.accepted_cost,
self.exact_scale,
self.exact_x_lower,
self.exact_x_upper,
self.exact_f_lower,
self.exact_f_upper,
self.exact_root,
self.exact_minus_dx,
self.exact_minus_dx_previous,
self.exact_x_tolerance,
self.ls_active_count,
],
device=self.device,
)
self.ls_loop_iteration.fill_(1)
wp.capture_while(
self.ls_active_count,
while_body=self._run_sap_exact_root_body,
contact_result=contact_result,
dt=float(dt),
max_iterations=int(max_iterations),
)
wp.launch(self.k._copy_i32_batched, dim=self.num_envs, inputs=[self.ls_accepted, self.ls_active], device=self.device)
wp.launch(
self.k._axpy_to_trial_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.ls_active, self.dof_per_env, self.v_env, self.dv, self.alpha, self.v_trial],
device=self.device,
)
self._evaluate_trial_cost_only(contact_result, dt)
self._replace_trial_cost_with_sap_line_search_cost(contact_result)
wp.launch(
self.k._store_exact_accepted_cost,
dim=self.num_envs,
inputs=[self.ls_accepted, self.trial_cost, self.accepted_cost],
device=self.device,
)
wp.launch(
self.k._accumulate_line_search_iterations_batched,
dim=self.num_envs,
inputs=[self.ls_accepted, self.ls_iterations, self.ls_iterations_total],
device=self.device,
)
wp.launch(
self.k._commit_line_search_step_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
self.newton_active,
self.ls_accepted,
self.dof_per_env,
self.alpha,
self.v_env,
self.dv,
self.v_flat,
self.cost,
self.previous_cost,
self.accepted_cost,
],
device=self.device,
)
if check_errors:
self._raise_line_search_errors_if_any(stage="exact")
def _run_sap_exact_root_body(
self,
*,
contact_result: SapContactJacobianResult,
dt: float,
max_iterations: int,
) -> None:
wp.launch(
self.k._axpy_to_trial_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[self.ls_active, self.dof_per_env, self.v_env, self.dv, self.alpha, self.v_trial],
device=self.device,
)
self._evaluate_trial(contact_result, dt, include_contact_hessian=True)
self._replace_trial_cost_with_sap_line_search_cost(contact_result)
self._compute_trial_derivative()
self._compute_trial_second_derivative(contact_result)
self.ls_active_count.zero_()
wp.launch(
self.k._update_sap_exact_root_state,
dim=self.num_envs,
inputs=[
self.trial_derivative,
self.trial_second_derivative,
float(_SAP_EXACT_LINE_SEARCH_F_TOLERANCE),
self.ls_loop_iteration,
int(max_iterations),
self.alpha,
self.ls_active,
self.ls_accepted,
self.ls_status,
self.ls_iterations,
self.exact_scale,
self.exact_x_lower,
self.exact_x_upper,
self.exact_f_lower,
self.exact_f_upper,
self.exact_root,
self.exact_minus_dx,
self.exact_minus_dx_previous,
self.exact_x_tolerance,
self.ls_active_count,
],
device=self.device,
)
wp.launch(self.k._increment_scalar_i32, dim=1, inputs=[self.ls_loop_iteration], device=self.device)
def _evaluate_trial(
self,
contact_result: SapContactJacobianResult,
dt: float,
*,
include_contact_hessian: bool = False,
) -> None:
self._evaluate_problem(
contact_result,
self.v_trial,
self.ls_active,
dt,
include_hessian=bool(include_contact_hessian),
cost=self.trial_cost,
constraint_impulse=self.trial_constraint_impulse,
contact_gamma=self.trial_contact_gamma,
contact_g=self.trial_contact_g,
contact_vc=self.trial_contact_vc,
contact_y=self.trial_contact_y,
contact_rt=self.trial_contact_rt,
contact_rn=self.trial_contact_rn,
contact_cost=self.trial_contact_cost,
contact_mode=self.trial_contact_mode,
pd_y=self.trial_pd_y,
pd_gamma=self.trial_pd_gamma,
pd_hdiag=self.trial_pd_hdiag,
pd_cost=self.trial_pd_cost,
lower_gamma=self.trial_limit_lower_gamma,
upper_gamma=self.trial_limit_upper_gamma,
limit_grad=self.trial_limit_grad,
limit_hdiag=self.trial_limit_hdiag,
limit_cost=self.trial_limit_cost,
)
def _evaluate_trial_cost_only(self, contact_result: SapContactJacobianResult, dt: float) -> None:
self.trial_cost.zero_()
contact_capacity = self._contact_capacity(contact_result)
contact_tau_d = getattr(contact_result, "contact_env_tau_d", None)
if contact_tau_d is None:
self.contact_tau_d_fallback.fill_(self.contact_tau_d)
contact_tau_d = self.contact_tau_d_fallback
self._compute_base_cost(
self.ls_active,
self.v_trial,
contact_result.dynamics_matrix_env,
self.trial_cost,
)
wp.launch(
self.k._projection_cost_only_contact_sap_batched,
dim=(self.num_envs, contact_capacity),
inputs=[
self.ls_active,
self.dof_per_env,
contact_capacity,
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
contact_result.contact_env_phi0,
contact_result.contact_env_w_eff,
contact_result.contact_env_mu,
contact_result.contact_env_stiffness,
contact_tau_d,
self.v_trial,
self.dv,
self.contact_beta,
self.contact_sigma,
float(dt),
self.trial_contact_cost,
self.trial_cost,
],
device=self.device,
)
wp.launch(
self.k._eval_pd_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_pd_terms),
self.ls_active,
self.dof_per_env,
self.pd_active,
self.pd_a,
self.pd_gain,
self.pd_limit,
self.v_trial,
float(dt),
self.trial_pd_y,
self.trial_pd_gamma,
self.trial_pd_hdiag,
self.trial_pd_cost,
self.trial_cost,
],
device=self.device,
)
wp.launch(
self.k._eval_limit_terms_sap_batched,
dim=(self.num_envs, self.dof_per_env),
inputs=[
int(self._has_limit_terms),
self.ls_active,
self.dof_per_env,
self.limit_lower_active,
self.limit_upper_active,
self.limit_lower_vhat,
self.limit_upper_vhat,
self.limit_lower_r,
self.limit_upper_r,
self.limit_lower_rinv,
self.limit_upper_rinv,
self.v_trial,
self.trial_limit_lower_gamma,
self.trial_limit_upper_gamma,
self.trial_limit_grad,
self.trial_limit_hdiag,
self.trial_limit_cost,
self.trial_cost,
],
device=self.device,
)
def _replace_trial_cost_with_sap_line_search_cost(self, contact_result: SapContactJacobianResult) -> None:
wp.launch(
self.k._replace_trial_cost_with_sap_line_search_cost_batched,
dim=self.num_envs,
inputs=[
self.ls_active,
self.dof_per_env,
self._contact_capacity(contact_result),
contact_result.contact_env_count,
self.trial_contact_cost,
int(self._has_pd_terms),
int(self._has_limit_terms),
self.trial_pd_cost,
self.trial_limit_cost,
self.line_momentum_cost,
self.dell_a0,
self.d2ell_a,
self.alpha,
self.trial_cost,
],
device=self.device,
)
def _compute_trial_derivative(self) -> None:
self.trial_derivative.zero_()
wp.launch(
self.k._compute_line_derivative_serial_batched,
dim=self.num_envs,
inputs=[
self.ls_active,
self.dof_per_env,
self.v_trial,
self.v_star_env,
self.dv,
self.dp,
self.trial_constraint_impulse,
self.trial_derivative,
],
device=self.device,
)
def _compute_trial_second_derivative(self, contact_result: SapContactJacobianResult) -> None:
self.trial_second_derivative.zero_()
wp.launch(
self.k._compute_line_second_derivative_serial_batched,
dim=self.num_envs,
inputs=[
self.ls_active,
self.dof_per_env,
self._contact_capacity(contact_result),
contact_result.contact_env_count,
contact_result.contact_env_jacobian,
self.trial_contact_g,
int(self._has_pd_terms),
int(self._has_limit_terms),
self.trial_pd_hdiag,
self.trial_limit_hdiag,
self.dv,
self.d2ell_a,
self.trial_second_derivative,
],
device=self.device,
)
def _raise_line_search_errors_if_any(self, *, stage: str) -> None:
status = self.ls_status.numpy()
bad = np.nonzero(status < 0)[0]
if bad.size == 0:
return
env = int(bad[0])
code = int(status[env])
dell0 = float(self.dell0.numpy()[env])
cost = float(self.cost.numpy()[env])
alpha = float(self.alpha.numpy()[env])
raise RuntimeError(
"SapContactSolve line search failed "
f"(stage={stage}, env={env}, status={code}, dell0={dell0:.6e}, "
f"cost={cost:.6e}, alpha={alpha:.6e})."
)
[docs]
def solve(
self,
contact_result: SapContactJacobianResult,
state: State,
control: Control | None,
dt: float,
v_star: wp.array,
*,
v0: wp.array | None = None,
v_guess: wp.array | None = None,
v_guess_active: wp.array | None = None,
max_iterations: int = 100,
optimality_abs_tol: float = 1.0e-14,
optimality_rel_tol: float = 1.0e-6,
cost_abs_tol: float | None = None,
cost_rel_tol: float | None = None,
line_search_max_iterations: int = 40,
armijo_c: float = 1.0e-4,
rho: float = 0.8,
line_search_relative_slop: float | None = None,
line_search_variant: str = "monotone_decay",
collect_iteration_stats: bool = True,
check_line_search_errors: bool = True,
graph_conditional: bool = True,
) -> SapContactSolveResult:
"""Solve the SAP velocity objective for the active contacts and write the next generalized velocity
in SAP order.
"""
if line_search_relative_slop is None:
line_search_relative_slop = 1000.0 * np.finfo(self.numpy_dtype).eps
contact_result = self._cast_contact_result(contact_result)
line_search_variant = normalize_sap_line_search_mode(line_search_variant)
if not bool(graph_conditional):
raise ValueError(
"SapContactSolve only supports graph_conditional=True; "
"the Python Newton-loop fallback has been removed."
)
if line_search_variant == "armijo_decay" and not (0.0 < float(rho) < 1.0):
raise ValueError(f"rho must lie in (0, 1), got {rho!r}.")
if line_search_variant == "exact_root" and int(line_search_max_iterations) == 40:
line_search_max_iterations = int(_SAP_EXACT_LINE_SEARCH_MAX_ITERATIONS)
if cost_abs_tol is None:
cost_abs_tol = 0.0 if line_search_variant == "monotone_decay" else 1.0e-30
if cost_rel_tol is None:
cost_rel_tol = 5.0e-3 if line_search_variant == "monotone_decay" else 1.0e-15
cost_min_alpha = 0.0 if line_search_variant == "monotone_decay" else 0.5
self.prepare(
contact_result,
state,
control,
dt,
v_star,
v0=v0,
v_guess=v_guess,
v_guess_active=v_guess_active,
)
self.last_iterations = 0
self.last_line_search_iterations = 0
self.stage2_active_count.zero_()
wp.launch(
self.k._initialize_and_mark_unconstrained_free_envs_batched,
dim=self.num_envs,
inputs=[
self.dof_per_env,
contact_result.contact_env_count,
self.pd_active,
self.limit_lower_active,
self.limit_upper_active,
self.participating_dof,
self.v_star_env,
self.v_env,
self.v_flat,
self.first_dv,
self.newton_iterations_env,
self.ls_iterations_total,
self.alpha,
self.previous_cost,
self.converged_env,
self.optimality_reached_env,
self.cost_reached_env,
self.stage2_active_env,
self.newton_active,
self.stage2_active_count,
],
device=self.device,
)
wp.capture_if(
self.stage2_active_count,
on_true=self._run_unit_conditional_newton_loop_capture,
contact_result=contact_result,
dt=float(dt),
max_iterations=int(max_iterations),
optimality_abs_tol=float(optimality_abs_tol),
optimality_rel_tol=float(optimality_rel_tol),
cost_abs_tol=float(cost_abs_tol),
cost_rel_tol=float(cost_rel_tol),
line_search_max_iterations=int(line_search_max_iterations),
line_search_variant=line_search_variant,
armijo_c=float(armijo_c),
rho=float(rho),
line_search_relative_slop=float(line_search_relative_slop),
cost_min_alpha=float(cost_min_alpha),
collect_iteration_stats=bool(collect_iteration_stats),
check_line_search_errors=bool(check_line_search_errors),
loop_counters_initialized=False,
v_flat_seeded=(v_guess is self.v_flat),
)
if v_guess_active is not None:
wp.launch(
self.k._set_scalar_i32,
dim=1,
inputs=[v_guess_active, 1],
device=self.device,
)
self.last_iterations = -1
self.last_line_search_iterations = -1
return self._make_result(self.last_iterations, self.last_line_search_iterations, True)
def _make_result(self, iterations: int, line_search_iterations: int, converged: bool) -> SapContactSolveResult:
return SapContactSolveResult(
v_env=self.v_env,
v_flat=self.v_flat,
cost=self.cost,
previous_cost=self.previous_cost,
grad=self.grad,
hessian=self.hessian,
constraint_impulse=self.constraint_impulse,
dynamics_impulse=self.dynamics_impulse,
contact_gamma=self.contact_gamma,
contact_g=self.contact_g,
contact_vc=self.contact_vc,
contact_y=self.contact_y,
contact_rt=self.contact_rt,
contact_rn=self.contact_rn,
contact_cost=self.contact_cost,
contact_mode=self.contact_mode,
pd_active=self.pd_active,
pd_y=self.pd_y,
pd_gamma=self.pd_gamma,
pd_hdiag=self.pd_hdiag,
pd_cost=self.pd_cost,
pd_kp_eff=self.pd_kp_eff,
pd_kd_eff=self.pd_kd_eff,
limit_lower_active=self.limit_lower_active,
limit_upper_active=self.limit_upper_active,
limit_lower_gamma=self.limit_lower_gamma,
limit_upper_gamma=self.limit_upper_gamma,
limit_grad=self.limit_grad,
limit_hdiag=self.limit_hdiag,
limit_cost=self.limit_cost,
first_dv=self.first_dv,
alpha=self.alpha,
newton_iterations_env=self.newton_iterations_env,
line_search_iterations_env=self.ls_iterations_total,
newton_active=self.newton_active,
converged_env=self.converged_env,
optimality_reached_env=self.optimality_reached_env,
cost_reached_env=self.cost_reached_env,
iterations=int(iterations),
line_search_iterations=int(line_search_iterations),
converged=bool(converged),
)
__all__ = [
"SapContactSolve",
"SapContactSolveResult",
"normalize_sap_line_search_mode",
"_CONTACT_MODE_NONE",
"_CONTACT_MODE_STICTION",
"_CONTACT_MODE_SLIDING",
"_CONTACT_MODE_FRICTIONLESS",
]