1use crate::types::{TropicalSemiring, TropicalWithArgmax};
2
3pub trait Microkernel<T: TropicalSemiring> {
8 const MR: usize;
10
11 const NR: usize;
13
14 unsafe fn execute(
25 &self,
26 mr: usize,
27 nr: usize,
28 k: usize,
29 a: *const T::Scalar,
30 b: *const T::Scalar,
31 c: *mut T,
32 ldc: usize,
33 );
34}
35
36pub trait MicrokernelWithArgmax<T: TropicalWithArgmax<Index = u32>>: Microkernel<T> {
38 unsafe fn execute_with_argmax(
47 &self,
48 mr: usize,
49 nr: usize,
50 k: usize,
51 k_offset: usize,
52 a: *const T::Scalar,
53 b: *const T::Scalar,
54 c: *mut T,
55 argmax: *mut u32,
56 ldc: usize,
57 );
58}
59
60#[derive(Default, Clone, Copy)]
62pub struct PortableMicrokernel;
63
64impl PortableMicrokernel {
66 pub const MR: usize = 4;
68 pub const NR: usize = 4;
70}
71
72impl<T: TropicalSemiring> Microkernel<T> for PortableMicrokernel {
73 const MR: usize = 4;
74 const NR: usize = 4;
75
76 unsafe fn execute(
77 &self,
78 mr: usize,
79 nr: usize,
80 k: usize,
81 a: *const T::Scalar,
82 b: *const T::Scalar,
83 c: *mut T,
84 ldc: usize,
85 ) {
86 const MR: usize = 4;
87 const NR: usize = 4;
88
89 let mut acc = [[T::tropical_zero(); NR]; MR];
91 for i in 0..mr {
92 for j in 0..nr {
93 acc[i][j] = *c.add(i * ldc + j);
94 }
95 }
96
97 for p in 0..k {
99 for i in 0..mr {
100 let a_val = T::from_scalar(*a.add(p * MR + i));
101 for j in 0..nr {
102 let b_val = T::from_scalar(*b.add(p * NR + j));
103 let product = a_val.tropical_mul(b_val);
104 acc[i][j] = acc[i][j].tropical_add(product);
105 }
106 }
107 }
108
109 for i in 0..mr {
111 for j in 0..nr {
112 *c.add(i * ldc + j) = acc[i][j];
113 }
114 }
115 }
116}
117
118impl<T: TropicalWithArgmax<Index = u32>> MicrokernelWithArgmax<T> for PortableMicrokernel {
119 unsafe fn execute_with_argmax(
120 &self,
121 mr: usize,
122 nr: usize,
123 k: usize,
124 k_offset: usize,
125 a: *const T::Scalar,
126 b: *const T::Scalar,
127 c: *mut T,
128 argmax: *mut u32,
129 ldc: usize,
130 ) {
131 const MR: usize = 4;
132 const NR: usize = 4;
133
134 let mut acc = [[T::tropical_zero(); NR]; MR];
136 let mut idx = [[0u32; NR]; MR];
137 for i in 0..mr {
138 for j in 0..nr {
139 acc[i][j] = *c.add(i * ldc + j);
140 idx[i][j] = *argmax.add(i * ldc + j);
141 }
142 }
143
144 for p in 0..k {
146 let current_k = (k_offset + p) as u32;
147 for i in 0..mr {
148 let a_val = T::from_scalar(*a.add(p * MR + i));
149 for j in 0..nr {
150 let b_val = T::from_scalar(*b.add(p * NR + j));
151 let product = a_val.tropical_mul(b_val);
152 let (new_acc, new_idx) =
153 acc[i][j].tropical_add_argmax(idx[i][j], product, current_k);
154 acc[i][j] = new_acc;
155 idx[i][j] = new_idx;
156 }
157 }
158 }
159
160 for i in 0..mr {
162 for j in 0..nr {
163 *c.add(i * ldc + j) = acc[i][j];
164 *argmax.add(i * ldc + j) = idx[i][j];
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::types::TropicalMaxPlus;
174
175 #[test]
176 fn test_portable_kernel() {
177 let kernel = PortableMicrokernel;
178 let mr = 2;
179 let nr = 2;
180 let k = 3;
181
182 let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
186
187 let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
192
193 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
195 let ldc = 2;
196
197 unsafe {
198 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
199 }
200
201 assert_eq!(c[0].0, 8.0);
204
205 assert_eq!(c[1].0, 9.0);
207
208 assert_eq!(c[2].0, 11.0);
210
211 assert_eq!(c[3].0, 12.0);
213 }
214
215 #[test]
216 fn test_portable_kernel_minplus() {
217 use crate::types::TropicalMinPlus;
218
219 let kernel = PortableMicrokernel;
220 let mr = 2;
221 let nr = 2;
222 let k = 3;
223
224 let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
225 let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
226
227 let mut c = [TropicalMinPlus::tropical_zero(); 4];
228 let ldc = 2;
229
230 unsafe {
231 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
232 }
233
234 assert_eq!(c[0].0, 2.0);
236 assert_eq!(c[1].0, 3.0);
238 assert_eq!(c[2].0, 5.0);
240 assert_eq!(c[3].0, 6.0);
242 }
243
244 #[test]
245 fn test_portable_kernel_maxmul() {
246 use crate::types::TropicalMaxMul;
247
248 let kernel = PortableMicrokernel;
249 let mr = 2;
250 let nr = 2;
251 let k = 2;
252
253 let a: [f64; 8] = [2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0];
255 let b: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
257
258 let mut c = [TropicalMaxMul::tropical_zero(); 4];
259 let ldc = 2;
260
261 unsafe {
262 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
263 }
264
265 assert_eq!(c[0].0, 12.0);
267 assert_eq!(c[1].0, 16.0);
269 assert_eq!(c[2].0, 15.0);
271 assert_eq!(c[3].0, 20.0);
273 }
274
275 #[test]
276 fn test_portable_kernel_with_argmax() {
277 let kernel = PortableMicrokernel;
278 let mr = 2;
279 let nr = 2;
280 let k = 3;
281
282 let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
283 let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
284
285 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
286 let mut argmax = [0u32; 4];
287 let ldc = 2;
288 let k_offset = 0;
289
290 unsafe {
291 kernel.execute_with_argmax(
292 mr,
293 nr,
294 k,
295 k_offset,
296 a.as_ptr(),
297 b.as_ptr(),
298 c.as_mut_ptr(),
299 argmax.as_mut_ptr(),
300 ldc,
301 );
302 }
303
304 assert_eq!(c[0].0, 8.0);
306 assert_eq!(argmax[0], 2);
307
308 assert_eq!(c[1].0, 9.0);
310 assert_eq!(argmax[1], 2);
311
312 assert_eq!(c[2].0, 11.0);
314 assert_eq!(argmax[2], 2);
315
316 assert_eq!(c[3].0, 12.0);
318 assert_eq!(argmax[3], 2);
319 }
320
321 #[test]
322 fn test_portable_kernel_with_argmax_offset() {
323 let kernel = PortableMicrokernel;
325 let mr = 2;
326 let nr = 2;
327 let k = 2;
328
329 let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 10.0, 20.0, 0.0, 0.0];
330 let b: [f64; 8] = [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
331
332 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
333 let mut argmax = [0u32; 4];
334 let ldc = 2;
335 let k_offset = 5; unsafe {
338 kernel.execute_with_argmax(
339 mr,
340 nr,
341 k,
342 k_offset,
343 a.as_ptr(),
344 b.as_ptr(),
345 c.as_mut_ptr(),
346 argmax.as_mut_ptr(),
347 ldc,
348 );
349 }
350
351 assert_eq!(c[0].0, 11.0);
354 assert_eq!(argmax[0], 6); assert_eq!(c[2].0, 21.0);
358 assert_eq!(argmax[2], 6);
359 }
360
361 #[test]
362 fn test_portable_kernel_f32() {
363 let kernel = PortableMicrokernel;
364 let mr = 2;
365 let nr = 2;
366 let k = 2;
367
368 let a: [f32; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
369 let b: [f32; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
370
371 let mut c = [TropicalMaxPlus::tropical_zero(); 4];
372 let ldc = 2;
373
374 unsafe {
375 kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
376 }
377
378 assert!((c[0].0 - 6.0).abs() < 1e-6);
380 assert!((c[1].0 - 7.0).abs() < 1e-6);
382 }
383}