pub fn tropical_matmul_strided_batched<T>(
a: &[T::Scalar],
b: &[T::Scalar],
batch_size: usize,
m: usize,
k: usize,
n: usize,
) -> Vec<T>Expand description
Strided batched GEMM: computes C[i] = A[i] ⊗ B[i] from contiguous memory.
This is more efficient than tropical_matmul_batched when all matrices
are stored contiguously in memory with fixed strides.
§Arguments
a: Contiguous array of all A matrices (batch_size × m × k elements)b: Contiguous array of all B matrices (batch_size × k × n elements)batch_size: Number of matrix pairsm: Rows in each Ak: Columns in A / rows in Bn: Columns in each B
§Returns
Contiguous array of all C matrices (batch_size × m × n elements)
§Example
use tropical_gemm::{tropical_matmul_strided_batched, TropicalMaxPlus};
// Two 2x2 matrix pairs stored contiguously
let a = vec![
1.0f32, 2.0, 3.0, 4.0, // A[0]
5.0, 6.0, 7.0, 8.0, // A[1]
];
let b = vec![
1.0f32, 2.0, 3.0, 4.0, // B[0]
1.0, 2.0, 3.0, 4.0, // B[1]
];
let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f32>>(&a, &b, 2, 2, 2, 2);
assert_eq!(c.len(), 8); // 2 batches × 2×2 results