Skip to main content

tropical_gemm/core/
gemm.rs

1use super::argmax::GemmWithArgmax;
2use super::kernel::{Microkernel, MicrokernelWithArgmax, PortableMicrokernel};
3use super::packing::{pack_a, pack_b, packed_a_size, packed_b_size, Layout, Transpose};
4use super::tiling::{BlockIterator, TilingParams};
5use crate::types::{TropicalSemiring, TropicalWithArgmax};
6
7/// Tropical GEMM: C = A ⊗ B
8///
9/// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
10///
11/// This is a portable (non-SIMD) implementation using BLIS-style blocking
12/// for cache efficiency.
13///
14/// # Parameters
15/// - `m`: Number of rows in A and C
16/// - `n`: Number of columns in B and C
17/// - `k`: Number of columns in A / rows in B
18/// - `a`: Pointer to matrix A data
19/// - `lda`: Leading dimension of A
20/// - `trans_a`: Whether A is transposed
21/// - `b`: Pointer to matrix B data
22/// - `ldb`: Leading dimension of B
23/// - `trans_b`: Whether B is transposed
24/// - `c`: Pointer to matrix C data (output)
25/// - `ldc`: Leading dimension of C
26///
27/// # Safety
28/// - All pointers must be valid for the specified dimensions
29/// - Memory regions must not overlap inappropriately
30pub unsafe fn tropical_gemm_portable<T: TropicalSemiring>(
31    m: usize,
32    n: usize,
33    k: usize,
34    a: *const T::Scalar,
35    lda: usize,
36    trans_a: Transpose,
37    b: *const T::Scalar,
38    ldb: usize,
39    trans_b: Transpose,
40    c: *mut T,
41    ldc: usize,
42) {
43    let params = TilingParams::PORTABLE;
44    let kernel = PortableMicrokernel;
45
46    tropical_gemm_inner::<T, PortableMicrokernel>(
47        m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
48    );
49}
50
51/// Tropical GEMM with custom kernel and tiling parameters.
52///
53/// # Safety
54/// Same requirements as `tropical_gemm_portable`
55pub unsafe fn tropical_gemm_inner<T: TropicalSemiring, K: Microkernel<T>>(
56    m: usize,
57    n: usize,
58    k: usize,
59    a: *const T::Scalar,
60    lda: usize,
61    trans_a: Transpose,
62    b: *const T::Scalar,
63    ldb: usize,
64    trans_b: Transpose,
65    c: *mut T,
66    ldc: usize,
67    params: &TilingParams,
68    kernel: &K,
69) {
70    if m == 0 || n == 0 || k == 0 {
71        return;
72    }
73
74    // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace.
75    // For repeated GEMM calls, consider adding a workspace-based API:
76    //   pub struct GemmWorkspace<T> { packed_a: Vec<T>, packed_b: Vec<T> }
77    //   pub fn tropical_gemm_with_workspace(..., workspace: &mut GemmWorkspace<T>)
78    let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
79    let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
80
81    // BLIS-style 5-loop blocking
82    // Loop 5: blocks of n
83    for (jc, nc) in BlockIterator::new(n, params.nc) {
84        // Loop 4: blocks of k
85        for (pc, kc) in BlockIterator::new(k, params.kc) {
86            // Pack B panel: kc × nc
87            pack_b::<T::Scalar>(
88                kc,
89                nc,
90                b_panel_ptr(b, pc, jc, ldb, trans_b),
91                ldb,
92                Layout::RowMajor,
93                trans_b,
94                packed_b.as_mut_ptr(),
95                K::NR,
96            );
97
98            // Loop 3: blocks of m
99            for (ic, mc) in BlockIterator::new(m, params.mc) {
100                // Pack A panel: mc × kc
101                pack_a::<T::Scalar>(
102                    mc,
103                    kc,
104                    a_panel_ptr(a, ic, pc, lda, trans_a),
105                    lda,
106                    Layout::RowMajor,
107                    trans_a,
108                    packed_a.as_mut_ptr(),
109                    K::MR,
110                );
111
112                // Loop 2: micro-blocks of n
113                let n_blocks = nc.div_ceil(K::NR);
114                for jr in 0..n_blocks {
115                    let j_start = jr * K::NR;
116                    let nr = (nc - j_start).min(K::NR);
117
118                    // Loop 1: micro-blocks of m
119                    let m_blocks = mc.div_ceil(K::MR);
120                    for ir in 0..m_blocks {
121                        let i_start = ir * K::MR;
122                        let mr = (mc - i_start).min(K::MR);
123
124                        // Microkernel
125                        let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
126                        let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
127                        let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
128
129                        kernel.execute(mr, nr, kc, a_ptr, b_ptr, c_ptr, ldc);
130                    }
131                }
132            }
133        }
134    }
135}
136
137/// Tropical GEMM with argmax tracking.
138///
139/// Same as `tropical_gemm_portable` but also computes argmax indices.
140///
141/// # Safety
142/// Same requirements as `tropical_gemm_portable`
143pub unsafe fn tropical_gemm_with_argmax_portable<T: TropicalWithArgmax<Index = u32>>(
144    m: usize,
145    n: usize,
146    k: usize,
147    a: *const T::Scalar,
148    lda: usize,
149    trans_a: Transpose,
150    b: *const T::Scalar,
151    ldb: usize,
152    trans_b: Transpose,
153    result: &mut GemmWithArgmax<T>,
154) {
155    let params = TilingParams::PORTABLE;
156    let kernel = PortableMicrokernel;
157
158    tropical_gemm_with_argmax_inner::<T, PortableMicrokernel>(
159        m, n, k, a, lda, trans_a, b, ldb, trans_b, result, &params, &kernel,
160    );
161}
162
163/// Tropical GEMM with argmax tracking and custom kernel.
164///
165/// # Safety
166/// Same requirements as `tropical_gemm_portable`
167pub unsafe fn tropical_gemm_with_argmax_inner<
168    T: TropicalWithArgmax<Index = u32>,
169    K: MicrokernelWithArgmax<T>,
170>(
171    m: usize,
172    n: usize,
173    k: usize,
174    a: *const T::Scalar,
175    lda: usize,
176    trans_a: Transpose,
177    b: *const T::Scalar,
178    ldb: usize,
179    trans_b: Transpose,
180    result: &mut GemmWithArgmax<T>,
181    params: &TilingParams,
182    kernel: &K,
183) {
184    if m == 0 || n == 0 || k == 0 {
185        return;
186    }
187
188    let ldc = result.ld;
189    let (c, argmax) = result.as_mut_ptrs();
190
191    // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace.
192    let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
193    let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
194
195    // BLIS-style 5-loop blocking
196    for (jc, nc) in BlockIterator::new(n, params.nc) {
197        for (pc, kc) in BlockIterator::new(k, params.kc) {
198            pack_b::<T::Scalar>(
199                kc,
200                nc,
201                b_panel_ptr(b, pc, jc, ldb, trans_b),
202                ldb,
203                Layout::RowMajor,
204                trans_b,
205                packed_b.as_mut_ptr(),
206                K::NR,
207            );
208
209            for (ic, mc) in BlockIterator::new(m, params.mc) {
210                pack_a::<T::Scalar>(
211                    mc,
212                    kc,
213                    a_panel_ptr(a, ic, pc, lda, trans_a),
214                    lda,
215                    Layout::RowMajor,
216                    trans_a,
217                    packed_a.as_mut_ptr(),
218                    K::MR,
219                );
220
221                let n_blocks = nc.div_ceil(K::NR);
222                for jr in 0..n_blocks {
223                    let j_start = jr * K::NR;
224                    let nr = (nc - j_start).min(K::NR);
225
226                    let m_blocks = mc.div_ceil(K::MR);
227                    for ir in 0..m_blocks {
228                        let i_start = ir * K::MR;
229                        let mr = (mc - i_start).min(K::MR);
230
231                        let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
232                        let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
233                        let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
234                        let argmax_ptr = argmax.add((ic + i_start) * ldc + (jc + j_start));
235
236                        kernel.execute_with_argmax(
237                            mr, nr, kc, pc, a_ptr, b_ptr, c_ptr, argmax_ptr, ldc,
238                        );
239                    }
240                }
241            }
242        }
243    }
244
245    // Canonicalize the argmax index of tropical-zero "no contribution" cells.
246    // Integer in-band sentinels drift under the guard-free `+` and let the
247    // accumulator adopt a spurious k; reset those cells to the deterministic
248    // seed (0) so the whole repo agrees on one value. Done as a single O(m*n)
249    // sweep here (kept out of the hot per-block write-back to preserve its
250    // vectorization), and it folds away entirely for float types, whose
251    // `is_no_contribution` is a const `false`.
252    for i in 0..m {
253        for j in 0..n {
254            if result.get(i, j).is_no_contribution() {
255                *result.get_argmax_mut(i, j) = 0;
256            }
257        }
258    }
259}
260
261/// Get pointer to A panel considering transpose.
262#[inline]
263unsafe fn a_panel_ptr<T>(
264    a: *const T,
265    row: usize,
266    col: usize,
267    lda: usize,
268    trans: Transpose,
269) -> *const T {
270    match trans {
271        Transpose::NoTrans => a.add(row * lda + col),
272        Transpose::Trans => a.add(col * lda + row),
273    }
274}
275
276/// Get pointer to B panel considering transpose.
277#[inline]
278unsafe fn b_panel_ptr<T>(
279    b: *const T,
280    row: usize,
281    col: usize,
282    ldb: usize,
283    trans: Transpose,
284) -> *const T {
285    match trans {
286        Transpose::NoTrans => b.add(row * ldb + col),
287        Transpose::Trans => b.add(col * ldb + row),
288    }
289}
290
291use crate::types::TropicalScalar;
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::types::TropicalMaxPlus;
297
298    #[test]
299    fn test_simple_gemm() {
300        let m = 2;
301        let n = 2;
302        let k = 3;
303
304        // A: 2x3 matrix
305        let a: [f64; 6] = [
306            1.0, 2.0, 3.0, // row 0
307            4.0, 5.0, 6.0, // row 1
308        ];
309
310        // B: 3x2 matrix
311        let b: [f64; 6] = [
312            1.0, 2.0, // row 0
313            3.0, 4.0, // row 1
314            5.0, 6.0, // row 2
315        ];
316
317        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
318
319        unsafe {
320            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
321                m,
322                n,
323                k,
324                a.as_ptr(),
325                3,
326                Transpose::NoTrans,
327                b.as_ptr(),
328                2,
329                Transpose::NoTrans,
330                c.as_mut_ptr(),
331                n,
332            );
333        }
334
335        // C[0,0] = max(1+1, 2+3, 3+5) = max(2, 5, 8) = 8
336        assert_eq!(c[0].0, 8.0);
337        // C[0,1] = max(1+2, 2+4, 3+6) = max(3, 6, 9) = 9
338        assert_eq!(c[1].0, 9.0);
339        // C[1,0] = max(4+1, 5+3, 6+5) = max(5, 8, 11) = 11
340        assert_eq!(c[2].0, 11.0);
341        // C[1,1] = max(4+2, 5+4, 6+6) = max(6, 9, 12) = 12
342        assert_eq!(c[3].0, 12.0);
343    }
344
345    #[test]
346    fn test_gemm_with_argmax() {
347        let m = 2;
348        let n = 2;
349        let k = 3;
350
351        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
352        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
353
354        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
355
356        unsafe {
357            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
358                m,
359                n,
360                k,
361                a.as_ptr(),
362                3,
363                Transpose::NoTrans,
364                b.as_ptr(),
365                2,
366                Transpose::NoTrans,
367                &mut result,
368            );
369        }
370
371        // C[0,0] = max(1+1, 2+3, 3+5) = 8 at k=2
372        assert_eq!(result.get(0, 0).0, 8.0);
373        assert_eq!(result.get_argmax(0, 0), 2);
374
375        // C[1,1] = max(4+2, 5+4, 6+6) = 12 at k=2
376        assert_eq!(result.get(1, 1).0, 12.0);
377        assert_eq!(result.get_argmax(1, 1), 2);
378    }
379
380    #[test]
381    fn test_gemm_with_argmax_all_positions() {
382        // Test that argmax correctly tracks the optimal k for all positions
383        let m = 2;
384        let n = 2;
385        let k = 3;
386
387        // Design A and B so each C[i,j] has a different optimal k
388        // A: 2x3, B: 3x2
389        // C[i,j] = max_k(A[i,k] + B[k,j])
390        let a: [f64; 6] = [
391            10.0, 1.0, 1.0, // row 0: k=0 dominates for C[0,*]
392            1.0, 1.0, 10.0, // row 1: k=2 dominates for C[1,*]
393        ];
394        let b: [f64; 6] = [
395            10.0, 1.0, // row 0: col 0 prefers k=0
396            1.0, 10.0, // row 1: col 1 prefers k=1
397            1.0, 1.0, // row 2
398        ];
399
400        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
401
402        unsafe {
403            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
404                m,
405                n,
406                k,
407                a.as_ptr(),
408                3,
409                Transpose::NoTrans,
410                b.as_ptr(),
411                2,
412                Transpose::NoTrans,
413                &mut result,
414            );
415        }
416
417        // C[0,0] = max(10+10, 1+1, 1+1) = 20 at k=0
418        assert_eq!(result.get(0, 0).0, 20.0);
419        assert_eq!(result.get_argmax(0, 0), 0);
420
421        // C[0,1] = max(10+1, 1+10, 1+1) = 11 at k=0 or k=1 (both give 11)
422        assert_eq!(result.get(0, 1).0, 11.0);
423        // k=0 gives 11, k=1 gives 11 - first wins (>=)
424        assert_eq!(result.get_argmax(0, 1), 0);
425
426        // C[1,0] = max(1+10, 1+1, 10+1) = 11 at k=0 or k=2
427        assert_eq!(result.get(1, 0).0, 11.0);
428        assert_eq!(result.get_argmax(1, 0), 0); // k=0 wins first
429
430        // C[1,1] = max(1+1, 1+10, 10+1) = 11 at k=1 or k=2
431        assert_eq!(result.get(1, 1).0, 11.0);
432        assert_eq!(result.get_argmax(1, 1), 1); // k=1 wins first with 11
433    }
434
435    #[test]
436    fn test_gemm_minplus_with_argmax() {
437        use crate::types::TropicalMinPlus;
438
439        let m = 2;
440        let n = 2;
441        let k = 3;
442
443        // For MinPlus, argmax tracks argmin
444        let a: [f64; 6] = [
445            1.0, 5.0, 3.0, // row 0
446            2.0, 4.0, 6.0, // row 1
447        ];
448        let b: [f64; 6] = [
449            1.0, 2.0, // row 0
450            3.0, 4.0, // row 1
451            5.0, 6.0, // row 2
452        ];
453
454        let mut result: GemmWithArgmax<TropicalMinPlus<f64>> = GemmWithArgmax::new(m, n);
455
456        unsafe {
457            tropical_gemm_with_argmax_portable::<TropicalMinPlus<f64>>(
458                m,
459                n,
460                k,
461                a.as_ptr(),
462                3,
463                Transpose::NoTrans,
464                b.as_ptr(),
465                2,
466                Transpose::NoTrans,
467                &mut result,
468            );
469        }
470
471        // C[0,0] = min(1+1, 5+3, 3+5) = min(2, 8, 8) = 2 at k=0
472        assert_eq!(result.get(0, 0).0, 2.0);
473        assert_eq!(result.get_argmax(0, 0), 0);
474
475        // C[0,1] = min(1+2, 5+4, 3+6) = min(3, 9, 9) = 3 at k=0
476        assert_eq!(result.get(0, 1).0, 3.0);
477        assert_eq!(result.get_argmax(0, 1), 0);
478
479        // C[1,0] = min(2+1, 4+3, 6+5) = min(3, 7, 11) = 3 at k=0
480        assert_eq!(result.get(1, 0).0, 3.0);
481        assert_eq!(result.get_argmax(1, 0), 0);
482
483        // C[1,1] = min(2+2, 4+4, 6+6) = min(4, 8, 12) = 4 at k=0
484        assert_eq!(result.get(1, 1).0, 4.0);
485        assert_eq!(result.get_argmax(1, 1), 0);
486    }
487
488    #[test]
489    fn test_gemm_larger_with_argmax() {
490        // Test with larger matrix to exercise blocking code paths
491        let m = 8;
492        let n = 8;
493        let k = 8;
494
495        let a: Vec<f64> = (0..m * k).map(|i| i as f64).collect();
496        let b: Vec<f64> = (0..k * n).map(|i| (k * n - 1 - i) as f64).collect();
497
498        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
499
500        unsafe {
501            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
502                m,
503                n,
504                k,
505                a.as_ptr(),
506                k,
507                Transpose::NoTrans,
508                b.as_ptr(),
509                n,
510                Transpose::NoTrans,
511                &mut result,
512            );
513        }
514
515        // Verify all results are finite and argmax indices are valid
516        for i in 0..m {
517            for j in 0..n {
518                assert!(result.get(i, j).0.is_finite());
519                assert!(result.get_argmax(i, j) < k as u32);
520            }
521        }
522    }
523
524    #[test]
525    fn test_gemm_trans_a() {
526        // Test with A transposed
527        // A is stored column-major (3x2), so A^T is 2x3
528        // A^T = [[1, 2, 3], [4, 5, 6]]
529        let m = 2;
530        let n = 2;
531        let k = 3;
532
533        let a: [f64; 6] = [
534            1.0, 4.0, // column 0
535            2.0, 5.0, // column 1
536            3.0, 6.0, // column 2
537        ];
538
539        let b: [f64; 6] = [
540            1.0, 2.0, // row 0
541            3.0, 4.0, // row 1
542            5.0, 6.0, // row 2
543        ];
544
545        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
546
547        unsafe {
548            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
549                m,
550                n,
551                k,
552                a.as_ptr(),
553                2,
554                Transpose::Trans, // lda=2 for column-major 3x2
555                b.as_ptr(),
556                2,
557                Transpose::NoTrans,
558                c.as_mut_ptr(),
559                n,
560            );
561        }
562
563        // A^T = [[1, 2, 3], [4, 5, 6]]
564        // B = [[1, 2], [3, 4], [5, 6]]
565        // C[0,0] = max(1+1, 2+3, 3+5) = 8
566        assert_eq!(c[0].0, 8.0);
567        // C[0,1] = max(1+2, 2+4, 3+6) = 9
568        assert_eq!(c[1].0, 9.0);
569        // C[1,0] = max(4+1, 5+3, 6+5) = 11
570        assert_eq!(c[2].0, 11.0);
571        // C[1,1] = max(4+2, 5+4, 6+6) = 12
572        assert_eq!(c[3].0, 12.0);
573    }
574
575    #[test]
576    fn test_gemm_trans_b() {
577        // Test with B transposed
578        // B is stored column-major (2x3), so B^T is 3x2
579        let m = 2;
580        let n = 2;
581        let k = 3;
582
583        let a: [f64; 6] = [
584            1.0, 2.0, 3.0, // row 0
585            4.0, 5.0, 6.0, // row 1
586        ];
587
588        // B stored column-major: columns are [1,3,5], [2,4,6]
589        let b: [f64; 6] = [
590            1.0, 3.0, 5.0, // column 0 of B^T = row of B
591            2.0, 4.0, 6.0, // column 1 of B^T
592        ];
593
594        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
595
596        unsafe {
597            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
598                m,
599                n,
600                k,
601                a.as_ptr(),
602                3,
603                Transpose::NoTrans,
604                b.as_ptr(),
605                3,
606                Transpose::Trans, // ldb=3 for column-major 2x3
607                c.as_mut_ptr(),
608                n,
609            );
610        }
611
612        // A = [[1, 2, 3], [4, 5, 6]]
613        // B^T = [[1, 2], [3, 4], [5, 6]]
614        // C[0,0] = max(1+1, 2+3, 3+5) = 8
615        assert_eq!(c[0].0, 8.0);
616        assert_eq!(c[1].0, 9.0);
617        assert_eq!(c[2].0, 11.0);
618        assert_eq!(c[3].0, 12.0);
619    }
620
621    #[test]
622    fn test_gemm_trans_both() {
623        // Test with both A and B transposed
624        let m = 2;
625        let n = 2;
626        let k = 3;
627
628        // A column-major (3x2), A^T is 2x3
629        let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
630        // B column-major (2x3), B^T is 3x2
631        let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
632
633        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
634
635        unsafe {
636            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
637                m,
638                n,
639                k,
640                a.as_ptr(),
641                2,
642                Transpose::Trans,
643                b.as_ptr(),
644                3,
645                Transpose::Trans,
646                c.as_mut_ptr(),
647                n,
648            );
649        }
650
651        assert_eq!(c[0].0, 8.0);
652        assert_eq!(c[1].0, 9.0);
653        assert_eq!(c[2].0, 11.0);
654        assert_eq!(c[3].0, 12.0);
655    }
656
657    #[test]
658    fn test_gemm_empty_m() {
659        let m = 0;
660        let n = 2;
661        let k = 3;
662
663        let a: [f64; 0] = [];
664        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
665        let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
666
667        unsafe {
668            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
669                m,
670                n,
671                k,
672                a.as_ptr(),
673                3,
674                Transpose::NoTrans,
675                b.as_ptr(),
676                2,
677                Transpose::NoTrans,
678                c.as_mut_ptr(),
679                n,
680            );
681        }
682
683        // Should complete without panic
684        assert!(c.is_empty());
685    }
686
687    #[test]
688    fn test_gemm_empty_n() {
689        let m = 2;
690        let n = 0;
691        let k = 3;
692
693        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
694        let b: [f64; 0] = [];
695        let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
696
697        unsafe {
698            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
699                m,
700                n,
701                k,
702                a.as_ptr(),
703                3,
704                Transpose::NoTrans,
705                b.as_ptr(),
706                2,
707                Transpose::NoTrans,
708                c.as_mut_ptr(),
709                n,
710            );
711        }
712
713        assert!(c.is_empty());
714    }
715
716    #[test]
717    fn test_gemm_empty_k() {
718        let m = 2;
719        let n = 2;
720        let k = 0;
721
722        let a: [f64; 0] = [];
723        let b: [f64; 0] = [];
724        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
725
726        unsafe {
727            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
728                m,
729                n,
730                k,
731                a.as_ptr(),
732                0,
733                Transpose::NoTrans,
734                b.as_ptr(),
735                2,
736                Transpose::NoTrans,
737                c.as_mut_ptr(),
738                n,
739            );
740        }
741
742        // C should remain initialized to tropical_zero
743        for val in &c {
744            assert!(val.0.is_infinite() && val.0 < 0.0);
745        }
746    }
747
748    #[test]
749    fn test_gemm_with_argmax_empty_k() {
750        let m = 2;
751        let n = 2;
752        let k = 0;
753
754        let a: [f64; 0] = [];
755        let b: [f64; 0] = [];
756        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
757
758        unsafe {
759            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
760                m,
761                n,
762                k,
763                a.as_ptr(),
764                0,
765                Transpose::NoTrans,
766                b.as_ptr(),
767                2,
768                Transpose::NoTrans,
769                &mut result,
770            );
771        }
772
773        // Should complete without panic
774        assert_eq!(result.m, 2);
775        assert_eq!(result.n, 2);
776    }
777
778    #[test]
779    fn test_gemm_with_argmax_trans_a() {
780        let m = 2;
781        let n = 2;
782        let k = 3;
783
784        let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
785        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
786
787        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
788
789        unsafe {
790            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
791                m,
792                n,
793                k,
794                a.as_ptr(),
795                2,
796                Transpose::Trans,
797                b.as_ptr(),
798                2,
799                Transpose::NoTrans,
800                &mut result,
801            );
802        }
803
804        assert_eq!(result.get(0, 0).0, 8.0);
805        assert_eq!(result.get_argmax(0, 0), 2);
806    }
807
808    #[test]
809    fn test_gemm_with_argmax_trans_b() {
810        let m = 2;
811        let n = 2;
812        let k = 3;
813
814        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
815        let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
816
817        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
818
819        unsafe {
820            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
821                m,
822                n,
823                k,
824                a.as_ptr(),
825                3,
826                Transpose::NoTrans,
827                b.as_ptr(),
828                3,
829                Transpose::Trans,
830                &mut result,
831            );
832        }
833
834        assert_eq!(result.get(0, 0).0, 8.0);
835        assert_eq!(result.get_argmax(0, 0), 2);
836    }
837
838    #[test]
839    fn test_gemm_with_argmax_int_zero_cell_canonicalized() {
840        use crate::types::TropicalScalar;
841
842        // Row 0 of A is the tropical zero (`-∞` sentinel), so every product for
843        // C[0, *] is a (drifted) tropical zero — no real contribution. Its argmax
844        // must canonicalize to the seed `0`, not drift to a data-dependent k.
845        let m = 2;
846        let n = 2;
847        let k = 3;
848        let neg = <i32 as TropicalScalar>::neg_infinity();
849        let a: [i32; 6] = [
850            neg, neg, neg, // row 0: all tropical zero
851            1, 2, 3, // row 1: finite
852        ];
853        let b: [i32; 6] = [
854            4, 5, // row 0
855            6, 7, // row 1
856            8, 9, // row 2
857        ];
858
859        let mut result: GemmWithArgmax<TropicalMaxPlus<i32>> = GemmWithArgmax::new(m, n);
860        unsafe {
861            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<i32>>(
862                m,
863                n,
864                k,
865                a.as_ptr(),
866                3,
867                Transpose::NoTrans,
868                b.as_ptr(),
869                2,
870                Transpose::NoTrans,
871                &mut result,
872            );
873        }
874
875        // Row 0: no contribution → value stays in `-∞` territory, argmax = 0.
876        for j in 0..n {
877            assert!(
878                result.get(0, j).0.is_drifted_neg_zero(),
879                "C[0,{j}] should be in tropical-zero territory"
880            );
881            assert_eq!(
882                result.get_argmax(0, j),
883                0,
884                "zero-cell argmax must canonicalize to 0, not drift"
885            );
886        }
887        // Row 1: real contributions → finite value, true argmax_k.
888        // C[1,0] = max(1+4, 2+6, 3+8) = 11 at k=2; C[1,1] = max(1+5,2+7,3+9) = 12 at k=2.
889        assert_eq!(result.get(1, 0).0, 11);
890        assert_eq!(result.get_argmax(1, 0), 2);
891        assert_eq!(result.get(1, 1).0, 12);
892        assert_eq!(result.get_argmax(1, 1), 2);
893    }
894
895    #[test]
896    fn test_gemm_with_argmax_float_zero_cell_keeps_seed() {
897        // Float `-∞` is exact (never drifts); a no-contribution cell already keeps
898        // the seed index 0. The canonicalization hook must not change this.
899        let m = 2;
900        let n = 2;
901        let k = 3;
902        let a: [f64; 6] = [
903            f64::NEG_INFINITY,
904            f64::NEG_INFINITY,
905            f64::NEG_INFINITY,
906            1.0,
907            2.0,
908            3.0,
909        ];
910        let b: [f64; 6] = [4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
911
912        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
913        unsafe {
914            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
915                m,
916                n,
917                k,
918                a.as_ptr(),
919                3,
920                Transpose::NoTrans,
921                b.as_ptr(),
922                2,
923                Transpose::NoTrans,
924                &mut result,
925            );
926        }
927
928        for j in 0..n {
929            assert_eq!(result.get(0, j).0, f64::NEG_INFINITY);
930            assert_eq!(result.get_argmax(0, j), 0);
931        }
932    }
933}