pub fn tropical_matmul_batched<T>(
a_batch: &[Vec<T::Scalar>],
b_batch: &[Vec<T::Scalar>],
m: usize,
k: usize,
n: usize,
) -> Vec<Vec<T>>Expand description
Batched tropical matrix multiplication: C[i] = A[i] ⊗ B[i] for i = 0..batch_size
All matrices in the batch must have the same dimensions:
- Each A[i] is m × k
- Each B[i] is k × n
- Each C[i] is m × n
§Arguments
a_batch: Slice of batch_size matrices, each of size m×k in row-major orderb_batch: Slice of batch_size matrices, each of size k×n in row-major orderm: Number of rows in each A matrixk: Number of columns in A / rows in Bn: Number of columns in each B matrix
§Returns
Vector of batch_size result matrices, each of size m×n
§Example
use tropical_gemm::{tropical_matmul_batched, TropicalMaxPlus};
// Two 2x2 matrix multiplications
let a_batch = vec![
vec![1.0f32, 2.0, 3.0, 4.0], // A[0]: 2x2
vec![5.0f32, 6.0, 7.0, 8.0], // A[1]: 2x2
];
let b_batch = vec![
vec![1.0f32, 2.0, 3.0, 4.0], // B[0]: 2x2
vec![1.0f32, 2.0, 3.0, 4.0], // B[1]: 2x2
];
let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f32>>(&a_batch, &b_batch, 2, 2, 2);
assert_eq!(c_batch.len(), 2);