tropical_gemm/simd/kernels/
portable.rs

1use crate::core::{Microkernel, MicrokernelWithArgmax};
2use crate::types::{TropicalSemiring, TropicalWithArgmax};
3
4/// Portable (non-SIMD) microkernel using the `wide` crate.
5///
6/// This provides a fallback when no SIMD instructions are available,
7/// but uses `wide` types which may still auto-vectorize.
8#[derive(Default, Clone, Copy)]
9pub struct PortableKernel;
10
11impl<T: TropicalSemiring> Microkernel<T> for PortableKernel {
12    const MR: usize = 4;
13    const NR: usize = 4;
14
15    unsafe fn execute(
16        &self,
17        mr: usize,
18        nr: usize,
19        k: usize,
20        a: *const T::Scalar,
21        b: *const T::Scalar,
22        c: *mut T,
23        ldc: usize,
24    ) {
25        // Delegate to the core portable implementation
26        let core_kernel = crate::core::PortableMicrokernel;
27        core_kernel.execute(mr, nr, k, a, b, c, ldc);
28    }
29}
30
31impl<T: TropicalWithArgmax<Index = u32>> MicrokernelWithArgmax<T> for PortableKernel {
32    unsafe fn execute_with_argmax(
33        &self,
34        mr: usize,
35        nr: usize,
36        k: usize,
37        k_offset: usize,
38        a: *const T::Scalar,
39        b: *const T::Scalar,
40        c: *mut T,
41        argmax: *mut u32,
42        ldc: usize,
43    ) {
44        let core_kernel = crate::core::PortableMicrokernel;
45        core_kernel.execute_with_argmax(mr, nr, k, k_offset, a, b, c, argmax, ldc);
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::types::TropicalMaxPlus;
53
54    #[test]
55    fn test_portable_kernel_execute() {
56        let kernel = PortableKernel;
57        let mr = 2;
58        let nr = 2;
59        let k = 2;
60
61        let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
62        let b: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
63        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
64        let ldc = 2;
65
66        unsafe {
67            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
68        }
69
70        // C[0,0] = max(1+1, 3+3) = 6
71        assert_eq!(c[0].0, 6.0);
72    }
73
74    #[test]
75    fn test_portable_kernel_execute_with_argmax() {
76        let kernel = PortableKernel;
77        let mr = 2;
78        let nr = 2;
79        let k = 2;
80
81        let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 10.0, 20.0, 0.0, 0.0];
82        let b: [f64; 8] = [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
83        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
84        let mut argmax = [0u32; 4];
85        let ldc = 2;
86        let k_offset = 0;
87
88        unsafe {
89            kernel.execute_with_argmax(
90                mr,
91                nr,
92                k,
93                k_offset,
94                a.as_ptr(),
95                b.as_ptr(),
96                c.as_mut_ptr(),
97                argmax.as_mut_ptr(),
98                ldc,
99            );
100        }
101
102        // C[0,0] = max(1+1, 10+1) = 11 at k=1
103        assert_eq!(c[0].0, 11.0);
104        assert_eq!(argmax[0], 1);
105    }
106
107    #[test]
108    fn test_portable_kernel_default() {
109        let kernel = PortableKernel::default();
110        // Just verify it can be created and constants are accessible
111        assert_eq!(<PortableKernel as Microkernel<TropicalMaxPlus<f64>>>::MR, 4);
112        assert_eq!(<PortableKernel as Microkernel<TropicalMaxPlus<f64>>>::NR, 4);
113        let _ = kernel;
114    }
115}