pub struct MatWithArgmax<S: TropicalWithArgmax> {
pub values: Mat<S>,
pub argmax: Vec<u32>,
}Expand description
Result of matrix multiplication with argmax tracking.
Fields§
§values: Mat<S>The result matrix values.
argmax: Vec<u32>The argmax indices (which k produced each C[i,j]).
Implementations§
Source§impl<S: TropicalWithArgmax<Index = u32>> MatWithArgmax<S>
impl<S: TropicalWithArgmax<Index = u32>> MatWithArgmax<S>
Sourcepub fn get_value(&self, i: usize, j: usize) -> S::Scalar
pub fn get_value(&self, i: usize, j: usize) -> S::Scalar
Get the scalar value at position (i, j).
This is a convenience method that extracts the underlying scalar without requiring a trait import.
Sourcepub fn get_argmax(&self, i: usize, j: usize) -> u32
pub fn get_argmax(&self, i: usize, j: usize) -> u32
Get the argmax index at position (i, j).
Sourcepub fn argmax_slice(&self) -> &[u32]
pub fn argmax_slice(&self) -> &[u32]
Get the argmax indices as a slice.
This is useful for backward pass computation.
Sourcepub fn backward_a<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
pub fn backward_a<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
Compute gradient with respect to matrix A.
Given the upstream gradient dL/dC, computes dL/dA using the argmax indices from the forward pass.
For C = A ⊗ B where C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j]): dL/dA[i,k] = Σ_j { dL/dC[i,j] if argmax[i,j] == k }
§Arguments
grad_c- Gradient of the loss with respect to C, dimensions m×nk- Number of columns in A (the inner dimension)
§Returns
Gradient of the loss with respect to A, dimensions m×k
§Example
use tropical_gemm::{Mat, MaxPlus, TropicalMaxPlus};
let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
// Forward pass with argmax
let result = a.matmul_argmax(&b);
// Backward pass: grad_c is upstream gradient (e.g., all ones)
let grad_c = Mat::<MaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
let grad_a = result.backward_a(&grad_c, 3); // k=3 (columns in A)
assert_eq!(grad_a.nrows(), 2);
assert_eq!(grad_a.ncols(), 3);Sourcepub fn backward_b<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
pub fn backward_b<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
Compute gradient with respect to matrix B.
Given the upstream gradient dL/dC, computes dL/dB using the argmax indices from the forward pass.
For C = A ⊗ B where C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j]): dL/dB[k,j] = Σ_i { dL/dC[i,j] if argmax[i,j] == k }
§Arguments
grad_c- Gradient of the loss with respect to C, dimensions m×nk- Number of rows in B (the inner dimension)
§Returns
Gradient of the loss with respect to B, dimensions k×n
§Example
use tropical_gemm::{Mat, MaxPlus, TropicalMaxPlus};
let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
// Forward pass with argmax
let result = a.matmul_argmax(&b);
// Backward pass: grad_c is upstream gradient
let grad_c = Mat::<MaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
let grad_b = result.backward_b(&grad_c, 3); // k=3 (rows in B)
assert_eq!(grad_b.nrows(), 3);
assert_eq!(grad_b.ncols(), 2);Auto Trait Implementations§
impl<S> Freeze for MatWithArgmax<S>
impl<S> RefUnwindSafe for MatWithArgmax<S>where
S: RefUnwindSafe,
impl<S> Send for MatWithArgmax<S>
impl<S> Sync for MatWithArgmax<S>
impl<S> Unpin for MatWithArgmax<S>where
S: Unpin,
impl<S> UnwindSafe for MatWithArgmax<S>where
S: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more