tropical_backward_a_batched

Function tropical_backward_a_batched 

Source
pub fn tropical_backward_a_batched<T: Copy + Default + AddAssign + Send + Sync>(
    grad_c_batch: &[Vec<T>],
    argmax_batch: &[Vec<u32>],
    m: usize,
    k: usize,
    n: usize,
) -> Vec<Vec<T>>
Expand description

Batched backward pass for gradient with respect to A.

Computes dL/dA[i] for each batch element.

§Arguments

  • grad_c_batch - Batch of upstream gradients, each size m×n
  • argmax_batch - Batch of argmax indices from forward pass
  • m - Number of rows in A
  • k - Number of columns in A
  • n - Number of columns in C

§Returns

Vector of gradients dL/dA[i], each of size m×k