1use super::detect::{simd_level, SimdLevel};
2use super::kernels::*;
3use crate::core::{tropical_gemm_inner, TilingParams, Transpose};
4use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring};
5
6pub unsafe fn tropical_gemm_dispatch<T: TropicalSemiring + KernelDispatch>(
11 m: usize,
12 n: usize,
13 k: usize,
14 a: *const T::Scalar,
15 lda: usize,
16 trans_a: Transpose,
17 b: *const T::Scalar,
18 ldb: usize,
19 trans_b: Transpose,
20 c: *mut T,
21 ldc: usize,
22) {
23 T::dispatch_gemm(m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc);
24}
25
26pub trait KernelDispatch: TropicalSemiring {
28 unsafe fn dispatch_gemm(
30 m: usize,
31 n: usize,
32 k: usize,
33 a: *const Self::Scalar,
34 lda: usize,
35 trans_a: Transpose,
36 b: *const Self::Scalar,
37 ldb: usize,
38 trans_b: Transpose,
39 c: *mut Self,
40 ldc: usize,
41 );
42}
43
44impl KernelDispatch for TropicalMaxPlus<f32> {
45 unsafe fn dispatch_gemm(
46 m: usize,
47 n: usize,
48 k: usize,
49 a: *const f32,
50 lda: usize,
51 trans_a: Transpose,
52 b: *const f32,
53 ldb: usize,
54 trans_b: Transpose,
55 c: *mut Self,
56 ldc: usize,
57 ) {
58 match simd_level() {
59 #[cfg(target_arch = "x86_64")]
60 SimdLevel::Avx2 | SimdLevel::Avx512 => {
61 let kernel = Avx2MaxPlusF32Kernel;
62 let params = TilingParams::F32_AVX2;
63 tropical_gemm_inner::<Self, _>(
64 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
65 );
66 }
67 #[cfg(target_arch = "aarch64")]
68 SimdLevel::Neon => {
69 let kernel = NeonMaxPlusF32Kernel;
70 let params = TilingParams::new(128, 128, 256, 4, 4);
71 tropical_gemm_inner::<Self, _>(
72 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
73 );
74 }
75 _ => {
76 let kernel = PortableKernel;
77 let params = TilingParams::PORTABLE;
78 tropical_gemm_inner::<Self, _>(
79 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
80 );
81 }
82 }
83 }
84}
85
86impl KernelDispatch for TropicalMaxPlus<f64> {
87 unsafe fn dispatch_gemm(
88 m: usize,
89 n: usize,
90 k: usize,
91 a: *const f64,
92 lda: usize,
93 trans_a: Transpose,
94 b: *const f64,
95 ldb: usize,
96 trans_b: Transpose,
97 c: *mut Self,
98 ldc: usize,
99 ) {
100 match simd_level() {
101 #[cfg(target_arch = "x86_64")]
102 SimdLevel::Avx2 | SimdLevel::Avx512 => {
103 let kernel = Avx2MaxPlusF64Kernel;
104 let params = TilingParams::F64_AVX2;
105 tropical_gemm_inner::<Self, _>(
106 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
107 );
108 }
109 #[cfg(target_arch = "aarch64")]
110 SimdLevel::Neon => {
111 let kernel = NeonMaxPlusF64Kernel;
112 let params = TilingParams::new(64, 64, 128, 2, 2);
113 tropical_gemm_inner::<Self, _>(
114 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
115 );
116 }
117 _ => {
118 let kernel = PortableKernel;
119 let params = TilingParams::PORTABLE;
120 tropical_gemm_inner::<Self, _>(
121 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
122 );
123 }
124 }
125 }
126}
127
128impl KernelDispatch for TropicalMinPlus<f32> {
129 unsafe fn dispatch_gemm(
130 m: usize,
131 n: usize,
132 k: usize,
133 a: *const f32,
134 lda: usize,
135 trans_a: Transpose,
136 b: *const f32,
137 ldb: usize,
138 trans_b: Transpose,
139 c: *mut Self,
140 ldc: usize,
141 ) {
142 match simd_level() {
143 #[cfg(target_arch = "x86_64")]
144 SimdLevel::Avx2 | SimdLevel::Avx512 => {
145 let kernel = Avx2MinPlusF32Kernel;
146 let params = TilingParams::F32_AVX2;
147 tropical_gemm_inner::<Self, _>(
148 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
149 );
150 }
151 #[cfg(target_arch = "aarch64")]
152 SimdLevel::Neon => {
153 let kernel = NeonMinPlusF32Kernel;
154 let params = TilingParams::new(128, 128, 256, 4, 4);
155 tropical_gemm_inner::<Self, _>(
156 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
157 );
158 }
159 _ => {
160 let kernel = PortableKernel;
161 let params = TilingParams::PORTABLE;
162 tropical_gemm_inner::<Self, _>(
163 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
164 );
165 }
166 }
167 }
168}
169
170impl KernelDispatch for TropicalMaxMul<f32> {
171 unsafe fn dispatch_gemm(
172 m: usize,
173 n: usize,
174 k: usize,
175 a: *const f32,
176 lda: usize,
177 trans_a: Transpose,
178 b: *const f32,
179 ldb: usize,
180 trans_b: Transpose,
181 c: *mut Self,
182 ldc: usize,
183 ) {
184 match simd_level() {
185 #[cfg(target_arch = "x86_64")]
186 SimdLevel::Avx2 | SimdLevel::Avx512 => {
187 let kernel = Avx2MaxMulF32Kernel;
188 let params = TilingParams::F32_AVX2;
189 tropical_gemm_inner::<Self, _>(
190 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
191 );
192 }
193 _ => {
194 let kernel = PortableKernel;
195 let params = TilingParams::PORTABLE;
196 tropical_gemm_inner::<Self, _>(
197 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
198 );
199 }
200 }
201 }
202}
203
204macro_rules! impl_kernel_dispatch_portable {
206 ($($t:ty),*) => {
207 $(
208 impl KernelDispatch for $t {
209 unsafe fn dispatch_gemm(
210 m: usize,
211 n: usize,
212 k: usize,
213 a: *const Self::Scalar,
214 lda: usize,
215 trans_a: Transpose,
216 b: *const Self::Scalar,
217 ldb: usize,
218 trans_b: Transpose,
219 c: *mut Self,
220 ldc: usize,
221 ) {
222 let kernel = PortableKernel;
223 let params = TilingParams::PORTABLE;
224 tropical_gemm_inner::<Self, _>(
225 m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc, ¶ms, &kernel,
226 );
227 }
228 }
229 )*
230 };
231}
232
233impl_kernel_dispatch_portable!(
234 TropicalMinPlus<f64>,
235 TropicalMaxMul<f64>,
236 TropicalMaxPlus<i32>,
237 TropicalMaxPlus<i64>,
238 TropicalMinPlus<i32>,
239 TropicalMinPlus<i64>,
240 TropicalMaxMul<i32>,
241 TropicalMaxMul<i64>
242);
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
250 fn test_dispatch_maxplus_f32() {
251 let a = vec![1.0f32, 2.0, 3.0, 4.0];
252 let b = vec![1.0f32, 2.0, 3.0, 4.0];
253 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
254
255 unsafe {
256 tropical_gemm_dispatch::<TropicalMaxPlus<f32>>(
257 2,
258 2,
259 2,
260 a.as_ptr(),
261 2,
262 Transpose::NoTrans,
263 b.as_ptr(),
264 2,
265 Transpose::NoTrans,
266 c.as_mut_ptr(),
267 2,
268 );
269 }
270
271 assert_eq!(c[0].0, 5.0);
273 }
274
275 #[test]
276 fn test_dispatch_maxplus_f64() {
277 let a = vec![1.0f64, 2.0, 3.0, 4.0];
278 let b = vec![1.0f64, 2.0, 3.0, 4.0];
279 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
280
281 unsafe {
282 tropical_gemm_dispatch::<TropicalMaxPlus<f64>>(
283 2,
284 2,
285 2,
286 a.as_ptr(),
287 2,
288 Transpose::NoTrans,
289 b.as_ptr(),
290 2,
291 Transpose::NoTrans,
292 c.as_mut_ptr(),
293 2,
294 );
295 }
296
297 assert_eq!(c[0].0, 5.0);
298 }
299
300 #[test]
301 fn test_dispatch_minplus_f32() {
302 let a = vec![1.0f32, 2.0, 3.0, 4.0];
303 let b = vec![1.0f32, 2.0, 3.0, 4.0];
304 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
305
306 unsafe {
307 tropical_gemm_dispatch::<TropicalMinPlus<f32>>(
308 2,
309 2,
310 2,
311 a.as_ptr(),
312 2,
313 Transpose::NoTrans,
314 b.as_ptr(),
315 2,
316 Transpose::NoTrans,
317 c.as_mut_ptr(),
318 2,
319 );
320 }
321
322 assert_eq!(c[0].0, 2.0);
324 }
325
326 #[test]
327 fn test_dispatch_minplus_f64() {
328 let a = vec![1.0f64, 2.0, 3.0, 4.0];
329 let b = vec![1.0f64, 2.0, 3.0, 4.0];
330 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
331
332 unsafe {
333 tropical_gemm_dispatch::<TropicalMinPlus<f64>>(
334 2,
335 2,
336 2,
337 a.as_ptr(),
338 2,
339 Transpose::NoTrans,
340 b.as_ptr(),
341 2,
342 Transpose::NoTrans,
343 c.as_mut_ptr(),
344 2,
345 );
346 }
347
348 assert_eq!(c[0].0, 2.0);
349 }
350
351 #[test]
352 fn test_dispatch_maxmul_f32() {
353 let a = vec![2.0f32, 3.0, 4.0, 5.0];
354 let b = vec![1.0f32, 2.0, 3.0, 4.0];
355 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
356
357 unsafe {
358 tropical_gemm_dispatch::<TropicalMaxMul<f32>>(
359 2,
360 2,
361 2,
362 a.as_ptr(),
363 2,
364 Transpose::NoTrans,
365 b.as_ptr(),
366 2,
367 Transpose::NoTrans,
368 c.as_mut_ptr(),
369 2,
370 );
371 }
372
373 assert_eq!(c[0].0, 9.0);
375 }
376
377 #[test]
378 fn test_dispatch_maxmul_f64() {
379 let a = vec![2.0f64, 3.0, 4.0, 5.0];
380 let b = vec![1.0f64, 2.0, 3.0, 4.0];
381 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
382
383 unsafe {
384 tropical_gemm_dispatch::<TropicalMaxMul<f64>>(
385 2,
386 2,
387 2,
388 a.as_ptr(),
389 2,
390 Transpose::NoTrans,
391 b.as_ptr(),
392 2,
393 Transpose::NoTrans,
394 c.as_mut_ptr(),
395 2,
396 );
397 }
398
399 assert_eq!(c[0].0, 9.0);
400 }
401
402 #[test]
403 fn test_dispatch_maxplus_i32() {
404 let a = vec![1i32, 2, 3, 4];
405 let b = vec![1i32, 2, 3, 4];
406 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
407
408 unsafe {
409 tropical_gemm_dispatch::<TropicalMaxPlus<i32>>(
410 2,
411 2,
412 2,
413 a.as_ptr(),
414 2,
415 Transpose::NoTrans,
416 b.as_ptr(),
417 2,
418 Transpose::NoTrans,
419 c.as_mut_ptr(),
420 2,
421 );
422 }
423
424 assert_eq!(c[0].0, 5);
425 }
426
427 #[test]
428 fn test_dispatch_maxplus_i64() {
429 let a = vec![1i64, 2, 3, 4];
430 let b = vec![1i64, 2, 3, 4];
431 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
432
433 unsafe {
434 tropical_gemm_dispatch::<TropicalMaxPlus<i64>>(
435 2,
436 2,
437 2,
438 a.as_ptr(),
439 2,
440 Transpose::NoTrans,
441 b.as_ptr(),
442 2,
443 Transpose::NoTrans,
444 c.as_mut_ptr(),
445 2,
446 );
447 }
448
449 assert_eq!(c[0].0, 5);
450 }
451
452 #[test]
453 fn test_dispatch_minplus_i32() {
454 let a = vec![1i32, 2, 3, 4];
455 let b = vec![1i32, 2, 3, 4];
456 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
457
458 unsafe {
459 tropical_gemm_dispatch::<TropicalMinPlus<i32>>(
460 2,
461 2,
462 2,
463 a.as_ptr(),
464 2,
465 Transpose::NoTrans,
466 b.as_ptr(),
467 2,
468 Transpose::NoTrans,
469 c.as_mut_ptr(),
470 2,
471 );
472 }
473
474 assert_eq!(c[0].0, 2);
475 }
476
477 #[test]
478 fn test_dispatch_minplus_i64() {
479 let a = vec![1i64, 2, 3, 4];
480 let b = vec![1i64, 2, 3, 4];
481 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
482
483 unsafe {
484 tropical_gemm_dispatch::<TropicalMinPlus<i64>>(
485 2,
486 2,
487 2,
488 a.as_ptr(),
489 2,
490 Transpose::NoTrans,
491 b.as_ptr(),
492 2,
493 Transpose::NoTrans,
494 c.as_mut_ptr(),
495 2,
496 );
497 }
498
499 assert_eq!(c[0].0, 2);
500 }
501
502 #[test]
503 fn test_dispatch_maxmul_i32() {
504 let a = vec![2i32, 3, 4, 5];
505 let b = vec![1i32, 2, 3, 4];
506 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
507
508 unsafe {
509 tropical_gemm_dispatch::<TropicalMaxMul<i32>>(
510 2,
511 2,
512 2,
513 a.as_ptr(),
514 2,
515 Transpose::NoTrans,
516 b.as_ptr(),
517 2,
518 Transpose::NoTrans,
519 c.as_mut_ptr(),
520 2,
521 );
522 }
523
524 assert_eq!(c[0].0, 9);
525 }
526
527 #[test]
528 fn test_dispatch_maxmul_i64() {
529 let a = vec![2i64, 3, 4, 5];
530 let b = vec![1i64, 2, 3, 4];
531 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
532
533 unsafe {
534 tropical_gemm_dispatch::<TropicalMaxMul<i64>>(
535 2,
536 2,
537 2,
538 a.as_ptr(),
539 2,
540 Transpose::NoTrans,
541 b.as_ptr(),
542 2,
543 Transpose::NoTrans,
544 c.as_mut_ptr(),
545 2,
546 );
547 }
548
549 assert_eq!(c[0].0, 9);
550 }
551
552 #[test]
553 fn test_dispatch_larger_matrix() {
554 let m = 16;
556 let n = 16;
557 let k = 16;
558
559 let a: Vec<f32> = (0..m * k).map(|i| (i % 10) as f32).collect();
560 let b: Vec<f32> = (0..k * n).map(|i| (i % 10) as f32).collect();
561 let mut c = vec![TropicalMaxPlus::tropical_zero(); m * n];
562
563 unsafe {
564 tropical_gemm_dispatch::<TropicalMaxPlus<f32>>(
565 m,
566 n,
567 k,
568 a.as_ptr(),
569 k,
570 Transpose::NoTrans,
571 b.as_ptr(),
572 n,
573 Transpose::NoTrans,
574 c.as_mut_ptr(),
575 n,
576 );
577 }
578
579 let has_non_zero = c.iter().any(|x| x.0 > f32::NEG_INFINITY);
581 assert!(has_non_zero);
582 }
583}