tropical_gemm/mat/
ops.rs

1//! Operator implementations for matrix types.
2
3use std::ops::Mul;
4
5use crate::simd::KernelDispatch;
6use crate::types::TropicalSemiring;
7
8use super::{Mat, MatRef};
9
10// MatRef * MatRef
11impl<'a, 'b, S> Mul<&'b MatRef<'b, S>> for &'a MatRef<'a, S>
12where
13    S: TropicalSemiring + KernelDispatch,
14{
15    type Output = Mat<S>;
16
17    fn mul(self, rhs: &'b MatRef<'b, S>) -> Mat<S> {
18        self.matmul(rhs)
19    }
20}
21
22// MatRef * MatRef (by value, since MatRef is Copy)
23impl<'a, 'b, S> Mul<MatRef<'b, S>> for MatRef<'a, S>
24where
25    S: TropicalSemiring + KernelDispatch,
26{
27    type Output = Mat<S>;
28
29    fn mul(self, rhs: MatRef<'b, S>) -> Mat<S> {
30        self.matmul(&rhs)
31    }
32}
33
34// &Mat * &MatRef
35impl<'a, S> Mul<&'a MatRef<'a, S>> for &Mat<S>
36where
37    S: TropicalSemiring + KernelDispatch,
38    S::Scalar: Copy,
39{
40    type Output = Mat<S>;
41
42    fn mul(self, rhs: &'a MatRef<'a, S>) -> Mat<S> {
43        self.as_ref().matmul(rhs)
44    }
45}
46
47// &Mat * &Mat
48impl<S> Mul<&Mat<S>> for &Mat<S>
49where
50    S: TropicalSemiring + KernelDispatch,
51    S::Scalar: Copy,
52{
53    type Output = Mat<S>;
54
55    fn mul(self, rhs: &Mat<S>) -> Mat<S> {
56        self.as_ref().matmul(&rhs.as_ref())
57    }
58}
59
60// Mat * Mat (consuming)
61impl<S> Mul<Mat<S>> for Mat<S>
62where
63    S: TropicalSemiring + KernelDispatch,
64    S::Scalar: Copy,
65{
66    type Output = Mat<S>;
67
68    fn mul(self, rhs: Mat<S>) -> Mat<S> {
69        self.as_ref().matmul(&rhs.as_ref())
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use crate::TropicalMaxPlus;
77
78    #[test]
79    fn test_matref_mul_matref() {
80        // Column-major: 2×2 matrix [[1,2],[3,4]] stored as [1,3,2,4]
81        let a_data = [1.0f64, 3.0, 2.0, 4.0];
82        let b_data = [1.0f64, 3.0, 2.0, 4.0];
83
84        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
85        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
86
87        let c = &a * &b;
88
89        // C[0,0] = max(1+1, 2+3) = 5
90        assert_eq!(c[(0, 0)].0, 5.0);
91    }
92
93    #[test]
94    fn test_matref_mul_matref_by_value() {
95        // Column-major: 2×2 matrix [[1,2],[3,4]] stored as [1,3,2,4]
96        let a_data = [1.0f64, 3.0, 2.0, 4.0];
97        let b_data = [1.0f64, 3.0, 2.0, 4.0];
98
99        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
100        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
101
102        // MatRef is Copy, so this tests the by-value multiplication
103        let c = a * b;
104
105        assert_eq!(c[(0, 0)].0, 5.0);
106    }
107
108    #[test]
109    fn test_mat_ref_mul_matref() {
110        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
111        // Column-major: 2×2 matrix [[1,2],[3,4]] stored as [1,3,2,4]
112        let b_data = [1.0f64, 3.0, 2.0, 4.0];
113        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 2, 2);
114
115        let c = &a * &b;
116
117        assert_eq!(c[(0, 0)].0, 5.0);
118    }
119
120    #[test]
121    fn test_mat_mul_mat() {
122        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
123        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
124
125        let c = &a * &b;
126
127        assert_eq!(c[(0, 0)].0, 5.0);
128    }
129
130    #[test]
131    fn test_mat_mul_consuming() {
132        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
133        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
134
135        let c = a * b;
136
137        assert_eq!(c[(0, 0)].0, 5.0);
138    }
139}