tropical_matmul_strided_batched

Function tropical_matmul_strided_batched 

Source
pub fn tropical_matmul_strided_batched<T>(
    a: &[T::Scalar],
    b: &[T::Scalar],
    batch_size: usize,
    m: usize,
    k: usize,
    n: usize,
) -> Vec<T>
Expand description

Strided batched GEMM: computes C[i] = A[i] ⊗ B[i] from contiguous memory.

This is more efficient than tropical_matmul_batched when all matrices are stored contiguously in memory with fixed strides.

§Arguments

  • a: Contiguous array of all A matrices (batch_size × m × k elements)
  • b: Contiguous array of all B matrices (batch_size × k × n elements)
  • batch_size: Number of matrix pairs
  • m: Rows in each A
  • k: Columns in A / rows in B
  • n: Columns in each B

§Returns

Contiguous array of all C matrices (batch_size × m × n elements)

§Example

use tropical_gemm::{tropical_matmul_strided_batched, TropicalMaxPlus};

// Two 2x2 matrix pairs stored contiguously
let a = vec![
    1.0f32, 2.0, 3.0, 4.0,  // A[0]
    5.0, 6.0, 7.0, 8.0,      // A[1]
];
let b = vec![
    1.0f32, 2.0, 3.0, 4.0,  // B[0]
    1.0, 2.0, 3.0, 4.0,      // B[1]
];

let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f32>>(&a, &b, 2, 2, 2, 2);
assert_eq!(c.len(), 8); // 2 batches × 2×2 results