tropical_gemm/core/
packing.rs

1use crate::types::TropicalScalar;
2
3/// Matrix layout enumeration.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Layout {
6    /// Row-major layout (C-style).
7    RowMajor,
8    /// Column-major layout (Fortran-style).
9    ColMajor,
10}
11
12/// Transpose specification.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Transpose {
15    /// No transpose.
16    NoTrans,
17    /// Transpose the matrix.
18    Trans,
19}
20
21/// Pack a panel of matrix A into a contiguous buffer.
22///
23/// The packed format stores `mc` rows in column-major order within
24/// blocks of `mr` rows. This improves cache locality during the
25/// microkernel computation.
26///
27/// # Layout
28/// For A with dimensions m×k:
29/// ```text
30/// Original A (row-major, m=6, k=4, mr=4):
31/// [ a00 a01 a02 a03 ]
32/// [ a10 a11 a12 a13 ]
33/// [ a20 a21 a22 a23 ]
34/// [ a30 a31 a32 a33 ]
35/// [ a40 a41 a42 a43 ]
36/// [ a50 a51 a52 a53 ]
37///
38/// Packed (column-major within mr×k blocks):
39/// Block 0 (rows 0-3): a00 a10 a20 a30 | a01 a11 a21 a31 | a02 a12 a22 a32 | a03 a13 a23 a33
40/// Block 1 (rows 4-5): a40 a50 0   0   | a41 a51 0   0   | a42 a52 0   0   | a43 a53 0   0
41/// ```
42///
43/// # Safety
44/// - `a` must point to valid memory for at least `m * lda` elements
45/// - `packed` must have capacity for at least `((m + mr - 1) / mr) * mr * k` elements
46pub unsafe fn pack_a<T: TropicalScalar>(
47    m: usize,
48    k: usize,
49    a: *const T,
50    lda: usize,
51    layout: Layout,
52    trans: Transpose,
53    packed: *mut T,
54    mr: usize,
55) {
56    let zero = T::scalar_zero();
57
58    let mut packed_idx = 0;
59
60    // Process full mr×k blocks
61    let m_blocks = m / mr;
62    let m_rem = m % mr;
63
64    for block in 0..m_blocks {
65        let row_start = block * mr;
66        for col in 0..k {
67            for row_offset in 0..mr {
68                let row = row_start + row_offset;
69                let val = get_element(a, row, col, lda, layout, trans);
70                *packed.add(packed_idx) = val;
71                packed_idx += 1;
72            }
73        }
74    }
75
76    // Process remaining rows (if any)
77    if m_rem > 0 {
78        let row_start = m_blocks * mr;
79        for col in 0..k {
80            for row_offset in 0..mr {
81                let row = row_start + row_offset;
82                let val = if row < m {
83                    get_element(a, row, col, lda, layout, trans)
84                } else {
85                    zero
86                };
87                *packed.add(packed_idx) = val;
88                packed_idx += 1;
89            }
90        }
91    }
92}
93
94/// Pack a panel of matrix B into a contiguous buffer.
95///
96/// The packed format stores `nc` columns in row-major order within
97/// blocks of `nr` columns.
98///
99/// # Layout
100/// For B with dimensions k×n:
101/// ```text
102/// Original B (row-major, k=3, n=6, nr=4):
103/// [ b00 b01 b02 b03 b04 b05 ]
104/// [ b10 b11 b12 b13 b14 b15 ]
105/// [ b20 b21 b22 b23 b24 b25 ]
106///
107/// Packed (row-major within k×nr blocks):
108/// Block 0 (cols 0-3): b00 b01 b02 b03 | b10 b11 b12 b13 | b20 b21 b22 b23
109/// Block 1 (cols 4-5): b04 b05 0   0   | b14 b15 0   0   | b24 b25 0   0
110/// ```
111///
112/// # Safety
113/// - `b` must point to valid memory for at least `k * ldb` or `ldb * n` elements
114/// - `packed` must have capacity for at least `((n + nr - 1) / nr) * nr * k` elements
115pub unsafe fn pack_b<T: TropicalScalar>(
116    k: usize,
117    n: usize,
118    b: *const T,
119    ldb: usize,
120    layout: Layout,
121    trans: Transpose,
122    packed: *mut T,
123    nr: usize,
124) {
125    let zero = T::scalar_zero();
126
127    let mut packed_idx = 0;
128
129    // Process full k×nr blocks
130    let n_blocks = n / nr;
131    let n_rem = n % nr;
132
133    for block in 0..n_blocks {
134        let col_start = block * nr;
135        for row in 0..k {
136            for col_offset in 0..nr {
137                let col = col_start + col_offset;
138                let val = get_element(b, row, col, ldb, layout, trans);
139                *packed.add(packed_idx) = val;
140                packed_idx += 1;
141            }
142        }
143    }
144
145    // Process remaining columns (if any)
146    if n_rem > 0 {
147        let col_start = n_blocks * nr;
148        for row in 0..k {
149            for col_offset in 0..nr {
150                let col = col_start + col_offset;
151                let val = if col < n {
152                    get_element(b, row, col, ldb, layout, trans)
153                } else {
154                    zero
155                };
156                *packed.add(packed_idx) = val;
157                packed_idx += 1;
158            }
159        }
160    }
161}
162
163/// Get element from matrix considering layout and transpose.
164#[inline(always)]
165unsafe fn get_element<T: Copy>(
166    ptr: *const T,
167    row: usize,
168    col: usize,
169    ld: usize,
170    layout: Layout,
171    trans: Transpose,
172) -> T {
173    let (actual_row, actual_col) = match trans {
174        Transpose::NoTrans => (row, col),
175        Transpose::Trans => (col, row),
176    };
177
178    let idx = match layout {
179        Layout::RowMajor => actual_row * ld + actual_col,
180        Layout::ColMajor => actual_col * ld + actual_row,
181    };
182
183    *ptr.add(idx)
184}
185
186/// Calculate packed buffer size for A.
187#[inline]
188pub fn packed_a_size(m: usize, k: usize, mr: usize) -> usize {
189    let m_padded = m.div_ceil(mr) * mr;
190    m_padded * k
191}
192
193/// Calculate packed buffer size for B.
194#[inline]
195pub fn packed_b_size(k: usize, n: usize, nr: usize) -> usize {
196    let n_padded = n.div_ceil(nr) * nr;
197    k * n_padded
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_pack_a_row_major() {
206        let a: [f64; 6] = [
207            1.0, 2.0, 3.0, // row 0
208            4.0, 5.0, 6.0, // row 1
209        ];
210        let m = 2;
211        let k = 3;
212        let mr = 4;
213        let lda = 3;
214
215        let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
216
217        unsafe {
218            pack_a(
219                m,
220                k,
221                a.as_ptr(),
222                lda,
223                Layout::RowMajor,
224                Transpose::NoTrans,
225                packed.as_mut_ptr(),
226                mr,
227            );
228        }
229
230        // Expected: column 0: [1,4,0,0], column 1: [2,5,0,0], column 2: [3,6,0,0]
231        assert_eq!(packed[0], 1.0); // a[0,0]
232        assert_eq!(packed[1], 4.0); // a[1,0]
233        assert_eq!(packed[2], 0.0); // padding
234        assert_eq!(packed[3], 0.0); // padding
235        assert_eq!(packed[4], 2.0); // a[0,1]
236        assert_eq!(packed[5], 5.0); // a[1,1]
237    }
238
239    #[test]
240    fn test_pack_a_col_major() {
241        // Column-major: columns are stored contiguously
242        // Matrix: [[1, 2, 3], [4, 5, 6]]
243        // Col-major storage: [1, 4, 2, 5, 3, 6]
244        let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
245        let m = 2;
246        let k = 3;
247        let mr = 4;
248        let lda = 2; // Leading dimension for col-major
249
250        let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
251
252        unsafe {
253            pack_a(
254                m,
255                k,
256                a.as_ptr(),
257                lda,
258                Layout::ColMajor,
259                Transpose::NoTrans,
260                packed.as_mut_ptr(),
261                mr,
262            );
263        }
264
265        // Same result as row-major since we're extracting the same logical matrix
266        assert_eq!(packed[0], 1.0); // a[0,0]
267        assert_eq!(packed[1], 4.0); // a[1,0]
268        assert_eq!(packed[4], 2.0); // a[0,1]
269        assert_eq!(packed[5], 5.0); // a[1,1]
270    }
271
272    #[test]
273    fn test_pack_b_row_major() {
274        let b: [f64; 6] = [
275            1.0, 2.0, // row 0
276            3.0, 4.0, // row 1
277            5.0, 6.0, // row 2
278        ];
279        let k = 3;
280        let n = 2;
281        let nr = 4;
282        let ldb = 2;
283
284        let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
285
286        unsafe {
287            pack_b(
288                k,
289                n,
290                b.as_ptr(),
291                ldb,
292                Layout::RowMajor,
293                Transpose::NoTrans,
294                packed.as_mut_ptr(),
295                nr,
296            );
297        }
298
299        // Expected: row 0: [1,2,0,0], row 1: [3,4,0,0], row 2: [5,6,0,0]
300        assert_eq!(packed[0], 1.0); // b[0,0]
301        assert_eq!(packed[1], 2.0); // b[0,1]
302        assert_eq!(packed[2], 0.0); // padding
303        assert_eq!(packed[3], 0.0); // padding
304        assert_eq!(packed[4], 3.0); // b[1,0]
305        assert_eq!(packed[5], 4.0); // b[1,1]
306    }
307
308    #[test]
309    fn test_pack_b_col_major() {
310        // Column-major: columns are stored contiguously
311        // Matrix B (k=3, n=2): [[1, 2], [3, 4], [5, 6]]
312        // Col-major storage: [1, 3, 5, 2, 4, 6]
313        let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
314        let k = 3;
315        let n = 2;
316        let nr = 4;
317        let ldb = 3; // Leading dimension for col-major (number of rows)
318
319        let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
320
321        unsafe {
322            pack_b(
323                k,
324                n,
325                b.as_ptr(),
326                ldb,
327                Layout::ColMajor,
328                Transpose::NoTrans,
329                packed.as_mut_ptr(),
330                nr,
331            );
332        }
333
334        // Expected: same logical values as row-major
335        assert_eq!(packed[0], 1.0); // b[0,0]
336        assert_eq!(packed[1], 2.0); // b[0,1]
337        assert_eq!(packed[4], 3.0); // b[1,0]
338        assert_eq!(packed[5], 4.0); // b[1,1]
339    }
340
341    #[test]
342    fn test_pack_a_with_transpose() {
343        // Test packing with transpose
344        let a: [f64; 6] = [
345            1.0, 2.0, // row 0 (becomes col 0 after trans)
346            3.0, 4.0, // row 1
347            5.0, 6.0, // row 2
348        ];
349        let m = 2; // After transpose: original 2 columns become 2 rows
350        let k = 3; // After transpose: original 3 rows become 3 cols
351        let mr = 4;
352        let lda = 2;
353
354        let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
355
356        unsafe {
357            pack_a(
358                m,
359                k,
360                a.as_ptr(),
361                lda,
362                Layout::RowMajor,
363                Transpose::Trans,
364                packed.as_mut_ptr(),
365                mr,
366            );
367        }
368
369        // A^T = [[1, 3, 5], [2, 4, 6]]
370        assert_eq!(packed[0], 1.0); // a^T[0,0]
371        assert_eq!(packed[1], 2.0); // a^T[1,0]
372        assert_eq!(packed[4], 3.0); // a^T[0,1]
373        assert_eq!(packed[5], 4.0); // a^T[1,1]
374    }
375
376    #[test]
377    fn test_pack_b_with_transpose() {
378        // Test packing B with transpose
379        let b: [f64; 6] = [
380            1.0, 2.0, 3.0, // row 0
381            4.0, 5.0, 6.0, // row 1
382        ];
383        let k = 3; // After transpose: original 3 cols become 3 rows
384        let n = 2; // After transpose: original 2 rows become 2 cols
385        let nr = 4;
386        let ldb = 3;
387
388        let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
389
390        unsafe {
391            pack_b(
392                k,
393                n,
394                b.as_ptr(),
395                ldb,
396                Layout::RowMajor,
397                Transpose::Trans,
398                packed.as_mut_ptr(),
399                nr,
400            );
401        }
402
403        // B^T = [[1, 4], [2, 5], [3, 6]]
404        assert_eq!(packed[0], 1.0); // b^T[0,0]
405        assert_eq!(packed[1], 4.0); // b^T[0,1]
406        assert_eq!(packed[4], 2.0); // b^T[1,0]
407        assert_eq!(packed[5], 5.0); // b^T[1,1]
408    }
409
410    #[test]
411    fn test_pack_a_exact_mr() {
412        // Test when m is exactly divisible by mr (no remainder path)
413        let a: [f64; 12] = [
414            1.0, 2.0, 3.0, // row 0
415            4.0, 5.0, 6.0, // row 1
416            7.0, 8.0, 9.0, // row 2
417            10.0, 11.0, 12.0, // row 3
418        ];
419        let m = 4;
420        let k = 3;
421        let mr = 4;
422        let lda = 3;
423
424        let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
425
426        unsafe {
427            pack_a(
428                m,
429                k,
430                a.as_ptr(),
431                lda,
432                Layout::RowMajor,
433                Transpose::NoTrans,
434                packed.as_mut_ptr(),
435                mr,
436            );
437        }
438
439        // No padding needed
440        assert_eq!(packed[0], 1.0);
441        assert_eq!(packed[1], 4.0);
442        assert_eq!(packed[2], 7.0);
443        assert_eq!(packed[3], 10.0);
444    }
445
446    #[test]
447    fn test_pack_b_exact_nr() {
448        // Test when n is exactly divisible by nr (no remainder path)
449        let b: [f64; 12] = [
450            1.0, 2.0, 3.0, 4.0, // row 0
451            5.0, 6.0, 7.0, 8.0, // row 1
452            9.0, 10.0, 11.0, 12.0, // row 2
453        ];
454        let k = 3;
455        let n = 4;
456        let nr = 4;
457        let ldb = 4;
458
459        let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
460
461        unsafe {
462            pack_b(
463                k,
464                n,
465                b.as_ptr(),
466                ldb,
467                Layout::RowMajor,
468                Transpose::NoTrans,
469                packed.as_mut_ptr(),
470                nr,
471            );
472        }
473
474        // No padding needed
475        assert_eq!(packed[0], 1.0);
476        assert_eq!(packed[1], 2.0);
477        assert_eq!(packed[2], 3.0);
478        assert_eq!(packed[3], 4.0);
479    }
480
481    #[test]
482    fn test_packed_a_size() {
483        // Exact multiple of mr
484        assert_eq!(packed_a_size(8, 10, 4), 8 * 10);
485        // Needs padding: m=5, mr=4 -> m_padded=8
486        assert_eq!(packed_a_size(5, 10, 4), 8 * 10);
487        // m=1, mr=4 -> m_padded=4
488        assert_eq!(packed_a_size(1, 10, 4), 4 * 10);
489    }
490
491    #[test]
492    fn test_packed_b_size() {
493        // Exact multiple of nr
494        assert_eq!(packed_b_size(10, 8, 4), 10 * 8);
495        // Needs padding: n=5, nr=4 -> n_padded=8
496        assert_eq!(packed_b_size(10, 5, 4), 10 * 8);
497        // n=1, nr=4 -> n_padded=4
498        assert_eq!(packed_b_size(10, 1, 4), 10 * 4);
499    }
500
501    #[test]
502    fn test_layout_debug() {
503        assert_eq!(format!("{:?}", Layout::RowMajor), "RowMajor");
504        assert_eq!(format!("{:?}", Layout::ColMajor), "ColMajor");
505    }
506
507    #[test]
508    fn test_layout_clone_eq() {
509        let l1 = Layout::RowMajor;
510        let l2 = l1;
511        assert_eq!(l1, l2);
512        assert_ne!(l1, Layout::ColMajor);
513    }
514
515    #[test]
516    fn test_transpose_debug() {
517        assert_eq!(format!("{:?}", Transpose::NoTrans), "NoTrans");
518        assert_eq!(format!("{:?}", Transpose::Trans), "Trans");
519    }
520
521    #[test]
522    fn test_transpose_clone_eq() {
523        let t1 = Transpose::Trans;
524        let t2 = t1;
525        assert_eq!(t1, t2);
526        assert_ne!(t1, Transpose::NoTrans);
527    }
528}