omeinsum/einsum/
engine.rs

1//! Einsum execution engine with omeco integration.
2
3use std::collections::{HashMap, HashSet};
4
5use omeco::{optimize_code, EinCode, GreedyMethod, Label, NestedEinsum, TreeSA};
6
7use crate::algebra::{Algebra, Scalar};
8use crate::backend::{Backend, BackendScalar};
9use crate::tensor::Tensor;
10
11/// Einsum specification and execution engine.
12///
13/// Supports contraction order optimization via omeco.
14///
15/// # Example
16///
17/// ```rust
18/// use omeinsum::{Einsum, Tensor, Cpu};
19/// use omeinsum::algebra::MaxPlus;
20/// use std::collections::HashMap;
21///
22/// // A[i,j] × B[j,k] → C[i,k]
23/// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
24/// let b = Tensor::<f32, Cpu>::from_data(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
25///
26/// let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
27/// let mut ein = Einsum::new(
28///     vec![vec![0, 1], vec![1, 2]],
29///     vec![0, 2],
30///     sizes,
31/// );
32///
33/// ein.optimize_greedy();
34/// let result = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
35/// assert_eq!(result.shape(), &[2, 2]);
36/// ```
37pub struct Einsum<L: Label = usize> {
38    /// Input index labels for each tensor
39    pub ixs: Vec<Vec<L>>,
40
41    /// Output index labels
42    pub iy: Vec<L>,
43
44    /// Dimension sizes for each index
45    pub size_dict: HashMap<L, usize>,
46
47    /// Optimized contraction tree (after optimization)
48    optimized: Option<NestedEinsum<L>>,
49}
50
51impl<L: Label> Einsum<L> {
52    /// Create a new einsum specification.
53    ///
54    /// # Arguments
55    ///
56    /// * `ixs` - Index labels for each input tensor
57    /// * `iy` - Output index labels
58    /// * `size_dict` - Mapping from index labels to dimension sizes
59    pub fn new(ixs: Vec<Vec<L>>, iy: Vec<L>, size_dict: HashMap<L, usize>) -> Self {
60        Self {
61            ixs,
62            iy,
63            size_dict,
64            optimized: None,
65        }
66    }
67
68    /// Get the einsum code specification.
69    pub fn code(&self) -> EinCode<L> {
70        EinCode::new(self.ixs.clone(), self.iy.clone())
71    }
72
73    /// Optimize contraction order using greedy algorithm.
74    ///
75    /// Fast O(n²) algorithm, good for most cases.
76    pub fn optimize_greedy(&mut self) -> &mut Self {
77        let code = self.code();
78        let optimizer = GreedyMethod::new(0.0, 0.0);
79        self.optimized = optimize_code(&code, &self.size_dict, &optimizer);
80        self
81    }
82
83    /// Optimize contraction order using simulated annealing.
84    ///
85    /// Slower but finds better orderings for complex networks.
86    pub fn optimize_treesa(&mut self) -> &mut Self {
87        let code = self.code();
88        let optimizer = TreeSA::default();
89        self.optimized = optimize_code(&code, &self.size_dict, &optimizer);
90        self
91    }
92
93    /// Check if optimization has been performed.
94    pub fn is_optimized(&self) -> bool {
95        self.optimized.is_some()
96    }
97
98    /// Get the optimized contraction tree.
99    pub fn contraction_tree(&self) -> Option<&NestedEinsum<L>> {
100        self.optimized.as_ref()
101    }
102}
103
104impl Einsum<usize> {
105    /// Execute the einsum contraction.
106    ///
107    /// # Type Parameters
108    ///
109    /// * `A` - The algebra to use (e.g., `Standard<f32>`, `MaxPlus<f32>`)
110    /// * `T` - The scalar type
111    /// * `B` - The backend type
112    pub fn execute<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
113    where
114        A: Algebra<Scalar = T, Index = u32>,
115        T: Scalar + BackendScalar<B>,
116        B: Backend + Default,
117    {
118        assert_eq!(
119            tensors.len(),
120            self.ixs.len(),
121            "Number of tensors {} doesn't match number of index specs {}",
122            tensors.len(),
123            self.ixs.len()
124        );
125
126        match &self.optimized {
127            Some(tree) => {
128                // Handle top-level Leaf (single tensor) specially to apply unary transformations
129                if let NestedEinsum::Leaf { tensor_index } = tree {
130                    execute_unary_naive::<A, T, B>(
131                        tensors[*tensor_index],
132                        &self.ixs[*tensor_index],
133                        &self.iy,
134                        &self.size_dict,
135                    )
136                } else {
137                    self.execute_tree::<A, T, B>(tree, tensors)
138                }
139            }
140            None => self.execute_pairwise::<A, T, B>(tensors),
141        }
142    }
143
144    /// Execute with argmax tracking for backpropagation.
145    ///
146    /// Returns `(result, argmax_cache)` where `argmax_cache` contains argmax
147    /// tensors for each binary contraction in the execution tree.
148    pub fn execute_with_argmax<A, T, B>(
149        &self,
150        tensors: &[&Tensor<T, B>],
151    ) -> (Tensor<T, B>, Vec<Tensor<u32, B>>)
152    where
153        A: Algebra<Scalar = T, Index = u32>,
154        T: Scalar + BackendScalar<B>,
155        B: Backend + Default,
156    {
157        assert_eq!(
158            tensors.len(),
159            self.ixs.len(),
160            "Number of tensors {} doesn't match number of index specs {}",
161            tensors.len(),
162            self.ixs.len()
163        );
164
165        let mut argmax_cache = Vec::new();
166
167        let result = match &self.optimized {
168            Some(tree) => {
169                // Handle top-level Leaf (single tensor) specially to apply unary transformations
170                if let NestedEinsum::Leaf { tensor_index } = tree {
171                    if A::needs_argmax() {
172                        let (result, argmax) = execute_unary_with_argmax::<A, T, B>(
173                            tensors[*tensor_index],
174                            &self.ixs[*tensor_index],
175                            &self.iy,
176                            &self.size_dict,
177                        );
178                        argmax_cache.push(argmax);
179                        result
180                    } else {
181                        execute_unary_naive::<A, T, B>(
182                            tensors[*tensor_index],
183                            &self.ixs[*tensor_index],
184                            &self.iy,
185                            &self.size_dict,
186                        )
187                    }
188                } else {
189                    self.execute_tree_with_argmax::<A, T, B>(tree, tensors, &mut argmax_cache)
190                }
191            }
192            None => self.execute_pairwise_with_argmax::<A, T, B>(tensors, &mut argmax_cache),
193        };
194
195        (result, argmax_cache)
196    }
197
198    /// Execute an optimized contraction tree with argmax tracking.
199    #[allow(clippy::only_used_in_recursion)]
200    fn execute_tree_with_argmax<A, T, B>(
201        &self,
202        tree: &NestedEinsum<usize>,
203        tensors: &[&Tensor<T, B>],
204        argmax_cache: &mut Vec<Tensor<u32, B>>,
205    ) -> Tensor<T, B>
206    where
207        A: Algebra<Scalar = T, Index = u32>,
208        T: Scalar + BackendScalar<B>,
209        B: Backend + Default,
210    {
211        match tree {
212            NestedEinsum::Leaf { tensor_index } => tensors[*tensor_index].clone(),
213            NestedEinsum::Node { args, eins } => {
214                assert_eq!(args.len(), 2, "Expected binary contraction tree");
215
216                let left =
217                    self.execute_tree_with_argmax::<A, T, B>(&args[0], tensors, argmax_cache);
218                let right =
219                    self.execute_tree_with_argmax::<A, T, B>(&args[1], tensors, argmax_cache);
220
221                let ia = &eins.ixs[0];
222                let ib = &eins.ixs[1];
223                let iy = &eins.iy;
224
225                if A::needs_argmax() {
226                    let (result, argmax) =
227                        left.contract_binary_with_argmax::<A>(&right, ia, ib, iy);
228                    argmax_cache.push(argmax);
229                    result
230                } else {
231                    left.contract_binary::<A>(&right, ia, ib, iy)
232                }
233            }
234        }
235    }
236
237    /// Execute pairwise contraction with argmax tracking.
238    fn execute_pairwise_with_argmax<A, T, B>(
239        &self,
240        tensors: &[&Tensor<T, B>],
241        argmax_cache: &mut Vec<Tensor<u32, B>>,
242    ) -> Tensor<T, B>
243    where
244        A: Algebra<Scalar = T, Index = u32>,
245        T: Scalar + BackendScalar<B>,
246        B: Backend + Default,
247    {
248        if tensors.is_empty() {
249            panic!("Cannot execute einsum with no tensors");
250        }
251
252        if tensors.len() == 1 {
253            if A::needs_argmax() {
254                let (result, argmax) = execute_unary_with_argmax::<A, T, B>(
255                    tensors[0],
256                    &self.ixs[0],
257                    &self.iy,
258                    &self.size_dict,
259                );
260                argmax_cache.push(argmax);
261                return result;
262            } else {
263                return execute_unary_naive::<A, T, B>(
264                    tensors[0],
265                    &self.ixs[0],
266                    &self.iy,
267                    &self.size_dict,
268                );
269            }
270        }
271
272        // Contract left to right
273        let mut result = tensors[0].clone();
274        let mut current_indices = self.ixs[0].clone();
275
276        for i in 1..tensors.len() {
277            let other = tensors[i];
278            let other_indices = &self.ixs[i];
279
280            let intermediate_output = if i == tensors.len() - 1 {
281                self.iy.clone()
282            } else {
283                compute_intermediate_output(&current_indices, other_indices, &self.iy)
284            };
285
286            if A::needs_argmax() {
287                let (new_result, argmax) = result.contract_binary_with_argmax::<A>(
288                    other,
289                    &current_indices,
290                    other_indices,
291                    &intermediate_output,
292                );
293                argmax_cache.push(argmax);
294                result = new_result;
295            } else {
296                result = result.contract_binary::<A>(
297                    other,
298                    &current_indices,
299                    other_indices,
300                    &intermediate_output,
301                );
302            }
303            current_indices = intermediate_output;
304        }
305
306        result
307    }
308
309    /// Execute an optimized contraction tree.
310    #[allow(clippy::only_used_in_recursion)]
311    fn execute_tree<A, T, B>(
312        &self,
313        tree: &NestedEinsum<usize>,
314        tensors: &[&Tensor<T, B>],
315    ) -> Tensor<T, B>
316    where
317        A: Algebra<Scalar = T, Index = u32>,
318        T: Scalar + BackendScalar<B>,
319        B: Backend + Default,
320    {
321        match tree {
322            NestedEinsum::Leaf { tensor_index } => tensors[*tensor_index].clone(),
323            NestedEinsum::Node { args, eins } => {
324                assert_eq!(args.len(), 2, "Expected binary contraction tree");
325
326                let left = self.execute_tree::<A, T, B>(&args[0], tensors);
327                let right = self.execute_tree::<A, T, B>(&args[1], tensors);
328
329                let ia = &eins.ixs[0];
330                let ib = &eins.ixs[1];
331                let iy = &eins.iy;
332
333                left.contract_binary::<A>(&right, ia, ib, iy)
334            }
335        }
336    }
337
338    /// Execute using simple pairwise contraction (no optimization).
339    fn execute_pairwise<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
340    where
341        A: Algebra<Scalar = T, Index = u32>,
342        T: Scalar + BackendScalar<B>,
343        B: Backend + Default,
344    {
345        if tensors.is_empty() {
346            panic!("Cannot execute einsum with no tensors");
347        }
348
349        if tensors.len() == 1 {
350            // Single tensor: just trace/reduce if needed
351            return execute_unary_naive::<A, T, B>(
352                tensors[0],
353                &self.ixs[0],
354                &self.iy,
355                &self.size_dict,
356            );
357        }
358
359        // Contract left to right
360        let mut result = tensors[0].clone();
361        let mut current_indices = self.ixs[0].clone();
362
363        for i in 1..tensors.len() {
364            let other = tensors[i];
365            let other_indices = &self.ixs[i];
366
367            // Determine output indices for this contraction
368            let intermediate_output = if i == tensors.len() - 1 {
369                // Last contraction: use final output
370                self.iy.clone()
371            } else {
372                // Intermediate: keep all non-contracted indices
373                compute_intermediate_output(&current_indices, other_indices, &self.iy)
374            };
375
376            result = result.contract_binary::<A>(
377                other,
378                &current_indices,
379                other_indices,
380                &intermediate_output,
381            );
382            current_indices = intermediate_output;
383        }
384
385        result
386    }
387}
388
389/// Compute intermediate output indices for pairwise contraction.
390fn compute_intermediate_output(ia: &[usize], ib: &[usize], final_output: &[usize]) -> Vec<usize> {
391    let final_set: std::collections::HashSet<_> = final_output.iter().copied().collect();
392    let ia_set: std::collections::HashSet<_> = ia.iter().copied().collect();
393    let ib_set: std::collections::HashSet<_> = ib.iter().copied().collect();
394
395    // Keep indices that are in the final output OR appear in only one input
396    let mut output = Vec::new();
397
398    for &i in ia {
399        if (final_set.contains(&i) || !ib_set.contains(&i)) && !output.contains(&i) {
400            output.push(i);
401        }
402    }
403
404    for &i in ib {
405        if (final_set.contains(&i) || !ia_set.contains(&i)) && !output.contains(&i) {
406            output.push(i);
407        }
408    }
409
410    output
411}
412
413/// Convert linear index to multi-dimensional index (column-major).
414///
415/// Given a flat/linear index and a shape, returns the multi-dimensional
416/// coordinates for column-major storage order.
417///
418/// # Arguments
419///
420/// * `linear` - The flat index into the tensor
421/// * `shape` - The shape of the tensor
422///
423/// # Returns
424///
425/// A vector of indices, one per dimension
426fn linear_to_multi(mut linear: usize, shape: &[usize]) -> Vec<usize> {
427    if shape.is_empty() {
428        return vec![];
429    }
430    let mut multi = vec![0; shape.len()];
431    for i in 0..shape.len() {
432        multi[i] = linear % shape[i];
433        linear /= shape[i];
434    }
435    multi
436}
437
438/// Compute input tensor position from index values (column-major).
439///
440/// Given index labels and their current values, computes the flat position
441/// in the input tensor using column-major ordering.
442///
443/// # Arguments
444///
445/// * `ix` - The index labels for the input tensor
446/// * `idx_values` - Mapping from index label to current value
447/// * `shape` - The shape of the input tensor
448///
449/// # Returns
450///
451/// The flat position in the tensor
452fn compute_input_position(
453    ix: &[usize],
454    idx_values: &HashMap<usize, usize>,
455    shape: &[usize],
456) -> usize {
457    let mut pos = 0;
458    let mut stride = 1;
459    for (dim, &idx) in ix.iter().enumerate() {
460        pos += idx_values[&idx] * stride;
461        stride *= shape[dim];
462    }
463    pos
464}
465
466/// Execute unary einsum operation using naive loop.
467/// Handles trace, diagonal, sum, permutation uniformly.
468///
469/// # Type Parameters
470///
471/// * `A` - The algebra to use for accumulation
472/// * `T` - The scalar type
473/// * `B` - The backend type
474///
475/// # Arguments
476///
477/// * `tensor` - The input tensor
478/// * `ix` - Input index labels (may contain repeated indices for trace/diagonal)
479/// * `iy` - Output index labels
480/// * `size_dict` - Mapping from index labels to dimension sizes
481///
482/// # Key Insight
483///
484/// For repeated indices like `ix = [0, 1, 1, 2]` (ijjk), positions 1 and 2 both map
485/// to index label `1`. This automatically handles diagonal extraction because
486/// `compute_input_position` uses `idx_values[&idx]` - when the same index label
487/// appears multiple times in `ix`, those positions will use the same value.
488#[allow(clippy::needless_range_loop)]
489pub(crate) fn execute_unary_naive<A, T, B>(
490    tensor: &Tensor<T, B>,
491    ix: &[usize],
492    iy: &[usize],
493    size_dict: &HashMap<usize, usize>,
494) -> Tensor<T, B>
495where
496    A: Algebra<Scalar = T>,
497    T: Scalar,
498    B: Backend + Default,
499{
500    // 1. Classify indices
501    // outer = output indices
502    // inner = indices that appear in input but not in output (summed over)
503    let outer: &[usize] = iy;
504    let outer_set: HashSet<usize> = outer.iter().copied().collect();
505    // Collect inner indices deterministically, preserving the order from `ix`
506    let mut inner_vec: Vec<usize> = Vec::new();
507    let mut seen: HashSet<usize> = HashSet::new();
508    for i in ix.iter().copied().filter(|i| !outer_set.contains(i)) {
509        if seen.insert(i) {
510            inner_vec.push(i);
511        }
512    }
513
514    // 2. Build output shape
515    let out_shape: Vec<usize> = outer.iter().map(|&idx| size_dict[&idx]).collect();
516    let out_size = out_shape.iter().product::<usize>().max(1);
517
518    // 3. Build inner ranges (dimensions to sum over)
519    let inner_ranges: Vec<usize> = inner_vec.iter().map(|&idx| size_dict[&idx]).collect();
520    let inner_size = inner_ranges.iter().product::<usize>().max(1);
521
522    // 4. Allocate output
523    let mut out_data = vec![A::zero().to_scalar(); out_size];
524
525    // 5. Loop over output positions
526    for out_linear in 0..out_size {
527        let out_multi = linear_to_multi(out_linear, &out_shape);
528
529        // Map: outer index label -> value
530        // For repeated output indices (like `ii`), check consistency
531        let mut idx_values: HashMap<usize, usize> = HashMap::new();
532        let mut skip_position = false;
533
534        for (&idx, &val) in outer.iter().zip(out_multi.iter()) {
535            if let Some(&existing) = idx_values.get(&idx) {
536                // Repeated index label - values must match
537                if existing != val {
538                    skip_position = true;
539                    break;
540                }
541            } else {
542                idx_values.insert(idx, val);
543            }
544        }
545
546        // Skip non-diagonal positions for repeated output indices
547        if skip_position {
548            // out_data[out_linear] is already zero
549            continue;
550        }
551
552        // 6. Accumulate over inner indices
553        let mut acc = A::zero();
554        for inner_linear in 0..inner_size {
555            let inner_multi = linear_to_multi(inner_linear, &inner_ranges);
556            for (&idx, &val) in inner_vec.iter().zip(inner_multi.iter()) {
557                idx_values.insert(idx, val);
558            }
559
560            // 7. Compute input position and accumulate
561            let in_pos = compute_input_position(ix, &idx_values, tensor.shape());
562            acc = acc.add(A::from_scalar(tensor.get(in_pos)));
563        }
564
565        out_data[out_linear] = acc.to_scalar();
566    }
567
568    if out_shape.is_empty() {
569        Tensor::from_data(&out_data, &[])
570    } else {
571        Tensor::from_data(&out_data, &out_shape)
572    }
573}
574
575/// Execute unary einsum with argmax tracking for tropical algebras.
576///
577/// Returns both the result tensor and an argmax tensor that tracks which
578/// inner index position "won" for each output element.
579///
580/// The argmax tensor has the same shape as the output. Each element stores
581/// the linear index into the input tensor that contributed to that output.
582pub(crate) fn execute_unary_with_argmax<A, T, B>(
583    tensor: &Tensor<T, B>,
584    ix: &[usize],
585    iy: &[usize],
586    size_dict: &HashMap<usize, usize>,
587) -> (Tensor<T, B>, Tensor<u32, B>)
588where
589    A: Algebra<Scalar = T, Index = u32>,
590    T: Scalar,
591    B: Backend + Default,
592{
593    // 1. Classify indices (same as execute_unary_naive)
594    let outer: &[usize] = iy;
595    let outer_set: HashSet<usize> = outer.iter().copied().collect();
596    let mut inner_vec: Vec<usize> = Vec::new();
597    let mut seen: HashSet<usize> = HashSet::new();
598    for i in ix.iter().copied().filter(|i| !outer_set.contains(i)) {
599        if seen.insert(i) {
600            inner_vec.push(i);
601        }
602    }
603
604    // 2. Build output shape
605    let out_shape: Vec<usize> = outer.iter().map(|&idx| size_dict[&idx]).collect();
606    let out_size = out_shape.iter().product::<usize>().max(1);
607
608    // 3. Build inner ranges
609    let inner_ranges: Vec<usize> = inner_vec.iter().map(|&idx| size_dict[&idx]).collect();
610    let inner_size = inner_ranges.iter().product::<usize>().max(1);
611
612    // 4. Allocate output and argmax
613    let mut out_data = vec![A::zero().to_scalar(); out_size];
614    let mut argmax_data = vec![0u32; out_size];
615
616    // 5. Loop over output positions
617    for out_linear in 0..out_size {
618        let out_multi = linear_to_multi(out_linear, &out_shape);
619
620        // Map: outer index label -> value
621        let mut idx_values: HashMap<usize, usize> = HashMap::new();
622        let mut skip_position = false;
623
624        for (&idx, &val) in outer.iter().zip(out_multi.iter()) {
625            if let Some(&existing) = idx_values.get(&idx) {
626                if existing != val {
627                    skip_position = true;
628                    break;
629                }
630            } else {
631                idx_values.insert(idx, val);
632            }
633        }
634
635        if skip_position {
636            continue;
637        }
638
639        // 6. Find max over inner indices (tropical-style)
640        let mut best_val = A::zero();
641        let mut best_in_pos = 0usize;
642
643        for inner_linear in 0..inner_size {
644            let inner_multi = linear_to_multi(inner_linear, &inner_ranges);
645            for (&idx, &val) in inner_vec.iter().zip(inner_multi.iter()) {
646                idx_values.insert(idx, val);
647            }
648
649            let in_pos = compute_input_position(ix, &idx_values, tensor.shape());
650            let val = A::from_scalar(tensor.get(in_pos));
651
652            // For first iteration or if this value is better
653            if inner_linear == 0 || A::is_better(&val, &best_val) {
654                best_val = val;
655                best_in_pos = in_pos;
656            }
657        }
658
659        out_data[out_linear] = best_val.to_scalar();
660        argmax_data[out_linear] = best_in_pos as u32;
661    }
662
663    let result = if out_shape.is_empty() {
664        Tensor::from_data(&out_data, &[])
665    } else {
666        Tensor::from_data(&out_data, &out_shape)
667    };
668
669    let argmax = if out_shape.is_empty() {
670        Tensor::from_data(&argmax_data, &[])
671    } else {
672        Tensor::from_data(&argmax_data, &out_shape)
673    };
674
675    (result, argmax)
676}
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681    use crate::algebra::Standard;
682    use crate::backend::Cpu;
683
684    #[cfg(feature = "tropical")]
685    use crate::algebra::MaxPlus;
686
687    #[test]
688    fn test_einsum_matmul() {
689        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
690        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
691
692        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
693        let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
694
695        // Without optimization
696        let c1 = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
697        assert_eq!(c1.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
698
699        // With optimization
700        ein.optimize_greedy();
701        let c2 = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
702        assert_eq!(c2.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
703    }
704
705    #[cfg(feature = "tropical")]
706    #[test]
707    fn test_einsum_tropical() {
708        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
709        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
710
711        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
712        let ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
713
714        let c = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
715        assert_eq!(c.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
716    }
717
718    #[test]
719    fn test_einsum_chain() {
720        // A[i,j] × B[j,k] × C[k,l] → D[i,l]
721        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
722        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
723        let c = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
724
725        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2), (3, 2)].into();
726        let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2], vec![2, 3]], vec![0, 3], sizes);
727
728        ein.optimize_greedy();
729        let d = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b, &c]);
730
731        assert_eq!(d.shape(), &[2, 2]);
732    }
733
734    #[test]
735    fn test_einsum_trace() {
736        // Trace: A[i,i] -> scalar (sum of diagonal)
737        // Matrix: [[1, 2], [3, 4]]
738        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
739
740        let sizes: HashMap<usize, usize> = [(0, 2)].into();
741        let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
742
743        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
744        // trace = 1 + 4 = 5
745        assert_eq!(result.to_vec()[0], 5.0);
746    }
747
748    #[test]
749    fn test_einsum_diagonal() {
750        // Diagonal: A[i,i] -> B[i] (extract diagonal)
751        // Matrix: [[1, 2], [3, 4]]
752        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
753
754        let sizes: HashMap<usize, usize> = [(0, 2)].into();
755        let ein = Einsum::new(vec![vec![0, 0]], vec![0], sizes);
756
757        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
758        // diagonal = [1, 4]
759        assert_eq!(result.to_vec(), vec![1.0, 4.0]);
760    }
761
762    #[test]
763    fn test_einsum_sum_axis() {
764        // Reduction: A[i,j] -> B[i] (sum over j)
765        // Column-major: data [1,2,3,4] for shape [2,2] represents:
766        // [[1, 3],
767        //  [2, 4]]
768        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
769
770        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
771        let ein = Einsum::new(vec![vec![0, 1]], vec![0], sizes);
772
773        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
774        // sum over j: [1+3, 2+4] = [4, 6]
775        assert_eq!(result.to_vec(), vec![4.0, 6.0]);
776    }
777
778    #[test]
779    fn test_einsum_sum_all() {
780        // Sum all: A[i,j] -> scalar
781        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
782
783        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
784        let ein = Einsum::new(vec![vec![0, 1]], vec![], sizes);
785
786        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
787        // sum = 1 + 2 + 3 + 4 = 10
788        assert_eq!(result.to_vec()[0], 10.0);
789    }
790
791    #[cfg(feature = "tropical")]
792    #[test]
793    fn test_einsum_trace_tropical() {
794        // Trace with max-plus algebra: A[i,i] -> scalar
795        // Matrix: [[1, 2], [3, 4]]
796        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
797
798        let sizes: HashMap<usize, usize> = [(0, 2)].into();
799        let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
800
801        let result = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a]);
802        // tropical trace = max(1, 4) = 4
803        assert_eq!(result.to_vec()[0], 4.0);
804    }
805
806    // Tests for helper functions
807
808    #[test]
809    fn test_linear_to_multi_empty_shape() {
810        // Empty shape should return empty multi-index
811        let result = linear_to_multi(0, &[]);
812        assert_eq!(result, vec![]);
813    }
814
815    #[test]
816    fn test_linear_to_multi_1d() {
817        // 1D array: linear index equals multi-index
818        assert_eq!(linear_to_multi(0, &[5]), vec![0]);
819        assert_eq!(linear_to_multi(3, &[5]), vec![3]);
820        assert_eq!(linear_to_multi(4, &[5]), vec![4]);
821    }
822
823    #[test]
824    fn test_linear_to_multi_2d() {
825        // 2D array with shape [2, 3] (column-major)
826        // Linear 0 -> (0, 0)
827        // Linear 1 -> (1, 0)
828        // Linear 2 -> (0, 1)
829        // Linear 3 -> (1, 1)
830        // Linear 4 -> (0, 2)
831        // Linear 5 -> (1, 2)
832        assert_eq!(linear_to_multi(0, &[2, 3]), vec![0, 0]);
833        assert_eq!(linear_to_multi(1, &[2, 3]), vec![1, 0]);
834        assert_eq!(linear_to_multi(2, &[2, 3]), vec![0, 1]);
835        assert_eq!(linear_to_multi(3, &[2, 3]), vec![1, 1]);
836        assert_eq!(linear_to_multi(4, &[2, 3]), vec![0, 2]);
837        assert_eq!(linear_to_multi(5, &[2, 3]), vec![1, 2]);
838    }
839
840    #[test]
841    fn test_linear_to_multi_3d() {
842        // 3D array with shape [2, 3, 4] (column-major)
843        // Strides: [1, 2, 6]
844        // Linear 0 -> (0, 0, 0)
845        // Linear 1 -> (1, 0, 0)
846        // Linear 2 -> (0, 1, 0)
847        // Linear 6 -> (0, 0, 1)
848        // Linear 7 -> (1, 0, 1)
849        assert_eq!(linear_to_multi(0, &[2, 3, 4]), vec![0, 0, 0]);
850        assert_eq!(linear_to_multi(1, &[2, 3, 4]), vec![1, 0, 0]);
851        assert_eq!(linear_to_multi(2, &[2, 3, 4]), vec![0, 1, 0]);
852        assert_eq!(linear_to_multi(6, &[2, 3, 4]), vec![0, 0, 1]);
853        assert_eq!(linear_to_multi(7, &[2, 3, 4]), vec![1, 0, 1]);
854        // Last element: linear 23 -> (1, 2, 3)
855        assert_eq!(linear_to_multi(23, &[2, 3, 4]), vec![1, 2, 3]);
856    }
857
858    #[test]
859    fn test_compute_input_position_1d() {
860        // 1D tensor with index label 0
861        let ix = vec![0];
862        let shape = vec![5];
863
864        let mut idx_values = HashMap::new();
865        idx_values.insert(0, 0);
866        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
867
868        idx_values.insert(0, 3);
869        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 3);
870    }
871
872    #[test]
873    fn test_compute_input_position_2d() {
874        // 2D tensor with shape [2, 3], index labels (0, 1)
875        // Column-major: position = i + j * 2
876        let ix = vec![0, 1];
877        let shape = vec![2, 3];
878
879        let mut idx_values = HashMap::new();
880
881        // (0, 0) -> position 0
882        idx_values.insert(0, 0);
883        idx_values.insert(1, 0);
884        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
885
886        // (1, 0) -> position 1
887        idx_values.insert(0, 1);
888        idx_values.insert(1, 0);
889        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 1);
890
891        // (0, 1) -> position 2
892        idx_values.insert(0, 0);
893        idx_values.insert(1, 1);
894        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 2);
895
896        // (1, 2) -> position 1 + 2*2 = 5
897        idx_values.insert(0, 1);
898        idx_values.insert(1, 2);
899        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 5);
900    }
901
902    #[test]
903    fn test_compute_input_position_3d() {
904        // 3D tensor with shape [2, 3, 4], index labels (0, 1, 2)
905        // Column-major: position = i + j * 2 + k * 6
906        let ix = vec![0, 1, 2];
907        let shape = vec![2, 3, 4];
908
909        let mut idx_values = HashMap::new();
910
911        // (0, 0, 0) -> position 0
912        idx_values.insert(0, 0);
913        idx_values.insert(1, 0);
914        idx_values.insert(2, 0);
915        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
916
917        // (1, 0, 0) -> position 1
918        idx_values.insert(0, 1);
919        idx_values.insert(1, 0);
920        idx_values.insert(2, 0);
921        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 1);
922
923        // (0, 1, 0) -> position 2
924        idx_values.insert(0, 0);
925        idx_values.insert(1, 1);
926        idx_values.insert(2, 0);
927        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 2);
928
929        // (0, 0, 1) -> position 6
930        idx_values.insert(0, 0);
931        idx_values.insert(1, 0);
932        idx_values.insert(2, 1);
933        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 6);
934
935        // (1, 2, 3) -> position 1 + 2*2 + 3*6 = 1 + 4 + 18 = 23
936        idx_values.insert(0, 1);
937        idx_values.insert(1, 2);
938        idx_values.insert(2, 3);
939        assert_eq!(compute_input_position(&ix, &idx_values, &shape), 23);
940    }
941
942    #[test]
943    fn test_linear_to_multi_roundtrip() {
944        // Verify that linear_to_multi and compute_input_position are consistent
945        let shape = vec![2, 3, 4];
946        let ix: Vec<usize> = (0..shape.len()).collect();
947        let total_size: usize = shape.iter().product();
948
949        for linear in 0..total_size {
950            let multi = linear_to_multi(linear, &shape);
951
952            // Build idx_values from multi
953            let mut idx_values = HashMap::new();
954            for (dim, &val) in multi.iter().enumerate() {
955                idx_values.insert(dim, val);
956            }
957
958            let computed_pos = compute_input_position(&ix, &idx_values, &shape);
959            assert_eq!(
960                computed_pos, linear,
961                "Roundtrip failed for linear={}, multi={:?}",
962                linear, multi
963            );
964        }
965    }
966
967    // ========================================================================
968    // Tests for execute_unary_naive
969    // ========================================================================
970
971    #[test]
972    fn test_unary_naive_transpose() {
973        // Transpose: A[i,j] -> B[j,i]
974        // Input matrix (column-major): [[1, 3], [2, 4]]
975        // data = [1, 2, 3, 4], shape = [2, 2]
976        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
977
978        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
979        let ix = vec![0, 1]; // A[i,j]
980        let iy = vec![1, 0]; // B[j,i]
981
982        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
983
984        // After transpose: [[1, 2], [3, 4]] in column-major = [1, 3, 2, 4]
985        assert_eq!(result.shape(), &[2, 2]);
986        assert_eq!(result.to_vec(), vec![1.0, 3.0, 2.0, 4.0]);
987    }
988
989    #[test]
990    fn test_unary_naive_trace() {
991        // Trace: A[i,i] -> scalar (sum of diagonal)
992        // Matrix (column-major): [[1, 3], [2, 4]]
993        // data = [1, 2, 3, 4], shape = [2, 2]
994        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
995
996        let size_dict: HashMap<usize, usize> = [(0, 2)].into();
997        let ix = vec![0, 0]; // A[i,i] - repeated index means diagonal
998        let iy = vec![]; // scalar output
999
1000        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1001
1002        // trace = A[0,0] + A[1,1] = 1 + 4 = 5
1003        assert_eq!(result.shape(), &[]);
1004        assert_eq!(result.to_vec()[0], 5.0);
1005    }
1006
1007    #[test]
1008    fn test_unary_naive_diagonal() {
1009        // Diagonal extraction: A[i,i] -> B[i]
1010        // Matrix (column-major): [[1, 3], [2, 4]]
1011        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1012
1013        let size_dict: HashMap<usize, usize> = [(0, 2)].into();
1014        let ix = vec![0, 0]; // A[i,i] - repeated index
1015        let iy = vec![0]; // output B[i]
1016
1017        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1018
1019        // diagonal = [A[0,0], A[1,1]] = [1, 4]
1020        assert_eq!(result.shape(), &[2]);
1021        assert_eq!(result.to_vec(), vec![1.0, 4.0]);
1022    }
1023
1024    #[test]
1025    fn test_unary_naive_sum_axis() {
1026        // Sum over axis: A[i,j] -> B[i] (sum over j)
1027        // Matrix (column-major): [[1, 3], [2, 4]]
1028        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1029
1030        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1031        let ix = vec![0, 1]; // A[i,j]
1032        let iy = vec![0]; // B[i] - j is summed out
1033
1034        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1035
1036        // sum over j: B[i] = sum_j A[i,j]
1037        // B[0] = A[0,0] + A[0,1] = 1 + 3 = 4
1038        // B[1] = A[1,0] + A[1,1] = 2 + 4 = 6
1039        assert_eq!(result.shape(), &[2]);
1040        assert_eq!(result.to_vec(), vec![4.0, 6.0]);
1041    }
1042
1043    #[test]
1044    fn test_unary_naive_sum_all() {
1045        // Sum all: A[i,j] -> scalar
1046        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1047
1048        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1049        let ix = vec![0, 1]; // A[i,j]
1050        let iy = vec![]; // scalar output
1051
1052        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1053
1054        // sum all = 1 + 2 + 3 + 4 = 10
1055        assert_eq!(result.shape(), &[]);
1056        assert_eq!(result.to_vec()[0], 10.0);
1057    }
1058
1059    #[test]
1060    fn test_unary_naive_partial_trace() {
1061        // Partial trace: A[i,j,i] -> B[j] (trace over i, keeping j)
1062        // 3D tensor with shape [2, 3, 2]
1063        // This is like having a batch of 2x2 matrices and taking the trace of each
1064        let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
1065        let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 2]);
1066
1067        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1068        let ix = vec![0, 1, 0]; // A[i,j,i] - i is repeated at positions 0 and 2
1069        let iy = vec![1]; // B[j] - output keeps only j
1070
1071        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1072
1073        // For each j, we sum A[0,j,0] + A[1,j,1]
1074        // Column-major layout: data[i + j*2 + k*6]
1075        // j=0: A[0,0,0] + A[1,0,1] = data[0] + data[1+0*2+1*6] = data[0] + data[7] = 1 + 8 = 9
1076        // j=1: A[0,1,0] + A[1,1,1] = data[0+1*2+0*6] + data[1+1*2+1*6] = data[2] + data[9] = 3 + 10 = 13
1077        // j=2: A[0,2,0] + A[1,2,1] = data[0+2*2+0*6] + data[1+2*2+1*6] = data[4] + data[11] = 5 + 12 = 17
1078        assert_eq!(result.shape(), &[3]);
1079        assert_eq!(result.to_vec(), vec![9.0, 13.0, 17.0]);
1080    }
1081
1082    #[test]
1083    fn test_unary_naive_3d_transpose() {
1084        // 3D permutation: A[i,j,k] -> B[k,i,j]
1085        let data: Vec<f32> = (1..=8).map(|x| x as f32).collect();
1086        let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 2, 2]);
1087
1088        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1089        let ix = vec![0, 1, 2]; // A[i,j,k]
1090        let iy = vec![2, 0, 1]; // B[k,i,j]
1091
1092        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1093
1094        assert_eq!(result.shape(), &[2, 2, 2]);
1095
1096        // Verify by checking specific elements
1097        // B[k,i,j] = A[i,j,k]
1098        // Build expected output manually
1099        let mut expected = vec![0.0f32; 8];
1100        for i in 0..2 {
1101            for j in 0..2 {
1102                for k in 0..2 {
1103                    // A[i,j,k] at position i + j*2 + k*4 in column-major
1104                    let a_pos = i + j * 2 + k * 4;
1105                    // B[k,i,j] at position k + i*2 + j*4 in column-major
1106                    let b_pos = k + i * 2 + j * 4;
1107                    expected[b_pos] = data[a_pos];
1108                }
1109            }
1110        }
1111        assert_eq!(result.to_vec(), expected);
1112    }
1113
1114    #[test]
1115    fn test_unary_naive_identity() {
1116        // Identity: A[i,j] -> B[i,j] (no change)
1117        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1118
1119        let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1120        let ix = vec![0, 1]; // A[i,j]
1121        let iy = vec![0, 1]; // B[i,j]
1122
1123        let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1124
1125        assert_eq!(result.shape(), &[2, 2]);
1126        assert_eq!(result.to_vec(), a.to_vec());
1127    }
1128
1129    #[cfg(feature = "tropical")]
1130    #[test]
1131    fn test_unary_naive_trace_tropical() {
1132        // Trace with max-plus algebra: A[i,i] -> scalar
1133        // Matrix (column-major): [[1, 3], [2, 4]]
1134        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1135
1136        let size_dict: HashMap<usize, usize> = [(0, 2)].into();
1137        let ix = vec![0, 0]; // A[i,i]
1138        let iy = vec![]; // scalar output
1139
1140        let result = execute_unary_naive::<MaxPlus<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1141
1142        // tropical trace = max(A[0,0], A[1,1]) = max(1, 4) = 4
1143        assert_eq!(result.shape(), &[]);
1144        assert_eq!(result.to_vec()[0], 4.0);
1145    }
1146
1147    #[test]
1148    fn test_einsum_trace_optimized() {
1149        // Test that the optimized path correctly handles unary trace operations
1150        // Matrix (column-major): [[1, 3], [2, 4]]
1151        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1152
1153        let sizes: HashMap<usize, usize> = [(0, 2)].into();
1154        let mut ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1155
1156        // Optimize and execute
1157        ein.optimize_greedy();
1158        assert!(ein.is_optimized());
1159
1160        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1161
1162        // trace = A[0,0] + A[1,1] = 1 + 4 = 5
1163        assert_eq!(result.to_vec()[0], 5.0);
1164    }
1165
1166    #[test]
1167    fn test_einsum_unary_with_argmax_optimized() {
1168        // Test execute_with_argmax for unary operations (optimized path)
1169        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1170
1171        let sizes: HashMap<usize, usize> = [(0, 2)].into();
1172        let mut ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1173
1174        ein.optimize_greedy();
1175        let (result, argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a]);
1176
1177        // trace = 1 + 4 = 5
1178        assert_eq!(result.to_vec()[0], 5.0);
1179        // No argmax for unary operations
1180        assert!(argmax_cache.is_empty());
1181    }
1182
1183    #[test]
1184    fn test_einsum_unary_pairwise_path() {
1185        // Test unary operation through pairwise path (no optimization)
1186        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1187
1188        let sizes: HashMap<usize, usize> = [(0, 2)].into();
1189        let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1190
1191        // Not optimized - uses pairwise path
1192        assert!(!ein.is_optimized());
1193
1194        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1195
1196        // trace = 1 + 4 = 5
1197        assert_eq!(result.to_vec()[0], 5.0);
1198    }
1199
1200    #[test]
1201    fn test_einsum_unary_with_argmax_pairwise() {
1202        // Test execute_with_argmax for unary operations (pairwise path)
1203        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1204
1205        let sizes: HashMap<usize, usize> = [(0, 2)].into();
1206        let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1207
1208        // Not optimized - uses pairwise path
1209        let (result, argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a]);
1210
1211        // trace = 1 + 4 = 5
1212        assert_eq!(result.to_vec()[0], 5.0);
1213        // No argmax for unary operations
1214        assert!(argmax_cache.is_empty());
1215    }
1216
1217    #[cfg(feature = "tropical")]
1218    #[test]
1219    fn test_einsum_with_argmax_tropical() {
1220        // Test execute_with_argmax for tropical algebra (needs argmax)
1221        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1222        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1223
1224        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1225        let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
1226
1227        ein.optimize_greedy();
1228        let (result, argmax_cache) = ein.execute_with_argmax::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
1229
1230        // MaxPlus matmul: C[i,k] = max_j(A[i,j] + B[j,k])
1231        assert_eq!(result.shape(), &[2, 2]);
1232        assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
1233
1234        // Should have argmax tensors for binary contractions
1235        assert!(!argmax_cache.is_empty());
1236    }
1237
1238    #[cfg(feature = "tropical")]
1239    #[test]
1240    fn test_einsum_with_argmax_tropical_pairwise() {
1241        // Test execute_with_argmax for tropical algebra (pairwise path)
1242        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1243        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1244
1245        let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1246        let ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
1247
1248        // Not optimized - uses pairwise path
1249        let (result, argmax_cache) = ein.execute_with_argmax::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
1250
1251        assert_eq!(result.shape(), &[2, 2]);
1252        assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
1253
1254        // Should have argmax tensors
1255        assert!(!argmax_cache.is_empty());
1256    }
1257
1258    #[test]
1259    fn test_einsum_transpose_optimized() {
1260        // Test transpose operation through optimized path
1261        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1262
1263        let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1264        let mut ein = Einsum::new(vec![vec![0, 1]], vec![1, 0], sizes);
1265
1266        ein.optimize_greedy();
1267        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1268
1269        assert_eq!(result.shape(), &[3, 2]);
1270        // A (col-major) = [[1,3,5],[2,4,6]]
1271        // A^T = [[1,2],[3,4],[5,6]]
1272        // In col-major: [1, 3, 5, 2, 4, 6]
1273        assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1274    }
1275
1276    #[test]
1277    fn test_intermediate_output_computation() {
1278        // Test the compute_intermediate_output function
1279        // ij,jk->ik: j is contracted
1280        let output = compute_intermediate_output(&[0, 1], &[1, 2], &[0, 2]);
1281        assert!(output.contains(&0));
1282        assert!(output.contains(&2));
1283        assert!(!output.contains(&1));
1284    }
1285
1286    #[test]
1287    fn test_outer_product_pairwise() {
1288        // Test outer product via pairwise path (no optimization)
1289        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1290        let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1291
1292        let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1293        let ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1294
1295        // NOT optimized - uses pairwise path
1296        assert!(!ein.is_optimized());
1297        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1298
1299        assert_eq!(result.shape(), &[2, 3]);
1300        // Outer product: a ⊗ b = [[1*3, 1*4, 1*5], [2*3, 2*4, 2*5]]
1301        //                      = [[3, 4, 5], [6, 8, 10]]
1302        // In column-major: [3, 6, 4, 8, 5, 10]
1303        assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1304    }
1305
1306    #[test]
1307    fn test_outer_product_optimized() {
1308        // Test outer product with optimization
1309        // The optimizer returns Leaf for outer products (no shared indices),
1310        // but we detect this and fall back to pairwise execution
1311        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1312        let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1313
1314        let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1315        let mut ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1316
1317        ein.optimize_greedy();
1318        let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1319
1320        assert_eq!(result.shape(), &[2, 3]);
1321        assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1322    }
1323
1324    #[test]
1325    fn test_outer_product_with_argmax() {
1326        // Test outer product through execute_with_argmax path
1327        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1328        let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1329
1330        let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1331        let mut ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1332
1333        ein.optimize_greedy();
1334        let (result, _argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1335
1336        assert_eq!(result.shape(), &[2, 3]);
1337        assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1338    }
1339}