tropical_gemm/simd/kernels/
avx2.rs

1use crate::core::Microkernel;
2use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus};
3use wide::{f32x8, f64x4};
4
5/// AVX2 microkernel for TropicalMaxPlus<f32>.
6///
7/// Uses 8x8 register blocking with f32x8 vectors.
8/// Total: 8 accumulators × 8 lanes = 64 elements in registers.
9#[derive(Default, Clone, Copy)]
10pub struct Avx2MaxPlusF32Kernel;
11
12impl Microkernel<TropicalMaxPlus<f32>> for Avx2MaxPlusF32Kernel {
13    const MR: usize = 8;
14    const NR: usize = 8;
15
16    #[target_feature(enable = "avx2")]
17    unsafe fn execute(
18        &self,
19        mr: usize,
20        nr: usize,
21        k: usize,
22        a: *const f32,
23        b: *const f32,
24        c: *mut TropicalMaxPlus<f32>,
25        ldc: usize,
26    ) {
27        // Initialize accumulators with -inf
28        let neg_inf = f32x8::splat(f32::NEG_INFINITY);
29        let mut acc = [neg_inf; 8];
30
31        // Load existing C values into accumulators
32        for i in 0..mr {
33            let mut row_acc = [f32::NEG_INFINITY; 8];
34            for j in 0..nr {
35                row_acc[j] = (*c.add(i * ldc + j)).0;
36            }
37            acc[i] = f32x8::from(row_acc);
38        }
39
40        // Main computation loop
41        for p in 0..k {
42            // Load A column (mr elements, padded to 8)
43            let mut a_vals = [0.0f32; 8];
44            for i in 0..mr {
45                a_vals[i] = *a.add(p * Self::MR + i);
46            }
47
48            // Load B row (nr elements, padded to 8)
49            let mut b_vals = [0.0f32; 8];
50            for j in 0..nr {
51                b_vals[j] = *b.add(p * Self::NR + j);
52            }
53            let b_vec = f32x8::from(b_vals);
54
55            // For each row of A
56            for i in 0..mr {
57                // Tropical mul: a[i] + b[j] for all j
58                let a_broadcast = f32x8::splat(a_vals[i]);
59                let product = a_broadcast + b_vec;
60
61                // Tropical add: max(acc, product)
62                acc[i] = acc[i].max(product);
63            }
64        }
65
66        // Write back results
67        for i in 0..mr {
68            let row: [f32; 8] = acc[i].into();
69            for j in 0..nr {
70                *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
71            }
72        }
73    }
74}
75
76/// AVX2 microkernel for TropicalMaxPlus<f64>.
77#[derive(Default, Clone, Copy)]
78pub struct Avx2MaxPlusF64Kernel;
79
80impl Microkernel<TropicalMaxPlus<f64>> for Avx2MaxPlusF64Kernel {
81    const MR: usize = 4;
82    const NR: usize = 4;
83
84    #[target_feature(enable = "avx2")]
85    unsafe fn execute(
86        &self,
87        mr: usize,
88        nr: usize,
89        k: usize,
90        a: *const f64,
91        b: *const f64,
92        c: *mut TropicalMaxPlus<f64>,
93        ldc: usize,
94    ) {
95        let neg_inf = f64x4::splat(f64::NEG_INFINITY);
96        let mut acc = [neg_inf; 4];
97
98        // Load existing C
99        for i in 0..mr {
100            let mut row_acc = [f64::NEG_INFINITY; 4];
101            for j in 0..nr {
102                row_acc[j] = (*c.add(i * ldc + j)).0;
103            }
104            acc[i] = f64x4::from(row_acc);
105        }
106
107        // Main loop
108        for p in 0..k {
109            let mut a_vals = [0.0f64; 4];
110            for i in 0..mr {
111                a_vals[i] = *a.add(p * Self::MR + i);
112            }
113
114            let mut b_vals = [0.0f64; 4];
115            for j in 0..nr {
116                b_vals[j] = *b.add(p * Self::NR + j);
117            }
118            let b_vec = f64x4::from(b_vals);
119
120            for i in 0..mr {
121                let a_broadcast = f64x4::splat(a_vals[i]);
122                let product = a_broadcast + b_vec;
123                acc[i] = acc[i].max(product);
124            }
125        }
126
127        // Write back
128        for i in 0..mr {
129            let row: [f64; 4] = acc[i].into();
130            for j in 0..nr {
131                *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
132            }
133        }
134    }
135}
136
137/// AVX2 microkernel for TropicalMinPlus<f32>.
138#[derive(Default, Clone, Copy)]
139pub struct Avx2MinPlusF32Kernel;
140
141impl Microkernel<TropicalMinPlus<f32>> for Avx2MinPlusF32Kernel {
142    const MR: usize = 8;
143    const NR: usize = 8;
144
145    #[target_feature(enable = "avx2")]
146    unsafe fn execute(
147        &self,
148        mr: usize,
149        nr: usize,
150        k: usize,
151        a: *const f32,
152        b: *const f32,
153        c: *mut TropicalMinPlus<f32>,
154        ldc: usize,
155    ) {
156        let pos_inf = f32x8::splat(f32::INFINITY);
157        let mut acc = [pos_inf; 8];
158
159        // Load existing C
160        for i in 0..mr {
161            let mut row_acc = [f32::INFINITY; 8];
162            for j in 0..nr {
163                row_acc[j] = (*c.add(i * ldc + j)).0;
164            }
165            acc[i] = f32x8::from(row_acc);
166        }
167
168        // Main loop
169        for p in 0..k {
170            let mut a_vals = [0.0f32; 8];
171            for i in 0..mr {
172                a_vals[i] = *a.add(p * Self::MR + i);
173            }
174
175            let mut b_vals = [0.0f32; 8];
176            for j in 0..nr {
177                b_vals[j] = *b.add(p * Self::NR + j);
178            }
179            let b_vec = f32x8::from(b_vals);
180
181            for i in 0..mr {
182                let a_broadcast = f32x8::splat(a_vals[i]);
183                let product = a_broadcast + b_vec;
184                // MinPlus: tropical add = min
185                acc[i] = acc[i].min(product);
186            }
187        }
188
189        // Write back
190        for i in 0..mr {
191            let row: [f32; 8] = acc[i].into();
192            for j in 0..nr {
193                *c.add(i * ldc + j) = TropicalMinPlus(row[j]);
194            }
195        }
196    }
197}
198
199/// AVX2 microkernel for TropicalMaxMul<f32>.
200#[derive(Default, Clone, Copy)]
201pub struct Avx2MaxMulF32Kernel;
202
203impl Microkernel<TropicalMaxMul<f32>> for Avx2MaxMulF32Kernel {
204    const MR: usize = 8;
205    const NR: usize = 8;
206
207    #[target_feature(enable = "avx2")]
208    unsafe fn execute(
209        &self,
210        mr: usize,
211        nr: usize,
212        k: usize,
213        a: *const f32,
214        b: *const f32,
215        c: *mut TropicalMaxMul<f32>,
216        ldc: usize,
217    ) {
218        let zero = f32x8::splat(0.0);
219        let mut acc = [zero; 8];
220
221        // Load existing C
222        for i in 0..mr {
223            let mut row_acc = [0.0f32; 8];
224            for j in 0..nr {
225                row_acc[j] = (*c.add(i * ldc + j)).0;
226            }
227            acc[i] = f32x8::from(row_acc);
228        }
229
230        // Main loop
231        for p in 0..k {
232            let mut a_vals = [0.0f32; 8];
233            for i in 0..mr {
234                a_vals[i] = *a.add(p * Self::MR + i);
235            }
236
237            let mut b_vals = [0.0f32; 8];
238            for j in 0..nr {
239                b_vals[j] = *b.add(p * Self::NR + j);
240            }
241            let b_vec = f32x8::from(b_vals);
242
243            for i in 0..mr {
244                let a_broadcast = f32x8::splat(a_vals[i]);
245                // MaxMul: tropical mul = standard mul
246                let product = a_broadcast * b_vec;
247                // tropical add = max
248                acc[i] = acc[i].max(product);
249            }
250        }
251
252        // Write back
253        for i in 0..mr {
254            let row: [f32; 8] = acc[i].into();
255            for j in 0..nr {
256                *c.add(i * ldc + j) = TropicalMaxMul(row[j]);
257            }
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::types::TropicalSemiring;
266
267    #[test]
268    #[cfg(target_arch = "x86_64")]
269    fn test_avx2_max_plus_f32() {
270        if !is_x86_feature_detected!("avx2") {
271            println!("AVX2 not available, skipping test");
272            return;
273        }
274
275        let kernel = Avx2MaxPlusF32Kernel;
276        let mr = 2;
277        let nr = 2;
278        let k = 3;
279
280        // A: 2x3 packed
281        let a: [f32; 24] = [
282            1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 0
283            2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 1
284            3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // col 2
285        ];
286
287        // B: 3x2 packed
288        let b: [f32; 24] = [
289            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 0
290            3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 1
291            5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // row 2
292        ];
293
294        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
295        let ldc = 2;
296
297        unsafe {
298            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
299        }
300
301        // C[0,0] = max(1+1, 2+3, 3+5) = 8
302        assert!((c[0].0 - 8.0).abs() < 1e-6);
303        // C[0,1] = max(1+2, 2+4, 3+6) = 9
304        assert!((c[1].0 - 9.0).abs() < 1e-6);
305        // C[1,0] = max(4+1, 5+3, 6+5) = 11
306        assert!((c[2].0 - 11.0).abs() < 1e-6);
307        // C[1,1] = max(4+2, 5+4, 6+6) = 12
308        assert!((c[3].0 - 12.0).abs() < 1e-6);
309    }
310
311    #[test]
312    #[cfg(target_arch = "x86_64")]
313    fn test_avx2_min_plus_f32() {
314        if !is_x86_feature_detected!("avx2") {
315            println!("AVX2 not available, skipping test");
316            return;
317        }
318
319        let kernel = Avx2MinPlusF32Kernel;
320        let mr = 2;
321        let nr = 2;
322        let k = 3;
323
324        // A: 2x3 packed
325        let a: [f32; 24] = [
326            1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0,
327            6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
328        ];
329
330        // B: 3x2 packed
331        let b: [f32; 24] = [
332            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0,
333            6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
334        ];
335
336        let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
337        let ldc = 2;
338
339        unsafe {
340            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
341        }
342
343        // C[0,0] = min(1+1, 2+3, 3+5) = 2
344        assert!((c[0].0 - 2.0).abs() < 1e-6);
345        // C[0,1] = min(1+2, 2+4, 3+6) = 3
346        assert!((c[1].0 - 3.0).abs() < 1e-6);
347        // C[1,0] = min(4+1, 5+3, 6+5) = 5
348        assert!((c[2].0 - 5.0).abs() < 1e-6);
349        // C[1,1] = min(4+2, 5+4, 6+6) = 6
350        assert!((c[3].0 - 6.0).abs() < 1e-6);
351    }
352
353    #[test]
354    #[cfg(target_arch = "x86_64")]
355    fn test_avx2_max_mul_f32() {
356        if !is_x86_feature_detected!("avx2") {
357            println!("AVX2 not available, skipping test");
358            return;
359        }
360
361        let kernel = Avx2MaxMulF32Kernel;
362        let mr = 2;
363        let nr = 2;
364        let k = 2;
365
366        // A: 2x2 packed
367        let a: [f32; 16] = [
368            2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
369        ];
370
371        // B: 2x2 packed
372        let b: [f32; 16] = [
373            1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
374        ];
375
376        let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
377        let ldc = 2;
378
379        unsafe {
380            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
381        }
382
383        // C[0,0] = max(2*1, 4*3) = max(2, 12) = 12
384        assert!((c[0].0 - 12.0).abs() < 1e-6);
385        // C[0,1] = max(2*2, 4*4) = max(4, 16) = 16
386        assert!((c[1].0 - 16.0).abs() < 1e-6);
387        // C[1,0] = max(3*1, 5*3) = max(3, 15) = 15
388        assert!((c[2].0 - 15.0).abs() < 1e-6);
389        // C[1,1] = max(3*2, 5*4) = max(6, 20) = 20
390        assert!((c[3].0 - 20.0).abs() < 1e-6);
391    }
392
393    #[test]
394    #[cfg(target_arch = "x86_64")]
395    fn test_avx2_max_plus_f64() {
396        if !is_x86_feature_detected!("avx2") {
397            println!("AVX2 not available, skipping test");
398            return;
399        }
400
401        let kernel = Avx2MaxPlusF64Kernel;
402        let mr = 2;
403        let nr = 2;
404        let k = 2;
405
406        // A: 2x2 packed (4 f64 per column for mr=4 padding)
407        let a: [f64; 8] = [
408            1.0, 2.0, 0.0, 0.0, // col 0
409            3.0, 4.0, 0.0, 0.0, // col 1
410        ];
411
412        // B: 2x2 packed (4 f64 per row for nr=4 padding)
413        let b: [f64; 8] = [
414            1.0, 2.0, 0.0, 0.0, // row 0
415            3.0, 4.0, 0.0, 0.0, // row 1
416        ];
417
418        let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
419        let ldc = 2;
420
421        unsafe {
422            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
423        }
424
425        // C[0,0] = max(1+1, 3+3) = 6
426        assert!((c[0].0 - 6.0).abs() < 1e-10);
427        // C[0,1] = max(1+2, 3+4) = 7
428        assert!((c[1].0 - 7.0).abs() < 1e-10);
429        // C[1,0] = max(2+1, 4+3) = 7
430        assert!((c[2].0 - 7.0).abs() < 1e-10);
431        // C[1,1] = max(2+2, 4+4) = 8
432        assert!((c[3].0 - 8.0).abs() < 1e-10);
433    }
434}