1use crate::core::Microkernel;
2use crate::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus};
3use wide::{f32x8, f64x4};
4
5#[derive(Default, Clone, Copy)]
10pub struct Avx2MaxPlusF32Kernel;
11
12impl Microkernel<TropicalMaxPlus<f32>> for Avx2MaxPlusF32Kernel {
13 const MR: usize = 8;
14 const NR: usize = 8;
15
16 #[target_feature(enable = "avx2")]
17 unsafe fn execute(
18 &self,
19 mr: usize,
20 nr: usize,
21 k: usize,
22 a: *const f32,
23 b: *const f32,
24 c: *mut TropicalMaxPlus<f32>,
25 ldc: usize,
26 ) {
27 let neg_inf = f32x8::splat(f32::NEG_INFINITY);
29 let mut acc = [neg_inf; 8];
30
31 for i in 0..mr {
33 let mut row_acc = [f32::NEG_INFINITY; 8];
34 for j in 0..nr {
35 row_acc[j] = (*c.add(i * ldc + j)).0;
36 }
37 acc[i] = f32x8::from(row_acc);
38 }
39
40 for p in 0..k {
42 let mut a_vals = [0.0f32; 8];
44 for i in 0..mr {
45 a_vals[i] = *a.add(p * Self::MR + i);
46 }
47
48 let mut b_vals = [0.0f32; 8];
50 for j in 0..nr {
51 b_vals[j] = *b.add(p * Self::NR + j);
52 }
53 let b_vec = f32x8::from(b_vals);
54
55 for i in 0..mr {
57 let a_broadcast = f32x8::splat(a_vals[i]);
59 let product = a_broadcast + b_vec;
60
61 acc[i] = acc[i].fast_max(product);
67 }
68 }
69
70 for i in 0..mr {
72 let row: [f32; 8] = acc[i].into();
73 for j in 0..nr {
74 *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
75 }
76 }
77 }
78}
79
80#[derive(Default, Clone, Copy)]
82pub struct Avx2MaxPlusF64Kernel;
83
84impl Microkernel<TropicalMaxPlus<f64>> for Avx2MaxPlusF64Kernel {
85 const MR: usize = 4;
86 const NR: usize = 4;
87
88 #[target_feature(enable = "avx2")]
89 unsafe fn execute(
90 &self,
91 mr: usize,
92 nr: usize,
93 k: usize,
94 a: *const f64,
95 b: *const f64,
96 c: *mut TropicalMaxPlus<f64>,
97 ldc: usize,
98 ) {
99 let neg_inf = f64x4::splat(f64::NEG_INFINITY);
100 let mut acc = [neg_inf; 4];
101
102 for i in 0..mr {
104 let mut row_acc = [f64::NEG_INFINITY; 4];
105 for j in 0..nr {
106 row_acc[j] = (*c.add(i * ldc + j)).0;
107 }
108 acc[i] = f64x4::from(row_acc);
109 }
110
111 for p in 0..k {
113 let mut a_vals = [0.0f64; 4];
114 for i in 0..mr {
115 a_vals[i] = *a.add(p * Self::MR + i);
116 }
117
118 let mut b_vals = [0.0f64; 4];
119 for j in 0..nr {
120 b_vals[j] = *b.add(p * Self::NR + j);
121 }
122 let b_vec = f64x4::from(b_vals);
123
124 for i in 0..mr {
125 let a_broadcast = f64x4::splat(a_vals[i]);
126 let product = a_broadcast + b_vec;
127 acc[i] = acc[i].fast_max(product);
129 }
130 }
131
132 for i in 0..mr {
134 let row: [f64; 4] = acc[i].into();
135 for j in 0..nr {
136 *c.add(i * ldc + j) = TropicalMaxPlus(row[j]);
137 }
138 }
139 }
140}
141
142#[derive(Default, Clone, Copy)]
144pub struct Avx2MinPlusF32Kernel;
145
146impl Microkernel<TropicalMinPlus<f32>> for Avx2MinPlusF32Kernel {
147 const MR: usize = 8;
148 const NR: usize = 8;
149
150 #[target_feature(enable = "avx2")]
151 unsafe fn execute(
152 &self,
153 mr: usize,
154 nr: usize,
155 k: usize,
156 a: *const f32,
157 b: *const f32,
158 c: *mut TropicalMinPlus<f32>,
159 ldc: usize,
160 ) {
161 let pos_inf = f32x8::splat(f32::INFINITY);
162 let mut acc = [pos_inf; 8];
163
164 for i in 0..mr {
166 let mut row_acc = [f32::INFINITY; 8];
167 for j in 0..nr {
168 row_acc[j] = (*c.add(i * ldc + j)).0;
169 }
170 acc[i] = f32x8::from(row_acc);
171 }
172
173 for p in 0..k {
175 let mut a_vals = [0.0f32; 8];
176 for i in 0..mr {
177 a_vals[i] = *a.add(p * Self::MR + i);
178 }
179
180 let mut b_vals = [0.0f32; 8];
181 for j in 0..nr {
182 b_vals[j] = *b.add(p * Self::NR + j);
183 }
184 let b_vec = f32x8::from(b_vals);
185
186 for i in 0..mr {
187 let a_broadcast = f32x8::splat(a_vals[i]);
188 let product = a_broadcast + b_vec;
189 acc[i] = acc[i].fast_min(product);
193 }
194 }
195
196 for i in 0..mr {
198 let row: [f32; 8] = acc[i].into();
199 for j in 0..nr {
200 *c.add(i * ldc + j) = TropicalMinPlus(row[j]);
201 }
202 }
203 }
204}
205
206#[derive(Default, Clone, Copy)]
208pub struct Avx2MaxMulF32Kernel;
209
210impl Microkernel<TropicalMaxMul<f32>> for Avx2MaxMulF32Kernel {
211 const MR: usize = 8;
212 const NR: usize = 8;
213
214 #[target_feature(enable = "avx2")]
215 unsafe fn execute(
216 &self,
217 mr: usize,
218 nr: usize,
219 k: usize,
220 a: *const f32,
221 b: *const f32,
222 c: *mut TropicalMaxMul<f32>,
223 ldc: usize,
224 ) {
225 let zero = f32x8::splat(0.0);
226 let mut acc = [zero; 8];
227
228 for i in 0..mr {
230 let mut row_acc = [0.0f32; 8];
231 for j in 0..nr {
232 row_acc[j] = (*c.add(i * ldc + j)).0;
233 }
234 acc[i] = f32x8::from(row_acc);
235 }
236
237 for p in 0..k {
239 let mut a_vals = [0.0f32; 8];
240 for i in 0..mr {
241 a_vals[i] = *a.add(p * Self::MR + i);
242 }
243
244 let mut b_vals = [0.0f32; 8];
245 for j in 0..nr {
246 b_vals[j] = *b.add(p * Self::NR + j);
247 }
248 let b_vec = f32x8::from(b_vals);
249
250 for i in 0..mr {
251 let a_broadcast = f32x8::splat(a_vals[i]);
252 let product = a_broadcast * b_vec;
254 acc[i] = acc[i].fast_max(product);
257 }
258 }
259
260 for i in 0..mr {
262 let row: [f32; 8] = acc[i].into();
263 for j in 0..nr {
264 *c.add(i * ldc + j) = TropicalMaxMul(row[j]);
265 }
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::types::TropicalSemiring;
274
275 #[test]
276 #[cfg(target_arch = "x86_64")]
277 fn test_avx2_max_plus_f32() {
278 if !is_x86_feature_detected!("avx2") {
279 println!("AVX2 not available, skipping test");
280 return;
281 }
282
283 let kernel = Avx2MaxPlusF32Kernel;
284 let mr = 2;
285 let nr = 2;
286 let k = 3;
287
288 let a: [f32; 24] = [
290 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
294
295 let b: [f32; 24] = [
297 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
301
302 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
303 let ldc = 2;
304
305 unsafe {
306 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
307 }
308
309 assert!((c[0].0 - 8.0).abs() < 1e-6);
311 assert!((c[1].0 - 9.0).abs() < 1e-6);
313 assert!((c[2].0 - 11.0).abs() < 1e-6);
315 assert!((c[3].0 - 12.0).abs() < 1e-6);
317 }
318
319 #[test]
320 #[cfg(target_arch = "x86_64")]
321 fn test_avx2_min_plus_f32() {
322 if !is_x86_feature_detected!("avx2") {
323 println!("AVX2 not available, skipping test");
324 return;
325 }
326
327 let kernel = Avx2MinPlusF32Kernel;
328 let mr = 2;
329 let nr = 2;
330 let k = 3;
331
332 let a: [f32; 24] = [
334 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0,
335 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
336 ];
337
338 let b: [f32; 24] = [
340 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0,
341 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
342 ];
343
344 let mut c = vec![TropicalMinPlus::tropical_zero(); 4];
345 let ldc = 2;
346
347 unsafe {
348 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
349 }
350
351 assert!((c[0].0 - 2.0).abs() < 1e-6);
353 assert!((c[1].0 - 3.0).abs() < 1e-6);
355 assert!((c[2].0 - 5.0).abs() < 1e-6);
357 assert!((c[3].0 - 6.0).abs() < 1e-6);
359 }
360
361 #[test]
362 #[cfg(target_arch = "x86_64")]
363 fn test_avx2_max_mul_f32() {
364 if !is_x86_feature_detected!("avx2") {
365 println!("AVX2 not available, skipping test");
366 return;
367 }
368
369 let kernel = Avx2MaxMulF32Kernel;
370 let mr = 2;
371 let nr = 2;
372 let k = 2;
373
374 let a: [f32; 16] = [
376 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
377 ];
378
379 let b: [f32; 16] = [
381 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
382 ];
383
384 let mut c = vec![TropicalMaxMul::tropical_zero(); 4];
385 let ldc = 2;
386
387 unsafe {
388 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
389 }
390
391 assert!((c[0].0 - 12.0).abs() < 1e-6);
393 assert!((c[1].0 - 16.0).abs() < 1e-6);
395 assert!((c[2].0 - 15.0).abs() < 1e-6);
397 assert!((c[3].0 - 20.0).abs() < 1e-6);
399 }
400
401 #[test]
402 #[cfg(target_arch = "x86_64")]
403 fn test_avx2_max_plus_f64() {
404 if !is_x86_feature_detected!("avx2") {
405 println!("AVX2 not available, skipping test");
406 return;
407 }
408
409 let kernel = Avx2MaxPlusF64Kernel;
410 let mr = 2;
411 let nr = 2;
412 let k = 2;
413
414 let a: [f64; 8] = [
416 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, ];
419
420 let b: [f64; 8] = [
422 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, ];
425
426 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
427 let ldc = 2;
428
429 unsafe {
430 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
431 }
432
433 assert!((c[0].0 - 6.0).abs() < 1e-10);
435 assert!((c[1].0 - 7.0).abs() < 1e-10);
437 assert!((c[2].0 - 7.0).abs() < 1e-10);
439 assert!((c[3].0 - 8.0).abs() < 1e-10);
441 }
442}