tropical_gemm/simd/
dispatch.rs

1use super::detect::{simd_level, SimdLevel};
2use super::kernels::*;
3use crate::core::{tropical_gemm_inner, TilingParams, Transpose};
4use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring};
5
6/// Runtime-dispatched GEMM that selects the best kernel for the current CPU.
7///
8/// # Safety
9/// Same requirements as `tropical_gemm_inner`
10pub unsafe fn tropical_gemm_dispatch<T: TropicalSemiring + KernelDispatch>(
11    m: usize,
12    n: usize,
13    k: usize,
14    a: *const T::Scalar,
15    lda: usize,
16    trans_a: Transpose,
17    b: *const T::Scalar,
18    ldb: usize,
19    trans_b: Transpose,
20    c: *mut T,
21    ldc: usize,
22) {
23    T::dispatch_gemm(m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc);
24}
25
26/// Trait for types that support kernel dispatch.
27pub trait KernelDispatch: TropicalSemiring {
28    /// Dispatch to the appropriate kernel based on CPU features.
29    unsafe fn dispatch_gemm(
30        m: usize,
31        n: usize,
32        k: usize,
33        a: *const Self::Scalar,
34        lda: usize,
35        trans_a: Transpose,
36        b: *const Self::Scalar,
37        ldb: usize,
38        trans_b: Transpose,
39        c: *mut Self,
40        ldc: usize,
41    );
42}
43
44impl KernelDispatch for TropicalMaxPlus<f32> {
45    unsafe fn dispatch_gemm(
46        m: usize,
47        n: usize,
48        k: usize,
49        a: *const f32,
50        lda: usize,
51        trans_a: Transpose,
52        b: *const f32,
53        ldb: usize,
54        trans_b: Transpose,
55        c: *mut Self,
56        ldc: usize,
57    ) {
58        match simd_level() {
59            #[cfg(target_arch = "x86_64")]
60            SimdLevel::Avx2 | SimdLevel::Avx512 => {
61                let kernel = Avx2MaxPlusF32Kernel;
62                let params = TilingParams::F32_AVX2;
63                tropical_gemm_inner::<Self, _>(
64                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
65                );
66            }
67            #[cfg(target_arch = "aarch64")]
68            SimdLevel::Neon => {
69                let kernel = NeonMaxPlusF32Kernel;
70                let params = TilingParams::new(128, 128, 256, 4, 4);
71                tropical_gemm_inner::<Self, _>(
72                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
73                );
74            }
75            _ => {
76                let kernel = PortableKernel;
77                let params = TilingParams::PORTABLE;
78                tropical_gemm_inner::<Self, _>(
79                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
80                );
81            }
82        }
83    }
84}
85
86impl KernelDispatch for TropicalMaxPlus<f64> {
87    unsafe fn dispatch_gemm(
88        m: usize,
89        n: usize,
90        k: usize,
91        a: *const f64,
92        lda: usize,
93        trans_a: Transpose,
94        b: *const f64,
95        ldb: usize,
96        trans_b: Transpose,
97        c: *mut Self,
98        ldc: usize,
99    ) {
100        match simd_level() {
101            #[cfg(target_arch = "x86_64")]
102            SimdLevel::Avx2 | SimdLevel::Avx512 => {
103                let kernel = Avx2MaxPlusF64Kernel;
104                let params = TilingParams::F64_AVX2;
105                tropical_gemm_inner::<Self, _>(
106                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
107                );
108            }
109            #[cfg(target_arch = "aarch64")]
110            SimdLevel::Neon => {
111                let kernel = NeonMaxPlusF64Kernel;
112                let params = TilingParams::new(64, 64, 128, 2, 2);
113                tropical_gemm_inner::<Self, _>(
114                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
115                );
116            }
117            _ => {
118                let kernel = PortableKernel;
119                let params = TilingParams::PORTABLE;
120                tropical_gemm_inner::<Self, _>(
121                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
122                );
123            }
124        }
125    }
126}
127
128impl KernelDispatch for TropicalMinPlus<f32> {
129    unsafe fn dispatch_gemm(
130        m: usize,
131        n: usize,
132        k: usize,
133        a: *const f32,
134        lda: usize,
135        trans_a: Transpose,
136        b: *const f32,
137        ldb: usize,
138        trans_b: Transpose,
139        c: *mut Self,
140        ldc: usize,
141    ) {
142        match simd_level() {
143            #[cfg(target_arch = "x86_64")]
144            SimdLevel::Avx2 | SimdLevel::Avx512 => {
145                let kernel = Avx2MinPlusF32Kernel;
146                let params = TilingParams::F32_AVX2;
147                tropical_gemm_inner::<Self, _>(
148                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
149                );
150            }
151            #[cfg(target_arch = "aarch64")]
152            SimdLevel::Neon => {
153                let kernel = NeonMinPlusF32Kernel;
154                let params = TilingParams::new(128, 128, 256, 4, 4);
155                tropical_gemm_inner::<Self, _>(
156                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
157                );
158            }
159            _ => {
160                let kernel = PortableKernel;
161                let params = TilingParams::PORTABLE;
162                tropical_gemm_inner::<Self, _>(
163                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
164                );
165            }
166        }
167    }
168}
169
170impl KernelDispatch for TropicalMaxMul<f32> {
171    unsafe fn dispatch_gemm(
172        m: usize,
173        n: usize,
174        k: usize,
175        a: *const f32,
176        lda: usize,
177        trans_a: Transpose,
178        b: *const f32,
179        ldb: usize,
180        trans_b: Transpose,
181        c: *mut Self,
182        ldc: usize,
183    ) {
184        match simd_level() {
185            #[cfg(target_arch = "x86_64")]
186            SimdLevel::Avx2 | SimdLevel::Avx512 => {
187                let kernel = Avx2MaxMulF32Kernel;
188                let params = TilingParams::F32_AVX2;
189                tropical_gemm_inner::<Self, _>(
190                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
191                );
192            }
193            _ => {
194                let kernel = PortableKernel;
195                let params = TilingParams::PORTABLE;
196                tropical_gemm_inner::<Self, _>(
197                    m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
198                );
199            }
200        }
201    }
202}
203
204// Fallback implementations for other types
205macro_rules! impl_kernel_dispatch_portable {
206    ($($t:ty),*) => {
207        $(
208            impl KernelDispatch for $t {
209                unsafe fn dispatch_gemm(
210                    m: usize,
211                    n: usize,
212                    k: usize,
213                    a: *const Self::Scalar,
214                    lda: usize,
215                    trans_a: Transpose,
216                    b: *const Self::Scalar,
217                    ldb: usize,
218                    trans_b: Transpose,
219                    c: *mut Self,
220                    ldc: usize,
221                ) {
222                    let kernel = PortableKernel;
223                    let params = TilingParams::PORTABLE;
224                    tropical_gemm_inner::<Self, _>(
225                        m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, &params, &kernel,
226                    );
227                }
228            }
229        )*
230    };
231}
232
233impl_kernel_dispatch_portable!(
234    TropicalMinPlus<f64>,
235    TropicalMaxMul<f64>,
236    TropicalMaxPlus<i32>,
237    TropicalMaxPlus<i64>,
238    TropicalMinPlus<i32>,
239    TropicalMinPlus<i64>,
240    TropicalMaxMul<i32>,
241    TropicalMaxMul<i64>
242);
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    // Test that the dispatch function exists and doesn't panic for small inputs
249    #[test]
250    fn test_dispatch_maxplus_f32() {
251        let a = vec![1.0f32, 2.0, 3.0, 4.0];
252        let b = vec![1.0f32, 2.0, 3.0, 4.0];
253        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
254
255        unsafe {
256            tropical_gemm_dispatch::<TropicalMaxPlus<f32>>(
257                2,
258                2,
259                2,
260                a.as_ptr(),
261                2,
262                Transpose::NoTrans,
263                b.as_ptr(),
264                2,
265                Transpose::NoTrans,
266                c.as_mut_ptr(),
267                2,
268            );
269        }
270
271        // C[0,0] = max(A[0,0]+B[0,0], A[0,1]+B[1,0]) = max(1+1, 2+3) = 5
272        assert_eq!(c[0].0, 5.0);
273    }
274
275    #[test]
276    fn test_dispatch_maxplus_f64() {
277        let a = vec![1.0f64, 2.0, 3.0, 4.0];
278        let b = vec![1.0f64, 2.0, 3.0, 4.0];
279        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
280
281        unsafe {
282            tropical_gemm_dispatch::<TropicalMaxPlus<f64>>(
283                2,
284                2,
285                2,
286                a.as_ptr(),
287                2,
288                Transpose::NoTrans,
289                b.as_ptr(),
290                2,
291                Transpose::NoTrans,
292                c.as_mut_ptr(),
293                2,
294            );
295        }
296
297        assert_eq!(c[0].0, 5.0);
298    }
299
300    #[test]
301    fn test_dispatch_minplus_f32() {
302        let a = vec![1.0f32, 2.0, 3.0, 4.0];
303        let b = vec![1.0f32, 2.0, 3.0, 4.0];
304        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
305
306        unsafe {
307            tropical_gemm_dispatch::<TropicalMinPlus<f32>>(
308                2,
309                2,
310                2,
311                a.as_ptr(),
312                2,
313                Transpose::NoTrans,
314                b.as_ptr(),
315                2,
316                Transpose::NoTrans,
317                c.as_mut_ptr(),
318                2,
319            );
320        }
321
322        // C[0,0] = min(A[0,0]+B[0,0], A[0,1]+B[1,0]) = min(1+1, 2+3) = 2
323        assert_eq!(c[0].0, 2.0);
324    }
325
326    #[test]
327    fn test_dispatch_minplus_f64() {
328        let a = vec![1.0f64, 2.0, 3.0, 4.0];
329        let b = vec![1.0f64, 2.0, 3.0, 4.0];
330        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
331
332        unsafe {
333            tropical_gemm_dispatch::<TropicalMinPlus<f64>>(
334                2,
335                2,
336                2,
337                a.as_ptr(),
338                2,
339                Transpose::NoTrans,
340                b.as_ptr(),
341                2,
342                Transpose::NoTrans,
343                c.as_mut_ptr(),
344                2,
345            );
346        }
347
348        assert_eq!(c[0].0, 2.0);
349    }
350
351    #[test]
352    fn test_dispatch_maxmul_f32() {
353        let a = vec![2.0f32, 3.0, 4.0, 5.0];
354        let b = vec![1.0f32, 2.0, 3.0, 4.0];
355        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
356
357        unsafe {
358            tropical_gemm_dispatch::<TropicalMaxMul<f32>>(
359                2,
360                2,
361                2,
362                a.as_ptr(),
363                2,
364                Transpose::NoTrans,
365                b.as_ptr(),
366                2,
367                Transpose::NoTrans,
368                c.as_mut_ptr(),
369                2,
370            );
371        }
372
373        // C[0,0] = max(A[0,0]*B[0,0], A[0,1]*B[1,0]) = max(2*1, 3*3) = 9
374        assert_eq!(c[0].0, 9.0);
375    }
376
377    #[test]
378    fn test_dispatch_maxmul_f64() {
379        let a = vec![2.0f64, 3.0, 4.0, 5.0];
380        let b = vec![1.0f64, 2.0, 3.0, 4.0];
381        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
382
383        unsafe {
384            tropical_gemm_dispatch::<TropicalMaxMul<f64>>(
385                2,
386                2,
387                2,
388                a.as_ptr(),
389                2,
390                Transpose::NoTrans,
391                b.as_ptr(),
392                2,
393                Transpose::NoTrans,
394                c.as_mut_ptr(),
395                2,
396            );
397        }
398
399        assert_eq!(c[0].0, 9.0);
400    }
401
402    #[test]
403    fn test_dispatch_maxplus_i32() {
404        let a = vec![1i32, 2, 3, 4];
405        let b = vec![1i32, 2, 3, 4];
406        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
407
408        unsafe {
409            tropical_gemm_dispatch::<TropicalMaxPlus<i32>>(
410                2,
411                2,
412                2,
413                a.as_ptr(),
414                2,
415                Transpose::NoTrans,
416                b.as_ptr(),
417                2,
418                Transpose::NoTrans,
419                c.as_mut_ptr(),
420                2,
421            );
422        }
423
424        assert_eq!(c[0].0, 5);
425    }
426
427    #[test]
428    fn test_dispatch_maxplus_i64() {
429        let a = vec![1i64, 2, 3, 4];
430        let b = vec![1i64, 2, 3, 4];
431        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
432
433        unsafe {
434            tropical_gemm_dispatch::<TropicalMaxPlus<i64>>(
435                2,
436                2,
437                2,
438                a.as_ptr(),
439                2,
440                Transpose::NoTrans,
441                b.as_ptr(),
442                2,
443                Transpose::NoTrans,
444                c.as_mut_ptr(),
445                2,
446            );
447        }
448
449        assert_eq!(c[0].0, 5);
450    }
451
452    #[test]
453    fn test_dispatch_minplus_i32() {
454        let a = vec![1i32, 2, 3, 4];
455        let b = vec![1i32, 2, 3, 4];
456        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
457
458        unsafe {
459            tropical_gemm_dispatch::<TropicalMinPlus<i32>>(
460                2,
461                2,
462                2,
463                a.as_ptr(),
464                2,
465                Transpose::NoTrans,
466                b.as_ptr(),
467                2,
468                Transpose::NoTrans,
469                c.as_mut_ptr(),
470                2,
471            );
472        }
473
474        assert_eq!(c[0].0, 2);
475    }
476
477    #[test]
478    fn test_dispatch_minplus_i64() {
479        let a = vec![1i64, 2, 3, 4];
480        let b = vec![1i64, 2, 3, 4];
481        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
482
483        unsafe {
484            tropical_gemm_dispatch::<TropicalMinPlus<i64>>(
485                2,
486                2,
487                2,
488                a.as_ptr(),
489                2,
490                Transpose::NoTrans,
491                b.as_ptr(),
492                2,
493                Transpose::NoTrans,
494                c.as_mut_ptr(),
495                2,
496            );
497        }
498
499        assert_eq!(c[0].0, 2);
500    }
501
502    #[test]
503    fn test_dispatch_maxmul_i32() {
504        let a = vec![2i32, 3, 4, 5];
505        let b = vec![1i32, 2, 3, 4];
506        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
507
508        unsafe {
509            tropical_gemm_dispatch::<TropicalMaxMul<i32>>(
510                2,
511                2,
512                2,
513                a.as_ptr(),
514                2,
515                Transpose::NoTrans,
516                b.as_ptr(),
517                2,
518                Transpose::NoTrans,
519                c.as_mut_ptr(),
520                2,
521            );
522        }
523
524        assert_eq!(c[0].0, 9);
525    }
526
527    #[test]
528    fn test_dispatch_maxmul_i64() {
529        let a = vec![2i64, 3, 4, 5];
530        let b = vec![1i64, 2, 3, 4];
531        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
532
533        unsafe {
534            tropical_gemm_dispatch::<TropicalMaxMul<i64>>(
535                2,
536                2,
537                2,
538                a.as_ptr(),
539                2,
540                Transpose::NoTrans,
541                b.as_ptr(),
542                2,
543                Transpose::NoTrans,
544                c.as_mut_ptr(),
545                2,
546            );
547        }
548
549        assert_eq!(c[0].0, 9);
550    }
551
552    #[test]
553    fn test_dispatch_larger_matrix() {
554        // Test a larger matrix to exercise blocking
555        let m = 16;
556        let n = 16;
557        let k = 16;
558
559        let a: Vec<f32> = (0..m * k).map(|i| (i % 10) as f32).collect();
560        let b: Vec<f32> = (0..k * n).map(|i| (i % 10) as f32).collect();
561        let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
562
563        unsafe {
564            tropical_gemm_dispatch::<TropicalMaxPlus<f32>>(
565                m,
566                n,
567                k,
568                a.as_ptr(),
569                k,
570                Transpose::NoTrans,
571                b.as_ptr(),
572                n,
573                Transpose::NoTrans,
574                c.as_mut_ptr(),
575                n,
576            );
577        }
578
579        // Just verify no panic and result is not all zeros
580        let has_non_zero = c.iter().any(|x| x.0 > f32::NEG_INFINITY);
581        assert!(has_non_zero);
582    }
583}