tropical_gemm/
api.rs

1use crate::core::{GemmWithArgmax, Transpose};
2use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
3use crate::types::{TropicalSemiring, TropicalWithArgmax};
4
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7
8/// Simple tropical matrix multiplication: C = A ⊗ B
9///
10/// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
11///
12/// # Arguments
13/// - `a`: Matrix A data in row-major order
14/// - `m`: Number of rows in A
15/// - `k`: Number of columns in A / rows in B
16/// - `b`: Matrix B data in row-major order
17/// - `n`: Number of columns in B
18///
19/// # Returns
20/// Result matrix C of size m×n in row-major order
21///
22/// # Example
23///
24/// ```
25/// use tropical_gemm::{tropical_matmul, TropicalMaxPlus};
26///
27/// let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
28/// let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
29///
30/// let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, 2, 3, &b, 2);
31/// assert_eq!(c.len(), 4); // 2x2 result
32/// ```
33pub fn tropical_matmul<T: TropicalSemiring + KernelDispatch>(
34    a: &[T::Scalar],
35    m: usize,
36    k: usize,
37    b: &[T::Scalar],
38    n: usize,
39) -> Vec<T> {
40    assert_eq!(a.len(), m * k, "A dimensions mismatch");
41    assert_eq!(b.len(), k * n, "B dimensions mismatch");
42
43    let mut c = vec![T::tropical_zero(); m * n];
44
45    unsafe {
46        tropical_gemm_dispatch::<T>(
47            m,
48            n,
49            k,
50            a.as_ptr(),
51            k,
52            Transpose::NoTrans,
53            b.as_ptr(),
54            n,
55            Transpose::NoTrans,
56            c.as_mut_ptr(),
57            n,
58        );
59    }
60
61    c
62}
63
64/// Tropical matrix multiplication with argmax tracking.
65///
66/// Returns both the result matrix and the argmax indices indicating
67/// which k produced each optimal C[i,j].
68///
69/// # Example
70///
71/// ```
72/// use tropical_gemm::{tropical_matmul_with_argmax, TropicalMaxPlus};
73///
74/// let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
75/// let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
76///
77/// let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
78/// assert_eq!(result.m, 2);
79/// assert_eq!(result.n, 2);
80/// ```
81pub fn tropical_matmul_with_argmax<T: TropicalWithArgmax<Index = u32> + KernelDispatch>(
82    a: &[T::Scalar],
83    m: usize,
84    k: usize,
85    b: &[T::Scalar],
86    n: usize,
87) -> GemmWithArgmax<T> {
88    assert_eq!(a.len(), m * k, "A dimensions mismatch");
89    assert_eq!(b.len(), k * n, "B dimensions mismatch");
90
91    let mut result = GemmWithArgmax::new(m, n);
92
93    unsafe {
94        crate::core::tropical_gemm_with_argmax_portable::<T>(
95            m,
96            n,
97            k,
98            a.as_ptr(),
99            k,
100            Transpose::NoTrans,
101            b.as_ptr(),
102            n,
103            Transpose::NoTrans,
104            &mut result,
105        );
106    }
107
108    result
109}
110
111/// Builder for configuring tropical GEMM operations.
112///
113/// Provides a fluent API for setting options like transposition,
114/// alpha/beta scaling, and output preferences.
115///
116/// # Example
117///
118/// ```
119/// use tropical_gemm::{TropicalGemm, TropicalMaxPlus, TropicalSemiring};
120///
121/// let a = vec![1.0f32; 6]; // 2x3
122/// let b = vec![1.0f32; 6]; // 3x2
123/// let mut c = vec![TropicalMaxPlus::tropical_zero(); 4]; // 2x2
124///
125/// TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3)
126///     .execute(&a, 3, &b, 2, &mut c, 2);
127/// ```
128pub struct TropicalGemm<T: TropicalSemiring> {
129    m: usize,
130    n: usize,
131    k: usize,
132    trans_a: Transpose,
133    trans_b: Transpose,
134    _phantom: std::marker::PhantomData<T>,
135}
136
137impl<T: TropicalSemiring + KernelDispatch> TropicalGemm<T> {
138    /// Create a new GEMM builder.
139    pub fn new(m: usize, n: usize, k: usize) -> Self {
140        Self {
141            m,
142            n,
143            k,
144            trans_a: Transpose::NoTrans,
145            trans_b: Transpose::NoTrans,
146            _phantom: std::marker::PhantomData,
147        }
148    }
149
150    /// Transpose matrix A.
151    pub fn trans_a(mut self) -> Self {
152        self.trans_a = Transpose::Trans;
153        self
154    }
155
156    /// Transpose matrix B.
157    pub fn trans_b(mut self) -> Self {
158        self.trans_b = Transpose::Trans;
159        self
160    }
161
162    /// Execute the GEMM operation.
163    ///
164    /// # Arguments
165    /// - `a`: Matrix A data
166    /// - `lda`: Leading dimension of A
167    /// - `b`: Matrix B data
168    /// - `ldb`: Leading dimension of B
169    /// - `c`: Output matrix C (must be pre-allocated)
170    /// - `ldc`: Leading dimension of C
171    pub fn execute(
172        self,
173        a: &[T::Scalar],
174        lda: usize,
175        b: &[T::Scalar],
176        ldb: usize,
177        c: &mut [T],
178        ldc: usize,
179    ) {
180        unsafe {
181            tropical_gemm_dispatch::<T>(
182                self.m,
183                self.n,
184                self.k,
185                a.as_ptr(),
186                lda,
187                self.trans_a,
188                b.as_ptr(),
189                ldb,
190                self.trans_b,
191                c.as_mut_ptr(),
192                ldc,
193            );
194        }
195    }
196}
197
198/// BLAS-style GEMM interface.
199///
200/// C = A ⊗ B
201///
202/// # Safety
203/// All pointers must be valid for the specified dimensions.
204pub unsafe fn tropical_gemm<T: TropicalSemiring + KernelDispatch>(
205    m: usize,
206    n: usize,
207    k: usize,
208    a: *const T::Scalar,
209    lda: usize,
210    trans_a: Transpose,
211    b: *const T::Scalar,
212    ldb: usize,
213    trans_b: Transpose,
214    c: *mut T,
215    ldc: usize,
216) {
217    tropical_gemm_dispatch::<T>(m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc);
218}
219
220/// Batched tropical matrix multiplication: C[i] = A[i] ⊗ B[i] for i = 0..batch_size
221///
222/// All matrices in the batch must have the same dimensions:
223/// - Each A[i] is m × k
224/// - Each B[i] is k × n
225/// - Each C[i] is m × n
226///
227/// # Arguments
228/// - `a_batch`: Slice of batch_size matrices, each of size m×k in row-major order
229/// - `b_batch`: Slice of batch_size matrices, each of size k×n in row-major order
230/// - `m`: Number of rows in each A matrix
231/// - `k`: Number of columns in A / rows in B
232/// - `n`: Number of columns in each B matrix
233///
234/// # Returns
235/// Vector of batch_size result matrices, each of size m×n
236///
237/// # Example
238///
239/// ```
240/// use tropical_gemm::{tropical_matmul_batched, TropicalMaxPlus};
241///
242/// // Two 2x2 matrix multiplications
243/// let a_batch = vec![
244///     vec![1.0f32, 2.0, 3.0, 4.0],  // A[0]: 2x2
245///     vec![5.0f32, 6.0, 7.0, 8.0],  // A[1]: 2x2
246/// ];
247/// let b_batch = vec![
248///     vec![1.0f32, 2.0, 3.0, 4.0],  // B[0]: 2x2
249///     vec![1.0f32, 2.0, 3.0, 4.0],  // B[1]: 2x2
250/// ];
251///
252/// let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f32>>(&a_batch, &b_batch, 2, 2, 2);
253/// assert_eq!(c_batch.len(), 2);
254/// ```
255pub fn tropical_matmul_batched<T: TropicalSemiring + KernelDispatch>(
256    a_batch: &[Vec<T::Scalar>],
257    b_batch: &[Vec<T::Scalar>],
258    m: usize,
259    k: usize,
260    n: usize,
261) -> Vec<Vec<T>>
262where
263    T::Scalar: Send + Sync,
264    T: Send + Sync,
265{
266    assert_eq!(
267        a_batch.len(),
268        b_batch.len(),
269        "Batch sizes must match: A has {} matrices, B has {}",
270        a_batch.len(),
271        b_batch.len()
272    );
273
274    let batch_size = a_batch.len();
275    if batch_size == 0 {
276        return Vec::new();
277    }
278
279    // Validate dimensions
280    for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
281        assert_eq!(
282            a.len(),
283            m * k,
284            "A[{}] dimensions mismatch: expected {}, got {}",
285            i,
286            m * k,
287            a.len()
288        );
289        assert_eq!(
290            b.len(),
291            k * n,
292            "B[{}] dimensions mismatch: expected {}, got {}",
293            i,
294            k * n,
295            b.len()
296        );
297    }
298
299    #[cfg(feature = "parallel")]
300    {
301        a_batch
302            .par_iter()
303            .zip(b_batch.par_iter())
304            .map(|(a, b)| tropical_matmul::<T>(a, m, k, b, n))
305            .collect()
306    }
307
308    #[cfg(not(feature = "parallel"))]
309    {
310        a_batch
311            .iter()
312            .zip(b_batch.iter())
313            .map(|(a, b)| tropical_matmul::<T>(a, m, k, b, n))
314            .collect()
315    }
316}
317
318/// Batched tropical matrix multiplication with argmax tracking.
319///
320/// C[i] = A[i] ⊗ B[i] for i = 0..batch_size, with argmax indices.
321///
322/// # Arguments
323/// - `a_batch`: Slice of batch_size matrices, each of size m×k
324/// - `b_batch`: Slice of batch_size matrices, each of size k×n
325/// - `m`: Number of rows in each A matrix
326/// - `k`: Number of columns in A / rows in B
327/// - `n`: Number of columns in each B matrix
328///
329/// # Returns
330/// Vector of batch_size GemmWithArgmax results
331pub fn tropical_matmul_batched_with_argmax<T: TropicalWithArgmax<Index = u32> + KernelDispatch>(
332    a_batch: &[Vec<T::Scalar>],
333    b_batch: &[Vec<T::Scalar>],
334    m: usize,
335    k: usize,
336    n: usize,
337) -> Vec<GemmWithArgmax<T>>
338where
339    T::Scalar: Send + Sync,
340    T: Send + Sync,
341{
342    assert_eq!(
343        a_batch.len(),
344        b_batch.len(),
345        "Batch sizes must match: A has {} matrices, B has {}",
346        a_batch.len(),
347        b_batch.len()
348    );
349
350    let batch_size = a_batch.len();
351    if batch_size == 0 {
352        return Vec::new();
353    }
354
355    // Validate dimensions
356    for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
357        assert_eq!(
358            a.len(),
359            m * k,
360            "A[{}] dimensions mismatch: expected {}, got {}",
361            i,
362            m * k,
363            a.len()
364        );
365        assert_eq!(
366            b.len(),
367            k * n,
368            "B[{}] dimensions mismatch: expected {}, got {}",
369            i,
370            k * n,
371            b.len()
372        );
373    }
374
375    #[cfg(feature = "parallel")]
376    {
377        a_batch
378            .par_iter()
379            .zip(b_batch.par_iter())
380            .map(|(a, b)| tropical_matmul_with_argmax::<T>(a, m, k, b, n))
381            .collect()
382    }
383
384    #[cfg(not(feature = "parallel"))]
385    {
386        a_batch
387            .iter()
388            .zip(b_batch.iter())
389            .map(|(a, b)| tropical_matmul_with_argmax::<T>(a, m, k, b, n))
390            .collect()
391    }
392}
393
394/// Strided batched GEMM: computes C[i] = A[i] ⊗ B[i] from contiguous memory.
395///
396/// This is more efficient than `tropical_matmul_batched` when all matrices
397/// are stored contiguously in memory with fixed strides.
398///
399/// # Arguments
400/// - `a`: Contiguous array of all A matrices (batch_size × m × k elements)
401/// - `b`: Contiguous array of all B matrices (batch_size × k × n elements)
402/// - `batch_size`: Number of matrix pairs
403/// - `m`: Rows in each A
404/// - `k`: Columns in A / rows in B
405/// - `n`: Columns in each B
406///
407/// # Returns
408/// Contiguous array of all C matrices (batch_size × m × n elements)
409///
410/// # Example
411///
412/// ```
413/// use tropical_gemm::{tropical_matmul_strided_batched, TropicalMaxPlus};
414///
415/// // Two 2x2 matrix pairs stored contiguously
416/// let a = vec![
417///     1.0f32, 2.0, 3.0, 4.0,  // A[0]
418///     5.0, 6.0, 7.0, 8.0,      // A[1]
419/// ];
420/// let b = vec![
421///     1.0f32, 2.0, 3.0, 4.0,  // B[0]
422///     1.0, 2.0, 3.0, 4.0,      // B[1]
423/// ];
424///
425/// let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f32>>(&a, &b, 2, 2, 2, 2);
426/// assert_eq!(c.len(), 8); // 2 batches × 2×2 results
427/// ```
428pub fn tropical_matmul_strided_batched<T: TropicalSemiring + KernelDispatch>(
429    a: &[T::Scalar],
430    b: &[T::Scalar],
431    batch_size: usize,
432    m: usize,
433    k: usize,
434    n: usize,
435) -> Vec<T>
436where
437    T::Scalar: Send + Sync + Copy,
438    T: Send + Sync,
439{
440    let a_stride = m * k;
441    let b_stride = k * n;
442    let c_stride = m * n;
443
444    assert_eq!(
445        a.len(),
446        batch_size * a_stride,
447        "A size mismatch: expected {}, got {}",
448        batch_size * a_stride,
449        a.len()
450    );
451    assert_eq!(
452        b.len(),
453        batch_size * b_stride,
454        "B size mismatch: expected {}, got {}",
455        batch_size * b_stride,
456        b.len()
457    );
458
459    if batch_size == 0 {
460        return Vec::new();
461    }
462
463    let mut c = vec![T::tropical_zero(); batch_size * c_stride];
464
465    #[cfg(feature = "parallel")]
466    {
467        c.par_chunks_mut(c_stride)
468            .enumerate()
469            .for_each(|(i, c_chunk)| {
470                let a_slice = &a[i * a_stride..(i + 1) * a_stride];
471                let b_slice = &b[i * b_stride..(i + 1) * b_stride];
472
473                unsafe {
474                    tropical_gemm_dispatch::<T>(
475                        m,
476                        n,
477                        k,
478                        a_slice.as_ptr(),
479                        k,
480                        Transpose::NoTrans,
481                        b_slice.as_ptr(),
482                        n,
483                        Transpose::NoTrans,
484                        c_chunk.as_mut_ptr(),
485                        n,
486                    );
487                }
488            });
489    }
490
491    #[cfg(not(feature = "parallel"))]
492    {
493        for i in 0..batch_size {
494            let a_slice = &a[i * a_stride..(i + 1) * a_stride];
495            let b_slice = &b[i * b_stride..(i + 1) * b_stride];
496            let c_slice = &mut c[i * c_stride..(i + 1) * c_stride];
497
498            unsafe {
499                tropical_gemm_dispatch::<T>(
500                    m,
501                    n,
502                    k,
503                    a_slice.as_ptr(),
504                    k,
505                    Transpose::NoTrans,
506                    b_slice.as_ptr(),
507                    n,
508                    Transpose::NoTrans,
509                    c_slice.as_mut_ptr(),
510                    n,
511                );
512            }
513        }
514    }
515
516    c
517}
518
519// ============================================================================
520// Backward Pass (Gradient Computation)
521// ============================================================================
522
523/// Compute gradient with respect to matrix A in tropical matmul.
524///
525/// Given the forward pass C = A ⊗ B with argmax tracking, and upstream
526/// gradient dL/dC, computes dL/dA.
527///
528/// For tropical matmul, the gradient routing is:
529/// ```text
530/// dL/dA[i,k] = Σ_j { dL/dC[i,j] if argmax[i,j] == k, else 0 }
531/// ```
532///
533/// # Arguments
534///
535/// * `grad_c` - Upstream gradient dL/dC, size m×n
536/// * `argmax` - Argmax indices from forward pass, size m×n
537/// * `m` - Number of rows in A
538/// * `k` - Number of columns in A
539/// * `n` - Number of columns in C (used for argmax indexing)
540///
541/// # Returns
542///
543/// Gradient dL/dA of size m×k
544///
545/// # Example
546///
547/// ```
548/// use tropical_gemm::{tropical_matmul_with_argmax, tropical_backward_a, TropicalMaxPlus};
549///
550/// let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
551/// let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
552///
553/// // Forward pass
554/// let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
555///
556/// // Upstream gradient (e.g., all ones)
557/// let grad_c = [1.0f64; 4]; // 2x2
558///
559/// // Backward pass for A
560/// let grad_a = tropical_backward_a::<f64>(&grad_c, result.argmax_slice(), 2, 3, 2);
561/// assert_eq!(grad_a.len(), 6); // 2x3
562/// ```
563pub fn tropical_backward_a<T: Copy + Default + std::ops::AddAssign>(
564    grad_c: &[T],
565    argmax: &[u32],
566    m: usize,
567    k: usize,
568    n: usize,
569) -> Vec<T> {
570    assert_eq!(grad_c.len(), m * n, "grad_c size mismatch");
571    assert_eq!(argmax.len(), m * n, "argmax size mismatch");
572
573    let mut grad_a = vec![T::default(); m * k];
574
575    for i in 0..m {
576        for j in 0..n {
577            let idx = argmax[i * n + j] as usize;
578            if idx < k {
579                grad_a[i * k + idx] += grad_c[i * n + j];
580            }
581        }
582    }
583
584    grad_a
585}
586
587/// Compute gradient with respect to matrix B in tropical matmul.
588///
589/// Given the forward pass C = A ⊗ B with argmax tracking, and upstream
590/// gradient dL/dC, computes dL/dB.
591///
592/// For tropical matmul, the gradient routing is:
593/// ```text
594/// dL/dB[k,j] = Σ_i { dL/dC[i,j] if argmax[i,j] == k, else 0 }
595/// ```
596///
597/// # Arguments
598///
599/// * `grad_c` - Upstream gradient dL/dC, size m×n
600/// * `argmax` - Argmax indices from forward pass, size m×n
601/// * `m` - Number of rows in C (used for iteration)
602/// * `k` - Number of rows in B
603/// * `n` - Number of columns in B
604///
605/// # Returns
606///
607/// Gradient dL/dB of size k×n
608///
609/// # Example
610///
611/// ```
612/// use tropical_gemm::{tropical_matmul_with_argmax, tropical_backward_b, TropicalMaxPlus};
613///
614/// let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
615/// let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
616///
617/// // Forward pass
618/// let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
619///
620/// // Upstream gradient
621/// let grad_c = [1.0f64; 4]; // 2x2
622///
623/// // Backward pass for B
624/// let grad_b = tropical_backward_b::<f64>(&grad_c, result.argmax_slice(), 2, 3, 2);
625/// assert_eq!(grad_b.len(), 6); // 3x2
626/// ```
627pub fn tropical_backward_b<T: Copy + Default + std::ops::AddAssign>(
628    grad_c: &[T],
629    argmax: &[u32],
630    m: usize,
631    k: usize,
632    n: usize,
633) -> Vec<T> {
634    assert_eq!(grad_c.len(), m * n, "grad_c size mismatch");
635    assert_eq!(argmax.len(), m * n, "argmax size mismatch");
636
637    let mut grad_b = vec![T::default(); k * n];
638
639    for i in 0..m {
640        for j in 0..n {
641            let idx = argmax[i * n + j] as usize;
642            if idx < k {
643                grad_b[idx * n + j] += grad_c[i * n + j];
644            }
645        }
646    }
647
648    grad_b
649}
650
651/// Batched backward pass for gradient with respect to A.
652///
653/// Computes dL/dA[i] for each batch element.
654///
655/// # Arguments
656///
657/// * `grad_c_batch` - Batch of upstream gradients, each size m×n
658/// * `argmax_batch` - Batch of argmax indices from forward pass
659/// * `m` - Number of rows in A
660/// * `k` - Number of columns in A
661/// * `n` - Number of columns in C
662///
663/// # Returns
664///
665/// Vector of gradients dL/dA[i], each of size m×k
666pub fn tropical_backward_a_batched<T: Copy + Default + std::ops::AddAssign + Send + Sync>(
667    grad_c_batch: &[Vec<T>],
668    argmax_batch: &[Vec<u32>],
669    m: usize,
670    k: usize,
671    n: usize,
672) -> Vec<Vec<T>> {
673    assert_eq!(
674        grad_c_batch.len(),
675        argmax_batch.len(),
676        "Batch sizes must match"
677    );
678
679    #[cfg(feature = "parallel")]
680    {
681        grad_c_batch
682            .par_iter()
683            .zip(argmax_batch.par_iter())
684            .map(|(grad_c, argmax)| tropical_backward_a(grad_c, argmax, m, k, n))
685            .collect()
686    }
687
688    #[cfg(not(feature = "parallel"))]
689    {
690        grad_c_batch
691            .iter()
692            .zip(argmax_batch.iter())
693            .map(|(grad_c, argmax)| tropical_backward_a(grad_c, argmax, m, k, n))
694            .collect()
695    }
696}
697
698/// Batched backward pass for gradient with respect to B.
699///
700/// Computes dL/dB[i] for each batch element.
701///
702/// # Arguments
703///
704/// * `grad_c_batch` - Batch of upstream gradients, each size m×n
705/// * `argmax_batch` - Batch of argmax indices from forward pass
706/// * `m` - Number of rows in C
707/// * `k` - Number of rows in B
708/// * `n` - Number of columns in B
709///
710/// # Returns
711///
712/// Vector of gradients dL/dB[i], each of size k×n
713pub fn tropical_backward_b_batched<T: Copy + Default + std::ops::AddAssign + Send + Sync>(
714    grad_c_batch: &[Vec<T>],
715    argmax_batch: &[Vec<u32>],
716    m: usize,
717    k: usize,
718    n: usize,
719) -> Vec<Vec<T>> {
720    assert_eq!(
721        grad_c_batch.len(),
722        argmax_batch.len(),
723        "Batch sizes must match"
724    );
725
726    #[cfg(feature = "parallel")]
727    {
728        grad_c_batch
729            .par_iter()
730            .zip(argmax_batch.par_iter())
731            .map(|(grad_c, argmax)| tropical_backward_b(grad_c, argmax, m, k, n))
732            .collect()
733    }
734
735    #[cfg(not(feature = "parallel"))]
736    {
737        grad_c_batch
738            .iter()
739            .zip(argmax_batch.iter())
740            .map(|(grad_c, argmax)| tropical_backward_b(grad_c, argmax, m, k, n))
741            .collect()
742    }
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use crate::types::TropicalMaxPlus;
749
750    #[test]
751    fn test_tropical_matmul() {
752        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
753        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
754
755        let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
756
757        // C[0,0] = max(1+1, 2+3, 3+5) = 8
758        assert_eq!(c[0].0, 8.0);
759        // C[0,1] = max(1+2, 2+4, 3+6) = 9
760        assert_eq!(c[1].0, 9.0);
761        // C[1,0] = max(4+1, 5+3, 6+5) = 11
762        assert_eq!(c[2].0, 11.0);
763        // C[1,1] = max(4+2, 5+4, 6+6) = 12
764        assert_eq!(c[3].0, 12.0);
765    }
766
767    #[test]
768    fn test_tropical_matmul_with_argmax() {
769        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
770        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
771
772        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
773
774        assert_eq!(result.get(0, 0).0, 8.0);
775        assert_eq!(result.get_argmax(0, 0), 2); // k=2 produced max
776
777        assert_eq!(result.get(1, 1).0, 12.0);
778        assert_eq!(result.get_argmax(1, 1), 2); // k=2 produced max
779    }
780
781    #[test]
782    fn test_builder_api() {
783        let a = vec![1.0f32; 6];
784        let b = vec![1.0f32; 6];
785        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
786
787        TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3).execute(&a, 3, &b, 2, &mut c, 2);
788
789        // C[0,0] = max(1+1, 1+1, 1+1) = 2 (tropical mul is addition, tropical add is max)
790        assert_eq!(c[0].0, 2.0);
791    }
792
793    #[test]
794    fn test_builder_api_trans_a() {
795        // A is 3x2 stored as column-major (actually 2x3 in row-major transposed)
796        // A^T is 2x3, B is 3x2, result is 2x2
797        let a = vec![1.0f32, 4.0, 2.0, 5.0, 3.0, 6.0]; // col-major 3x2
798        let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // row-major 3x2
799        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
800
801        TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3)
802            .trans_a()
803            .execute(&a, 2, &b, 2, &mut c, 2);
804
805        // A^T = [[1, 2, 3], [4, 5, 6]]
806        // B = [[1, 2], [3, 4], [5, 6]]
807        // C[0,0] = max(1+1, 2+3, 3+5) = 8
808        assert_eq!(c[0].0, 8.0);
809    }
810
811    #[test]
812    fn test_builder_api_trans_b() {
813        // A is 2x3, B^T is 2x3 stored as column-major, result is 2x2
814        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // row-major 2x3
815        let b = vec![1.0f32, 3.0, 5.0, 2.0, 4.0, 6.0]; // col-major 2x3
816        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
817
818        TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3)
819            .trans_b()
820            .execute(&a, 3, &b, 3, &mut c, 2);
821
822        // A = [[1, 2, 3], [4, 5, 6]]
823        // B^T = [[1, 2], [3, 4], [5, 6]]
824        // C[0,0] = max(1+1, 2+3, 3+5) = 8
825        assert_eq!(c[0].0, 8.0);
826    }
827
828    #[test]
829    fn test_tropical_matmul_min_plus() {
830        use crate::types::TropicalMinPlus;
831
832        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
833        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
834
835        let c = tropical_matmul::<TropicalMinPlus<f64>>(&a, 2, 3, &b, 2);
836
837        // C[0,0] = min(1+1, 2+3, 3+5) = 2
838        assert_eq!(c[0].0, 2.0);
839        // C[0,1] = min(1+2, 2+4, 3+6) = 3
840        assert_eq!(c[1].0, 3.0);
841        // C[1,0] = min(4+1, 5+3, 6+5) = 5
842        assert_eq!(c[2].0, 5.0);
843        // C[1,1] = min(4+2, 5+4, 6+6) = 6
844        assert_eq!(c[3].0, 6.0);
845    }
846
847    #[test]
848    fn test_tropical_matmul_max_mul() {
849        use crate::types::TropicalMaxMul;
850
851        let a = vec![2.0f64, 3.0, 4.0, 5.0];
852        let b = vec![1.0f64, 2.0, 3.0, 4.0];
853
854        let c = tropical_matmul::<TropicalMaxMul<f64>>(&a, 2, 2, &b, 2);
855
856        // C[0,0] = max(2*1, 3*3) = max(2, 9) = 9
857        assert_eq!(c[0].0, 9.0);
858        // C[0,1] = max(2*2, 3*4) = max(4, 12) = 12
859        assert_eq!(c[1].0, 12.0);
860        // C[1,0] = max(4*1, 5*3) = max(4, 15) = 15
861        assert_eq!(c[2].0, 15.0);
862        // C[1,1] = max(4*2, 5*4) = max(8, 20) = 20
863        assert_eq!(c[3].0, 20.0);
864    }
865
866    #[test]
867    fn test_tropical_matmul_f32() {
868        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
869        let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
870
871        let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, 2, 3, &b, 2);
872
873        assert!((c[0].0 - 8.0).abs() < 1e-6);
874        assert!((c[1].0 - 9.0).abs() < 1e-6);
875        assert!((c[2].0 - 11.0).abs() < 1e-6);
876        assert!((c[3].0 - 12.0).abs() < 1e-6);
877    }
878
879    #[test]
880    fn test_non_square_matrices() {
881        // 3x2 * 2x4 = 3x4
882        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
883        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
884
885        let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 3, 2, &b, 4);
886
887        assert_eq!(c.len(), 12);
888        // C[0,0] = max(1+1, 2+5) = 7
889        assert_eq!(c[0].0, 7.0);
890    }
891
892    #[test]
893    fn test_single_element() {
894        let a = vec![5.0f64];
895        let b = vec![3.0f64];
896
897        let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 1, 1, &b, 1);
898
899        assert_eq!(c.len(), 1);
900        assert_eq!(c[0].0, 8.0); // 5 + 3 = 8
901    }
902
903    #[test]
904    fn test_larger_matrix() {
905        let n = 16;
906        let a: Vec<f64> = (0..n * n).map(|i| i as f64).collect();
907        let b: Vec<f64> = (0..n * n).map(|i| (n * n - 1 - i) as f64).collect();
908
909        let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, n, n, &b, n);
910
911        assert_eq!(c.len(), n * n);
912        // Just verify it doesn't panic and produces reasonable results
913        for val in &c {
914            assert!(val.0.is_finite());
915        }
916    }
917
918    #[test]
919    fn test_tropical_matmul_i32() {
920        let a = vec![1i32, 2, 3, 4, 5, 6];
921        let b = vec![1i32, 2, 3, 4, 5, 6];
922
923        let c = tropical_matmul::<TropicalMaxPlus<i32>>(&a, 2, 3, &b, 2);
924
925        assert_eq!(c[0].0, 8);
926        assert_eq!(c[1].0, 9);
927        assert_eq!(c[2].0, 11);
928        assert_eq!(c[3].0, 12);
929    }
930
931    #[test]
932    fn test_tropical_matmul_i64() {
933        let a = vec![1i64, 2, 3, 4, 5, 6];
934        let b = vec![1i64, 2, 3, 4, 5, 6];
935
936        let c = tropical_matmul::<TropicalMaxPlus<i64>>(&a, 2, 3, &b, 2);
937
938        assert_eq!(c[0].0, 8);
939        assert_eq!(c[1].0, 9);
940        assert_eq!(c[2].0, 11);
941        assert_eq!(c[3].0, 12);
942    }
943
944    #[test]
945    fn test_tropical_matmul_minplus_i32() {
946        use crate::types::TropicalMinPlus;
947
948        let a = vec![1i32, 2, 3, 4, 5, 6];
949        let b = vec![1i32, 2, 3, 4, 5, 6];
950
951        let c = tropical_matmul::<TropicalMinPlus<i32>>(&a, 2, 3, &b, 2);
952
953        assert_eq!(c[0].0, 2);
954        assert_eq!(c[1].0, 3);
955        assert_eq!(c[2].0, 5);
956        assert_eq!(c[3].0, 6);
957    }
958
959    #[test]
960    fn test_unsafe_tropical_gemm() {
961        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
962        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
963        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
964
965        unsafe {
966            tropical_gemm::<TropicalMaxPlus<f64>>(
967                2,
968                2,
969                3,
970                a.as_ptr(),
971                3,
972                Transpose::NoTrans,
973                b.as_ptr(),
974                2,
975                Transpose::NoTrans,
976                c.as_mut_ptr(),
977                2,
978            );
979        }
980
981        assert_eq!(c[0].0, 8.0);
982        assert_eq!(c[1].0, 9.0);
983        assert_eq!(c[2].0, 11.0);
984        assert_eq!(c[3].0, 12.0);
985    }
986
987    #[test]
988    fn test_minplus_with_argmax() {
989        use crate::types::TropicalMinPlus;
990
991        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
992        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
993
994        let result = tropical_matmul_with_argmax::<TropicalMinPlus<f64>>(&a, 2, 3, &b, 2);
995
996        // C[0,0] = min(1+1, 2+3, 3+5) = 2 at k=0
997        assert_eq!(result.get(0, 0).0, 2.0);
998        assert_eq!(result.get_argmax(0, 0), 0);
999
1000        // C[1,1] = min(4+2, 5+4, 6+6) = 6 at k=0
1001        assert_eq!(result.get(1, 1).0, 6.0);
1002        assert_eq!(result.get_argmax(1, 1), 0);
1003    }
1004
1005    #[test]
1006    fn test_maxmul_with_argmax() {
1007        use crate::types::TropicalMaxMul;
1008
1009        let a = vec![2.0f64, 3.0, 4.0, 5.0];
1010        let b = vec![1.0f64, 2.0, 3.0, 4.0];
1011
1012        let result = tropical_matmul_with_argmax::<TropicalMaxMul<f64>>(&a, 2, 2, &b, 2);
1013
1014        // C[0,0] = max(2*1, 3*3) = 9 at k=1
1015        assert_eq!(result.get(0, 0).0, 9.0);
1016        assert_eq!(result.get_argmax(0, 0), 1);
1017    }
1018
1019    #[test]
1020    fn test_gemmwithargmax_dimensions() {
1021        let a = vec![1.0f64; 12]; // 3x4
1022        let b = vec![1.0f64; 20]; // 4x5
1023
1024        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 3, 4, &b, 5);
1025
1026        assert_eq!(result.m, 3);
1027        assert_eq!(result.n, 5);
1028        assert_eq!(result.values.len(), 15);
1029        assert_eq!(result.argmax.len(), 15);
1030    }
1031
1032    #[test]
1033    fn test_identity_like_matrix() {
1034        // Matrix with -inf everywhere except diagonal has 0
1035        let a = vec![0.0f64, f64::NEG_INFINITY, f64::NEG_INFINITY, 0.0];
1036        let b = vec![1.0f64, 2.0, 3.0, 4.0];
1037
1038        let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 2, 2, &b, 2);
1039
1040        // With "identity" A, C should equal B
1041        assert_eq!(c[0].0, 1.0);
1042        assert_eq!(c[1].0, 2.0);
1043        assert_eq!(c[2].0, 3.0);
1044        assert_eq!(c[3].0, 4.0);
1045    }
1046
1047    #[test]
1048    fn test_tropical_matmul_batched() {
1049        let a_batch = vec![
1050            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 2x3
1051            vec![2.0f64, 3.0, 4.0, 5.0, 6.0, 7.0], // 2x3
1052        ];
1053        let b_batch = vec![
1054            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 3x2
1055            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 3x2
1056        ];
1057
1058        let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, 2, 3, 2);
1059
1060        assert_eq!(c_batch.len(), 2);
1061
1062        // C[0][0,0] = max(1+1, 2+3, 3+5) = 8
1063        assert_eq!(c_batch[0][0].0, 8.0);
1064        // C[0][1,1] = max(4+2, 5+4, 6+6) = 12
1065        assert_eq!(c_batch[0][3].0, 12.0);
1066
1067        // C[1][0,0] = max(2+1, 3+3, 4+5) = 9
1068        assert_eq!(c_batch[1][0].0, 9.0);
1069        // C[1][1,1] = max(5+2, 6+4, 7+6) = 13
1070        assert_eq!(c_batch[1][3].0, 13.0);
1071    }
1072
1073    #[test]
1074    fn test_tropical_matmul_batched_empty() {
1075        let a_batch: Vec<Vec<f64>> = vec![];
1076        let b_batch: Vec<Vec<f64>> = vec![];
1077
1078        let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, 2, 2, 2);
1079
1080        assert!(c_batch.is_empty());
1081    }
1082
1083    #[test]
1084    fn test_tropical_matmul_batched_with_argmax() {
1085        let a_batch = vec![
1086            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 2x3
1087            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 2x3
1088        ];
1089        let b_batch = vec![
1090            vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0],  // 3x2
1091            vec![10.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], // 3x2 (different first element)
1092        ];
1093
1094        let results = tropical_matmul_batched_with_argmax::<TropicalMaxPlus<f64>>(
1095            &a_batch, &b_batch, 2, 3, 2,
1096        );
1097
1098        assert_eq!(results.len(), 2);
1099
1100        // First batch: C[0,0] = max(1+1, 2+3, 3+5) = 8 at k=2
1101        assert_eq!(results[0].get(0, 0).0, 8.0);
1102        assert_eq!(results[0].get_argmax(0, 0), 2);
1103
1104        // Second batch: C[0,0] = max(1+10, 2+3, 3+5) = 11 at k=0
1105        assert_eq!(results[1].get(0, 0).0, 11.0);
1106        assert_eq!(results[1].get_argmax(0, 0), 0);
1107    }
1108
1109    #[test]
1110    fn test_tropical_matmul_batched_with_argmax_empty() {
1111        let a_batch: Vec<Vec<f64>> = vec![];
1112        let b_batch: Vec<Vec<f64>> = vec![];
1113
1114        let results = tropical_matmul_batched_with_argmax::<TropicalMaxPlus<f64>>(
1115            &a_batch, &b_batch, 2, 2, 2,
1116        );
1117
1118        assert!(results.is_empty());
1119    }
1120
1121    #[test]
1122    fn test_tropical_matmul_strided_batched() {
1123        // Two 2x2 matrices stored contiguously
1124        let a = vec![
1125            1.0f64, 2.0, 3.0, 4.0, // A[0]
1126            5.0, 6.0, 7.0, 8.0, // A[1]
1127        ];
1128        let b = vec![
1129            1.0f64, 2.0, 3.0, 4.0, // B[0]
1130            1.0, 2.0, 3.0, 4.0, // B[1]
1131        ];
1132
1133        let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f64>>(&a, &b, 2, 2, 2, 2);
1134
1135        assert_eq!(c.len(), 8);
1136
1137        // C[0][0,0] = max(1+1, 2+3) = 5
1138        assert_eq!(c[0].0, 5.0);
1139        // C[0][1,1] = max(3+2, 4+4) = 8
1140        assert_eq!(c[3].0, 8.0);
1141
1142        // C[1][0,0] = max(5+1, 6+3) = 9
1143        assert_eq!(c[4].0, 9.0);
1144        // C[1][1,1] = max(7+2, 8+4) = 12
1145        assert_eq!(c[7].0, 12.0);
1146    }
1147
1148    #[test]
1149    fn test_tropical_matmul_strided_batched_empty() {
1150        let a: Vec<f64> = vec![];
1151        let b: Vec<f64> = vec![];
1152
1153        let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f64>>(&a, &b, 0, 2, 2, 2);
1154
1155        assert!(c.is_empty());
1156    }
1157
1158    #[test]
1159    fn test_tropical_matmul_strided_batched_minplus() {
1160        use crate::types::TropicalMinPlus;
1161
1162        let a = vec![
1163            1.0f64, 2.0, 3.0, 4.0, // A[0]
1164            5.0, 6.0, 7.0, 8.0, // A[1]
1165        ];
1166        let b = vec![
1167            1.0f64, 2.0, 3.0, 4.0, // B[0]
1168            1.0, 2.0, 3.0, 4.0, // B[1]
1169        ];
1170
1171        let c = tropical_matmul_strided_batched::<TropicalMinPlus<f64>>(&a, &b, 2, 2, 2, 2);
1172
1173        assert_eq!(c.len(), 8);
1174
1175        // C[0][0,0] = min(1+1, 2+3) = 2
1176        assert_eq!(c[0].0, 2.0);
1177        // C[0][1,1] = min(3+2, 4+4) = 5
1178        assert_eq!(c[3].0, 5.0);
1179    }
1180
1181    #[test]
1182    fn test_tropical_matmul_batched_larger() {
1183        let batch_size = 10;
1184        let m = 8;
1185        let k = 6;
1186        let n = 4;
1187
1188        let a_batch: Vec<Vec<f64>> = (0..batch_size)
1189            .map(|i| (0..m * k).map(|j| (i * m * k + j) as f64).collect())
1190            .collect();
1191        let b_batch: Vec<Vec<f64>> = (0..batch_size)
1192            .map(|_| (0..k * n).map(|j| j as f64).collect())
1193            .collect();
1194
1195        let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, m, k, n);
1196
1197        assert_eq!(c_batch.len(), batch_size);
1198        for c in &c_batch {
1199            assert_eq!(c.len(), m * n);
1200            // Just verify all values are finite
1201            for val in c {
1202                assert!(val.0.is_finite());
1203            }
1204        }
1205    }
1206
1207    // ========================================================================
1208    // Backward pass tests
1209    // ========================================================================
1210
1211    #[test]
1212    fn test_tropical_backward_a() {
1213        // A is 2x3, B is 3x2, C is 2x2
1214        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1215        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1216
1217        // Forward pass
1218        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1219
1220        // For this example:
1221        // C[0,0] = max(1+1, 2+3, 3+5) = 8, argmax=2
1222        // C[0,1] = max(1+2, 2+4, 3+6) = 9, argmax=2
1223        // C[1,0] = max(4+1, 5+3, 6+5) = 11, argmax=2
1224        // C[1,1] = max(4+2, 5+4, 6+6) = 12, argmax=2
1225        assert_eq!(result.get_argmax(0, 0), 2);
1226        assert_eq!(result.get_argmax(0, 1), 2);
1227        assert_eq!(result.get_argmax(1, 0), 2);
1228        assert_eq!(result.get_argmax(1, 1), 2);
1229
1230        // Upstream gradient (all ones)
1231        let grad_c = vec![1.0f64; 4];
1232
1233        // Backward for A
1234        let grad_a = tropical_backward_a(&grad_c, result.argmax_slice(), 2, 3, 2);
1235
1236        // Since all argmax = 2, gradients should flow to A[i,2]:
1237        // grad_a[0,0] = 0, grad_a[0,1] = 0, grad_a[0,2] = 2 (from C[0,0] and C[0,1])
1238        // grad_a[1,0] = 0, grad_a[1,1] = 0, grad_a[1,2] = 2 (from C[1,0] and C[1,1])
1239        assert_eq!(grad_a[0], 0.0); // A[0,0]
1240        assert_eq!(grad_a[1], 0.0); // A[0,1]
1241        assert_eq!(grad_a[2], 2.0); // A[0,2]
1242        assert_eq!(grad_a[3], 0.0); // A[1,0]
1243        assert_eq!(grad_a[4], 0.0); // A[1,1]
1244        assert_eq!(grad_a[5], 2.0); // A[1,2]
1245    }
1246
1247    #[test]
1248    fn test_tropical_backward_b() {
1249        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1250        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1251
1252        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1253
1254        let grad_c = vec![1.0f64; 4];
1255
1256        // Backward for B
1257        let grad_b = tropical_backward_b(&grad_c, result.argmax_slice(), 2, 3, 2);
1258
1259        // Since all argmax = 2, gradients flow to B[2,j]:
1260        // grad_b[0,0] = 0, grad_b[0,1] = 0
1261        // grad_b[1,0] = 0, grad_b[1,1] = 0
1262        // grad_b[2,0] = 2 (from C[0,0] and C[1,0]), grad_b[2,1] = 2 (from C[0,1] and C[1,1])
1263        assert_eq!(grad_b[0], 0.0); // B[0,0]
1264        assert_eq!(grad_b[1], 0.0); // B[0,1]
1265        assert_eq!(grad_b[2], 0.0); // B[1,0]
1266        assert_eq!(grad_b[3], 0.0); // B[1,1]
1267        assert_eq!(grad_b[4], 2.0); // B[2,0]
1268        assert_eq!(grad_b[5], 2.0); // B[2,1]
1269    }
1270
1271    #[test]
1272    fn test_tropical_backward_varied_argmax() {
1273        // Design matrices where different k-indices win
1274        // A = [[10, 1], [1, 10]]
1275        // B = [[1, 10], [10, 1]]
1276        let a = vec![10.0f64, 1.0, 1.0, 10.0];
1277        let b = vec![1.0f64, 10.0, 10.0, 1.0];
1278
1279        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 2, &b, 2);
1280
1281        // C[0,0] = max(10+1, 1+10) = 11, argmax=0 or 1 (tie, left wins) -> 0
1282        // C[0,1] = max(10+10, 1+1) = 20, argmax=0
1283        // C[1,0] = max(1+1, 10+10) = 20, argmax=1
1284        // C[1,1] = max(1+10, 10+1) = 11, argmax=0 or 1 (tie) -> 0
1285
1286        let grad_c = vec![1.0f64; 4];
1287        let grad_a = tropical_backward_a(&grad_c, result.argmax_slice(), 2, 2, 2);
1288        let grad_b = tropical_backward_b(&grad_c, result.argmax_slice(), 2, 2, 2);
1289
1290        // Verify gradients are distributed according to argmax
1291        assert_eq!(grad_a.len(), 4);
1292        assert_eq!(grad_b.len(), 4);
1293
1294        // The total gradient should equal the number of output elements
1295        let total_grad_a: f64 = grad_a.iter().sum();
1296        let total_grad_b: f64 = grad_b.iter().sum();
1297        assert_eq!(total_grad_a, 4.0);
1298        assert_eq!(total_grad_b, 4.0);
1299    }
1300
1301    #[test]
1302    fn test_tropical_backward_batched() {
1303        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1304        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1305
1306        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1307
1308        // Create batch
1309        let grad_c_batch = vec![vec![1.0f64; 4], vec![2.0f64; 4]];
1310        let argmax_batch = vec![
1311            result.argmax_slice().to_vec(),
1312            result.argmax_slice().to_vec(),
1313        ];
1314
1315        let grad_a_batch = tropical_backward_a_batched(&grad_c_batch, &argmax_batch, 2, 3, 2);
1316        let grad_b_batch = tropical_backward_b_batched(&grad_c_batch, &argmax_batch, 2, 3, 2);
1317
1318        assert_eq!(grad_a_batch.len(), 2);
1319        assert_eq!(grad_b_batch.len(), 2);
1320
1321        // First batch has upstream grad = 1
1322        assert_eq!(grad_a_batch[0][2], 2.0);
1323        assert_eq!(grad_b_batch[0][4], 2.0);
1324
1325        // Second batch has upstream grad = 2, so gradients should be doubled
1326        assert_eq!(grad_a_batch[1][2], 4.0);
1327        assert_eq!(grad_b_batch[1][4], 4.0);
1328    }
1329}