1use std::ops::Mul;
4
5use crate::simd::KernelDispatch;
6use crate::types::TropicalSemiring;
7
8use super::{Mat, MatRef};
9
10impl<'a, 'b, S> Mul<&'b MatRef<'b, S>> for &'a MatRef<'a, S>
12where
13 S: TropicalSemiring + KernelDispatch,
14{
15 type Output = Mat<S>;
16
17 fn mul(self, rhs: &'b MatRef<'b, S>) -> Mat<S> {
18 self.matmul(rhs)
19 }
20}
21
22impl<'a, 'b, S> Mul<MatRef<'b, S>> for MatRef<'a, S>
24where
25 S: TropicalSemiring + KernelDispatch,
26{
27 type Output = Mat<S>;
28
29 fn mul(self, rhs: MatRef<'b, S>) -> Mat<S> {
30 self.matmul(&rhs)
31 }
32}
33
34impl<'a, S> Mul<&'a MatRef<'a, S>> for &Mat<S>
36where
37 S: TropicalSemiring + KernelDispatch,
38 S::Scalar: Copy,
39{
40 type Output = Mat<S>;
41
42 fn mul(self, rhs: &'a MatRef<'a, S>) -> Mat<S> {
43 self.as_ref().matmul(rhs)
44 }
45}
46
47impl<S> Mul<&Mat<S>> for &Mat<S>
49where
50 S: TropicalSemiring + KernelDispatch,
51 S::Scalar: Copy,
52{
53 type Output = Mat<S>;
54
55 fn mul(self, rhs: &Mat<S>) -> Mat<S> {
56 self.as_ref().matmul(&rhs.as_ref())
57 }
58}
59
60impl<S> Mul<Mat<S>> for Mat<S>
62where
63 S: TropicalSemiring + KernelDispatch,
64 S::Scalar: Copy,
65{
66 type Output = Mat<S>;
67
68 fn mul(self, rhs: Mat<S>) -> Mat<S> {
69 self.as_ref().matmul(&rhs.as_ref())
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::TropicalMaxPlus;
77
78 #[test]
79 fn test_matref_mul_matref() {
80 let a_data = [1.0f64, 3.0, 2.0, 4.0];
82 let b_data = [1.0f64, 3.0, 2.0, 4.0];
83
84 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
85 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
86
87 let c = &a * &b;
88
89 assert_eq!(c[(0, 0)].0, 5.0);
91 }
92
93 #[test]
94 fn test_matref_mul_matref_by_value() {
95 let a_data = [1.0f64, 3.0, 2.0, 4.0];
97 let b_data = [1.0f64, 3.0, 2.0, 4.0];
98
99 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
100 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
101
102 let c = a * b;
104
105 assert_eq!(c[(0, 0)].0, 5.0);
106 }
107
108 #[test]
109 fn test_mat_ref_mul_matref() {
110 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
111 let b_data = [1.0f64, 3.0, 2.0, 4.0];
113 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
114
115 let c = &a * &b;
116
117 assert_eq!(c[(0, 0)].0, 5.0);
118 }
119
120 #[test]
121 fn test_mat_mul_mat() {
122 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
123 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
124
125 let c = &a * &b;
126
127 assert_eq!(c[(0, 0)].0, 5.0);
128 }
129
130 #[test]
131 fn test_mat_mul_consuming() {
132 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
133 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
134
135 let c = a * b;
136
137 assert_eq!(c[(0, 0)].0, 5.0);
138 }
139}