Skip to main content

tropical_gemm/simd/kernels/
avx2.rs

1use crate::core::Microkernel;
2use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus};
3use wide::{f32x8, f64x4};
4
5/// AVX2 microkernel for TropicalMaxPlus<f32>.
6///
7/// Uses 8x8 register blocking with f32x8 vectors.
8/// Total: 8 accumulators × 8 lanes = 64 elements in registers.
9#[derive(Default, Clone, Copy)]
10pub struct Avx2MaxPlusF32Kernel;
11
12impl Microkernel<TropicalMaxPlus<f32>> for Avx2MaxPlusF32Kernel {
13    const MR: usize = 8;
14    const NR: usize = 8;
15
16    #[target_feature(enable = "avx2")]
17    unsafe fn execute(
18        &self,
19        mr: usize,
20        nr: usize,
21        k: usize,
22        a: *const f32,
23        b: *const f32,
24        c: *mut TropicalMaxPlus<f32>,
25        ldc: usize,
26    ) {
27        // Initialize accumulators with -inf
28        let neg_inf = f32x8::splat(f32::NEG_INFINITY);
29        let mut acc = [neg_inf; 8];
30
31        // Load existing C values into accumulators
32        for i in 0..mr {
33            let mut row_acc = [f32::NEG_INFINITY; 8];
34            for j in 0..nr {
35                row_acc[j] = (*c.add(i * ldc + j)).0;
36            }
37            acc[i] = f32x8::from(row_acc);
38        }
39
40        // Main computation loop
41        for p in 0..k {
42            // Load A column (mr elements, padded to 8)
43            let mut a_vals = [0.0f32; 8];
44            for i in 0..mr {
45                a_vals[i] = *a.add(p * Self::MR + i);
46            }
47
48            // Load B row (nr elements, padded to 8)
49            let mut b_vals = [0.0f32; 8];
50            for j in 0..nr {
51                b_vals[j] = *b.add(p * Self::NR + j);
52            }
53            let b_vec = f32x8::from(b_vals);
54
55            // For each row of A
56            for i in 0..mr {
57                // Tropical mul: a[i] + b[j] for all j
58                let a_broadcast = f32x8::splat(a_vals[i]);
59                let product = a_broadcast + b_vec;
60
61                // Tropical add: max(acc, product).
62                // fast_max == hardware maxps (1 instr). Plain max() adds an
63                // is_nan compare + blend (3 instr) we don't need: tropical
64                // MaxPlus inputs are finite or -inf, so +inf+-inf (the only
65                // NaN source) never occurs.
66                acc[i] = acc[i].fast_max(product);
67            }
68        }
69
70        // Write back results
71        for i in 0..mr {
72            let row: [f32; 8] = acc[i].into();
73            for j in 0..nr {
74                *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
75            }
76        }
77    }
78}
79
80/// AVX2 microkernel for TropicalMaxPlus<f64>.
81#[derive(Default, Clone, Copy)]
82pub struct Avx2MaxPlusF64Kernel;
83
84impl Microkernel<TropicalMaxPlus<f64>> for Avx2MaxPlusF64Kernel {
85    const MR: usize = 4;
86    const NR: usize = 4;
87
88    #[target_feature(enable = "avx2")]
89    unsafe fn execute(
90        &self,
91        mr: usize,
92        nr: usize,
93        k: usize,
94        a: *const f64,
95        b: *const f64,
96        c: *mut TropicalMaxPlus<f64>,
97        ldc: usize,
98    ) {
99        let neg_inf = f64x4::splat(f64::NEG_INFINITY);
100        let mut acc = [neg_inf; 4];
101
102        // Load existing C
103        for i in 0..mr {
104            let mut row_acc = [f64::NEG_INFINITY; 4];
105            for j in 0..nr {
106                row_acc[j] = (*c.add(i * ldc + j)).0;
107            }
108            acc[i] = f64x4::from(row_acc);
109        }
110
111        // Main loop
112        for p in 0..k {
113            let mut a_vals = [0.0f64; 4];
114            for i in 0..mr {
115                a_vals[i] = *a.add(p * Self::MR + i);
116            }
117
118            let mut b_vals = [0.0f64; 4];
119            for j in 0..nr {
120                b_vals[j] = *b.add(p * Self::NR + j);
121            }
122            let b_vec = f64x4::from(b_vals);
123
124            for i in 0..mr {
125                let a_broadcast = f64x4::splat(a_vals[i]);
126                let product = a_broadcast + b_vec;
127                // fast_max: no-NaN hardware max (see f32 kernel above).
128                acc[i] = acc[i].fast_max(product);
129            }
130        }
131
132        // Write back
133        for i in 0..mr {
134            let row: [f64; 4] = acc[i].into();
135            for j in 0..nr {
136                *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
137            }
138        }
139    }
140}
141
142/// AVX2 microkernel for TropicalMinPlus<f32>.
143#[derive(Default, Clone, Copy)]
144pub struct Avx2MinPlusF32Kernel;
145
146impl Microkernel<TropicalMinPlus<f32>> for Avx2MinPlusF32Kernel {
147    const MR: usize = 8;
148    const NR: usize = 8;
149
150    #[target_feature(enable = "avx2")]
151    unsafe fn execute(
152        &self,
153        mr: usize,
154        nr: usize,
155        k: usize,
156        a: *const f32,
157        b: *const f32,
158        c: *mut TropicalMinPlus<f32>,
159        ldc: usize,
160    ) {
161        let pos_inf = f32x8::splat(f32::INFINITY);
162        let mut acc = [pos_inf; 8];
163
164        // Load existing C
165        for i in 0..mr {
166            let mut row_acc = [f32::INFINITY; 8];
167            for j in 0..nr {
168                row_acc[j] = (*c.add(i * ldc + j)).0;
169            }
170            acc[i] = f32x8::from(row_acc);
171        }
172
173        // Main loop
174        for p in 0..k {
175            let mut a_vals = [0.0f32; 8];
176            for i in 0..mr {
177                a_vals[i] = *a.add(p * Self::MR + i);
178            }
179
180            let mut b_vals = [0.0f32; 8];
181            for j in 0..nr {
182                b_vals[j] = *b.add(p * Self::NR + j);
183            }
184            let b_vec = f32x8::from(b_vals);
185
186            for i in 0..mr {
187                let a_broadcast = f32x8::splat(a_vals[i]);
188                let product = a_broadcast + b_vec;
189                // MinPlus: tropical add = min.
190                // fast_min == hardware minps (1 instr); MinPlus inputs are
191                // finite or +inf, so the NaN-blend in plain min() is dead work.
192                acc[i] = acc[i].fast_min(product);
193            }
194        }
195
196        // Write back
197        for i in 0..mr {
198            let row: [f32; 8] = acc[i].into();
199            for j in 0..nr {
200                *c.add(i * ldc + j) = TropicalMinPlus(row[j]);
201            }
202        }
203    }
204}
205
206/// AVX2 microkernel for TropicalMaxMul<f32>.
207#[derive(Default, Clone, Copy)]
208pub struct Avx2MaxMulF32Kernel;
209
210impl Microkernel<TropicalMaxMul<f32>> for Avx2MaxMulF32Kernel {
211    const MR: usize = 8;
212    const NR: usize = 8;
213
214    #[target_feature(enable = "avx2")]
215    unsafe fn execute(
216        &self,
217        mr: usize,
218        nr: usize,
219        k: usize,
220        a: *const f32,
221        b: *const f32,
222        c: *mut TropicalMaxMul<f32>,
223        ldc: usize,
224    ) {
225        let zero = f32x8::splat(0.0);
226        let mut acc = [zero; 8];
227
228        // Load existing C
229        for i in 0..mr {
230            let mut row_acc = [0.0f32; 8];
231            for j in 0..nr {
232                row_acc[j] = (*c.add(i * ldc + j)).0;
233            }
234            acc[i] = f32x8::from(row_acc);
235        }
236
237        // Main loop
238        for p in 0..k {
239            let mut a_vals = [0.0f32; 8];
240            for i in 0..mr {
241                a_vals[i] = *a.add(p * Self::MR + i);
242            }
243
244            let mut b_vals = [0.0f32; 8];
245            for j in 0..nr {
246                b_vals[j] = *b.add(p * Self::NR + j);
247            }
248            let b_vec = f32x8::from(b_vals);
249
250            for i in 0..mr {
251                let a_broadcast = f32x8::splat(a_vals[i]);
252                // MaxMul: tropical mul = standard mul
253                let product = a_broadcast * b_vec;
254                // tropical add = max. fast_max: MaxMul inputs are finite (zero
255                // = 0.0), products finite → no NaN, so skip the NaN blend.
256                acc[i] = acc[i].fast_max(product);
257            }
258        }
259
260        // Write back
261        for i in 0..mr {
262            let row: [f32; 8] = acc[i].into();
263            for j in 0..nr {
264                *c.add(i * ldc + j) = TropicalMaxMul(row[j]);
265            }
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::types::TropicalSemiring;
274
275    #[test]
276    #[cfg(target_arch = "x86_64")]
277    fn test_avx2_max_plus_f32() {
278        if !is_x86_feature_detected!("avx2") {
279            println!("AVX2 not available, skipping test");
280            return;
281        }
282
283        let kernel = Avx2MaxPlusF32Kernel;
284        let mr = 2;
285        let nr = 2;
286        let k = 3;
287
288        // A: 2x3 packed
289        let a: [f32; 24] = [
290            1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 0
291            2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 1
292            3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 2
293        ];
294
295        // B: 3x2 packed
296        let b: [f32; 24] = [
297            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 0
298            3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 1
299            5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 2
300        ];
301
302        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
303        let ldc = 2;
304
305        unsafe {
306            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
307        }
308
309        // C[0,0] = max(1+1, 2+3, 3+5) = 8
310        assert!((c[0].0 - 8.0).abs() < 1e-6);
311        // C[0,1] = max(1+2, 2+4, 3+6) = 9
312        assert!((c[1].0 - 9.0).abs() < 1e-6);
313        // C[1,0] = max(4+1, 5+3, 6+5) = 11
314        assert!((c[2].0 - 11.0).abs() < 1e-6);
315        // C[1,1] = max(4+2, 5+4, 6+6) = 12
316        assert!((c[3].0 - 12.0).abs() < 1e-6);
317    }
318
319    #[test]
320    #[cfg(target_arch = "x86_64")]
321    fn test_avx2_min_plus_f32() {
322        if !is_x86_feature_detected!("avx2") {
323            println!("AVX2 not available, skipping test");
324            return;
325        }
326
327        let kernel = Avx2MinPlusF32Kernel;
328        let mr = 2;
329        let nr = 2;
330        let k = 3;
331
332        // A: 2x3 packed
333        let a: [f32; 24] = [
334            1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0,
335            6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
336        ];
337
338        // B: 3x2 packed
339        let b: [f32; 24] = [
340            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0,
341            6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
342        ];
343
344        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
345        let ldc = 2;
346
347        unsafe {
348            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
349        }
350
351        // C[0,0] = min(1+1, 2+3, 3+5) = 2
352        assert!((c[0].0 - 2.0).abs() < 1e-6);
353        // C[0,1] = min(1+2, 2+4, 3+6) = 3
354        assert!((c[1].0 - 3.0).abs() < 1e-6);
355        // C[1,0] = min(4+1, 5+3, 6+5) = 5
356        assert!((c[2].0 - 5.0).abs() < 1e-6);
357        // C[1,1] = min(4+2, 5+4, 6+6) = 6
358        assert!((c[3].0 - 6.0).abs() < 1e-6);
359    }
360
361    #[test]
362    #[cfg(target_arch = "x86_64")]
363    fn test_avx2_max_mul_f32() {
364        if !is_x86_feature_detected!("avx2") {
365            println!("AVX2 not available, skipping test");
366            return;
367        }
368
369        let kernel = Avx2MaxMulF32Kernel;
370        let mr = 2;
371        let nr = 2;
372        let k = 2;
373
374        // A: 2x2 packed
375        let a: [f32; 16] = [
376            2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
377        ];
378
379        // B: 2x2 packed
380        let b: [f32; 16] = [
381            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
382        ];
383
384        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
385        let ldc = 2;
386
387        unsafe {
388            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
389        }
390
391        // C[0,0] = max(2*1, 4*3) = max(2, 12) = 12
392        assert!((c[0].0 - 12.0).abs() < 1e-6);
393        // C[0,1] = max(2*2, 4*4) = max(4, 16) = 16
394        assert!((c[1].0 - 16.0).abs() < 1e-6);
395        // C[1,0] = max(3*1, 5*3) = max(3, 15) = 15
396        assert!((c[2].0 - 15.0).abs() < 1e-6);
397        // C[1,1] = max(3*2, 5*4) = max(6, 20) = 20
398        assert!((c[3].0 - 20.0).abs() < 1e-6);
399    }
400
401    #[test]
402    #[cfg(target_arch = "x86_64")]
403    fn test_avx2_max_plus_f64() {
404        if !is_x86_feature_detected!("avx2") {
405            println!("AVX2 not available, skipping test");
406            return;
407        }
408
409        let kernel = Avx2MaxPlusF64Kernel;
410        let mr = 2;
411        let nr = 2;
412        let k = 2;
413
414        // A: 2x2 packed (4 f64 per column for mr=4 padding)
415        let a: [f64; 8] = [
416            1.0, 2.0, 0.0, 0.0, // col 0
417            3.0, 4.0, 0.0, 0.0, // col 1
418        ];
419
420        // B: 2x2 packed (4 f64 per row for nr=4 padding)
421        let b: [f64; 8] = [
422            1.0, 2.0, 0.0, 0.0, // row 0
423            3.0, 4.0, 0.0, 0.0, // row 1
424        ];
425
426        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
427        let ldc = 2;
428
429        unsafe {
430            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
431        }
432
433        // C[0,0] = max(1+1, 3+3) = 6
434        assert!((c[0].0 - 6.0).abs() < 1e-10);
435        // C[0,1] = max(1+2, 3+4) = 7
436        assert!((c[1].0 - 7.0).abs() < 1e-10);
437        // C[1,0] = max(2+1, 4+3) = 7
438        assert!((c[2].0 - 7.0).abs() < 1e-10);
439        // C[1,1] = max(2+2, 4+4) = 8
440        assert!((c[3].0 - 8.0).abs() < 1e-10);
441    }
442}