1mod contract;
4
5use super::traits::{Backend, BackendScalar, Storage};
6use crate::algebra::{Algebra, Scalar, Standard};
7use std::any::TypeId;
8
9#[derive(Clone, Debug, Default)]
11pub struct Cpu;
12
13impl Cpu {
14 pub(crate) fn gemm_internal<A: Algebra>(
22 &self,
23 a: &[A::Scalar],
24 m: usize,
25 k: usize,
26 b: &[A::Scalar],
27 n: usize,
28 ) -> Vec<A::Scalar> {
29 if TypeId::of::<A>() == TypeId::of::<Standard<f32>>() {
31 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
33 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
34 let result = faer_gemm_f32(a_f32, m, k, b_f32, n);
35 return unsafe { std::mem::transmute::<Vec<f32>, Vec<A::Scalar>>(result) };
36 }
37 if TypeId::of::<A>() == TypeId::of::<Standard<f64>>() {
38 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
39 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
40 let result = faer_gemm_f64(a_f64, m, k, b_f64, n);
41 return unsafe { std::mem::transmute::<Vec<f64>, Vec<A::Scalar>>(result) };
42 }
43
44 #[cfg(feature = "tropical-kernels")]
46 {
47 if let Some(result) = try_tropical_gemm::<A>(a, m, k, b, n) {
48 return result;
49 }
50 }
51
52 generic_gemm::<A>(a, m, k, b, n)
54 }
55
56 pub(crate) fn gemm_with_argmax_internal<A: Algebra<Index = u32>>(
61 &self,
62 a: &[A::Scalar],
63 m: usize,
64 k: usize,
65 b: &[A::Scalar],
66 n: usize,
67 ) -> (Vec<A::Scalar>, Vec<u32>) {
68 #[cfg(feature = "tropical-kernels")]
70 {
71 if let Some(result) = try_tropical_gemm_with_argmax::<A>(a, m, k, b, n) {
72 return result;
73 }
74 }
75
76 generic_gemm_with_argmax::<A>(a, m, k, b, n)
78 }
79
80 #[allow(dead_code)]
83 pub(crate) fn gemm_backward_a_internal<A: Algebra>(
84 &self,
85 grad_c: &[A::Scalar],
86 argmax: &[u32],
87 _b: &[A::Scalar],
88 m: usize,
89 k: usize,
90 n: usize,
91 ) -> Vec<A::Scalar> {
92 let mut grad_a = vec![A::Scalar::default(); m * k];
93
94 if A::needs_argmax() {
98 for j in 0..n {
99 for i in 0..m {
100 let idx = argmax[j * m + i] as usize; grad_a[idx * m + i] += grad_c[j * m + i];
103 }
104 }
105 }
106
107 grad_a
108 }
109
110 #[allow(dead_code)]
113 pub(crate) fn gemm_backward_b_internal<A: Algebra>(
114 &self,
115 grad_c: &[A::Scalar],
116 argmax: &[u32],
117 _a: &[A::Scalar],
118 m: usize,
119 k: usize,
120 n: usize,
121 ) -> Vec<A::Scalar> {
122 let mut grad_b = vec![A::Scalar::default(); k * n];
123
124 if A::needs_argmax() {
126 for j in 0..n {
127 for i in 0..m {
128 let idx = argmax[j * m + i] as usize; grad_b[j * k + idx] += grad_c[j * m + i];
131 }
132 }
133 }
134
135 grad_b
136 }
137
138 pub(crate) fn gemm_batched_internal<A: Algebra>(
140 &self,
141 a: &[A::Scalar],
142 batch_size: usize,
143 m: usize,
144 k: usize,
145 b: &[A::Scalar],
146 n: usize,
147 ) -> Vec<A::Scalar> {
148 let a_batch_stride = m * k;
149 let b_batch_stride = k * n;
150 let c_batch_stride = m * n;
151
152 let mut c = vec![A::zero().to_scalar(); batch_size * m * n];
153
154 for batch in 0..batch_size {
155 let a_offset = batch * a_batch_stride;
156 let b_offset = batch * b_batch_stride;
157 let c_offset = batch * c_batch_stride;
158
159 let a_slice = &a[a_offset..a_offset + a_batch_stride];
160 let b_slice = &b[b_offset..b_offset + b_batch_stride];
161
162 let c_batch = generic_gemm::<A>(a_slice, m, k, b_slice, n);
163 c[c_offset..c_offset + c_batch_stride].copy_from_slice(&c_batch);
164 }
165
166 c
167 }
168
169 pub(crate) fn gemm_batched_with_argmax_internal<A: Algebra<Index = u32>>(
171 &self,
172 a: &[A::Scalar],
173 batch_size: usize,
174 m: usize,
175 k: usize,
176 b: &[A::Scalar],
177 n: usize,
178 ) -> (Vec<A::Scalar>, Vec<u32>) {
179 let a_batch_stride = m * k;
180 let b_batch_stride = k * n;
181 let c_batch_stride = m * n;
182
183 let mut c = vec![A::zero().to_scalar(); batch_size * m * n];
184 let mut argmax = vec![0u32; batch_size * m * n];
185
186 for batch in 0..batch_size {
187 let a_offset = batch * a_batch_stride;
188 let b_offset = batch * b_batch_stride;
189 let c_offset = batch * c_batch_stride;
190
191 let a_slice = &a[a_offset..a_offset + a_batch_stride];
192 let b_slice = &b[b_offset..b_offset + b_batch_stride];
193
194 let (c_batch, argmax_batch) = generic_gemm_with_argmax::<A>(a_slice, m, k, b_slice, n);
195 c[c_offset..c_offset + c_batch_stride].copy_from_slice(&c_batch);
196 argmax[c_offset..c_offset + c_batch_stride].copy_from_slice(&argmax_batch);
197 }
198
199 (c, argmax)
200 }
201}
202
203impl<T: Scalar> Storage<T> for Vec<T> {
204 #[inline]
205 fn len(&self) -> usize {
206 Vec::len(self)
207 }
208
209 #[inline]
210 fn get(&self, index: usize) -> T {
211 self[index]
212 }
213
214 #[inline]
215 fn set(&mut self, index: usize, value: T) {
216 self[index] = value;
217 }
218
219 #[inline]
220 fn to_vec(&self) -> Vec<T> {
221 self.clone()
222 }
223
224 #[inline]
225 fn from_slice(data: &[T]) -> Self {
226 data.to_vec()
227 }
228
229 #[inline]
230 fn zeros(len: usize) -> Self {
231 vec![T::default(); len]
232 }
233}
234
235impl Backend for Cpu {
236 type Storage<T: Scalar> = Vec<T>;
237
238 fn name() -> &'static str {
239 "cpu"
240 }
241
242 fn synchronize(&self) {
243 }
245
246 fn alloc<T: Scalar>(&self, len: usize) -> Vec<T> {
247 vec![T::default(); len]
248 }
249
250 fn from_slice<T: Scalar>(&self, data: &[T]) -> Vec<T> {
251 data.to_vec()
252 }
253
254 fn contract<A: Algebra>(
255 &self,
256 a: &Self::Storage<A::Scalar>,
257 shape_a: &[usize],
258 strides_a: &[usize],
259 modes_a: &[i32],
260 b: &Self::Storage<A::Scalar>,
261 shape_b: &[usize],
262 strides_b: &[usize],
263 modes_b: &[i32],
264 shape_c: &[usize],
265 modes_c: &[i32],
266 ) -> Self::Storage<A::Scalar>
267 where
268 A::Scalar: BackendScalar<Self>,
269 {
270 contract::contract::<A>(
271 self, a, shape_a, strides_a, modes_a,
272 b, shape_b, strides_b, modes_b,
273 shape_c, modes_c,
274 )
275 }
276
277 fn contract_with_argmax<A: Algebra<Index = u32>>(
278 &self,
279 a: &Self::Storage<A::Scalar>,
280 shape_a: &[usize],
281 strides_a: &[usize],
282 modes_a: &[i32],
283 b: &Self::Storage<A::Scalar>,
284 shape_b: &[usize],
285 strides_b: &[usize],
286 modes_b: &[i32],
287 shape_c: &[usize],
288 modes_c: &[i32],
289 ) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)
290 where
291 A::Scalar: BackendScalar<Self>,
292 {
293 contract::contract_with_argmax::<A>(
294 self, a, shape_a, strides_a, modes_a,
295 b, shape_b, strides_b, modes_b,
296 shape_c, modes_c,
297 )
298 }
299
300 fn copy_strided<T: Scalar>(
301 &self,
302 src: &Vec<T>,
303 shape: &[usize],
304 strides: &[usize],
305 offset: usize,
306 ) -> Vec<T> {
307 let numel: usize = shape.iter().product();
308 let mut dst = vec![T::default(); numel];
309
310 let mut indices = vec![0usize; shape.len()];
312 for dst_elem in dst.iter_mut() {
313 let src_offset: usize = offset
315 + indices
316 .iter()
317 .zip(strides.iter())
318 .map(|(i, s)| i * s)
319 .sum::<usize>();
320
321 *dst_elem = src[src_offset];
322
323 for dim in 0..shape.len() {
325 indices[dim] += 1;
326 if indices[dim] < shape[dim] {
327 break;
328 }
329 indices[dim] = 0;
330 }
331 }
332
333 dst
334 }
335}
336
337fn faer_gemm_f32(a: &[f32], m: usize, k: usize, b: &[f32], n: usize) -> Vec<f32> {
341 use faer::Mat;
342
343 let a_mat = Mat::from_fn(m, k, |i, j| a[j * m + i]);
346 let b_mat = Mat::from_fn(k, n, |i, j| b[j * k + i]);
347
348 let c_mat = &a_mat * &b_mat;
350
351 let mut c = vec![0.0f32; m * n];
353 for j in 0..n {
354 for i in 0..m {
355 c[j * m + i] = c_mat[(i, j)];
356 }
357 }
358 c
359}
360
361fn faer_gemm_f64(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> Vec<f64> {
363 use faer::Mat;
364
365 let a_mat = Mat::from_fn(m, k, |i, j| a[j * m + i]);
366 let b_mat = Mat::from_fn(k, n, |i, j| b[j * k + i]);
367
368 let c_mat = &a_mat * &b_mat;
369
370 let mut c = vec![0.0f64; m * n];
371 for j in 0..n {
372 for i in 0..m {
373 c[j * m + i] = c_mat[(i, j)];
374 }
375 }
376 c
377}
378
379fn generic_gemm<A: Algebra>(
381 a: &[A::Scalar],
382 m: usize,
383 k: usize,
384 b: &[A::Scalar],
385 n: usize,
386) -> Vec<A::Scalar> {
387 let mut c = vec![A::zero().to_scalar(); m * n];
388
389 for j in 0..n {
391 for i in 0..m {
392 let mut acc = A::zero();
393 for kk in 0..k {
394 let a_val = A::from_scalar(a[kk * m + i]); let b_val = A::from_scalar(b[j * k + kk]); let prod = a_val.mul(b_val);
397 acc = acc.add(prod);
398 }
399 c[j * m + i] = acc.to_scalar();
400 }
401 }
402
403 c
404}
405
406fn generic_gemm_with_argmax<A: Algebra<Index = u32>>(
408 a: &[A::Scalar],
409 m: usize,
410 k: usize,
411 b: &[A::Scalar],
412 n: usize,
413) -> (Vec<A::Scalar>, Vec<u32>) {
414 let mut c = vec![A::zero().to_scalar(); m * n];
415 let mut argmax = vec![0u32; m * n];
416
417 for j in 0..n {
419 for i in 0..m {
420 let mut acc = A::zero();
421 let mut best_k = 0u32;
422
423 for kk in 0..k {
424 let a_val = A::from_scalar(a[kk * m + i]); let b_val = A::from_scalar(b[j * k + kk]); let prod = a_val.mul(b_val);
427 let (new_acc, winner) = acc.add_with_argmax(best_k, prod, kk as u32);
428 acc = new_acc;
429 best_k = winner;
430 }
431
432 c[j * m + i] = acc.to_scalar();
433 argmax[j * m + i] = best_k;
434 }
435 }
436
437 (c, argmax)
438}
439
440#[cfg(feature = "tropical-kernels")]
442fn try_tropical_gemm<A: Algebra>(
443 a: &[A::Scalar],
444 m: usize,
445 k: usize,
446 b: &[A::Scalar],
447 n: usize,
448) -> Option<Vec<A::Scalar>> {
449 use crate::algebra::{MaxMul, MaxPlus, MinPlus};
450 use std::any::TypeId;
451 use tropical_gemm::{
452 tropical_matmul, TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring,
453 };
454
455 if TypeId::of::<A>() == TypeId::of::<MaxPlus<f32>>() {
460 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
462 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
463
464 let result: Vec<TropicalMaxPlus<f32>> =
465 tropical_matmul::<TropicalMaxPlus<f32>>(a_f32, m, k, b_f32, n);
466
467 let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
469
470 Some(unsafe { std::mem::transmute(scalars) })
472 } else if TypeId::of::<A>() == TypeId::of::<MaxPlus<f64>>() {
473 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
474 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
475
476 let result: Vec<TropicalMaxPlus<f64>> =
477 tropical_matmul::<TropicalMaxPlus<f64>>(a_f64, m, k, b_f64, n);
478 let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
479
480 Some(unsafe { std::mem::transmute(scalars) })
481 } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f32>>() {
482 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
483 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
484
485 let result: Vec<TropicalMinPlus<f32>> =
486 tropical_matmul::<TropicalMinPlus<f32>>(a_f32, m, k, b_f32, n);
487 let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
488
489 Some(unsafe { std::mem::transmute(scalars) })
490 } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f64>>() {
491 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
492 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
493
494 let result: Vec<TropicalMinPlus<f64>> =
495 tropical_matmul::<TropicalMinPlus<f64>>(a_f64, m, k, b_f64, n);
496 let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
497
498 Some(unsafe { std::mem::transmute(scalars) })
499 } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f32>>() {
500 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
501 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
502
503 let result: Vec<TropicalMaxMul<f32>> =
504 tropical_matmul::<TropicalMaxMul<f32>>(a_f32, m, k, b_f32, n);
505 let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
506
507 Some(unsafe { std::mem::transmute(scalars) })
508 } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f64>>() {
509 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
510 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
511
512 let result: Vec<TropicalMaxMul<f64>> =
513 tropical_matmul::<TropicalMaxMul<f64>>(a_f64, m, k, b_f64, n);
514 let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
515
516 Some(unsafe { std::mem::transmute(scalars) })
517 } else {
518 None
520 }
521}
522
523#[cfg(feature = "tropical-kernels")]
524fn try_tropical_gemm_with_argmax<A: Algebra<Index = u32>>(
525 a: &[A::Scalar],
526 m: usize,
527 k: usize,
528 b: &[A::Scalar],
529 n: usize,
530) -> Option<(Vec<A::Scalar>, Vec<u32>)> {
531 use crate::algebra::{MaxMul, MaxPlus, MinPlus};
532 use std::any::TypeId;
533 use tropical_gemm::{
534 tropical_matmul_with_argmax, TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus,
535 TropicalSemiring,
536 };
537
538 if TypeId::of::<A>() == TypeId::of::<MaxPlus<f32>>() {
540 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
541 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
542
543 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f32>>(a_f32, m, k, b_f32, n);
544
545 let mut scalars = Vec::with_capacity(m * n);
548 let mut argmax = Vec::with_capacity(m * n);
549 for j in 0..n {
550 for i in 0..m {
551 scalars.push(result.get(j, i).value());
552 argmax.push(result.get_argmax(j, i));
553 }
554 }
555
556 Some((unsafe { std::mem::transmute(scalars) }, argmax))
557 } else if TypeId::of::<A>() == TypeId::of::<MaxPlus<f64>>() {
558 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
559 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
560
561 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(a_f64, m, k, b_f64, n);
562
563 let mut scalars = Vec::with_capacity(m * n);
565 let mut argmax = Vec::with_capacity(m * n);
566 for j in 0..n {
567 for i in 0..m {
568 scalars.push(result.get(j, i).value());
569 argmax.push(result.get_argmax(j, i));
570 }
571 }
572
573 Some((unsafe { std::mem::transmute(scalars) }, argmax))
574 } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f32>>() {
575 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
576 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
577
578 let result = tropical_matmul_with_argmax::<TropicalMinPlus<f32>>(a_f32, m, k, b_f32, n);
579
580 let mut scalars = Vec::with_capacity(m * n);
582 let mut argmax = Vec::with_capacity(m * n);
583 for j in 0..n {
584 for i in 0..m {
585 scalars.push(result.get(j, i).value());
586 argmax.push(result.get_argmax(j, i));
587 }
588 }
589
590 Some((unsafe { std::mem::transmute(scalars) }, argmax))
591 } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f64>>() {
592 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
593 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
594
595 let result = tropical_matmul_with_argmax::<TropicalMinPlus<f64>>(a_f64, m, k, b_f64, n);
596
597 let mut scalars = Vec::with_capacity(m * n);
599 let mut argmax = Vec::with_capacity(m * n);
600 for j in 0..n {
601 for i in 0..m {
602 scalars.push(result.get(j, i).value());
603 argmax.push(result.get_argmax(j, i));
604 }
605 }
606
607 Some((unsafe { std::mem::transmute(scalars) }, argmax))
608 } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f32>>() {
609 let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
610 let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
611
612 let result = tropical_matmul_with_argmax::<TropicalMaxMul<f32>>(a_f32, m, k, b_f32, n);
613
614 let mut scalars = Vec::with_capacity(m * n);
616 let mut argmax = Vec::with_capacity(m * n);
617 for j in 0..n {
618 for i in 0..m {
619 scalars.push(result.get(j, i).value());
620 argmax.push(result.get_argmax(j, i));
621 }
622 }
623
624 Some((unsafe { std::mem::transmute(scalars) }, argmax))
625 } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f64>>() {
626 let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
627 let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
628
629 let result = tropical_matmul_with_argmax::<TropicalMaxMul<f64>>(a_f64, m, k, b_f64, n);
630
631 let mut scalars = Vec::with_capacity(m * n);
633 let mut argmax = Vec::with_capacity(m * n);
634 for j in 0..n {
635 for i in 0..m {
636 scalars.push(result.get(j, i).value());
637 argmax.push(result.get_argmax(j, i));
638 }
639 }
640
641 Some((unsafe { std::mem::transmute(scalars) }, argmax))
642 } else {
643 None
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::algebra::Standard;
652
653 #[cfg(feature = "tropical")]
654 use crate::algebra::MaxPlus;
655
656 #[test]
657 fn test_cpu_gemm_standard() {
658 let cpu = Cpu;
659 let a = vec![1.0f32, 2.0, 3.0, 4.0]; let b = vec![1.0f32, 2.0, 3.0, 4.0]; let c = cpu.gemm_internal::<Standard<f32>>(&a, 2, 2, &b, 2);
663
664 assert_eq!(c, vec![7.0, 10.0, 15.0, 22.0]);
667 }
668
669 #[cfg(feature = "tropical")]
670 #[test]
671 fn test_cpu_gemm_maxplus() {
672 let cpu = Cpu;
673 let a = vec![1.0f32, 2.0, 3.0, 4.0]; let b = vec![1.0f32, 2.0, 3.0, 4.0]; let c = cpu.gemm_internal::<MaxPlus<f32>>(&a, 2, 2, &b, 2);
677
678 assert_eq!(c, vec![5.0, 6.0, 7.0, 8.0]);
684 }
685
686 #[cfg(feature = "tropical")]
687 #[test]
688 fn test_cpu_gemm_with_argmax() {
689 let cpu = Cpu;
690 let a = vec![1.0f32, 2.0, 3.0, 4.0];
691 let b = vec![1.0f32, 2.0, 3.0, 4.0];
692
693 let (c, argmax) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, 2, 2, &b, 2);
694
695 assert_eq!(c, vec![5.0, 6.0, 7.0, 8.0]);
696 assert_eq!(argmax, vec![1, 1, 1, 1]);
698 }
699
700 #[test]
701 fn test_copy_strided() {
702 let cpu = Cpu;
703 let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
707
708 let dst = cpu.copy_strided(&src, &[3, 2], &[2, 1], 0);
711
712 assert_eq!(dst, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
717 }
718
719 #[cfg(feature = "tropical-kernels")]
721 #[test]
722 fn test_tropical_gemm_optimized_maxplus() {
723 use crate::algebra::MaxPlus;
724
725 let cpu = Cpu;
726 let m = 64;
727 let k = 64;
728 let n = 64;
729
730 let a: Vec<f32> = (0..m * k).map(|i| (i % 100) as f32).collect();
731 let b: Vec<f32> = (0..k * n).map(|i| (i % 100) as f32).collect();
732
733 let c_opt = cpu.gemm_internal::<MaxPlus<f32>>(&a, m, k, &b, n);
735 let c_generic = generic_gemm::<MaxPlus<f32>>(&a, m, k, &b, n);
736
737 for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
738 assert!(
739 (opt - gen).abs() < 1e-6,
740 "MaxPlus mismatch at index {}: opt={}, gen={}",
741 i,
742 opt,
743 gen
744 );
745 }
746 }
747
748 #[cfg(feature = "tropical-kernels")]
749 #[test]
750 fn test_tropical_gemm_optimized_minplus() {
751 use crate::algebra::MinPlus;
752
753 let cpu = Cpu;
754 let m = 32;
755 let k = 32;
756 let n = 32;
757
758 let a: Vec<f32> = (0..m * k).map(|i| (i % 50) as f32).collect();
759 let b: Vec<f32> = (0..k * n).map(|i| (i % 50) as f32).collect();
760
761 let c_opt = cpu.gemm_internal::<MinPlus<f32>>(&a, m, k, &b, n);
763 let c_generic = generic_gemm::<MinPlus<f32>>(&a, m, k, &b, n);
764
765 for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
766 assert!(
767 (opt - gen).abs() < 1e-6,
768 "MinPlus mismatch at index {}: opt={}, gen={}",
769 i,
770 opt,
771 gen
772 );
773 }
774 }
775
776 #[cfg(feature = "tropical-kernels")]
777 #[test]
778 fn test_tropical_gemm_optimized_maxmul() {
779 use crate::algebra::MaxMul;
780
781 let cpu = Cpu;
782 let m = 16;
783 let k = 16;
784 let n = 16;
785
786 let a: Vec<f32> = (0..m * k).map(|i| ((i % 10) as f32) * 0.1 + 0.1).collect();
788 let b: Vec<f32> = (0..k * n).map(|i| ((i % 10) as f32) * 0.1 + 0.1).collect();
789
790 let c_opt = cpu.gemm_internal::<MaxMul<f32>>(&a, m, k, &b, n);
792 let c_generic = generic_gemm::<MaxMul<f32>>(&a, m, k, &b, n);
793
794 for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
795 assert!(
796 (opt - gen).abs() < 1e-5,
797 "MaxMul mismatch at index {}: opt={}, gen={}",
798 i,
799 opt,
800 gen
801 );
802 }
803 }
804
805 #[cfg(feature = "tropical-kernels")]
806 #[test]
807 fn test_tropical_gemm_with_argmax_optimized() {
808 use crate::algebra::MaxPlus;
809
810 let cpu = Cpu;
811 let m = 32;
812 let k = 32;
813 let n = 32;
814
815 let a: Vec<f32> = (0..m * k).map(|i| (i % 100) as f32).collect();
816 let b: Vec<f32> = (0..k * n).map(|i| (i % 100) as f32).collect();
817
818 let (c_opt, argmax_opt) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, m, k, &b, n);
820 let (c_generic, argmax_generic) = generic_gemm_with_argmax::<MaxPlus<f32>>(&a, m, k, &b, n);
821
822 for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
823 assert!(
824 (opt - gen).abs() < 1e-6,
825 "MaxPlus with argmax: value mismatch at index {}: opt={}, gen={}",
826 i,
827 opt,
828 gen
829 );
830 }
831
832 for (i, (opt, gen)) in argmax_opt.iter().zip(argmax_generic.iter()).enumerate() {
833 assert_eq!(
834 opt, gen,
835 "MaxPlus with argmax: argmax mismatch at index {}: opt={}, gen={}",
836 i, opt, gen
837 );
838 }
839 }
840
841 #[cfg(feature = "tropical")]
842 #[test]
843 fn test_gemm_backward() {
844 let cpu = Cpu;
845 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let (_c, argmax) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, 2, 3, &b, 2);
849
850 let grad_c = vec![1.0f32; 4];
851 let grad_a = cpu.gemm_backward_a_internal::<MaxPlus<f32>>(&grad_c, &argmax, &b, 2, 3, 2);
852 let grad_b = cpu.gemm_backward_b_internal::<MaxPlus<f32>>(&grad_c, &argmax, &a, 2, 3, 2);
853
854 assert_eq!(grad_a.len(), 6);
855 assert_eq!(grad_b.len(), 6);
856
857 let grad_a_sum: f32 = grad_a.iter().sum();
861 let grad_b_sum: f32 = grad_b.iter().sum();
862 let grad_c_sum: f32 = grad_c.iter().sum();
863
864 assert_eq!(grad_a_sum, grad_c_sum, "grad_a sum should equal grad_c sum");
865 assert_eq!(grad_b_sum, grad_c_sum, "grad_b sum should equal grad_c sum");
866 }
867}