1use crate::core::Microkernel;
2use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus};
3use wide::{f32x8, f64x4};
4
5#[derive(Default, Clone, Copy)]
10pub struct Avx2MaxPlusF32Kernel;
11
12impl Microkernel<TropicalMaxPlus<f32>> for Avx2MaxPlusF32Kernel {
13 const MR: usize = 8;
14 const NR: usize = 8;
15
16 #[target_feature(enable = "avx2")]
17 unsafe fn execute(
18 &self,
19 mr: usize,
20 nr: usize,
21 k: usize,
22 a: *const f32,
23 b: *const f32,
24 c: *mut TropicalMaxPlus<f32>,
25 ldc: usize,
26 ) {
27 let neg_inf = f32x8::splat(f32::NEG_INFINITY);
29 let mut acc = [neg_inf; 8];
30
31 for i in 0..mr {
33 let mut row_acc = [f32::NEG_INFINITY; 8];
34 for j in 0..nr {
35 row_acc[j] = (*c.add(i * ldc + j)).0;
36 }
37 acc[i] = f32x8::from(row_acc);
38 }
39
40 for p in 0..k {
42 let mut a_vals = [0.0f32; 8];
44 for i in 0..mr {
45 a_vals[i] = *a.add(p * Self::MR + i);
46 }
47
48 let mut b_vals = [0.0f32; 8];
50 for j in 0..nr {
51 b_vals[j] = *b.add(p * Self::NR + j);
52 }
53 let b_vec = f32x8::from(b_vals);
54
55 for i in 0..mr {
57 let a_broadcast = f32x8::splat(a_vals[i]);
59 let product = a_broadcast + b_vec;
60
61 acc[i] = acc[i].max(product);
63 }
64 }
65
66 for i in 0..mr {
68 let row: [f32; 8] = acc[i].into();
69 for j in 0..nr {
70 *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
71 }
72 }
73 }
74}
75
76#[derive(Default, Clone, Copy)]
78pub struct Avx2MaxPlusF64Kernel;
79
80impl Microkernel<TropicalMaxPlus<f64>> for Avx2MaxPlusF64Kernel {
81 const MR: usize = 4;
82 const NR: usize = 4;
83
84 #[target_feature(enable = "avx2")]
85 unsafe fn execute(
86 &self,
87 mr: usize,
88 nr: usize,
89 k: usize,
90 a: *const f64,
91 b: *const f64,
92 c: *mut TropicalMaxPlus<f64>,
93 ldc: usize,
94 ) {
95 let neg_inf = f64x4::splat(f64::NEG_INFINITY);
96 let mut acc = [neg_inf; 4];
97
98 for i in 0..mr {
100 let mut row_acc = [f64::NEG_INFINITY; 4];
101 for j in 0..nr {
102 row_acc[j] = (*c.add(i * ldc + j)).0;
103 }
104 acc[i] = f64x4::from(row_acc);
105 }
106
107 for p in 0..k {
109 let mut a_vals = [0.0f64; 4];
110 for i in 0..mr {
111 a_vals[i] = *a.add(p * Self::MR + i);
112 }
113
114 let mut b_vals = [0.0f64; 4];
115 for j in 0..nr {
116 b_vals[j] = *b.add(p * Self::NR + j);
117 }
118 let b_vec = f64x4::from(b_vals);
119
120 for i in 0..mr {
121 let a_broadcast = f64x4::splat(a_vals[i]);
122 let product = a_broadcast + b_vec;
123 acc[i] = acc[i].max(product);
124 }
125 }
126
127 for i in 0..mr {
129 let row: [f64; 4] = acc[i].into();
130 for j in 0..nr {
131 *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
132 }
133 }
134 }
135}
136
137#[derive(Default, Clone, Copy)]
139pub struct Avx2MinPlusF32Kernel;
140
141impl Microkernel<TropicalMinPlus<f32>> for Avx2MinPlusF32Kernel {
142 const MR: usize = 8;
143 const NR: usize = 8;
144
145 #[target_feature(enable = "avx2")]
146 unsafe fn execute(
147 &self,
148 mr: usize,
149 nr: usize,
150 k: usize,
151 a: *const f32,
152 b: *const f32,
153 c: *mut TropicalMinPlus<f32>,
154 ldc: usize,
155 ) {
156 let pos_inf = f32x8::splat(f32::INFINITY);
157 let mut acc = [pos_inf; 8];
158
159 for i in 0..mr {
161 let mut row_acc = [f32::INFINITY; 8];
162 for j in 0..nr {
163 row_acc[j] = (*c.add(i * ldc + j)).0;
164 }
165 acc[i] = f32x8::from(row_acc);
166 }
167
168 for p in 0..k {
170 let mut a_vals = [0.0f32; 8];
171 for i in 0..mr {
172 a_vals[i] = *a.add(p * Self::MR + i);
173 }
174
175 let mut b_vals = [0.0f32; 8];
176 for j in 0..nr {
177 b_vals[j] = *b.add(p * Self::NR + j);
178 }
179 let b_vec = f32x8::from(b_vals);
180
181 for i in 0..mr {
182 let a_broadcast = f32x8::splat(a_vals[i]);
183 let product = a_broadcast + b_vec;
184 acc[i] = acc[i].min(product);
186 }
187 }
188
189 for i in 0..mr {
191 let row: [f32; 8] = acc[i].into();
192 for j in 0..nr {
193 *c.add(i * ldc + j) = TropicalMinPlus(row[j]);
194 }
195 }
196 }
197}
198
199#[derive(Default, Clone, Copy)]
201pub struct Avx2MaxMulF32Kernel;
202
203impl Microkernel<TropicalMaxMul<f32>> for Avx2MaxMulF32Kernel {
204 const MR: usize = 8;
205 const NR: usize = 8;
206
207 #[target_feature(enable = "avx2")]
208 unsafe fn execute(
209 &self,
210 mr: usize,
211 nr: usize,
212 k: usize,
213 a: *const f32,
214 b: *const f32,
215 c: *mut TropicalMaxMul<f32>,
216 ldc: usize,
217 ) {
218 let zero = f32x8::splat(0.0);
219 let mut acc = [zero; 8];
220
221 for i in 0..mr {
223 let mut row_acc = [0.0f32; 8];
224 for j in 0..nr {
225 row_acc[j] = (*c.add(i * ldc + j)).0;
226 }
227 acc[i] = f32x8::from(row_acc);
228 }
229
230 for p in 0..k {
232 let mut a_vals = [0.0f32; 8];
233 for i in 0..mr {
234 a_vals[i] = *a.add(p * Self::MR + i);
235 }
236
237 let mut b_vals = [0.0f32; 8];
238 for j in 0..nr {
239 b_vals[j] = *b.add(p * Self::NR + j);
240 }
241 let b_vec = f32x8::from(b_vals);
242
243 for i in 0..mr {
244 let a_broadcast = f32x8::splat(a_vals[i]);
245 let product = a_broadcast * b_vec;
247 acc[i] = acc[i].max(product);
249 }
250 }
251
252 for i in 0..mr {
254 let row: [f32; 8] = acc[i].into();
255 for j in 0..nr {
256 *c.add(i * ldc + j) = TropicalMaxMul(row[j]);
257 }
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::types::TropicalSemiring;
266
267 #[test]
268 #[cfg(target_arch = "x86_64")]
269 fn test_avx2_max_plus_f32() {
270 if !is_x86_feature_detected!("avx2") {
271 println!("AVX2 not available, skipping test");
272 return;
273 }
274
275 let kernel = Avx2MaxPlusF32Kernel;
276 let mr = 2;
277 let nr = 2;
278 let k = 3;
279
280 let a: [f32; 24] = [
282 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
286
287 let b: [f32; 24] = [
289 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
293
294 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
295 let ldc = 2;
296
297 unsafe {
298 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
299 }
300
301 assert!((c[0].0 - 8.0).abs() < 1e-6);
303 assert!((c[1].0 - 9.0).abs() < 1e-6);
305 assert!((c[2].0 - 11.0).abs() < 1e-6);
307 assert!((c[3].0 - 12.0).abs() < 1e-6);
309 }
310
311 #[test]
312 #[cfg(target_arch = "x86_64")]
313 fn test_avx2_min_plus_f32() {
314 if !is_x86_feature_detected!("avx2") {
315 println!("AVX2 not available, skipping test");
316 return;
317 }
318
319 let kernel = Avx2MinPlusF32Kernel;
320 let mr = 2;
321 let nr = 2;
322 let k = 3;
323
324 let a: [f32; 24] = [
326 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0,
327 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
328 ];
329
330 let b: [f32; 24] = [
332 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0,
333 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
334 ];
335
336 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
337 let ldc = 2;
338
339 unsafe {
340 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
341 }
342
343 assert!((c[0].0 - 2.0).abs() < 1e-6);
345 assert!((c[1].0 - 3.0).abs() < 1e-6);
347 assert!((c[2].0 - 5.0).abs() < 1e-6);
349 assert!((c[3].0 - 6.0).abs() < 1e-6);
351 }
352
353 #[test]
354 #[cfg(target_arch = "x86_64")]
355 fn test_avx2_max_mul_f32() {
356 if !is_x86_feature_detected!("avx2") {
357 println!("AVX2 not available, skipping test");
358 return;
359 }
360
361 let kernel = Avx2MaxMulF32Kernel;
362 let mr = 2;
363 let nr = 2;
364 let k = 2;
365
366 let a: [f32; 16] = [
368 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
369 ];
370
371 let b: [f32; 16] = [
373 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
374 ];
375
376 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
377 let ldc = 2;
378
379 unsafe {
380 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
381 }
382
383 assert!((c[0].0 - 12.0).abs() < 1e-6);
385 assert!((c[1].0 - 16.0).abs() < 1e-6);
387 assert!((c[2].0 - 15.0).abs() < 1e-6);
389 assert!((c[3].0 - 20.0).abs() < 1e-6);
391 }
392
393 #[test]
394 #[cfg(target_arch = "x86_64")]
395 fn test_avx2_max_plus_f64() {
396 if !is_x86_feature_detected!("avx2") {
397 println!("AVX2 not available, skipping test");
398 return;
399 }
400
401 let kernel = Avx2MaxPlusF64Kernel;
402 let mr = 2;
403 let nr = 2;
404 let k = 2;
405
406 let a: [f64; 8] = [
408 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, ];
411
412 let b: [f64; 8] = [
414 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, ];
417
418 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
419 let ldc = 2;
420
421 unsafe {
422 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
423 }
424
425 assert!((c[0].0 - 6.0).abs() < 1e-10);
427 assert!((c[1].0 - 7.0).abs() < 1e-10);
429 assert!((c[2].0 - 7.0).abs() < 1e-10);
431 assert!((c[3].0 - 8.0).abs() < 1e-10);
433 }
434}