tropical_backward_b_batched

Function tropical_backward_b_batched 

Source
pub fn tropical_backward_b_batched<T: Copy + Default + AddAssign + Send + Sync>(
    grad_c_batch: &[Vec<T>],
    argmax_batch: &[Vec<u32>],
    m: usize,
    k: usize,
    n: usize,
) -> Vec<Vec<T>>
Expand description

Batched backward pass for gradient with respect to B.

Computes dL/dB[i] for each batch element.

§Arguments

  • grad_c_batch - Batch of upstream gradients, each size m×n
  • argmax_batch - Batch of argmax indices from forward pass
  • m - Number of rows in C
  • k - Number of rows in B
  • n - Number of columns in B

§Returns

Vector of gradients dL/dB[i], each of size k×n