tropical_gemm/simd/kernels/
portable.rs1use crate::core::{Microkernel, MicrokernelWithArgmax};
2use crate::types::{TropicalSemiring, TropicalWithArgmax};
3
4#[derive(Default, Clone, Copy)]
9pub struct PortableKernel;
10
11impl<T: TropicalSemiring> Microkernel<T> for PortableKernel {
12 const MR: usize = 4;
13 const NR: usize = 4;
14
15 unsafe fn execute(
16 &self,
17 mr: usize,
18 nr: usize,
19 k: usize,
20 a: *const T::Scalar,
21 b: *const T::Scalar,
22 c: *mut T,
23 ldc: usize,
24 ) {
25 let core_kernel = crate::core::PortableMicrokernel;
27 core_kernel.execute(mr, nr, k, a, b, c, ldc);
28 }
29}
30
31impl<T: TropicalWithArgmax<Index = u32>> MicrokernelWithArgmax<T> for PortableKernel {
32 unsafe fn execute_with_argmax(
33 &self,
34 mr: usize,
35 nr: usize,
36 k: usize,
37 k_offset: usize,
38 a: *const T::Scalar,
39 b: *const T::Scalar,
40 c: *mut T,
41 argmax: *mut u32,
42 ldc: usize,
43 ) {
44 let core_kernel = crate::core::PortableMicrokernel;
45 core_kernel.execute_with_argmax(mr, nr, k, k_offset, a, b, c, argmax, ldc);
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::types::TropicalMaxPlus;
53
54 #[test]
55 fn test_portable_kernel_execute() {
56 let kernel = PortableKernel;
57 let mr = 2;
58 let nr = 2;
59 let k = 2;
60
61 let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
62 let b: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
63 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
64 let ldc = 2;
65
66 unsafe {
67 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
68 }
69
70 assert_eq!(c[0].0, 6.0);
72 }
73
74 #[test]
75 fn test_portable_kernel_execute_with_argmax() {
76 let kernel = PortableKernel;
77 let mr = 2;
78 let nr = 2;
79 let k = 2;
80
81 let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 10.0, 20.0, 0.0, 0.0];
82 let b: [f64; 8] = [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
83 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
84 let mut argmax = [0u32; 4];
85 let ldc = 2;
86 let k_offset = 0;
87
88 unsafe {
89 kernel.execute_with_argmax(
90 mr,
91 nr,
92 k,
93 k_offset,
94 a.as_ptr(),
95 b.as_ptr(),
96 c.as_mut_ptr(),
97 argmax.as_mut_ptr(),
98 ldc,
99 );
100 }
101
102 assert_eq!(c[0].0, 11.0);
104 assert_eq!(argmax[0], 1);
105 }
106
107 #[test]
108 fn test_portable_kernel_default() {
109 let kernel = PortableKernel::default();
110 assert_eq!(<PortableKernel as Microkernel<TropicalMaxPlus<f64>>>::MR, 4);
112 assert_eq!(<PortableKernel as Microkernel<TropicalMaxPlus<f64>>>::NR, 4);
113 let _ = kernel;
114 }
115}