tropical_backward_a

Function tropical_backward_a 

Source
pub fn tropical_backward_a<T: Copy + Default + AddAssign>(
    grad_c: &[T],
    argmax: &[u32],
    m: usize,
    k: usize,
    n: usize,
) -> Vec<T>
Expand description

Compute gradient with respect to matrix A in tropical matmul.

Given the forward pass C = A ⊗ B with argmax tracking, and upstream gradient dL/dC, computes dL/dA.

For tropical matmul, the gradient routing is:

dL/dA[i,k] = Σ_j { dL/dC[i,j] if argmax[i,j] == k, else 0 }

§Arguments

  • grad_c - Upstream gradient dL/dC, size m×n
  • argmax - Argmax indices from forward pass, size m×n
  • m - Number of rows in A
  • k - Number of columns in A
  • n - Number of columns in C (used for argmax indexing)

§Returns

Gradient dL/dA of size m×k

§Example

use tropical_gemm::{tropical_matmul_with_argmax, tropical_backward_a, TropicalMaxPlus};

let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2

// Forward pass
let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);

// Upstream gradient (e.g., all ones)
let grad_c = [1.0f64; 4]; // 2x2

// Backward pass for A
let grad_a = tropical_backward_a::<f64>(&grad_c, result.argmax_slice(), 2, 3, 2);
assert_eq!(grad_a.len(), 6); // 2x3