pub fn tropical_backward_a_batched<T: Copy + Default + AddAssign + Send + Sync>(
grad_c_batch: &[Vec<T>],
argmax_batch: &[Vec<u32>],
m: usize,
k: usize,
n: usize,
) -> Vec<Vec<T>>Expand description
Batched backward pass for gradient with respect to A.
Computes dL/dA[i] for each batch element.
§Arguments
grad_c_batch- Batch of upstream gradients, each size m×nargmax_batch- Batch of argmax indices from forward passm- Number of rows in Ak- Number of columns in An- Number of columns in C
§Returns
Vector of gradients dL/dA[i], each of size m×k