1use super::argmax::GemmWithArgmax;
2use super::kernel::{Microkernel, MicrokernelWithArgmax, PortableMicrokernel};
3use super::packing::{pack_a, pack_b, packed_a_size, packed_b_size, Layout, Transpose};
4use super::tiling::{BlockIterator, TilingParams};
5use crate::types::{TropicalSemiring, TropicalWithArgmax};
6
7pub unsafe fn tropical_gemm_portable<T: TropicalSemiring>(
31 m: usize,
32 n: usize,
33 k: usize,
34 a: *const T::Scalar,
35 lda: usize,
36 trans_a: Transpose,
37 b: *const T::Scalar,
38 ldb: usize,
39 trans_b: Transpose,
40 c: *mut T,
41 ldc: usize,
42) {
43 let params = TilingParams::PORTABLE;
44 let kernel = PortableMicrokernel;
45
46 tropical_gemm_inner::<T, PortableMicrokernel>(
47 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
48 );
49}
50
51pub unsafe fn tropical_gemm_inner<T: TropicalSemiring, K: Microkernel<T>>(
56 m: usize,
57 n: usize,
58 k: usize,
59 a: *const T::Scalar,
60 lda: usize,
61 trans_a: Transpose,
62 b: *const T::Scalar,
63 ldb: usize,
64 trans_b: Transpose,
65 c: *mut T,
66 ldc: usize,
67 params: &TilingParams,
68 kernel: &K,
69) {
70 if m == 0 || n == 0 || k == 0 {
71 return;
72 }
73
74 let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
79 let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
80
81 for (jc, nc) in BlockIterator::new(n, params.nc) {
84 for (pc, kc) in BlockIterator::new(k, params.kc) {
86 pack_b::<T::Scalar>(
88 kc,
89 nc,
90 b_panel_ptr(b, pc, jc, ldb, trans_b),
91 ldb,
92 Layout::RowMajor,
93 trans_b,
94 packed_b.as_mut_ptr(),
95 K::NR,
96 );
97
98 for (ic, mc) in BlockIterator::new(m, params.mc) {
100 pack_a::<T::Scalar>(
102 mc,
103 kc,
104 a_panel_ptr(a, ic, pc, lda, trans_a),
105 lda,
106 Layout::RowMajor,
107 trans_a,
108 packed_a.as_mut_ptr(),
109 K::MR,
110 );
111
112 let n_blocks = nc.div_ceil(K::NR);
114 for jr in 0..n_blocks {
115 let j_start = jr * K::NR;
116 let nr = (nc - j_start).min(K::NR);
117
118 let m_blocks = mc.div_ceil(K::MR);
120 for ir in 0..m_blocks {
121 let i_start = ir * K::MR;
122 let mr = (mc - i_start).min(K::MR);
123
124 let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
126 let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
127 let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
128
129 kernel.execute(mr, nr, kc, a_ptr, b_ptr, c_ptr, ldc);
130 }
131 }
132 }
133 }
134 }
135}
136
137pub unsafe fn tropical_gemm_with_argmax_portable<T: TropicalWithArgmax<Index = u32>>(
144 m: usize,
145 n: usize,
146 k: usize,
147 a: *const T::Scalar,
148 lda: usize,
149 trans_a: Transpose,
150 b: *const T::Scalar,
151 ldb: usize,
152 trans_b: Transpose,
153 result: &mut GemmWithArgmax<T>,
154) {
155 let params = TilingParams::PORTABLE;
156 let kernel = PortableMicrokernel;
157
158 tropical_gemm_with_argmax_inner::<T, PortableMicrokernel>(
159 m, n, k, a, lda, trans_a, b, ldb, trans_b, result, ¶ms, &kernel,
160 );
161}
162
163pub unsafe fn tropical_gemm_with_argmax_inner<
168 T: TropicalWithArgmax<Index = u32>,
169 K: MicrokernelWithArgmax<T>,
170>(
171 m: usize,
172 n: usize,
173 k: usize,
174 a: *const T::Scalar,
175 lda: usize,
176 trans_a: Transpose,
177 b: *const T::Scalar,
178 ldb: usize,
179 trans_b: Transpose,
180 result: &mut GemmWithArgmax<T>,
181 params: &TilingParams,
182 kernel: &K,
183) {
184 if m == 0 || n == 0 || k == 0 {
185 return;
186 }
187
188 let ldc = result.ld;
189 let (c, argmax) = result.as_mut_ptrs();
190
191 let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
193 let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
194
195 for (jc, nc) in BlockIterator::new(n, params.nc) {
197 for (pc, kc) in BlockIterator::new(k, params.kc) {
198 pack_b::<T::Scalar>(
199 kc,
200 nc,
201 b_panel_ptr(b, pc, jc, ldb, trans_b),
202 ldb,
203 Layout::RowMajor,
204 trans_b,
205 packed_b.as_mut_ptr(),
206 K::NR,
207 );
208
209 for (ic, mc) in BlockIterator::new(m, params.mc) {
210 pack_a::<T::Scalar>(
211 mc,
212 kc,
213 a_panel_ptr(a, ic, pc, lda, trans_a),
214 lda,
215 Layout::RowMajor,
216 trans_a,
217 packed_a.as_mut_ptr(),
218 K::MR,
219 );
220
221 let n_blocks = nc.div_ceil(K::NR);
222 for jr in 0..n_blocks {
223 let j_start = jr * K::NR;
224 let nr = (nc - j_start).min(K::NR);
225
226 let m_blocks = mc.div_ceil(K::MR);
227 for ir in 0..m_blocks {
228 let i_start = ir * K::MR;
229 let mr = (mc - i_start).min(K::MR);
230
231 let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
232 let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
233 let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
234 let argmax_ptr = argmax.add((ic + i_start) * ldc + (jc + j_start));
235
236 kernel.execute_with_argmax(
237 mr, nr, kc, pc, a_ptr, b_ptr, c_ptr, argmax_ptr, ldc,
238 );
239 }
240 }
241 }
242 }
243 }
244
245 for i in 0..m {
253 for j in 0..n {
254 if result.get(i, j).is_no_contribution() {
255 *result.get_argmax_mut(i, j) = 0;
256 }
257 }
258 }
259}
260
261#[inline]
263unsafe fn a_panel_ptr<T>(
264 a: *const T,
265 row: usize,
266 col: usize,
267 lda: usize,
268 trans: Transpose,
269) -> *const T {
270 match trans {
271 Transpose::NoTrans => a.add(row * lda + col),
272 Transpose::Trans => a.add(col * lda + row),
273 }
274}
275
276#[inline]
278unsafe fn b_panel_ptr<T>(
279 b: *const T,
280 row: usize,
281 col: usize,
282 ldb: usize,
283 trans: Transpose,
284) -> *const T {
285 match trans {
286 Transpose::NoTrans => b.add(row * ldb + col),
287 Transpose::Trans => b.add(col * ldb + row),
288 }
289}
290
291use crate::types::TropicalScalar;
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::types::TropicalMaxPlus;
297
298 #[test]
299 fn test_simple_gemm() {
300 let m = 2;
301 let n = 2;
302 let k = 3;
303
304 let a: [f64; 6] = [
306 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
309
310 let b: [f64; 6] = [
312 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
316
317 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
318
319 unsafe {
320 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
321 m,
322 n,
323 k,
324 a.as_ptr(),
325 3,
326 Transpose::NoTrans,
327 b.as_ptr(),
328 2,
329 Transpose::NoTrans,
330 c.as_mut_ptr(),
331 n,
332 );
333 }
334
335 assert_eq!(c[0].0, 8.0);
337 assert_eq!(c[1].0, 9.0);
339 assert_eq!(c[2].0, 11.0);
341 assert_eq!(c[3].0, 12.0);
343 }
344
345 #[test]
346 fn test_gemm_with_argmax() {
347 let m = 2;
348 let n = 2;
349 let k = 3;
350
351 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
352 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
353
354 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
355
356 unsafe {
357 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
358 m,
359 n,
360 k,
361 a.as_ptr(),
362 3,
363 Transpose::NoTrans,
364 b.as_ptr(),
365 2,
366 Transpose::NoTrans,
367 &mut result,
368 );
369 }
370
371 assert_eq!(result.get(0, 0).0, 8.0);
373 assert_eq!(result.get_argmax(0, 0), 2);
374
375 assert_eq!(result.get(1, 1).0, 12.0);
377 assert_eq!(result.get_argmax(1, 1), 2);
378 }
379
380 #[test]
381 fn test_gemm_with_argmax_all_positions() {
382 let m = 2;
384 let n = 2;
385 let k = 3;
386
387 let a: [f64; 6] = [
391 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, ];
394 let b: [f64; 6] = [
395 10.0, 1.0, 1.0, 10.0, 1.0, 1.0, ];
399
400 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
401
402 unsafe {
403 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
404 m,
405 n,
406 k,
407 a.as_ptr(),
408 3,
409 Transpose::NoTrans,
410 b.as_ptr(),
411 2,
412 Transpose::NoTrans,
413 &mut result,
414 );
415 }
416
417 assert_eq!(result.get(0, 0).0, 20.0);
419 assert_eq!(result.get_argmax(0, 0), 0);
420
421 assert_eq!(result.get(0, 1).0, 11.0);
423 assert_eq!(result.get_argmax(0, 1), 0);
425
426 assert_eq!(result.get(1, 0).0, 11.0);
428 assert_eq!(result.get_argmax(1, 0), 0); assert_eq!(result.get(1, 1).0, 11.0);
432 assert_eq!(result.get_argmax(1, 1), 1); }
434
435 #[test]
436 fn test_gemm_minplus_with_argmax() {
437 use crate::types::TropicalMinPlus;
438
439 let m = 2;
440 let n = 2;
441 let k = 3;
442
443 let a: [f64; 6] = [
445 1.0, 5.0, 3.0, 2.0, 4.0, 6.0, ];
448 let b: [f64; 6] = [
449 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
453
454 let mut result: GemmWithArgmax<TropicalMinPlus<f64>> = GemmWithArgmax::new(m, n);
455
456 unsafe {
457 tropical_gemm_with_argmax_portable::<TropicalMinPlus<f64>>(
458 m,
459 n,
460 k,
461 a.as_ptr(),
462 3,
463 Transpose::NoTrans,
464 b.as_ptr(),
465 2,
466 Transpose::NoTrans,
467 &mut result,
468 );
469 }
470
471 assert_eq!(result.get(0, 0).0, 2.0);
473 assert_eq!(result.get_argmax(0, 0), 0);
474
475 assert_eq!(result.get(0, 1).0, 3.0);
477 assert_eq!(result.get_argmax(0, 1), 0);
478
479 assert_eq!(result.get(1, 0).0, 3.0);
481 assert_eq!(result.get_argmax(1, 0), 0);
482
483 assert_eq!(result.get(1, 1).0, 4.0);
485 assert_eq!(result.get_argmax(1, 1), 0);
486 }
487
488 #[test]
489 fn test_gemm_larger_with_argmax() {
490 let m = 8;
492 let n = 8;
493 let k = 8;
494
495 let a: Vec<f64> = (0..m * k).map(|i| i as f64).collect();
496 let b: Vec<f64> = (0..k * n).map(|i| (k * n - 1 - i) as f64).collect();
497
498 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
499
500 unsafe {
501 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
502 m,
503 n,
504 k,
505 a.as_ptr(),
506 k,
507 Transpose::NoTrans,
508 b.as_ptr(),
509 n,
510 Transpose::NoTrans,
511 &mut result,
512 );
513 }
514
515 for i in 0..m {
517 for j in 0..n {
518 assert!(result.get(i, j).0.is_finite());
519 assert!(result.get_argmax(i, j) < k as u32);
520 }
521 }
522 }
523
524 #[test]
525 fn test_gemm_trans_a() {
526 let m = 2;
530 let n = 2;
531 let k = 3;
532
533 let a: [f64; 6] = [
534 1.0, 4.0, 2.0, 5.0, 3.0, 6.0, ];
538
539 let b: [f64; 6] = [
540 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
544
545 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
546
547 unsafe {
548 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
549 m,
550 n,
551 k,
552 a.as_ptr(),
553 2,
554 Transpose::Trans, b.as_ptr(),
556 2,
557 Transpose::NoTrans,
558 c.as_mut_ptr(),
559 n,
560 );
561 }
562
563 assert_eq!(c[0].0, 8.0);
567 assert_eq!(c[1].0, 9.0);
569 assert_eq!(c[2].0, 11.0);
571 assert_eq!(c[3].0, 12.0);
573 }
574
575 #[test]
576 fn test_gemm_trans_b() {
577 let m = 2;
580 let n = 2;
581 let k = 3;
582
583 let a: [f64; 6] = [
584 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
587
588 let b: [f64; 6] = [
590 1.0, 3.0, 5.0, 2.0, 4.0, 6.0, ];
593
594 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
595
596 unsafe {
597 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
598 m,
599 n,
600 k,
601 a.as_ptr(),
602 3,
603 Transpose::NoTrans,
604 b.as_ptr(),
605 3,
606 Transpose::Trans, c.as_mut_ptr(),
608 n,
609 );
610 }
611
612 assert_eq!(c[0].0, 8.0);
616 assert_eq!(c[1].0, 9.0);
617 assert_eq!(c[2].0, 11.0);
618 assert_eq!(c[3].0, 12.0);
619 }
620
621 #[test]
622 fn test_gemm_trans_both() {
623 let m = 2;
625 let n = 2;
626 let k = 3;
627
628 let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
630 let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
632
633 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
634
635 unsafe {
636 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
637 m,
638 n,
639 k,
640 a.as_ptr(),
641 2,
642 Transpose::Trans,
643 b.as_ptr(),
644 3,
645 Transpose::Trans,
646 c.as_mut_ptr(),
647 n,
648 );
649 }
650
651 assert_eq!(c[0].0, 8.0);
652 assert_eq!(c[1].0, 9.0);
653 assert_eq!(c[2].0, 11.0);
654 assert_eq!(c[3].0, 12.0);
655 }
656
657 #[test]
658 fn test_gemm_empty_m() {
659 let m = 0;
660 let n = 2;
661 let k = 3;
662
663 let a: [f64; 0] = [];
664 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
665 let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
666
667 unsafe {
668 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
669 m,
670 n,
671 k,
672 a.as_ptr(),
673 3,
674 Transpose::NoTrans,
675 b.as_ptr(),
676 2,
677 Transpose::NoTrans,
678 c.as_mut_ptr(),
679 n,
680 );
681 }
682
683 assert!(c.is_empty());
685 }
686
687 #[test]
688 fn test_gemm_empty_n() {
689 let m = 2;
690 let n = 0;
691 let k = 3;
692
693 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
694 let b: [f64; 0] = [];
695 let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
696
697 unsafe {
698 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
699 m,
700 n,
701 k,
702 a.as_ptr(),
703 3,
704 Transpose::NoTrans,
705 b.as_ptr(),
706 2,
707 Transpose::NoTrans,
708 c.as_mut_ptr(),
709 n,
710 );
711 }
712
713 assert!(c.is_empty());
714 }
715
716 #[test]
717 fn test_gemm_empty_k() {
718 let m = 2;
719 let n = 2;
720 let k = 0;
721
722 let a: [f64; 0] = [];
723 let b: [f64; 0] = [];
724 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
725
726 unsafe {
727 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
728 m,
729 n,
730 k,
731 a.as_ptr(),
732 0,
733 Transpose::NoTrans,
734 b.as_ptr(),
735 2,
736 Transpose::NoTrans,
737 c.as_mut_ptr(),
738 n,
739 );
740 }
741
742 for val in &c {
744 assert!(val.0.is_infinite() && val.0 < 0.0);
745 }
746 }
747
748 #[test]
749 fn test_gemm_with_argmax_empty_k() {
750 let m = 2;
751 let n = 2;
752 let k = 0;
753
754 let a: [f64; 0] = [];
755 let b: [f64; 0] = [];
756 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
757
758 unsafe {
759 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
760 m,
761 n,
762 k,
763 a.as_ptr(),
764 0,
765 Transpose::NoTrans,
766 b.as_ptr(),
767 2,
768 Transpose::NoTrans,
769 &mut result,
770 );
771 }
772
773 assert_eq!(result.m, 2);
775 assert_eq!(result.n, 2);
776 }
777
778 #[test]
779 fn test_gemm_with_argmax_trans_a() {
780 let m = 2;
781 let n = 2;
782 let k = 3;
783
784 let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
785 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
786
787 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
788
789 unsafe {
790 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
791 m,
792 n,
793 k,
794 a.as_ptr(),
795 2,
796 Transpose::Trans,
797 b.as_ptr(),
798 2,
799 Transpose::NoTrans,
800 &mut result,
801 );
802 }
803
804 assert_eq!(result.get(0, 0).0, 8.0);
805 assert_eq!(result.get_argmax(0, 0), 2);
806 }
807
808 #[test]
809 fn test_gemm_with_argmax_trans_b() {
810 let m = 2;
811 let n = 2;
812 let k = 3;
813
814 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
815 let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
816
817 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
818
819 unsafe {
820 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
821 m,
822 n,
823 k,
824 a.as_ptr(),
825 3,
826 Transpose::NoTrans,
827 b.as_ptr(),
828 3,
829 Transpose::Trans,
830 &mut result,
831 );
832 }
833
834 assert_eq!(result.get(0, 0).0, 8.0);
835 assert_eq!(result.get_argmax(0, 0), 2);
836 }
837
838 #[test]
839 fn test_gemm_with_argmax_int_zero_cell_canonicalized() {
840 use crate::types::TropicalScalar;
841
842 let m = 2;
846 let n = 2;
847 let k = 3;
848 let neg = <i32 as TropicalScalar>::neg_infinity();
849 let a: [i32; 6] = [
850 neg, neg, neg, 1, 2, 3, ];
853 let b: [i32; 6] = [
854 4, 5, 6, 7, 8, 9, ];
858
859 let mut result: GemmWithArgmax<TropicalMaxPlus<i32>> = GemmWithArgmax::new(m, n);
860 unsafe {
861 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<i32>>(
862 m,
863 n,
864 k,
865 a.as_ptr(),
866 3,
867 Transpose::NoTrans,
868 b.as_ptr(),
869 2,
870 Transpose::NoTrans,
871 &mut result,
872 );
873 }
874
875 for j in 0..n {
877 assert!(
878 result.get(0, j).0.is_drifted_neg_zero(),
879 "C[0,{j}] should be in tropical-zero territory"
880 );
881 assert_eq!(
882 result.get_argmax(0, j),
883 0,
884 "zero-cell argmax must canonicalize to 0, not drift"
885 );
886 }
887 assert_eq!(result.get(1, 0).0, 11);
890 assert_eq!(result.get_argmax(1, 0), 2);
891 assert_eq!(result.get(1, 1).0, 12);
892 assert_eq!(result.get_argmax(1, 1), 2);
893 }
894
895 #[test]
896 fn test_gemm_with_argmax_float_zero_cell_keeps_seed() {
897 let m = 2;
900 let n = 2;
901 let k = 3;
902 let a: [f64; 6] = [
903 f64::NEG_INFINITY,
904 f64::NEG_INFINITY,
905 f64::NEG_INFINITY,
906 1.0,
907 2.0,
908 3.0,
909 ];
910 let b: [f64; 6] = [4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
911
912 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
913 unsafe {
914 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
915 m,
916 n,
917 k,
918 a.as_ptr(),
919 3,
920 Transpose::NoTrans,
921 b.as_ptr(),
922 2,
923 Transpose::NoTrans,
924 &mut result,
925 );
926 }
927
928 for j in 0..n {
929 assert_eq!(result.get(0, j).0, f64::NEG_INFINITY);
930 assert_eq!(result.get_argmax(0, j), 0);
931 }
932 }
933}