1mod backward;
7mod builder;
8mod engine;
9
10pub use builder::EinBuilder;
11pub use engine::Einsum;
12
13use crate::algebra::{Algebra, Scalar};
14use crate::backend::{Backend, BackendScalar};
15use crate::tensor::Tensor;
16
17pub fn einsum<A, T, B>(tensors: &[&Tensor<T, B>], ixs: &[&[usize]], iy: &[usize]) -> Tensor<T, B>
39where
40 A: Algebra<Scalar = T, Index = u32>,
41 T: Scalar + BackendScalar<B>,
42 B: Backend + Default,
43{
44 let size_dict = infer_size_dict(tensors, ixs);
45 let ixs_owned: Vec<Vec<usize>> = ixs.iter().map(|ix| ix.to_vec()).collect();
46
47 let mut ein = Einsum::new(ixs_owned, iy.to_vec(), size_dict);
48 ein.optimize_greedy();
49 ein.execute::<A, T, B>(tensors)
50}
51
52pub fn einsum_with_grad<A, T, B>(
60 tensors: &[&Tensor<T, B>],
61 ixs: &[&[usize]],
62 iy: &[usize],
63) -> (Tensor<T, B>, EinsumGradient<T, B>)
64where
65 A: Algebra<Scalar = T, Index = u32>,
66 T: Scalar + BackendScalar<B>,
67 B: Backend + Default,
68{
69 let size_dict = infer_size_dict(tensors, ixs);
70 let ixs_owned: Vec<Vec<usize>> = ixs.iter().map(|ix| ix.to_vec()).collect();
71
72 let mut ein = Einsum::new(ixs_owned.clone(), iy.to_vec(), size_dict.clone());
73 ein.optimize_greedy();
74
75 let (result, argmax_cache) = if A::needs_argmax() {
78 ein.execute_with_argmax::<A, T, B>(tensors)
79 } else {
80 (ein.execute::<A, T, B>(tensors), Vec::new())
81 };
82
83 let gradient = EinsumGradient {
84 ixs: ixs_owned,
85 iy: iy.to_vec(),
86 size_dict,
87 argmax_cache,
88 _phantom: std::marker::PhantomData,
89 };
90
91 (result, gradient)
92}
93
94pub struct EinsumGradient<T: Scalar, B: Backend> {
96 ixs: Vec<Vec<usize>>,
97 iy: Vec<usize>,
98 size_dict: std::collections::HashMap<usize, usize>,
99 argmax_cache: Vec<Tensor<u32, B>>,
100 _phantom: std::marker::PhantomData<T>,
101}
102
103impl<T: Scalar + BackendScalar<B>, B: Backend + Default> EinsumGradient<T, B> {
104 pub fn backward<A: Algebra<Scalar = T, Index = u32>>(
115 &self,
116 grad_output: &Tensor<T, B>,
117 inputs: &[&Tensor<T, B>],
118 ) -> Vec<Tensor<T, B>> {
119 assert_eq!(
120 inputs.len(),
121 self.ixs.len(),
122 "Number of inputs {} doesn't match stored indices {}",
123 inputs.len(),
124 self.ixs.len()
125 );
126
127 if inputs.len() == 1 {
129 let grad_x = if A::needs_argmax() {
130 let argmax = self
132 .argmax_cache
133 .first()
134 .expect("Tropical unary backward requires argmax from forward pass");
135 backward::tropical_unary_backward::<T, B>(grad_output, argmax, inputs[0].shape())
136 } else {
137 backward::contract_unary_backward::<A, T, B>(
141 grad_output,
142 &self.ixs[0],
143 &self.iy,
144 &self.size_dict,
145 )
146 };
147 return vec![grad_x];
148 }
149
150 if inputs.len() == 2 {
152 let argmax = if A::needs_argmax() && !self.argmax_cache.is_empty() {
153 Some(&self.argmax_cache[0])
154 } else {
155 None
156 };
157
158 let (grad_a, grad_b) = backward::contract_binary_backward::<A, T, B>(
159 grad_output,
160 inputs[0],
161 inputs[1],
162 argmax,
163 &self.ixs[0],
164 &self.ixs[1],
165 &self.iy,
166 );
167
168 return vec![grad_a, grad_b];
169 }
170
171 unimplemented!(
181 "Backward pass for {} inputs not yet implemented. \
182 Currently only 2-input contractions are supported.",
183 inputs.len()
184 )
185 }
186}
187
188fn infer_size_dict<T: Scalar, B: Backend>(
190 tensors: &[&Tensor<T, B>],
191 ixs: &[&[usize]],
192) -> std::collections::HashMap<usize, usize> {
193 let mut size_dict = std::collections::HashMap::new();
194
195 for (tensor, ix) in tensors.iter().zip(ixs.iter()) {
196 assert_eq!(
197 tensor.ndim(),
198 ix.len(),
199 "Index count {} doesn't match tensor ndim {}",
200 ix.len(),
201 tensor.ndim()
202 );
203
204 for (dim, &label) in ix.iter().enumerate() {
205 let size = tensor.shape()[dim];
206 if let Some(&existing) = size_dict.get(&label) {
207 assert_eq!(
208 existing, size,
209 "Inconsistent size for index {}: {} vs {}",
210 label, existing, size
211 );
212 } else {
213 size_dict.insert(label, size);
214 }
215 }
216 }
217
218 size_dict
219}