1use crate::types::TropicalScalar;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Layout {
6 RowMajor,
8 ColMajor,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Transpose {
15 NoTrans,
17 Trans,
19}
20
21pub unsafe fn pack_a<T: TropicalScalar>(
47 m: usize,
48 k: usize,
49 a: *const T,
50 lda: usize,
51 layout: Layout,
52 trans: Transpose,
53 packed: *mut T,
54 mr: usize,
55) {
56 let zero = T::scalar_zero();
57
58 let mut packed_idx = 0;
59
60 let m_blocks = m / mr;
62 let m_rem = m % mr;
63
64 for block in 0..m_blocks {
65 let row_start = block * mr;
66 for col in 0..k {
67 for row_offset in 0..mr {
68 let row = row_start + row_offset;
69 let val = get_element(a, row, col, lda, layout, trans);
70 *packed.add(packed_idx) = val;
71 packed_idx += 1;
72 }
73 }
74 }
75
76 if m_rem > 0 {
78 let row_start = m_blocks * mr;
79 for col in 0..k {
80 for row_offset in 0..mr {
81 let row = row_start + row_offset;
82 let val = if row < m {
83 get_element(a, row, col, lda, layout, trans)
84 } else {
85 zero
86 };
87 *packed.add(packed_idx) = val;
88 packed_idx += 1;
89 }
90 }
91 }
92}
93
94pub unsafe fn pack_b<T: TropicalScalar>(
116 k: usize,
117 n: usize,
118 b: *const T,
119 ldb: usize,
120 layout: Layout,
121 trans: Transpose,
122 packed: *mut T,
123 nr: usize,
124) {
125 let zero = T::scalar_zero();
126
127 let mut packed_idx = 0;
128
129 let n_blocks = n / nr;
131 let n_rem = n % nr;
132
133 for block in 0..n_blocks {
134 let col_start = block * nr;
135 for row in 0..k {
136 for col_offset in 0..nr {
137 let col = col_start + col_offset;
138 let val = get_element(b, row, col, ldb, layout, trans);
139 *packed.add(packed_idx) = val;
140 packed_idx += 1;
141 }
142 }
143 }
144
145 if n_rem > 0 {
147 let col_start = n_blocks * nr;
148 for row in 0..k {
149 for col_offset in 0..nr {
150 let col = col_start + col_offset;
151 let val = if col < n {
152 get_element(b, row, col, ldb, layout, trans)
153 } else {
154 zero
155 };
156 *packed.add(packed_idx) = val;
157 packed_idx += 1;
158 }
159 }
160 }
161}
162
163#[inline(always)]
165unsafe fn get_element<T: Copy>(
166 ptr: *const T,
167 row: usize,
168 col: usize,
169 ld: usize,
170 layout: Layout,
171 trans: Transpose,
172) -> T {
173 let (actual_row, actual_col) = match trans {
174 Transpose::NoTrans => (row, col),
175 Transpose::Trans => (col, row),
176 };
177
178 let idx = match layout {
179 Layout::RowMajor => actual_row * ld + actual_col,
180 Layout::ColMajor => actual_col * ld + actual_row,
181 };
182
183 *ptr.add(idx)
184}
185
186#[inline]
188pub fn packed_a_size(m: usize, k: usize, mr: usize) -> usize {
189 let m_padded = m.div_ceil(mr) * mr;
190 m_padded * k
191}
192
193#[inline]
195pub fn packed_b_size(k: usize, n: usize, nr: usize) -> usize {
196 let n_padded = n.div_ceil(nr) * nr;
197 k * n_padded
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_pack_a_row_major() {
206 let a: [f64; 6] = [
207 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
210 let m = 2;
211 let k = 3;
212 let mr = 4;
213 let lda = 3;
214
215 let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
216
217 unsafe {
218 pack_a(
219 m,
220 k,
221 a.as_ptr(),
222 lda,
223 Layout::RowMajor,
224 Transpose::NoTrans,
225 packed.as_mut_ptr(),
226 mr,
227 );
228 }
229
230 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 4.0); assert_eq!(packed[2], 0.0); assert_eq!(packed[3], 0.0); assert_eq!(packed[4], 2.0); assert_eq!(packed[5], 5.0); }
238
239 #[test]
240 fn test_pack_a_col_major() {
241 let a: [f64; 6] = [1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
245 let m = 2;
246 let k = 3;
247 let mr = 4;
248 let lda = 2; let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
251
252 unsafe {
253 pack_a(
254 m,
255 k,
256 a.as_ptr(),
257 lda,
258 Layout::ColMajor,
259 Transpose::NoTrans,
260 packed.as_mut_ptr(),
261 mr,
262 );
263 }
264
265 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 4.0); assert_eq!(packed[4], 2.0); assert_eq!(packed[5], 5.0); }
271
272 #[test]
273 fn test_pack_b_row_major() {
274 let b: [f64; 6] = [
275 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
279 let k = 3;
280 let n = 2;
281 let nr = 4;
282 let ldb = 2;
283
284 let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
285
286 unsafe {
287 pack_b(
288 k,
289 n,
290 b.as_ptr(),
291 ldb,
292 Layout::RowMajor,
293 Transpose::NoTrans,
294 packed.as_mut_ptr(),
295 nr,
296 );
297 }
298
299 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 2.0); assert_eq!(packed[2], 0.0); assert_eq!(packed[3], 0.0); assert_eq!(packed[4], 3.0); assert_eq!(packed[5], 4.0); }
307
308 #[test]
309 fn test_pack_b_col_major() {
310 let b: [f64; 6] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
314 let k = 3;
315 let n = 2;
316 let nr = 4;
317 let ldb = 3; let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
320
321 unsafe {
322 pack_b(
323 k,
324 n,
325 b.as_ptr(),
326 ldb,
327 Layout::ColMajor,
328 Transpose::NoTrans,
329 packed.as_mut_ptr(),
330 nr,
331 );
332 }
333
334 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 2.0); assert_eq!(packed[4], 3.0); assert_eq!(packed[5], 4.0); }
340
341 #[test]
342 fn test_pack_a_with_transpose() {
343 let a: [f64; 6] = [
345 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
349 let m = 2; let k = 3; let mr = 4;
352 let lda = 2;
353
354 let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
355
356 unsafe {
357 pack_a(
358 m,
359 k,
360 a.as_ptr(),
361 lda,
362 Layout::RowMajor,
363 Transpose::Trans,
364 packed.as_mut_ptr(),
365 mr,
366 );
367 }
368
369 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 2.0); assert_eq!(packed[4], 3.0); assert_eq!(packed[5], 4.0); }
375
376 #[test]
377 fn test_pack_b_with_transpose() {
378 let b: [f64; 6] = [
380 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
383 let k = 3; let n = 2; let nr = 4;
386 let ldb = 3;
387
388 let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
389
390 unsafe {
391 pack_b(
392 k,
393 n,
394 b.as_ptr(),
395 ldb,
396 Layout::RowMajor,
397 Transpose::Trans,
398 packed.as_mut_ptr(),
399 nr,
400 );
401 }
402
403 assert_eq!(packed[0], 1.0); assert_eq!(packed[1], 4.0); assert_eq!(packed[4], 2.0); assert_eq!(packed[5], 5.0); }
409
410 #[test]
411 fn test_pack_a_exact_mr() {
412 let a: [f64; 12] = [
414 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
419 let m = 4;
420 let k = 3;
421 let mr = 4;
422 let lda = 3;
423
424 let mut packed = vec![0.0f64; packed_a_size(m, k, mr)];
425
426 unsafe {
427 pack_a(
428 m,
429 k,
430 a.as_ptr(),
431 lda,
432 Layout::RowMajor,
433 Transpose::NoTrans,
434 packed.as_mut_ptr(),
435 mr,
436 );
437 }
438
439 assert_eq!(packed[0], 1.0);
441 assert_eq!(packed[1], 4.0);
442 assert_eq!(packed[2], 7.0);
443 assert_eq!(packed[3], 10.0);
444 }
445
446 #[test]
447 fn test_pack_b_exact_nr() {
448 let b: [f64; 12] = [
450 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
454 let k = 3;
455 let n = 4;
456 let nr = 4;
457 let ldb = 4;
458
459 let mut packed = vec![0.0f64; packed_b_size(k, n, nr)];
460
461 unsafe {
462 pack_b(
463 k,
464 n,
465 b.as_ptr(),
466 ldb,
467 Layout::RowMajor,
468 Transpose::NoTrans,
469 packed.as_mut_ptr(),
470 nr,
471 );
472 }
473
474 assert_eq!(packed[0], 1.0);
476 assert_eq!(packed[1], 2.0);
477 assert_eq!(packed[2], 3.0);
478 assert_eq!(packed[3], 4.0);
479 }
480
481 #[test]
482 fn test_packed_a_size() {
483 assert_eq!(packed_a_size(8, 10, 4), 8 * 10);
485 assert_eq!(packed_a_size(5, 10, 4), 8 * 10);
487 assert_eq!(packed_a_size(1, 10, 4), 4 * 10);
489 }
490
491 #[test]
492 fn test_packed_b_size() {
493 assert_eq!(packed_b_size(10, 8, 4), 10 * 8);
495 assert_eq!(packed_b_size(10, 5, 4), 10 * 8);
497 assert_eq!(packed_b_size(10, 1, 4), 10 * 4);
499 }
500
501 #[test]
502 fn test_layout_debug() {
503 assert_eq!(format!("{:?}", Layout::RowMajor), "RowMajor");
504 assert_eq!(format!("{:?}", Layout::ColMajor), "ColMajor");
505 }
506
507 #[test]
508 fn test_layout_clone_eq() {
509 let l1 = Layout::RowMajor;
510 let l2 = l1;
511 assert_eq!(l1, l2);
512 assert_ne!(l1, Layout::ColMajor);
513 }
514
515 #[test]
516 fn test_transpose_debug() {
517 assert_eq!(format!("{:?}", Transpose::NoTrans), "NoTrans");
518 assert_eq!(format!("{:?}", Transpose::Trans), "Trans");
519 }
520
521 #[test]
522 fn test_transpose_clone_eq() {
523 let t1 = Transpose::Trans;
524 let t2 = t1;
525 assert_eq!(t1, t2);
526 assert_ne!(t1, Transpose::NoTrans);
527 }
528}