1use super::argmax::GemmWithArgmax;
2use super::kernel::{Microkernel, MicrokernelWithArgmax, PortableMicrokernel};
3use super::packing::{pack_a, pack_b, packed_a_size, packed_b_size, Layout, Transpose};
4use super::tiling::{BlockIterator, TilingParams};
5use crate::types::{TropicalSemiring, TropicalWithArgmax};
6
7pub unsafe fn tropical_gemm_portable<T: TropicalSemiring>(
31 m: usize,
32 n: usize,
33 k: usize,
34 a: *const T::Scalar,
35 lda: usize,
36 trans_a: Transpose,
37 b: *const T::Scalar,
38 ldb: usize,
39 trans_b: Transpose,
40 c: *mut T,
41 ldc: usize,
42) {
43 let params = TilingParams::PORTABLE;
44 let kernel = PortableMicrokernel;
45
46 tropical_gemm_inner::<T, PortableMicrokernel>(
47 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
48 );
49}
50
51pub unsafe fn tropical_gemm_inner<T: TropicalSemiring, K: Microkernel<T>>(
56 m: usize,
57 n: usize,
58 k: usize,
59 a: *const T::Scalar,
60 lda: usize,
61 trans_a: Transpose,
62 b: *const T::Scalar,
63 ldb: usize,
64 trans_b: Transpose,
65 c: *mut T,
66 ldc: usize,
67 params: &TilingParams,
68 kernel: &K,
69) {
70 if m == 0 || n == 0 || k == 0 {
71 return;
72 }
73
74 let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
79 let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
80
81 for (jc, nc) in BlockIterator::new(n, params.nc) {
84 for (pc, kc) in BlockIterator::new(k, params.kc) {
86 pack_b::<T::Scalar>(
88 kc,
89 nc,
90 b_panel_ptr(b, pc, jc, ldb, trans_b),
91 ldb,
92 Layout::RowMajor,
93 trans_b,
94 packed_b.as_mut_ptr(),
95 K::NR,
96 );
97
98 for (ic, mc) in BlockIterator::new(m, params.mc) {
100 pack_a::<T::Scalar>(
102 mc,
103 kc,
104 a_panel_ptr(a, ic, pc, lda, trans_a),
105 lda,
106 Layout::RowMajor,
107 trans_a,
108 packed_a.as_mut_ptr(),
109 K::MR,
110 );
111
112 let n_blocks = nc.div_ceil(K::NR);
114 for jr in 0..n_blocks {
115 let j_start = jr * K::NR;
116 let nr = (nc - j_start).min(K::NR);
117
118 let m_blocks = mc.div_ceil(K::MR);
120 for ir in 0..m_blocks {
121 let i_start = ir * K::MR;
122 let mr = (mc - i_start).min(K::MR);
123
124 let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
126 let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
127 let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
128
129 kernel.execute(mr, nr, kc, a_ptr, b_ptr, c_ptr, ldc);
130 }
131 }
132 }
133 }
134 }
135}
136
137pub unsafe fn tropical_gemm_with_argmax_portable<T: TropicalWithArgmax<Index = u32>>(
144 m: usize,
145 n: usize,
146 k: usize,
147 a: *const T::Scalar,
148 lda: usize,
149 trans_a: Transpose,
150 b: *const T::Scalar,
151 ldb: usize,
152 trans_b: Transpose,
153 result: &mut GemmWithArgmax<T>,
154) {
155 let params = TilingParams::PORTABLE;
156 let kernel = PortableMicrokernel;
157
158 tropical_gemm_with_argmax_inner::<T, PortableMicrokernel>(
159 m, n, k, a, lda, trans_a, b, ldb, trans_b, result, ¶ms, &kernel,
160 );
161}
162
163pub unsafe fn tropical_gemm_with_argmax_inner<
168 T: TropicalWithArgmax<Index = u32>,
169 K: MicrokernelWithArgmax<T>,
170>(
171 m: usize,
172 n: usize,
173 k: usize,
174 a: *const T::Scalar,
175 lda: usize,
176 trans_a: Transpose,
177 b: *const T::Scalar,
178 ldb: usize,
179 trans_b: Transpose,
180 result: &mut GemmWithArgmax<T>,
181 params: &TilingParams,
182 kernel: &K,
183) {
184 if m == 0 || n == 0 || k == 0 {
185 return;
186 }
187
188 let ldc = result.ld;
189 let (c, argmax) = result.as_mut_ptrs();
190
191 let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)];
193 let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)];
194
195 for (jc, nc) in BlockIterator::new(n, params.nc) {
197 for (pc, kc) in BlockIterator::new(k, params.kc) {
198 pack_b::<T::Scalar>(
199 kc,
200 nc,
201 b_panel_ptr(b, pc, jc, ldb, trans_b),
202 ldb,
203 Layout::RowMajor,
204 trans_b,
205 packed_b.as_mut_ptr(),
206 K::NR,
207 );
208
209 for (ic, mc) in BlockIterator::new(m, params.mc) {
210 pack_a::<T::Scalar>(
211 mc,
212 kc,
213 a_panel_ptr(a, ic, pc, lda, trans_a),
214 lda,
215 Layout::RowMajor,
216 trans_a,
217 packed_a.as_mut_ptr(),
218 K::MR,
219 );
220
221 let n_blocks = nc.div_ceil(K::NR);
222 for jr in 0..n_blocks {
223 let j_start = jr * K::NR;
224 let nr = (nc - j_start).min(K::NR);
225
226 let m_blocks = mc.div_ceil(K::MR);
227 for ir in 0..m_blocks {
228 let i_start = ir * K::MR;
229 let mr = (mc - i_start).min(K::MR);
230
231 let a_ptr = packed_a.as_ptr().add(ir * K::MR * kc);
232 let b_ptr = packed_b.as_ptr().add(jr * K::NR * kc);
233 let c_ptr = c.add((ic + i_start) * ldc + (jc + j_start));
234 let argmax_ptr = argmax.add((ic + i_start) * ldc + (jc + j_start));
235
236 kernel.execute_with_argmax(
237 mr, nr, kc, pc, a_ptr, b_ptr, c_ptr, argmax_ptr, ldc,
238 );
239 }
240 }
241 }
242 }
243 }
244}
245
246#[inline]
248unsafe fn a_panel_ptr<T>(
249 a: *const T,
250 row: usize,
251 col: usize,
252 lda: usize,
253 trans: Transpose,
254) -> *const T {
255 match trans {
256 Transpose::NoTrans => a.add(row * lda + col),
257 Transpose::Trans => a.add(col * lda + row),
258 }
259}
260
261#[inline]
263unsafe fn b_panel_ptr<T>(
264 b: *const T,
265 row: usize,
266 col: usize,
267 ldb: usize,
268 trans: Transpose,
269) -> *const T {
270 match trans {
271 Transpose::NoTrans => b.add(row * ldb + col),
272 Transpose::Trans => b.add(col * ldb + row),
273 }
274}
275
276use crate::types::TropicalScalar;
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::types::TropicalMaxPlus;
282
283 #[test]
284 fn test_simple_gemm() {
285 let m = 2;
286 let n = 2;
287 let k = 3;
288
289 let a: [f64; 6] = [
291 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
294
295 let b: [f64; 6] = [
297 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
301
302 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
303
304 unsafe {
305 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
306 m,
307 n,
308 k,
309 a.as_ptr(),
310 3,
311 Transpose::NoTrans,
312 b.as_ptr(),
313 2,
314 Transpose::NoTrans,
315 c.as_mut_ptr(),
316 n,
317 );
318 }
319
320 assert_eq!(c[0].0, 8.0);
322 assert_eq!(c[1].0, 9.0);
324 assert_eq!(c[2].0, 11.0);
326 assert_eq!(c[3].0, 12.0);
328 }
329
330 #[test]
331 fn test_gemm_with_argmax() {
332 let m = 2;
333 let n = 2;
334 let k = 3;
335
336 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
337 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
338
339 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
340
341 unsafe {
342 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
343 m,
344 n,
345 k,
346 a.as_ptr(),
347 3,
348 Transpose::NoTrans,
349 b.as_ptr(),
350 2,
351 Transpose::NoTrans,
352 &mut result,
353 );
354 }
355
356 assert_eq!(result.get(0, 0).0, 8.0);
358 assert_eq!(result.get_argmax(0, 0), 2);
359
360 assert_eq!(result.get(1, 1).0, 12.0);
362 assert_eq!(result.get_argmax(1, 1), 2);
363 }
364
365 #[test]
366 fn test_gemm_with_argmax_all_positions() {
367 let m = 2;
369 let n = 2;
370 let k = 3;
371
372 let a: [f64; 6] = [
376 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, ];
379 let b: [f64; 6] = [
380 10.0, 1.0, 1.0, 10.0, 1.0, 1.0, ];
384
385 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
386
387 unsafe {
388 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
389 m,
390 n,
391 k,
392 a.as_ptr(),
393 3,
394 Transpose::NoTrans,
395 b.as_ptr(),
396 2,
397 Transpose::NoTrans,
398 &mut result,
399 );
400 }
401
402 assert_eq!(result.get(0, 0).0, 20.0);
404 assert_eq!(result.get_argmax(0, 0), 0);
405
406 assert_eq!(result.get(0, 1).0, 11.0);
408 assert_eq!(result.get_argmax(0, 1), 0);
410
411 assert_eq!(result.get(1, 0).0, 11.0);
413 assert_eq!(result.get_argmax(1, 0), 0); assert_eq!(result.get(1, 1).0, 11.0);
417 assert_eq!(result.get_argmax(1, 1), 1); }
419
420 #[test]
421 fn test_gemm_minplus_with_argmax() {
422 use crate::types::TropicalMinPlus;
423
424 let m = 2;
425 let n = 2;
426 let k = 3;
427
428 let a: [f64; 6] = [
430 1.0, 5.0, 3.0, 2.0, 4.0, 6.0, ];
433 let b: [f64; 6] = [
434 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
438
439 let mut result: GemmWithArgmax<TropicalMinPlus<f64>> = GemmWithArgmax::new(m, n);
440
441 unsafe {
442 tropical_gemm_with_argmax_portable::<TropicalMinPlus<f64>>(
443 m,
444 n,
445 k,
446 a.as_ptr(),
447 3,
448 Transpose::NoTrans,
449 b.as_ptr(),
450 2,
451 Transpose::NoTrans,
452 &mut result,
453 );
454 }
455
456 assert_eq!(result.get(0, 0).0, 2.0);
458 assert_eq!(result.get_argmax(0, 0), 0);
459
460 assert_eq!(result.get(0, 1).0, 3.0);
462 assert_eq!(result.get_argmax(0, 1), 0);
463
464 assert_eq!(result.get(1, 0).0, 3.0);
466 assert_eq!(result.get_argmax(1, 0), 0);
467
468 assert_eq!(result.get(1, 1).0, 4.0);
470 assert_eq!(result.get_argmax(1, 1), 0);
471 }
472
473 #[test]
474 fn test_gemm_larger_with_argmax() {
475 let m = 8;
477 let n = 8;
478 let k = 8;
479
480 let a: Vec<f64> = (0..m * k).map(|i| i as f64).collect();
481 let b: Vec<f64> = (0..k * n).map(|i| (k * n - 1 - i) as f64).collect();
482
483 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
484
485 unsafe {
486 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
487 m,
488 n,
489 k,
490 a.as_ptr(),
491 k,
492 Transpose::NoTrans,
493 b.as_ptr(),
494 n,
495 Transpose::NoTrans,
496 &mut result,
497 );
498 }
499
500 for i in 0..m {
502 for j in 0..n {
503 assert!(result.get(i, j).0.is_finite());
504 assert!(result.get_argmax(i, j) < k as u32);
505 }
506 }
507 }
508
509 #[test]
510 fn test_gemm_trans_a() {
511 let m = 2;
515 let n = 2;
516 let k = 3;
517
518 let a: [f64; 6] = [
519 1.0, 4.0, 2.0, 5.0, 3.0, 6.0, ];
523
524 let b: [f64; 6] = [
525 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
529
530 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
531
532 unsafe {
533 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
534 m,
535 n,
536 k,
537 a.as_ptr(),
538 2,
539 Transpose::Trans, b.as_ptr(),
541 2,
542 Transpose::NoTrans,
543 c.as_mut_ptr(),
544 n,
545 );
546 }
547
548 assert_eq!(c[0].0, 8.0);
552 assert_eq!(c[1].0, 9.0);
554 assert_eq!(c[2].0, 11.0);
556 assert_eq!(c[3].0, 12.0);
558 }
559
560 #[test]
561 fn test_gemm_trans_b() {
562 let m = 2;
565 let n = 2;
566 let k = 3;
567
568 let a: [f64; 6] = [
569 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
572
573 let b: [f64; 6] = [
575 1.0, 3.0, 5.0, 2.0, 4.0, 6.0, ];
578
579 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
580
581 unsafe {
582 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
583 m,
584 n,
585 k,
586 a.as_ptr(),
587 3,
588 Transpose::NoTrans,
589 b.as_ptr(),
590 3,
591 Transpose::Trans, c.as_mut_ptr(),
593 n,
594 );
595 }
596
597 assert_eq!(c[0].0, 8.0);
601 assert_eq!(c[1].0, 9.0);
602 assert_eq!(c[2].0, 11.0);
603 assert_eq!(c[3].0, 12.0);
604 }
605
606 #[test]
607 fn test_gemm_trans_both() {
608 let m = 2;
610 let n = 2;
611 let k = 3;
612
613 let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
615 let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
617
618 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
619
620 unsafe {
621 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
622 m,
623 n,
624 k,
625 a.as_ptr(),
626 2,
627 Transpose::Trans,
628 b.as_ptr(),
629 3,
630 Transpose::Trans,
631 c.as_mut_ptr(),
632 n,
633 );
634 }
635
636 assert_eq!(c[0].0, 8.0);
637 assert_eq!(c[1].0, 9.0);
638 assert_eq!(c[2].0, 11.0);
639 assert_eq!(c[3].0, 12.0);
640 }
641
642 #[test]
643 fn test_gemm_empty_m() {
644 let m = 0;
645 let n = 2;
646 let k = 3;
647
648 let a: [f64; 0] = [];
649 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
650 let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
651
652 unsafe {
653 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
654 m,
655 n,
656 k,
657 a.as_ptr(),
658 3,
659 Transpose::NoTrans,
660 b.as_ptr(),
661 2,
662 Transpose::NoTrans,
663 c.as_mut_ptr(),
664 n,
665 );
666 }
667
668 assert!(c.is_empty());
670 }
671
672 #[test]
673 fn test_gemm_empty_n() {
674 let m = 2;
675 let n = 0;
676 let k = 3;
677
678 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
679 let b: [f64; 0] = [];
680 let mut c: Vec<TropicalMaxPlus<f64>> = vec![];
681
682 unsafe {
683 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
684 m,
685 n,
686 k,
687 a.as_ptr(),
688 3,
689 Transpose::NoTrans,
690 b.as_ptr(),
691 2,
692 Transpose::NoTrans,
693 c.as_mut_ptr(),
694 n,
695 );
696 }
697
698 assert!(c.is_empty());
699 }
700
701 #[test]
702 fn test_gemm_empty_k() {
703 let m = 2;
704 let n = 2;
705 let k = 0;
706
707 let a: [f64; 0] = [];
708 let b: [f64; 0] = [];
709 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
710
711 unsafe {
712 tropical_gemm_portable::<TropicalMaxPlus<f64>>(
713 m,
714 n,
715 k,
716 a.as_ptr(),
717 0,
718 Transpose::NoTrans,
719 b.as_ptr(),
720 2,
721 Transpose::NoTrans,
722 c.as_mut_ptr(),
723 n,
724 );
725 }
726
727 for val in &c {
729 assert!(val.0.is_infinite() && val.0 < 0.0);
730 }
731 }
732
733 #[test]
734 fn test_gemm_with_argmax_empty_k() {
735 let m = 2;
736 let n = 2;
737 let k = 0;
738
739 let a: [f64; 0] = [];
740 let b: [f64; 0] = [];
741 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
742
743 unsafe {
744 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
745 m,
746 n,
747 k,
748 a.as_ptr(),
749 0,
750 Transpose::NoTrans,
751 b.as_ptr(),
752 2,
753 Transpose::NoTrans,
754 &mut result,
755 );
756 }
757
758 assert_eq!(result.m, 2);
760 assert_eq!(result.n, 2);
761 }
762
763 #[test]
764 fn test_gemm_with_argmax_trans_a() {
765 let m = 2;
766 let n = 2;
767 let k = 3;
768
769 let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
770 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
771
772 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
773
774 unsafe {
775 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
776 m,
777 n,
778 k,
779 a.as_ptr(),
780 2,
781 Transpose::Trans,
782 b.as_ptr(),
783 2,
784 Transpose::NoTrans,
785 &mut result,
786 );
787 }
788
789 assert_eq!(result.get(0, 0).0, 8.0);
790 assert_eq!(result.get_argmax(0, 0), 2);
791 }
792
793 #[test]
794 fn test_gemm_with_argmax_trans_b() {
795 let m = 2;
796 let n = 2;
797 let k = 3;
798
799 let a: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
800 let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
801
802 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(m, n);
803
804 unsafe {
805 tropical_gemm_with_argmax_portable::<TropicalMaxPlus<f64>>(
806 m,
807 n,
808 k,
809 a.as_ptr(),
810 3,
811 Transpose::NoTrans,
812 b.as_ptr(),
813 3,
814 Transpose::Trans,
815 &mut result,
816 );
817 }
818
819 assert_eq!(result.get(0, 0).0, 8.0);
820 assert_eq!(result.get_argmax(0, 0), 2);
821 }
822}