pub fn tropical_backward_b<T: Copy + Default + AddAssign>(
grad_c: &[T],
argmax: &[u32],
m: usize,
k: usize,
n: usize,
) -> Vec<T>Expand description
Compute gradient with respect to matrix B in tropical matmul.
Given the forward pass C = A ⊗ B with argmax tracking, and upstream gradient dL/dC, computes dL/dB.
For tropical matmul, the gradient routing is:
dL/dB[k,j] = Σ_i { dL/dC[i,j] if argmax[i,j] == k, else 0 }§Arguments
grad_c- Upstream gradient dL/dC, size m×nargmax- Argmax indices from forward pass, size m×nm- Number of rows in C (used for iteration)k- Number of rows in Bn- Number of columns in B
§Returns
Gradient dL/dB of size k×n
§Example
use tropical_gemm::{tropical_matmul_with_argmax, tropical_backward_b, TropicalMaxPlus};
let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
// Forward pass
let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
// Upstream gradient
let grad_c = [1.0f64; 4]; // 2x2
// Backward pass for B
let grad_b = tropical_backward_b::<f64>(&grad_c, result.argmax_slice(), 2, 3, 2);
assert_eq!(grad_b.len(), 6); // 3x2