tropical_gemm/core/
gemm.rs

1use super::argmax::GemmWithArgmax;
2use super::kernel::{Microkernel, MicrokernelWithArgmax, PortableMicrokernel};
3use super::packing::{pack_a, pack_b, packed_a_size, packed_b_size, Layout, Transpose};
4use super::tiling::{BlockIterator, TilingParams};
5use crate::types::{TropicalSemiring, TropicalWithArgmax};
6
7/// Tropical GEMM: C = A ⊗ B
8///
9/// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
10///
11/// This is a portable (non-SIMD) implementation using BLIS-style blocking
12/// for cache efficiency.
13///
14/// # Parameters
15/// - `m`: Number of rows in A and C
16/// - `n`: Number of columns in B and C
17/// - `k`: Number of columns in A / rows in B
18/// - `a`: Pointer to matrix A data
19/// - `lda`: Leading dimension of A
20/// - `trans_a`: Whether A is transposed
21/// - `b`: Pointer to matrix B data
22/// - `ldb`: Leading dimension of B
23/// - `trans_b`: Whether B is transposed
24/// - `c`: Pointer to matrix C data (output)
25/// - `ldc`: Leading dimension of C
26///
27/// # Safety
28/// - All pointers must be valid for the specified dimensions
29/// - Memory regions must not overlap inappropriately
30pub unsafe fn tropical_gemm_portable<T: TropicalSemiring>(
31    m: usize,
32    n: usize,
33    k: usize,
34    a: *const T::Scalar,
35    lda: usize,
36    trans_a: Transpose,
37    b: *const T::Scalar,
38    ldb: usize,
39    trans_b: Transpose,
40    c: *mut T,
41    ldc: usize,
42) {
43    let params = TilingParams::PORTABLE;
44    let kernel = PortableMicrokernel;
45
46    tropical_gemm_inner::<T, PortableMicrokernel>(
47        m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
48    );
49}
50
51/// Tropical GEMM with custom kernel and tiling parameters.
52///
53/// # Safety
54/// Same requirements as `tropical_gemm_portable`
55pub unsafe fn tropical_gemm_inner<T: TropicalSemiring, K: Microkernel<T>>(
56    m: usize,
57    n: usize,
58    k: usize,
59    a: *const T::Scalar,
60    lda: usize,
61    trans_a: Transpose,
62    b: *const T::Scalar,
63    ldb: usize,
64    trans_b: Transpose,
65    c: *mut T,
66    ldc: usize,
67    params: &TilingParams,
68    kernel: &K,
69) {
70    if m == 0 || n == 0 || k == 0 {
71        return;
72    }
73
74    // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace.
75    // For repeated GEMM calls, consider adding a workspace-based API:
76    //   pub struct GemmWorkspace<T> { packed_a: Vec<T>, packed_b: Vec<T> }
77    //   pub fn tropical_gemm_with_workspace(..., workspace: &mut GemmWorkspace<T>)
78    let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
79    let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
80
81    // BLIS-style 5-loop blocking
82    // Loop 5: blocks of n
83    for (jc, nc) in BlockIterator::new(n, params.nc) {
84        // Loop 4: blocks of k
85        for (pc, kc) in BlockIterator::new(k, params.kc) {
86            // Pack B panel: kc × nc
87            pack_b::<T::Scalar>(
88                kc,
89                nc,
90                b_panel_ptr(b, pc, jc, ldb, trans_b),
91                ldb,
92                Layout::RowMajor,
93                trans_b,
94                packed_b.as_mut_ptr(),
95                K::NR,
96            );
97
98            // Loop 3: blocks of m
99            for (ic, mc) in BlockIterator::new(m, params.mc) {
100                // Pack A panel: mc × kc
101                pack_a::<T::Scalar>(
102                    mc,
103                    kc,
104                    a_panel_ptr(a, ic, pc, lda, trans_a),
105                    lda,
106                    Layout::RowMajor,
107                    trans_a,
108                    packed_a.as_mut_ptr(),
109                    K::MR,
110                );
111
112                // Loop 2: micro-blocks of n
113                let n_blocks = nc.div_ceil(K::NR);
114                for jr in 0..n_blocks {
115                    let j_start = jr * K::NR;
116                    let nr = (nc - j_start).min(K::NR);
117
118                    // Loop 1: micro-blocks of m
119                    let m_blocks = mc.div_ceil(K::MR);
120                    for ir in 0..m_blocks {
121                        let i_start = ir * K::MR;
122                        let mr = (mc - i_start).min(K::MR);
123
124                        // Microkernel
125                        let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
126                        let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
127                        let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
128
129                        kernel.execute(mr, nr, kc, a_ptr, b_ptr, c_ptr, ldc);
130                    }
131                }
132            }
133        }
134    }
135}
136
137/// Tropical GEMM with argmax tracking.
138///
139/// Same as `tropical_gemm_portable` but also computes argmax indices.
140///
141/// # Safety
142/// Same requirements as `tropical_gemm_portable`
143pub unsafe fn tropical_gemm_with_argmax_portable<T: TropicalWithArgmax<Index = u32>>(
144    m: usize,
145    n: usize,
146    k: usize,
147    a: *const T::Scalar,
148    lda: usize,
149    trans_a: Transpose,
150    b: *const T::Scalar,
151    ldb: usize,
152    trans_b: Transpose,
153    result: &mut GemmWithArgmax<T>,
154) {
155    let params = TilingParams::PORTABLE;
156    let kernel = PortableMicrokernel;
157
158    tropical_gemm_with_argmax_inner::<T, PortableMicrokernel>(
159        m, n, k, a, lda, trans_a, b, ldb, trans_b, result, &params, &kernel,
160    );
161}
162
163/// Tropical GEMM with argmax tracking and custom kernel.
164///
165/// # Safety
166/// Same requirements as `tropical_gemm_portable`
167pub unsafe fn tropical_gemm_with_argmax_inner<
168    T: TropicalWithArgmax<Index = u32>,
169    K: MicrokernelWithArgmax<T>,
170>(
171    m: usize,
172    n: usize,
173    k: usize,
174    a: *const T::Scalar,
175    lda: usize,
176    trans_a: Transpose,
177    b: *const T::Scalar,
178    ldb: usize,
179    trans_b: Transpose,
180    result: &mut GemmWithArgmax<T>,
181    params: &TilingParams,
182    kernel: &K,
183) {
184    if m == 0 || n == 0 || k == 0 {
185        return;
186    }
187
188    let ldc = result.ld;
189    let (c, argmax) = result.as_mut_ptrs();
190
191    // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace.
192    let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
193    let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
194
195    // BLIS-style 5-loop blocking
196    for (jc, nc) in BlockIterator::new(n, params.nc) {
197        for (pc, kc) in BlockIterator::new(k, params.kc) {
198            pack_b::<T::Scalar>(
199                kc,
200                nc,
201                b_panel_ptr(b, pc, jc, ldb, trans_b),
202                ldb,
203                Layout::RowMajor,
204                trans_b,
205                packed_b.as_mut_ptr(),
206                K::NR,
207            );
208
209            for (ic, mc) in BlockIterator::new(m, params.mc) {
210                pack_a::<T::Scalar>(
211                    mc,
212                    kc,
213                    a_panel_ptr(a, ic, pc, lda, trans_a),
214                    lda,
215                    Layout::RowMajor,
216                    trans_a,
217                    packed_a.as_mut_ptr(),
218                    K::MR,
219                );
220
221                let n_blocks = nc.div_ceil(K::NR);
222                for jr in 0..n_blocks {
223                    let j_start = jr * K::NR;
224                    let nr = (nc - j_start).min(K::NR);
225
226                    let m_blocks = mc.div_ceil(K::MR);
227                    for ir in 0..m_blocks {
228                        let i_start = ir * K::MR;
229                        let mr = (mc - i_start).min(K::MR);
230
231                        let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
232                        let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
233                        let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
234                        let argmax_ptr = argmax.add((ic + i_start) * ldc + (jc + j_start));
235
236                        kernel.execute_with_argmax(
237                            mr, nr, kc, pc, a_ptr, b_ptr, c_ptr, argmax_ptr, ldc,
238                        );
239                    }
240                }
241            }
242        }
243    }
244}
245
246/// Get pointer to A panel considering transpose.
247#[inline]
248unsafe fn a_panel_ptr<T>(
249    a: *const T,
250    row: usize,
251    col: usize,
252    lda: usize,
253    trans: Transpose,
254) -> *const T {
255    match trans {
256        Transpose::NoTrans => a.add(row * lda + col),
257        Transpose::Trans => a.add(col * lda + row),
258    }
259}
260
261/// Get pointer to B panel considering transpose.
262#[inline]
263unsafe fn b_panel_ptr<T>(
264    b: *const T,
265    row: usize,
266    col: usize,
267    ldb: usize,
268    trans: Transpose,
269) -> *const T {
270    match trans {
271        Transpose::NoTrans => b.add(row * ldb + col),
272        Transpose::Trans => b.add(col * ldb + row),
273    }
274}
275
276use crate::types::TropicalScalar;
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::types::TropicalMaxPlus;
282
283    #[test]
284    fn test_simple_gemm() {
285        let m = 2;
286        let n = 2;
287        let k = 3;
288
289        // A: 2x3 matrix
290        let a: [f64; 6] = [
291            1.0, 2.0, 3.0, // row 0
292            4.0, 5.0, 6.0, // row 1
293        ];
294
295        // B: 3x2 matrix
296        let b: [f64; 6] = [
297            1.0, 2.0, // row 0
298            3.0, 4.0, // row 1
299            5.0, 6.0, // row 2
300        ];
301
302        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
303
304        unsafe {
305            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
306                m,
307                n,
308                k,
309                a.as_ptr(),
310                3,
311                Transpose::NoTrans,
312                b.as_ptr(),
313                2,
314                Transpose::NoTrans,
315                c.as_mut_ptr(),
316                n,
317            );
318        }
319
320        // C[0,0] = max(1+1, 2+3, 3+5) = max(2, 5, 8) = 8
321        assert_eq!(c[0].0, 8.0);
322        // C[0,1] = max(1+2, 2+4, 3+6) = max(3, 6, 9) = 9
323        assert_eq!(c[1].0, 9.0);
324        // C[1,0] = max(4+1, 5+3, 6+5) = max(5, 8, 11) = 11
325        assert_eq!(c[2].0, 11.0);
326        // C[1,1] = max(4+2, 5+4, 6+6) = max(6, 9, 12) = 12
327        assert_eq!(c[3].0, 12.0);
328    }
329
330    #[test]
331    fn test_gemm_with_argmax() {
332        let m = 2;
333        let n = 2;
334        let k = 3;
335
336        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
337        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
338
339        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
340
341        unsafe {
342            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
343                m,
344                n,
345                k,
346                a.as_ptr(),
347                3,
348                Transpose::NoTrans,
349                b.as_ptr(),
350                2,
351                Transpose::NoTrans,
352                &mut result,
353            );
354        }
355
356        // C[0,0] = max(1+1, 2+3, 3+5) = 8 at k=2
357        assert_eq!(result.get(0, 0).0, 8.0);
358        assert_eq!(result.get_argmax(0, 0), 2);
359
360        // C[1,1] = max(4+2, 5+4, 6+6) = 12 at k=2
361        assert_eq!(result.get(1, 1).0, 12.0);
362        assert_eq!(result.get_argmax(1, 1), 2);
363    }
364
365    #[test]
366    fn test_gemm_with_argmax_all_positions() {
367        // Test that argmax correctly tracks the optimal k for all positions
368        let m = 2;
369        let n = 2;
370        let k = 3;
371
372        // Design A and B so each C[i,j] has a different optimal k
373        // A: 2x3, B: 3x2
374        // C[i,j] = max_k(A[i,k] + B[k,j])
375        let a: [f64; 6] = [
376            10.0, 1.0, 1.0, // row 0: k=0 dominates for C[0,*]
377            1.0, 1.0, 10.0, // row 1: k=2 dominates for C[1,*]
378        ];
379        let b: [f64; 6] = [
380            10.0, 1.0, // row 0: col 0 prefers k=0
381            1.0, 10.0, // row 1: col 1 prefers k=1
382            1.0, 1.0, // row 2
383        ];
384
385        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
386
387        unsafe {
388            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
389                m,
390                n,
391                k,
392                a.as_ptr(),
393                3,
394                Transpose::NoTrans,
395                b.as_ptr(),
396                2,
397                Transpose::NoTrans,
398                &mut result,
399            );
400        }
401
402        // C[0,0] = max(10+10, 1+1, 1+1) = 20 at k=0
403        assert_eq!(result.get(0, 0).0, 20.0);
404        assert_eq!(result.get_argmax(0, 0), 0);
405
406        // C[0,1] = max(10+1, 1+10, 1+1) = 11 at k=0 or k=1 (both give 11)
407        assert_eq!(result.get(0, 1).0, 11.0);
408        // k=0 gives 11, k=1 gives 11 - first wins (>=)
409        assert_eq!(result.get_argmax(0, 1), 0);
410
411        // C[1,0] = max(1+10, 1+1, 10+1) = 11 at k=0 or k=2
412        assert_eq!(result.get(1, 0).0, 11.0);
413        assert_eq!(result.get_argmax(1, 0), 0); // k=0 wins first
414
415        // C[1,1] = max(1+1, 1+10, 10+1) = 11 at k=1 or k=2
416        assert_eq!(result.get(1, 1).0, 11.0);
417        assert_eq!(result.get_argmax(1, 1), 1); // k=1 wins first with 11
418    }
419
420    #[test]
421    fn test_gemm_minplus_with_argmax() {
422        use crate::types::TropicalMinPlus;
423
424        let m = 2;
425        let n = 2;
426        let k = 3;
427
428        // For MinPlus, argmax tracks argmin
429        let a: [f64; 6] = [
430            1.0, 5.0, 3.0, // row 0
431            2.0, 4.0, 6.0, // row 1
432        ];
433        let b: [f64; 6] = [
434            1.0, 2.0, // row 0
435            3.0, 4.0, // row 1
436            5.0, 6.0, // row 2
437        ];
438
439        let mut result: GemmWithArgmax<TropicalMinPlus<f64>> = GemmWithArgmax::new(m, n);
440
441        unsafe {
442            tropical_gemm_with_argmax_portable::<TropicalMinPlus<f64>>(
443                m,
444                n,
445                k,
446                a.as_ptr(),
447                3,
448                Transpose::NoTrans,
449                b.as_ptr(),
450                2,
451                Transpose::NoTrans,
452                &mut result,
453            );
454        }
455
456        // C[0,0] = min(1+1, 5+3, 3+5) = min(2, 8, 8) = 2 at k=0
457        assert_eq!(result.get(0, 0).0, 2.0);
458        assert_eq!(result.get_argmax(0, 0), 0);
459
460        // C[0,1] = min(1+2, 5+4, 3+6) = min(3, 9, 9) = 3 at k=0
461        assert_eq!(result.get(0, 1).0, 3.0);
462        assert_eq!(result.get_argmax(0, 1), 0);
463
464        // C[1,0] = min(2+1, 4+3, 6+5) = min(3, 7, 11) = 3 at k=0
465        assert_eq!(result.get(1, 0).0, 3.0);
466        assert_eq!(result.get_argmax(1, 0), 0);
467
468        // C[1,1] = min(2+2, 4+4, 6+6) = min(4, 8, 12) = 4 at k=0
469        assert_eq!(result.get(1, 1).0, 4.0);
470        assert_eq!(result.get_argmax(1, 1), 0);
471    }
472
473    #[test]
474    fn test_gemm_larger_with_argmax() {
475        // Test with larger matrix to exercise blocking code paths
476        let m = 8;
477        let n = 8;
478        let k = 8;
479
480        let a: Vec<f64> = (0..m * k).map(|i| i as f64).collect();
481        let b: Vec<f64> = (0..k * n).map(|i| (k * n - 1 - i) as f64).collect();
482
483        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
484
485        unsafe {
486            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
487                m,
488                n,
489                k,
490                a.as_ptr(),
491                k,
492                Transpose::NoTrans,
493                b.as_ptr(),
494                n,
495                Transpose::NoTrans,
496                &mut result,
497            );
498        }
499
500        // Verify all results are finite and argmax indices are valid
501        for i in 0..m {
502            for j in 0..n {
503                assert!(result.get(i, j).0.is_finite());
504                assert!(result.get_argmax(i, j) < k as u32);
505            }
506        }
507    }
508
509    #[test]
510    fn test_gemm_trans_a() {
511        // Test with A transposed
512        // A is stored column-major (3x2), so A^T is 2x3
513        // A^T = [[1, 2, 3], [4, 5, 6]]
514        let m = 2;
515        let n = 2;
516        let k = 3;
517
518        let a: [f64; 6] = [
519            1.0, 4.0, // column 0
520            2.0, 5.0, // column 1
521            3.0, 6.0, // column 2
522        ];
523
524        let b: [f64; 6] = [
525            1.0, 2.0, // row 0
526            3.0, 4.0, // row 1
527            5.0, 6.0, // row 2
528        ];
529
530        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
531
532        unsafe {
533            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
534                m,
535                n,
536                k,
537                a.as_ptr(),
538                2,
539                Transpose::Trans, // lda=2 for column-major 3x2
540                b.as_ptr(),
541                2,
542                Transpose::NoTrans,
543                c.as_mut_ptr(),
544                n,
545            );
546        }
547
548        // A^T = [[1, 2, 3], [4, 5, 6]]
549        // B = [[1, 2], [3, 4], [5, 6]]
550        // C[0,0] = max(1+1, 2+3, 3+5) = 8
551        assert_eq!(c[0].0, 8.0);
552        // C[0,1] = max(1+2, 2+4, 3+6) = 9
553        assert_eq!(c[1].0, 9.0);
554        // C[1,0] = max(4+1, 5+3, 6+5) = 11
555        assert_eq!(c[2].0, 11.0);
556        // C[1,1] = max(4+2, 5+4, 6+6) = 12
557        assert_eq!(c[3].0, 12.0);
558    }
559
560    #[test]
561    fn test_gemm_trans_b() {
562        // Test with B transposed
563        // B is stored column-major (2x3), so B^T is 3x2
564        let m = 2;
565        let n = 2;
566        let k = 3;
567
568        let a: [f64; 6] = [
569            1.0, 2.0, 3.0, // row 0
570            4.0, 5.0, 6.0, // row 1
571        ];
572
573        // B stored column-major: columns are [1,3,5], [2,4,6]
574        let b: [f64; 6] = [
575            1.0, 3.0, 5.0, // column 0 of B^T = row of B
576            2.0, 4.0, 6.0, // column 1 of B^T
577        ];
578
579        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
580
581        unsafe {
582            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
583                m,
584                n,
585                k,
586                a.as_ptr(),
587                3,
588                Transpose::NoTrans,
589                b.as_ptr(),
590                3,
591                Transpose::Trans, // ldb=3 for column-major 2x3
592                c.as_mut_ptr(),
593                n,
594            );
595        }
596
597        // A = [[1, 2, 3], [4, 5, 6]]
598        // B^T = [[1, 2], [3, 4], [5, 6]]
599        // C[0,0] = max(1+1, 2+3, 3+5) = 8
600        assert_eq!(c[0].0, 8.0);
601        assert_eq!(c[1].0, 9.0);
602        assert_eq!(c[2].0, 11.0);
603        assert_eq!(c[3].0, 12.0);
604    }
605
606    #[test]
607    fn test_gemm_trans_both() {
608        // Test with both A and B transposed
609        let m = 2;
610        let n = 2;
611        let k = 3;
612
613        // A column-major (3x2), A^T is 2x3
614        let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
615        // B column-major (2x3), B^T is 3x2
616        let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
617
618        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
619
620        unsafe {
621            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
622                m,
623                n,
624                k,
625                a.as_ptr(),
626                2,
627                Transpose::Trans,
628                b.as_ptr(),
629                3,
630                Transpose::Trans,
631                c.as_mut_ptr(),
632                n,
633            );
634        }
635
636        assert_eq!(c[0].0, 8.0);
637        assert_eq!(c[1].0, 9.0);
638        assert_eq!(c[2].0, 11.0);
639        assert_eq!(c[3].0, 12.0);
640    }
641
642    #[test]
643    fn test_gemm_empty_m() {
644        let m = 0;
645        let n = 2;
646        let k = 3;
647
648        let a: [f64; 0] = [];
649        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
650        let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
651
652        unsafe {
653            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
654                m,
655                n,
656                k,
657                a.as_ptr(),
658                3,
659                Transpose::NoTrans,
660                b.as_ptr(),
661                2,
662                Transpose::NoTrans,
663                c.as_mut_ptr(),
664                n,
665            );
666        }
667
668        // Should complete without panic
669        assert!(c.is_empty());
670    }
671
672    #[test]
673    fn test_gemm_empty_n() {
674        let m = 2;
675        let n = 0;
676        let k = 3;
677
678        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
679        let b: [f64; 0] = [];
680        let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
681
682        unsafe {
683            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
684                m,
685                n,
686                k,
687                a.as_ptr(),
688                3,
689                Transpose::NoTrans,
690                b.as_ptr(),
691                2,
692                Transpose::NoTrans,
693                c.as_mut_ptr(),
694                n,
695            );
696        }
697
698        assert!(c.is_empty());
699    }
700
701    #[test]
702    fn test_gemm_empty_k() {
703        let m = 2;
704        let n = 2;
705        let k = 0;
706
707        let a: [f64; 0] = [];
708        let b: [f64; 0] = [];
709        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
710
711        unsafe {
712            tropical_gemm_portable::<TropicalMaxPlus<f64>>(
713                m,
714                n,
715                k,
716                a.as_ptr(),
717                0,
718                Transpose::NoTrans,
719                b.as_ptr(),
720                2,
721                Transpose::NoTrans,
722                c.as_mut_ptr(),
723                n,
724            );
725        }
726
727        // C should remain initialized to tropical_zero
728        for val in &c {
729            assert!(val.0.is_infinite() && val.0 < 0.0);
730        }
731    }
732
733    #[test]
734    fn test_gemm_with_argmax_empty_k() {
735        let m = 2;
736        let n = 2;
737        let k = 0;
738
739        let a: [f64; 0] = [];
740        let b: [f64; 0] = [];
741        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
742
743        unsafe {
744            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
745                m,
746                n,
747                k,
748                a.as_ptr(),
749                0,
750                Transpose::NoTrans,
751                b.as_ptr(),
752                2,
753                Transpose::NoTrans,
754                &mut result,
755            );
756        }
757
758        // Should complete without panic
759        assert_eq!(result.m, 2);
760        assert_eq!(result.n, 2);
761    }
762
763    #[test]
764    fn test_gemm_with_argmax_trans_a() {
765        let m = 2;
766        let n = 2;
767        let k = 3;
768
769        let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
770        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
771
772        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
773
774        unsafe {
775            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
776                m,
777                n,
778                k,
779                a.as_ptr(),
780                2,
781                Transpose::Trans,
782                b.as_ptr(),
783                2,
784                Transpose::NoTrans,
785                &mut result,
786            );
787        }
788
789        assert_eq!(result.get(0, 0).0, 8.0);
790        assert_eq!(result.get_argmax(0, 0), 2);
791    }
792
793    #[test]
794    fn test_gemm_with_argmax_trans_b() {
795        let m = 2;
796        let n = 2;
797        let k = 3;
798
799        let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
800        let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
801
802        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
803
804        unsafe {
805            tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
806                m,
807                n,
808                k,
809                a.as_ptr(),
810                3,
811                Transpose::NoTrans,
812                b.as_ptr(),
813                3,
814                Transpose::Trans,
815                &mut result,
816            );
817        }
818
819        assert_eq!(result.get(0, 0).0, 8.0);
820        assert_eq!(result.get_argmax(0, 0), 2);
821    }
822}