1use crate::core::{GemmWithArgmax, Transpose};
2use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
3use crate::types::{TropicalSemiring, TropicalWithArgmax};
4
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7
8pub fn tropical_matmul<T: TropicalSemiring + KernelDispatch>(
34 a: &[T::Scalar],
35 m: usize,
36 k: usize,
37 b: &[T::Scalar],
38 n: usize,
39) -> Vec<T> {
40 assert_eq!(a.len(), m * k, "A dimensions mismatch");
41 assert_eq!(b.len(), k * n, "B dimensions mismatch");
42
43 let mut c = vec![T::tropical_zero(); m * n];
44
45 unsafe {
46 tropical_gemm_dispatch::<T>(
47 m,
48 n,
49 k,
50 a.as_ptr(),
51 k,
52 Transpose::NoTrans,
53 b.as_ptr(),
54 n,
55 Transpose::NoTrans,
56 c.as_mut_ptr(),
57 n,
58 );
59 }
60
61 c
62}
63
64pub fn tropical_matmul_with_argmax<T: TropicalWithArgmax<Index = u32> + KernelDispatch>(
82 a: &[T::Scalar],
83 m: usize,
84 k: usize,
85 b: &[T::Scalar],
86 n: usize,
87) -> GemmWithArgmax<T> {
88 assert_eq!(a.len(), m * k, "A dimensions mismatch");
89 assert_eq!(b.len(), k * n, "B dimensions mismatch");
90
91 let mut result = GemmWithArgmax::new(m, n);
92
93 unsafe {
94 crate::core::tropical_gemm_with_argmax_portable::<T>(
95 m,
96 n,
97 k,
98 a.as_ptr(),
99 k,
100 Transpose::NoTrans,
101 b.as_ptr(),
102 n,
103 Transpose::NoTrans,
104 &mut result,
105 );
106 }
107
108 result
109}
110
111pub struct TropicalGemm<T: TropicalSemiring> {
129 m: usize,
130 n: usize,
131 k: usize,
132 trans_a: Transpose,
133 trans_b: Transpose,
134 _phantom: std::marker::PhantomData<T>,
135}
136
137impl<T: TropicalSemiring + KernelDispatch> TropicalGemm<T> {
138 pub fn new(m: usize, n: usize, k: usize) -> Self {
140 Self {
141 m,
142 n,
143 k,
144 trans_a: Transpose::NoTrans,
145 trans_b: Transpose::NoTrans,
146 _phantom: std::marker::PhantomData,
147 }
148 }
149
150 pub fn trans_a(mut self) -> Self {
152 self.trans_a = Transpose::Trans;
153 self
154 }
155
156 pub fn trans_b(mut self) -> Self {
158 self.trans_b = Transpose::Trans;
159 self
160 }
161
162 pub fn execute(
172 self,
173 a: &[T::Scalar],
174 lda: usize,
175 b: &[T::Scalar],
176 ldb: usize,
177 c: &mut [T],
178 ldc: usize,
179 ) {
180 unsafe {
181 tropical_gemm_dispatch::<T>(
182 self.m,
183 self.n,
184 self.k,
185 a.as_ptr(),
186 lda,
187 self.trans_a,
188 b.as_ptr(),
189 ldb,
190 self.trans_b,
191 c.as_mut_ptr(),
192 ldc,
193 );
194 }
195 }
196}
197
198pub unsafe fn tropical_gemm<T: TropicalSemiring + KernelDispatch>(
205 m: usize,
206 n: usize,
207 k: usize,
208 a: *const T::Scalar,
209 lda: usize,
210 trans_a: Transpose,
211 b: *const T::Scalar,
212 ldb: usize,
213 trans_b: Transpose,
214 c: *mut T,
215 ldc: usize,
216) {
217 tropical_gemm_dispatch::<T>(m, n, k, a, lda, trans_a, b, ldb, trans_b, c, ldc);
218}
219
220pub fn tropical_matmul_batched<T: TropicalSemiring + KernelDispatch>(
256 a_batch: &[Vec<T::Scalar>],
257 b_batch: &[Vec<T::Scalar>],
258 m: usize,
259 k: usize,
260 n: usize,
261) -> Vec<Vec<T>>
262where
263 T::Scalar: Send + Sync,
264 T: Send + Sync,
265{
266 assert_eq!(
267 a_batch.len(),
268 b_batch.len(),
269 "Batch sizes must match: A has {} matrices, B has {}",
270 a_batch.len(),
271 b_batch.len()
272 );
273
274 let batch_size = a_batch.len();
275 if batch_size == 0 {
276 return Vec::new();
277 }
278
279 for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
281 assert_eq!(
282 a.len(),
283 m * k,
284 "A[{}] dimensions mismatch: expected {}, got {}",
285 i,
286 m * k,
287 a.len()
288 );
289 assert_eq!(
290 b.len(),
291 k * n,
292 "B[{}] dimensions mismatch: expected {}, got {}",
293 i,
294 k * n,
295 b.len()
296 );
297 }
298
299 #[cfg(feature = "parallel")]
300 {
301 a_batch
302 .par_iter()
303 .zip(b_batch.par_iter())
304 .map(|(a, b)| tropical_matmul::<T>(a, m, k, b, n))
305 .collect()
306 }
307
308 #[cfg(not(feature = "parallel"))]
309 {
310 a_batch
311 .iter()
312 .zip(b_batch.iter())
313 .map(|(a, b)| tropical_matmul::<T>(a, m, k, b, n))
314 .collect()
315 }
316}
317
318pub fn tropical_matmul_batched_with_argmax<T: TropicalWithArgmax<Index = u32> + KernelDispatch>(
332 a_batch: &[Vec<T::Scalar>],
333 b_batch: &[Vec<T::Scalar>],
334 m: usize,
335 k: usize,
336 n: usize,
337) -> Vec<GemmWithArgmax<T>>
338where
339 T::Scalar: Send + Sync,
340 T: Send + Sync,
341{
342 assert_eq!(
343 a_batch.len(),
344 b_batch.len(),
345 "Batch sizes must match: A has {} matrices, B has {}",
346 a_batch.len(),
347 b_batch.len()
348 );
349
350 let batch_size = a_batch.len();
351 if batch_size == 0 {
352 return Vec::new();
353 }
354
355 for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
357 assert_eq!(
358 a.len(),
359 m * k,
360 "A[{}] dimensions mismatch: expected {}, got {}",
361 i,
362 m * k,
363 a.len()
364 );
365 assert_eq!(
366 b.len(),
367 k * n,
368 "B[{}] dimensions mismatch: expected {}, got {}",
369 i,
370 k * n,
371 b.len()
372 );
373 }
374
375 #[cfg(feature = "parallel")]
376 {
377 a_batch
378 .par_iter()
379 .zip(b_batch.par_iter())
380 .map(|(a, b)| tropical_matmul_with_argmax::<T>(a, m, k, b, n))
381 .collect()
382 }
383
384 #[cfg(not(feature = "parallel"))]
385 {
386 a_batch
387 .iter()
388 .zip(b_batch.iter())
389 .map(|(a, b)| tropical_matmul_with_argmax::<T>(a, m, k, b, n))
390 .collect()
391 }
392}
393
394pub fn tropical_matmul_strided_batched<T: TropicalSemiring + KernelDispatch>(
429 a: &[T::Scalar],
430 b: &[T::Scalar],
431 batch_size: usize,
432 m: usize,
433 k: usize,
434 n: usize,
435) -> Vec<T>
436where
437 T::Scalar: Send + Sync + Copy,
438 T: Send + Sync,
439{
440 let a_stride = m * k;
441 let b_stride = k * n;
442 let c_stride = m * n;
443
444 assert_eq!(
445 a.len(),
446 batch_size * a_stride,
447 "A size mismatch: expected {}, got {}",
448 batch_size * a_stride,
449 a.len()
450 );
451 assert_eq!(
452 b.len(),
453 batch_size * b_stride,
454 "B size mismatch: expected {}, got {}",
455 batch_size * b_stride,
456 b.len()
457 );
458
459 if batch_size == 0 {
460 return Vec::new();
461 }
462
463 let mut c = vec![T::tropical_zero(); batch_size * c_stride];
464
465 #[cfg(feature = "parallel")]
466 {
467 c.par_chunks_mut(c_stride)
468 .enumerate()
469 .for_each(|(i, c_chunk)| {
470 let a_slice = &a[i * a_stride..(i + 1) * a_stride];
471 let b_slice = &b[i * b_stride..(i + 1) * b_stride];
472
473 unsafe {
474 tropical_gemm_dispatch::<T>(
475 m,
476 n,
477 k,
478 a_slice.as_ptr(),
479 k,
480 Transpose::NoTrans,
481 b_slice.as_ptr(),
482 n,
483 Transpose::NoTrans,
484 c_chunk.as_mut_ptr(),
485 n,
486 );
487 }
488 });
489 }
490
491 #[cfg(not(feature = "parallel"))]
492 {
493 for i in 0..batch_size {
494 let a_slice = &a[i * a_stride..(i + 1) * a_stride];
495 let b_slice = &b[i * b_stride..(i + 1) * b_stride];
496 let c_slice = &mut c[i * c_stride..(i + 1) * c_stride];
497
498 unsafe {
499 tropical_gemm_dispatch::<T>(
500 m,
501 n,
502 k,
503 a_slice.as_ptr(),
504 k,
505 Transpose::NoTrans,
506 b_slice.as_ptr(),
507 n,
508 Transpose::NoTrans,
509 c_slice.as_mut_ptr(),
510 n,
511 );
512 }
513 }
514 }
515
516 c
517}
518
519pub fn tropical_backward_a<T: Copy + Default + std::ops::AddAssign>(
564 grad_c: &[T],
565 argmax: &[u32],
566 m: usize,
567 k: usize,
568 n: usize,
569) -> Vec<T> {
570 assert_eq!(grad_c.len(), m * n, "grad_c size mismatch");
571 assert_eq!(argmax.len(), m * n, "argmax size mismatch");
572
573 let mut grad_a = vec![T::default(); m * k];
574
575 for i in 0..m {
576 for j in 0..n {
577 let idx = argmax[i * n + j] as usize;
578 if idx < k {
579 grad_a[i * k + idx] += grad_c[i * n + j];
580 }
581 }
582 }
583
584 grad_a
585}
586
587pub fn tropical_backward_b<T: Copy + Default + std::ops::AddAssign>(
628 grad_c: &[T],
629 argmax: &[u32],
630 m: usize,
631 k: usize,
632 n: usize,
633) -> Vec<T> {
634 assert_eq!(grad_c.len(), m * n, "grad_c size mismatch");
635 assert_eq!(argmax.len(), m * n, "argmax size mismatch");
636
637 let mut grad_b = vec![T::default(); k * n];
638
639 for i in 0..m {
640 for j in 0..n {
641 let idx = argmax[i * n + j] as usize;
642 if idx < k {
643 grad_b[idx * n + j] += grad_c[i * n + j];
644 }
645 }
646 }
647
648 grad_b
649}
650
651pub fn tropical_backward_a_batched<T: Copy + Default + std::ops::AddAssign + Send + Sync>(
667 grad_c_batch: &[Vec<T>],
668 argmax_batch: &[Vec<u32>],
669 m: usize,
670 k: usize,
671 n: usize,
672) -> Vec<Vec<T>> {
673 assert_eq!(
674 grad_c_batch.len(),
675 argmax_batch.len(),
676 "Batch sizes must match"
677 );
678
679 #[cfg(feature = "parallel")]
680 {
681 grad_c_batch
682 .par_iter()
683 .zip(argmax_batch.par_iter())
684 .map(|(grad_c, argmax)| tropical_backward_a(grad_c, argmax, m, k, n))
685 .collect()
686 }
687
688 #[cfg(not(feature = "parallel"))]
689 {
690 grad_c_batch
691 .iter()
692 .zip(argmax_batch.iter())
693 .map(|(grad_c, argmax)| tropical_backward_a(grad_c, argmax, m, k, n))
694 .collect()
695 }
696}
697
698pub fn tropical_backward_b_batched<T: Copy + Default + std::ops::AddAssign + Send + Sync>(
714 grad_c_batch: &[Vec<T>],
715 argmax_batch: &[Vec<u32>],
716 m: usize,
717 k: usize,
718 n: usize,
719) -> Vec<Vec<T>> {
720 assert_eq!(
721 grad_c_batch.len(),
722 argmax_batch.len(),
723 "Batch sizes must match"
724 );
725
726 #[cfg(feature = "parallel")]
727 {
728 grad_c_batch
729 .par_iter()
730 .zip(argmax_batch.par_iter())
731 .map(|(grad_c, argmax)| tropical_backward_b(grad_c, argmax, m, k, n))
732 .collect()
733 }
734
735 #[cfg(not(feature = "parallel"))]
736 {
737 grad_c_batch
738 .iter()
739 .zip(argmax_batch.iter())
740 .map(|(grad_c, argmax)| tropical_backward_b(grad_c, argmax, m, k, n))
741 .collect()
742 }
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748 use crate::types::TropicalMaxPlus;
749
750 #[test]
751 fn test_tropical_matmul() {
752 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
753 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
754
755 let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
756
757 assert_eq!(c[0].0, 8.0);
759 assert_eq!(c[1].0, 9.0);
761 assert_eq!(c[2].0, 11.0);
763 assert_eq!(c[3].0, 12.0);
765 }
766
767 #[test]
768 fn test_tropical_matmul_with_argmax() {
769 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
770 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
771
772 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
773
774 assert_eq!(result.get(0, 0).0, 8.0);
775 assert_eq!(result.get_argmax(0, 0), 2); assert_eq!(result.get(1, 1).0, 12.0);
778 assert_eq!(result.get_argmax(1, 1), 2); }
780
781 #[test]
782 fn test_builder_api() {
783 let a = vec![1.0f32; 6];
784 let b = vec![1.0f32; 6];
785 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
786
787 TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3).execute(&a, 3, &b, 2, &mut c, 2);
788
789 assert_eq!(c[0].0, 2.0);
791 }
792
793 #[test]
794 fn test_builder_api_trans_a() {
795 let a = vec![1.0f32, 4.0, 2.0, 5.0, 3.0, 6.0]; let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
800
801 TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3)
802 .trans_a()
803 .execute(&a, 2, &b, 2, &mut c, 2);
804
805 assert_eq!(c[0].0, 8.0);
809 }
810
811 #[test]
812 fn test_builder_api_trans_b() {
813 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![1.0f32, 3.0, 5.0, 2.0, 4.0, 6.0]; let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
817
818 TropicalGemm::<TropicalMaxPlus<f32>>::new(2, 2, 3)
819 .trans_b()
820 .execute(&a, 3, &b, 3, &mut c, 2);
821
822 assert_eq!(c[0].0, 8.0);
826 }
827
828 #[test]
829 fn test_tropical_matmul_min_plus() {
830 use crate::types::TropicalMinPlus;
831
832 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
833 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
834
835 let c = tropical_matmul::<TropicalMinPlus<f64>>(&a, 2, 3, &b, 2);
836
837 assert_eq!(c[0].0, 2.0);
839 assert_eq!(c[1].0, 3.0);
841 assert_eq!(c[2].0, 5.0);
843 assert_eq!(c[3].0, 6.0);
845 }
846
847 #[test]
848 fn test_tropical_matmul_max_mul() {
849 use crate::types::TropicalMaxMul;
850
851 let a = vec![2.0f64, 3.0, 4.0, 5.0];
852 let b = vec![1.0f64, 2.0, 3.0, 4.0];
853
854 let c = tropical_matmul::<TropicalMaxMul<f64>>(&a, 2, 2, &b, 2);
855
856 assert_eq!(c[0].0, 9.0);
858 assert_eq!(c[1].0, 12.0);
860 assert_eq!(c[2].0, 15.0);
862 assert_eq!(c[3].0, 20.0);
864 }
865
866 #[test]
867 fn test_tropical_matmul_f32() {
868 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
869 let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
870
871 let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, 2, 3, &b, 2);
872
873 assert!((c[0].0 - 8.0).abs() < 1e-6);
874 assert!((c[1].0 - 9.0).abs() < 1e-6);
875 assert!((c[2].0 - 11.0).abs() < 1e-6);
876 assert!((c[3].0 - 12.0).abs() < 1e-6);
877 }
878
879 #[test]
880 fn test_non_square_matrices() {
881 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
883 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
884
885 let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 3, 2, &b, 4);
886
887 assert_eq!(c.len(), 12);
888 assert_eq!(c[0].0, 7.0);
890 }
891
892 #[test]
893 fn test_single_element() {
894 let a = vec![5.0f64];
895 let b = vec![3.0f64];
896
897 let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 1, 1, &b, 1);
898
899 assert_eq!(c.len(), 1);
900 assert_eq!(c[0].0, 8.0); }
902
903 #[test]
904 fn test_larger_matrix() {
905 let n = 16;
906 let a: Vec<f64> = (0..n * n).map(|i| i as f64).collect();
907 let b: Vec<f64> = (0..n * n).map(|i| (n * n - 1 - i) as f64).collect();
908
909 let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, n, n, &b, n);
910
911 assert_eq!(c.len(), n * n);
912 for val in &c {
914 assert!(val.0.is_finite());
915 }
916 }
917
918 #[test]
919 fn test_tropical_matmul_i32() {
920 let a = vec![1i32, 2, 3, 4, 5, 6];
921 let b = vec![1i32, 2, 3, 4, 5, 6];
922
923 let c = tropical_matmul::<TropicalMaxPlus<i32>>(&a, 2, 3, &b, 2);
924
925 assert_eq!(c[0].0, 8);
926 assert_eq!(c[1].0, 9);
927 assert_eq!(c[2].0, 11);
928 assert_eq!(c[3].0, 12);
929 }
930
931 #[test]
932 fn test_tropical_matmul_i64() {
933 let a = vec![1i64, 2, 3, 4, 5, 6];
934 let b = vec![1i64, 2, 3, 4, 5, 6];
935
936 let c = tropical_matmul::<TropicalMaxPlus<i64>>(&a, 2, 3, &b, 2);
937
938 assert_eq!(c[0].0, 8);
939 assert_eq!(c[1].0, 9);
940 assert_eq!(c[2].0, 11);
941 assert_eq!(c[3].0, 12);
942 }
943
944 #[test]
945 fn test_tropical_matmul_minplus_i32() {
946 use crate::types::TropicalMinPlus;
947
948 let a = vec![1i32, 2, 3, 4, 5, 6];
949 let b = vec![1i32, 2, 3, 4, 5, 6];
950
951 let c = tropical_matmul::<TropicalMinPlus<i32>>(&a, 2, 3, &b, 2);
952
953 assert_eq!(c[0].0, 2);
954 assert_eq!(c[1].0, 3);
955 assert_eq!(c[2].0, 5);
956 assert_eq!(c[3].0, 6);
957 }
958
959 #[test]
960 fn test_unsafe_tropical_gemm() {
961 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
962 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
963 let mut c = vec![TropicalMaxPlus::tropical_zero(); 4];
964
965 unsafe {
966 tropical_gemm::<TropicalMaxPlus<f64>>(
967 2,
968 2,
969 3,
970 a.as_ptr(),
971 3,
972 Transpose::NoTrans,
973 b.as_ptr(),
974 2,
975 Transpose::NoTrans,
976 c.as_mut_ptr(),
977 2,
978 );
979 }
980
981 assert_eq!(c[0].0, 8.0);
982 assert_eq!(c[1].0, 9.0);
983 assert_eq!(c[2].0, 11.0);
984 assert_eq!(c[3].0, 12.0);
985 }
986
987 #[test]
988 fn test_minplus_with_argmax() {
989 use crate::types::TropicalMinPlus;
990
991 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
992 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
993
994 let result = tropical_matmul_with_argmax::<TropicalMinPlus<f64>>(&a, 2, 3, &b, 2);
995
996 assert_eq!(result.get(0, 0).0, 2.0);
998 assert_eq!(result.get_argmax(0, 0), 0);
999
1000 assert_eq!(result.get(1, 1).0, 6.0);
1002 assert_eq!(result.get_argmax(1, 1), 0);
1003 }
1004
1005 #[test]
1006 fn test_maxmul_with_argmax() {
1007 use crate::types::TropicalMaxMul;
1008
1009 let a = vec![2.0f64, 3.0, 4.0, 5.0];
1010 let b = vec![1.0f64, 2.0, 3.0, 4.0];
1011
1012 let result = tropical_matmul_with_argmax::<TropicalMaxMul<f64>>(&a, 2, 2, &b, 2);
1013
1014 assert_eq!(result.get(0, 0).0, 9.0);
1016 assert_eq!(result.get_argmax(0, 0), 1);
1017 }
1018
1019 #[test]
1020 fn test_gemmwithargmax_dimensions() {
1021 let a = vec![1.0f64; 12]; let b = vec![1.0f64; 20]; let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 3, 4, &b, 5);
1025
1026 assert_eq!(result.m, 3);
1027 assert_eq!(result.n, 5);
1028 assert_eq!(result.values.len(), 15);
1029 assert_eq!(result.argmax.len(), 15);
1030 }
1031
1032 #[test]
1033 fn test_identity_like_matrix() {
1034 let a = vec![0.0f64, f64::NEG_INFINITY, f64::NEG_INFINITY, 0.0];
1036 let b = vec![1.0f64, 2.0, 3.0, 4.0];
1037
1038 let c = tropical_matmul::<TropicalMaxPlus<f64>>(&a, 2, 2, &b, 2);
1039
1040 assert_eq!(c[0].0, 1.0);
1042 assert_eq!(c[1].0, 2.0);
1043 assert_eq!(c[2].0, 3.0);
1044 assert_eq!(c[3].0, 4.0);
1045 }
1046
1047 #[test]
1048 fn test_tropical_matmul_batched() {
1049 let a_batch = vec![
1050 vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2.0f64, 3.0, 4.0, 5.0, 6.0, 7.0], ];
1053 let b_batch = vec![
1054 vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], ];
1057
1058 let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, 2, 3, 2);
1059
1060 assert_eq!(c_batch.len(), 2);
1061
1062 assert_eq!(c_batch[0][0].0, 8.0);
1064 assert_eq!(c_batch[0][3].0, 12.0);
1066
1067 assert_eq!(c_batch[1][0].0, 9.0);
1069 assert_eq!(c_batch[1][3].0, 13.0);
1071 }
1072
1073 #[test]
1074 fn test_tropical_matmul_batched_empty() {
1075 let a_batch: Vec<Vec<f64>> = vec![];
1076 let b_batch: Vec<Vec<f64>> = vec![];
1077
1078 let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, 2, 2, 2);
1079
1080 assert!(c_batch.is_empty());
1081 }
1082
1083 #[test]
1084 fn test_tropical_matmul_batched_with_argmax() {
1085 let a_batch = vec![
1086 vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], ];
1089 let b_batch = vec![
1090 vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![10.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], ];
1093
1094 let results = tropical_matmul_batched_with_argmax::<TropicalMaxPlus<f64>>(
1095 &a_batch, &b_batch, 2, 3, 2,
1096 );
1097
1098 assert_eq!(results.len(), 2);
1099
1100 assert_eq!(results[0].get(0, 0).0, 8.0);
1102 assert_eq!(results[0].get_argmax(0, 0), 2);
1103
1104 assert_eq!(results[1].get(0, 0).0, 11.0);
1106 assert_eq!(results[1].get_argmax(0, 0), 0);
1107 }
1108
1109 #[test]
1110 fn test_tropical_matmul_batched_with_argmax_empty() {
1111 let a_batch: Vec<Vec<f64>> = vec![];
1112 let b_batch: Vec<Vec<f64>> = vec![];
1113
1114 let results = tropical_matmul_batched_with_argmax::<TropicalMaxPlus<f64>>(
1115 &a_batch, &b_batch, 2, 2, 2,
1116 );
1117
1118 assert!(results.is_empty());
1119 }
1120
1121 #[test]
1122 fn test_tropical_matmul_strided_batched() {
1123 let a = vec![
1125 1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
1128 let b = vec![
1129 1.0f64, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, ];
1132
1133 let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f64>>(&a, &b, 2, 2, 2, 2);
1134
1135 assert_eq!(c.len(), 8);
1136
1137 assert_eq!(c[0].0, 5.0);
1139 assert_eq!(c[3].0, 8.0);
1141
1142 assert_eq!(c[4].0, 9.0);
1144 assert_eq!(c[7].0, 12.0);
1146 }
1147
1148 #[test]
1149 fn test_tropical_matmul_strided_batched_empty() {
1150 let a: Vec<f64> = vec![];
1151 let b: Vec<f64> = vec![];
1152
1153 let c = tropical_matmul_strided_batched::<TropicalMaxPlus<f64>>(&a, &b, 0, 2, 2, 2);
1154
1155 assert!(c.is_empty());
1156 }
1157
1158 #[test]
1159 fn test_tropical_matmul_strided_batched_minplus() {
1160 use crate::types::TropicalMinPlus;
1161
1162 let a = vec![
1163 1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
1166 let b = vec![
1167 1.0f64, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, ];
1170
1171 let c = tropical_matmul_strided_batched::<TropicalMinPlus<f64>>(&a, &b, 2, 2, 2, 2);
1172
1173 assert_eq!(c.len(), 8);
1174
1175 assert_eq!(c[0].0, 2.0);
1177 assert_eq!(c[3].0, 5.0);
1179 }
1180
1181 #[test]
1182 fn test_tropical_matmul_batched_larger() {
1183 let batch_size = 10;
1184 let m = 8;
1185 let k = 6;
1186 let n = 4;
1187
1188 let a_batch: Vec<Vec<f64>> = (0..batch_size)
1189 .map(|i| (0..m * k).map(|j| (i * m * k + j) as f64).collect())
1190 .collect();
1191 let b_batch: Vec<Vec<f64>> = (0..batch_size)
1192 .map(|_| (0..k * n).map(|j| j as f64).collect())
1193 .collect();
1194
1195 let c_batch = tropical_matmul_batched::<TropicalMaxPlus<f64>>(&a_batch, &b_batch, m, k, n);
1196
1197 assert_eq!(c_batch.len(), batch_size);
1198 for c in &c_batch {
1199 assert_eq!(c.len(), m * n);
1200 for val in c {
1202 assert!(val.0.is_finite());
1203 }
1204 }
1205 }
1206
1207 #[test]
1212 fn test_tropical_backward_a() {
1213 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1215 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1216
1217 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1219
1220 assert_eq!(result.get_argmax(0, 0), 2);
1226 assert_eq!(result.get_argmax(0, 1), 2);
1227 assert_eq!(result.get_argmax(1, 0), 2);
1228 assert_eq!(result.get_argmax(1, 1), 2);
1229
1230 let grad_c = vec![1.0f64; 4];
1232
1233 let grad_a = tropical_backward_a(&grad_c, result.argmax_slice(), 2, 3, 2);
1235
1236 assert_eq!(grad_a[0], 0.0); assert_eq!(grad_a[1], 0.0); assert_eq!(grad_a[2], 2.0); assert_eq!(grad_a[3], 0.0); assert_eq!(grad_a[4], 0.0); assert_eq!(grad_a[5], 2.0); }
1246
1247 #[test]
1248 fn test_tropical_backward_b() {
1249 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1250 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1251
1252 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1253
1254 let grad_c = vec![1.0f64; 4];
1255
1256 let grad_b = tropical_backward_b(&grad_c, result.argmax_slice(), 2, 3, 2);
1258
1259 assert_eq!(grad_b[0], 0.0); assert_eq!(grad_b[1], 0.0); assert_eq!(grad_b[2], 0.0); assert_eq!(grad_b[3], 0.0); assert_eq!(grad_b[4], 2.0); assert_eq!(grad_b[5], 2.0); }
1270
1271 #[test]
1272 fn test_tropical_backward_varied_argmax() {
1273 let a = vec![10.0f64, 1.0, 1.0, 10.0];
1277 let b = vec![1.0f64, 10.0, 10.0, 1.0];
1278
1279 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 2, &b, 2);
1280
1281 let grad_c = vec![1.0f64; 4];
1287 let grad_a = tropical_backward_a(&grad_c, result.argmax_slice(), 2, 2, 2);
1288 let grad_b = tropical_backward_b(&grad_c, result.argmax_slice(), 2, 2, 2);
1289
1290 assert_eq!(grad_a.len(), 4);
1292 assert_eq!(grad_b.len(), 4);
1293
1294 let total_grad_a: f64 = grad_a.iter().sum();
1296 let total_grad_b: f64 = grad_b.iter().sum();
1297 assert_eq!(total_grad_a, 4.0);
1298 assert_eq!(total_grad_b, 4.0);
1299 }
1300
1301 #[test]
1302 fn test_tropical_backward_batched() {
1303 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1304 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1305
1306 let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);
1307
1308 let grad_c_batch = vec![vec![1.0f64; 4], vec![2.0f64; 4]];
1310 let argmax_batch = vec![
1311 result.argmax_slice().to_vec(),
1312 result.argmax_slice().to_vec(),
1313 ];
1314
1315 let grad_a_batch = tropical_backward_a_batched(&grad_c_batch, &argmax_batch, 2, 3, 2);
1316 let grad_b_batch = tropical_backward_b_batched(&grad_c_batch, &argmax_batch, 2, 3, 2);
1317
1318 assert_eq!(grad_a_batch.len(), 2);
1319 assert_eq!(grad_b_batch.len(), 2);
1320
1321 assert_eq!(grad_a_batch[0][2], 2.0);
1323 assert_eq!(grad_b_batch[0][4], 2.0);
1324
1325 assert_eq!(grad_a_batch[1][2], 4.0);
1327 assert_eq!(grad_b_batch[1][4], 4.0);
1328 }
1329}