omeinsum/einsum/
mod.rs

1//! Einstein summation engine with contraction order optimization.
2//!
3//! This module provides the [`Einsum`] type for specifying and executing
4//! tensor network contractions, with optional optimization via omeco.
5
6mod backward;
7mod builder;
8mod engine;
9
10pub use builder::EinBuilder;
11pub use engine::Einsum;
12
13use crate::algebra::{Algebra, Scalar};
14use crate::backend::{Backend, BackendScalar};
15use crate::tensor::Tensor;
16
17/// One-shot einsum with automatic optimization.
18///
19/// # Arguments
20///
21/// * `tensors` - Input tensors
22/// * `ixs` - Index labels for each input tensor
23/// * `iy` - Output index labels
24///
25/// # Example
26///
27/// ```rust
28/// use omeinsum::{einsum, Tensor, Cpu};
29/// use omeinsum::algebra::MaxPlus;
30///
31/// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
32/// let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 4]);
33///
34/// // C[i,k] = max_j (A[i,j] + B[j,k])
35/// let c = einsum::<MaxPlus<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
36/// assert_eq!(c.shape(), &[2, 4]);
37/// ```
38pub fn einsum<A, T, B>(tensors: &[&Tensor<T, B>], ixs: &[&[usize]], iy: &[usize]) -> Tensor<T, B>
39where
40    A: Algebra<Scalar = T, Index = u32>,
41    T: Scalar + BackendScalar<B>,
42    B: Backend + Default,
43{
44    let size_dict = infer_size_dict(tensors, ixs);
45    let ixs_owned: Vec<Vec<usize>> = ixs.iter().map(|ix| ix.to_vec()).collect();
46
47    let mut ein = Einsum::new(ixs_owned, iy.to_vec(), size_dict);
48    ein.optimize_greedy();
49    ein.execute::<A, T, B>(tensors)
50}
51
52/// Einsum with gradient computation.
53///
54/// Returns `(result, gradient_fn)` where `gradient_fn` can be called
55/// with the output gradient to compute input gradients.
56///
57/// For Standard algebra, gradients are computed via einsum (no argmax tracking needed).
58/// For tropical algebras, argmax is tracked during forward pass for gradient routing.
59pub fn einsum_with_grad<A, T, B>(
60    tensors: &[&Tensor<T, B>],
61    ixs: &[&[usize]],
62    iy: &[usize],
63) -> (Tensor<T, B>, EinsumGradient<T, B>)
64where
65    A: Algebra<Scalar = T, Index = u32>,
66    T: Scalar + BackendScalar<B>,
67    B: Backend + Default,
68{
69    let size_dict = infer_size_dict(tensors, ixs);
70    let ixs_owned: Vec<Vec<usize>> = ixs.iter().map(|ix| ix.to_vec()).collect();
71
72    let mut ein = Einsum::new(ixs_owned.clone(), iy.to_vec(), size_dict.clone());
73    ein.optimize_greedy();
74
75    // Only track argmax for algebras that need it (tropical algebras)
76    // Standard algebra computes gradients via einsum, no argmax needed
77    let (result, argmax_cache) = if A::needs_argmax() {
78        ein.execute_with_argmax::<A, T, B>(tensors)
79    } else {
80        (ein.execute::<A, T, B>(tensors), Vec::new())
81    };
82
83    let gradient = EinsumGradient {
84        ixs: ixs_owned,
85        iy: iy.to_vec(),
86        size_dict,
87        argmax_cache,
88        _phantom: std::marker::PhantomData,
89    };
90
91    (result, gradient)
92}
93
94/// Gradient computation helper for einsum.
95pub struct EinsumGradient<T: Scalar, B: Backend> {
96    ixs: Vec<Vec<usize>>,
97    iy: Vec<usize>,
98    size_dict: std::collections::HashMap<usize, usize>,
99    argmax_cache: Vec<Tensor<u32, B>>,
100    _phantom: std::marker::PhantomData<T>,
101}
102
103impl<T: Scalar + BackendScalar<B>, B: Backend + Default> EinsumGradient<T, B> {
104    /// Compute gradients for all inputs given the output gradient.
105    ///
106    /// # Arguments
107    ///
108    /// * `grad_output` - Gradient of the einsum output
109    /// * `inputs` - Original input tensors (same as passed to forward)
110    ///
111    /// # Returns
112    ///
113    /// Vector of gradients, one for each input tensor.
114    pub fn backward<A: Algebra<Scalar = T, Index = u32>>(
115        &self,
116        grad_output: &Tensor<T, B>,
117        inputs: &[&Tensor<T, B>],
118    ) -> Vec<Tensor<T, B>> {
119        assert_eq!(
120            inputs.len(),
121            self.ixs.len(),
122            "Number of inputs {} doesn't match stored indices {}",
123            inputs.len(),
124            self.ixs.len()
125        );
126
127        // Handle single input case
128        if inputs.len() == 1 {
129            let grad_x = if A::needs_argmax() {
130                // Tropical algebras: route gradients through argmax
131                let argmax = self
132                    .argmax_cache
133                    .first()
134                    .expect("Tropical unary backward requires argmax from forward pass");
135                backward::tropical_unary_backward::<T, B>(grad_output, argmax, inputs[0].shape())
136            } else {
137                // Standard algebra: use index-exchange trick
138                // Forward: y = einsum(ix -> iy, x)
139                // Backward: grad_x = einsum(iy -> ix, grad_y)
140                backward::contract_unary_backward::<A, T, B>(
141                    grad_output,
142                    &self.ixs[0],
143                    &self.iy,
144                    &self.size_dict,
145                )
146            };
147            return vec![grad_x];
148        }
149
150        // For a single binary contraction (2 inputs), we can directly compute gradients
151        if inputs.len() == 2 {
152            let argmax = if A::needs_argmax() && !self.argmax_cache.is_empty() {
153                Some(&self.argmax_cache[0])
154            } else {
155                None
156            };
157
158            let (grad_a, grad_b) = backward::contract_binary_backward::<A, T, B>(
159                grad_output,
160                inputs[0],
161                inputs[1],
162                argmax,
163                &self.ixs[0],
164                &self.ixs[1],
165                &self.iy,
166            );
167
168            return vec![grad_a, grad_b];
169        }
170
171        // For more complex contractions with >2 tensors, we need to reverse through
172        // the contraction tree. This requires storing intermediate results from forward pass.
173        // For now, implement the simple case.
174        //
175        // TODO: Implement full backward pass for multi-tensor contractions
176        // This would require:
177        // 1. Storing intermediate results during forward pass
178        // 2. Reversing through the contraction tree
179        // 3. Accumulating gradients for each input
180        unimplemented!(
181            "Backward pass for {} inputs not yet implemented. \
182             Currently only 2-input contractions are supported.",
183            inputs.len()
184        )
185    }
186}
187
188/// Infer size dictionary from tensors and their index labels.
189fn infer_size_dict<T: Scalar, B: Backend>(
190    tensors: &[&Tensor<T, B>],
191    ixs: &[&[usize]],
192) -> std::collections::HashMap<usize, usize> {
193    let mut size_dict = std::collections::HashMap::new();
194
195    for (tensor, ix) in tensors.iter().zip(ixs.iter()) {
196        assert_eq!(
197            tensor.ndim(),
198            ix.len(),
199            "Index count {} doesn't match tensor ndim {}",
200            ix.len(),
201            tensor.ndim()
202        );
203
204        for (dim, &label) in ix.iter().enumerate() {
205            let size = tensor.shape()[dim];
206            if let Some(&existing) = size_dict.get(&label) {
207                assert_eq!(
208                    existing, size,
209                    "Inconsistent size for index {}: {} vs {}",
210                    label, existing, size
211                );
212            } else {
213                size_dict.insert(label, size);
214            }
215        }
216    }
217
218    size_dict
219}