1use super::Tensor;
4use crate::algebra::{Algebra, Scalar};
5use crate::backend::{Backend, BackendScalar};
6
7fn compute_output_shape(
9 shape_a: &[usize],
10 modes_a: &[i32],
11 shape_b: &[usize],
12 modes_b: &[i32],
13 modes_c: &[i32],
14) -> Vec<usize> {
15 let mut shape_map = std::collections::HashMap::new();
16 for (idx, &m) in modes_a.iter().enumerate() {
17 shape_map.insert(m, shape_a[idx]);
18 }
19 for (idx, &m) in modes_b.iter().enumerate() {
20 shape_map.insert(m, shape_b[idx]);
21 }
22 modes_c.iter().map(|m| shape_map[m]).collect()
23}
24
25impl<T: Scalar, B: Backend> Tensor<T, B> {
26 pub fn contract_binary<A: Algebra<Scalar = T, Index = u32>>(
48 &self,
49 other: &Self,
50 ia: &[usize],
51 ib: &[usize],
52 iy: &[usize],
53 ) -> Self
54 where
55 T: BackendScalar<B>,
56 {
57 let (result, _) = self.contract_binary_impl::<A>(other, ia, ib, iy, false);
58 result
59 }
60
61 pub fn contract_binary_with_argmax<A: Algebra<Scalar = T, Index = u32>>(
63 &self,
64 other: &Self,
65 ia: &[usize],
66 ib: &[usize],
67 iy: &[usize],
68 ) -> (Self, Tensor<u32, B>)
69 where
70 T: BackendScalar<B>,
71 {
72 let (result, argmax) = self.contract_binary_impl::<A>(other, ia, ib, iy, true);
73 (result, argmax.expect("argmax requested but not returned"))
74 }
75
76 fn contract_binary_impl<A: Algebra<Scalar = T, Index = u32>>(
77 &self,
78 other: &Self,
79 ia: &[usize],
80 ib: &[usize],
81 iy: &[usize],
82 track_argmax: bool,
83 ) -> (Self, Option<Tensor<u32, B>>)
84 where
85 T: BackendScalar<B>,
86 {
87 assert_eq!(ia.len(), self.ndim(), "ia length must match self.ndim()");
88 assert_eq!(ib.len(), other.ndim(), "ib length must match other.ndim()");
89
90 let modes_a: Vec<i32> = ia.iter().map(|&i| i as i32).collect();
92 let modes_b: Vec<i32> = ib.iter().map(|&i| i as i32).collect();
93 let modes_c: Vec<i32> = iy.iter().map(|&i| i as i32).collect();
94
95 let shape_c = compute_output_shape(
97 self.shape(), &modes_a,
98 other.shape(), &modes_b,
99 &modes_c,
100 );
101
102 if track_argmax {
103 let (c_storage, argmax_storage) = self.backend.contract_with_argmax::<A>(
104 self.storage.as_ref(),
105 self.shape(),
106 self.strides(),
107 &modes_a,
108 other.storage.as_ref(),
109 other.shape(),
110 other.strides(),
111 &modes_b,
112 &shape_c,
113 &modes_c,
114 );
115
116 let c = Self::from_storage(c_storage, &shape_c, self.backend.clone());
117 let argmax = Tensor::<u32, B>::from_storage(
118 argmax_storage,
119 &shape_c,
120 self.backend.clone(),
121 );
122 (c, Some(argmax))
123 } else {
124 let c_storage = self.backend.contract::<A>(
125 self.storage.as_ref(),
126 self.shape(),
127 self.strides(),
128 &modes_a,
129 other.storage.as_ref(),
130 other.shape(),
131 other.strides(),
132 &modes_b,
133 &shape_c,
134 &modes_c,
135 );
136
137 let c = Self::from_storage(c_storage, &shape_c, self.backend.clone());
138 (c, None)
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::algebra::Standard;
147 use crate::backend::Cpu;
148
149 #[cfg(feature = "tropical")]
150 use crate::algebra::MaxPlus;
151
152 #[test]
153 fn test_contract_binary_matmul_standard() {
154 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
156 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
157
158 let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
159
160 assert_eq!(c.shape(), &[2, 2]);
161 assert_eq!(c.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
162 }
163
164 #[cfg(feature = "tropical")]
165 #[test]
166 fn test_contract_binary_matmul_maxplus() {
167 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
169 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
170
171 let c = a.contract_binary::<MaxPlus<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
172
173 assert_eq!(c.shape(), &[2, 2]);
174 assert_eq!(c.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
175 }
176
177 #[test]
178 fn test_contract_binary() {
179 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
181 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
182
183 let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
184
185 assert_eq!(c.shape(), &[2, 2]);
186 assert_eq!(c.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
187 }
188
189 #[test]
190 fn test_contract_binary_batched() {
191 let a =
195 Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 2, 2]);
196 let b =
197 Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0], &[2, 2, 2]);
198
199 let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1, 2], &[0, 2, 3], &[0, 1, 3]);
200
201 assert_eq!(c.shape(), &[2, 2, 2]);
202 assert_eq!(c.to_vec(), vec![16.0, 28.0, 24.0, 40.0, 1.0, 6.0, 3.0, 8.0]);
208 }
209}