tropical_gemm/core/
kernel.rs

1use crate::types::{TropicalSemiring, TropicalWithArgmax};
2
3/// Trait for GEMM microkernels.
4///
5/// A microkernel computes a small block of C += A * B using register blocking.
6/// The dimensions mr x nr define the "register tile" that fits in CPU registers.
7pub trait Microkernel<T: TropicalSemiring> {
8    /// Rows of the microkernel (typically 4-8 for f32).
9    const MR: usize;
10
11    /// Columns of the microkernel (typically 4-8 for f32).
12    const NR: usize;
13
14    /// Execute the microkernel.
15    ///
16    /// Computes C[0..mr, 0..nr] = A[0..mr, 0..k] ⊗ B[0..k, 0..nr]
17    /// where the result is combined with existing C values using tropical addition.
18    ///
19    /// # Safety
20    /// - `a` must point to at least `mr * k` elements (packed column-major)
21    /// - `b` must point to at least `k * nr` elements (packed row-major)
22    /// - `c` must point to at least `mr * ldc` elements
23    /// - `mr <= Self::MR` and `nr <= Self::NR`
24    unsafe fn execute(
25        &self,
26        mr: usize,
27        nr: usize,
28        k: usize,
29        a: *const T::Scalar,
30        b: *const T::Scalar,
31        c: *mut T,
32        ldc: usize,
33    );
34}
35
36/// Trait for microkernels that track argmax during computation.
37pub trait MicrokernelWithArgmax<T: TropicalWithArgmax<Index = u32>>: Microkernel<T> {
38    /// Execute the microkernel with argmax tracking.
39    ///
40    /// Same as `execute`, but also fills `argmax` with the k-index that
41    /// produced each optimal C[i,j] value.
42    ///
43    /// # Safety
44    /// Same requirements as `execute`, plus:
45    /// - `argmax` must point to at least `mr * ldc` elements
46    unsafe fn execute_with_argmax(
47        &self,
48        mr: usize,
49        nr: usize,
50        k: usize,
51        k_offset: usize,
52        a: *const T::Scalar,
53        b: *const T::Scalar,
54        c: *mut T,
55        argmax: *mut u32,
56        ldc: usize,
57    );
58}
59
60/// Portable (non-SIMD) microkernel implementation.
61#[derive(Default, Clone, Copy)]
62pub struct PortableMicrokernel;
63
64/// Constants for PortableMicrokernel
65impl PortableMicrokernel {
66    /// Microkernel row dimension.
67    pub const MR: usize = 4;
68    /// Microkernel column dimension.
69    pub const NR: usize = 4;
70}
71
72impl<T: TropicalSemiring> Microkernel<T> for PortableMicrokernel {
73    const MR: usize = 4;
74    const NR: usize = 4;
75
76    unsafe fn execute(
77        &self,
78        mr: usize,
79        nr: usize,
80        k: usize,
81        a: *const T::Scalar,
82        b: *const T::Scalar,
83        c: *mut T,
84        ldc: usize,
85    ) {
86        const MR: usize = 4;
87        const NR: usize = 4;
88
89        // Initialize accumulators from C
90        let mut acc = [[T::tropical_zero(); NR]; MR];
91        for i in 0..mr {
92            for j in 0..nr {
93                acc[i][j] = *c.add(i * ldc + j);
94            }
95        }
96
97        // Main loop
98        for p in 0..k {
99            for i in 0..mr {
100                let a_val = T::from_scalar(*a.add(p * MR + i));
101                for j in 0..nr {
102                    let b_val = T::from_scalar(*b.add(p * NR + j));
103                    let product = a_val.tropical_mul(b_val);
104                    acc[i][j] = acc[i][j].tropical_add(product);
105                }
106            }
107        }
108
109        // Write back
110        for i in 0..mr {
111            for j in 0..nr {
112                *c.add(i * ldc + j) = acc[i][j];
113            }
114        }
115    }
116}
117
118impl<T: TropicalWithArgmax<Index = u32>> MicrokernelWithArgmax<T> for PortableMicrokernel {
119    unsafe fn execute_with_argmax(
120        &self,
121        mr: usize,
122        nr: usize,
123        k: usize,
124        k_offset: usize,
125        a: *const T::Scalar,
126        b: *const T::Scalar,
127        c: *mut T,
128        argmax: *mut u32,
129        ldc: usize,
130    ) {
131        const MR: usize = 4;
132        const NR: usize = 4;
133
134        // Initialize accumulators from C and existing argmax
135        let mut acc = [[T::tropical_zero(); NR]; MR];
136        let mut idx = [[0u32; NR]; MR];
137        for i in 0..mr {
138            for j in 0..nr {
139                acc[i][j] = *c.add(i * ldc + j);
140                idx[i][j] = *argmax.add(i * ldc + j);
141            }
142        }
143
144        // Main loop with argmax tracking
145        for p in 0..k {
146            let current_k = (k_offset + p) as u32;
147            for i in 0..mr {
148                let a_val = T::from_scalar(*a.add(p * MR + i));
149                for j in 0..nr {
150                    let b_val = T::from_scalar(*b.add(p * NR + j));
151                    let product = a_val.tropical_mul(b_val);
152                    let (new_acc, new_idx) =
153                        acc[i][j].tropical_add_argmax(idx[i][j], product, current_k);
154                    acc[i][j] = new_acc;
155                    idx[i][j] = new_idx;
156                }
157            }
158        }
159
160        // Write back
161        for i in 0..mr {
162            for j in 0..nr {
163                *c.add(i * ldc + j) = acc[i][j];
164                *argmax.add(i * ldc + j) = idx[i][j];
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::types::TropicalMaxPlus;
174
175    #[test]
176    fn test_portable_kernel() {
177        let kernel = PortableMicrokernel;
178        let mr = 2;
179        let nr = 2;
180        let k = 3;
181
182        // A: 2x3 matrix (packed column-major in MR chunks)
183        // A = [[1, 2, 3],
184        //      [4, 5, 6]]
185        let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
186
187        // B: 3x2 matrix (packed row-major in NR chunks)
188        // B = [[1, 2],
189        //      [3, 4],
190        //      [5, 6]]
191        let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
192
193        // C: 2x2 output
194        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
195        let ldc = 2;
196
197        unsafe {
198            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
199        }
200
201        // C[0,0] = max(A[0,0]+B[0,0], A[0,1]+B[1,0], A[0,2]+B[2,0])
202        //        = max(1+1, 2+3, 3+5) = max(2, 5, 8) = 8
203        assert_eq!(c[0].0, 8.0);
204
205        // C[0,1] = max(1+2, 2+4, 3+6) = max(3, 6, 9) = 9
206        assert_eq!(c[1].0, 9.0);
207
208        // C[1,0] = max(4+1, 5+3, 6+5) = max(5, 8, 11) = 11
209        assert_eq!(c[2].0, 11.0);
210
211        // C[1,1] = max(4+2, 5+4, 6+6) = max(6, 9, 12) = 12
212        assert_eq!(c[3].0, 12.0);
213    }
214
215    #[test]
216    fn test_portable_kernel_minplus() {
217        use crate::types::TropicalMinPlus;
218
219        let kernel = PortableMicrokernel;
220        let mr = 2;
221        let nr = 2;
222        let k = 3;
223
224        let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
225        let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
226
227        let mut c = [TropicalMinPlus::tropical_zero(); 4];
228        let ldc = 2;
229
230        unsafe {
231            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
232        }
233
234        // C[0,0] = min(1+1, 2+3, 3+5) = min(2, 5, 8) = 2
235        assert_eq!(c[0].0, 2.0);
236        // C[0,1] = min(1+2, 2+4, 3+6) = min(3, 6, 9) = 3
237        assert_eq!(c[1].0, 3.0);
238        // C[1,0] = min(4+1, 5+3, 6+5) = min(5, 8, 11) = 5
239        assert_eq!(c[2].0, 5.0);
240        // C[1,1] = min(4+2, 5+4, 6+6) = min(6, 9, 12) = 6
241        assert_eq!(c[3].0, 6.0);
242    }
243
244    #[test]
245    fn test_portable_kernel_maxmul() {
246        use crate::types::TropicalMaxMul;
247
248        let kernel = PortableMicrokernel;
249        let mr = 2;
250        let nr = 2;
251        let k = 2;
252
253        // A: [[2, 4], [3, 5]]
254        let a: [f64; 8] = [2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0];
255        // B: [[1, 2], [3, 4]]
256        let b: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
257
258        let mut c = [TropicalMaxMul::tropical_zero(); 4];
259        let ldc = 2;
260
261        unsafe {
262            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
263        }
264
265        // C[0,0] = max(2*1, 4*3) = max(2, 12) = 12
266        assert_eq!(c[0].0, 12.0);
267        // C[0,1] = max(2*2, 4*4) = max(4, 16) = 16
268        assert_eq!(c[1].0, 16.0);
269        // C[1,0] = max(3*1, 5*3) = max(3, 15) = 15
270        assert_eq!(c[2].0, 15.0);
271        // C[1,1] = max(3*2, 5*4) = max(6, 20) = 20
272        assert_eq!(c[3].0, 20.0);
273    }
274
275    #[test]
276    fn test_portable_kernel_with_argmax() {
277        let kernel = PortableMicrokernel;
278        let mr = 2;
279        let nr = 2;
280        let k = 3;
281
282        let a: [f64; 12] = [1.0, 4.0, 0.0, 0.0, 2.0, 5.0, 0.0, 0.0, 3.0, 6.0, 0.0, 0.0];
283        let b: [f64; 12] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0];
284
285        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
286        let mut argmax = [0u32; 4];
287        let ldc = 2;
288        let k_offset = 0;
289
290        unsafe {
291            kernel.execute_with_argmax(
292                mr,
293                nr,
294                k,
295                k_offset,
296                a.as_ptr(),
297                b.as_ptr(),
298                c.as_mut_ptr(),
299                argmax.as_mut_ptr(),
300                ldc,
301            );
302        }
303
304        // C[0,0] = max(1+1, 2+3, 3+5) = 8 at k=2
305        assert_eq!(c[0].0, 8.0);
306        assert_eq!(argmax[0], 2);
307
308        // C[0,1] = max(1+2, 2+4, 3+6) = 9 at k=2
309        assert_eq!(c[1].0, 9.0);
310        assert_eq!(argmax[1], 2);
311
312        // C[1,0] = max(4+1, 5+3, 6+5) = 11 at k=2
313        assert_eq!(c[2].0, 11.0);
314        assert_eq!(argmax[2], 2);
315
316        // C[1,1] = max(4+2, 5+4, 6+6) = 12 at k=2
317        assert_eq!(c[3].0, 12.0);
318        assert_eq!(argmax[3], 2);
319    }
320
321    #[test]
322    fn test_portable_kernel_with_argmax_offset() {
323        // Test that k_offset is correctly applied
324        let kernel = PortableMicrokernel;
325        let mr = 2;
326        let nr = 2;
327        let k = 2;
328
329        let a: [f64; 8] = [1.0, 2.0, 0.0, 0.0, 10.0, 20.0, 0.0, 0.0];
330        let b: [f64; 8] = [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
331
332        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
333        let mut argmax = [0u32; 4];
334        let ldc = 2;
335        let k_offset = 5; // Start from global k=5
336
337        unsafe {
338            kernel.execute_with_argmax(
339                mr,
340                nr,
341                k,
342                k_offset,
343                a.as_ptr(),
344                b.as_ptr(),
345                c.as_mut_ptr(),
346                argmax.as_mut_ptr(),
347                ldc,
348            );
349        }
350
351        // A[:,1] has larger values, so k=1 (global k=6) should win
352        // C[0,0] = max(1+1, 10+1) = 11 at local k=1, global k=6
353        assert_eq!(c[0].0, 11.0);
354        assert_eq!(argmax[0], 6); // k_offset + 1
355
356        // C[1,0] = max(2+1, 20+1) = 21 at local k=1, global k=6
357        assert_eq!(c[2].0, 21.0);
358        assert_eq!(argmax[2], 6);
359    }
360
361    #[test]
362    fn test_portable_kernel_f32() {
363        let kernel = PortableMicrokernel;
364        let mr = 2;
365        let nr = 2;
366        let k = 2;
367
368        let a: [f32; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
369        let b: [f32; 8] = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0];
370
371        let mut c = [TropicalMaxPlus::tropical_zero(); 4];
372        let ldc = 2;
373
374        unsafe {
375            kernel.execute(mr, nr, k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), ldc);
376        }
377
378        // C[0,0] = max(1+1, 3+3) = 6
379        assert!((c[0].0 - 6.0).abs() < 1e-6);
380        // C[0,1] = max(1+2, 3+4) = 7
381        assert!((c[1].0 - 7.0).abs() < 1e-6);
382    }
383}