tropical_matmul_batched_with_argmax

Function tropical_matmul_batched_with_argmax 

Source
pub fn tropical_matmul_batched_with_argmax<T>(
    a_batch: &[Vec<T::Scalar>],
    b_batch: &[Vec<T::Scalar>],
    m: usize,
    k: usize,
    n: usize,
) -> Vec<GemmWithArgmax<T>>
where T::Scalar: Send + Sync, T: Send + Sync + TropicalWithArgmax<Index = u32> + KernelDispatch,
Expand description

Batched tropical matrix multiplication with argmax tracking.

C[i] = A[i] ⊗ B[i] for i = 0..batch_size, with argmax indices.

§Arguments

  • a_batch: Slice of batch_size matrices, each of size m×k
  • b_batch: Slice of batch_size matrices, each of size k×n
  • m: Number of rows in each A matrix
  • k: Number of columns in A / rows in B
  • n: Number of columns in each B matrix

§Returns

Vector of batch_size GemmWithArgmax results