1use std::collections::{HashMap, HashSet};
4
5use omeco::{optimize_code, EinCode, GreedyMethod, Label, NestedEinsum, TreeSA};
6
7use crate::algebra::{Algebra, Scalar};
8use crate::backend::{Backend, BackendScalar};
9use crate::tensor::Tensor;
10
11pub struct Einsum<L: Label = usize> {
38 pub ixs: Vec<Vec<L>>,
40
41 pub iy: Vec<L>,
43
44 pub size_dict: HashMap<L, usize>,
46
47 optimized: Option<NestedEinsum<L>>,
49}
50
51impl<L: Label> Einsum<L> {
52 pub fn new(ixs: Vec<Vec<L>>, iy: Vec<L>, size_dict: HashMap<L, usize>) -> Self {
60 Self {
61 ixs,
62 iy,
63 size_dict,
64 optimized: None,
65 }
66 }
67
68 pub fn code(&self) -> EinCode<L> {
70 EinCode::new(self.ixs.clone(), self.iy.clone())
71 }
72
73 pub fn optimize_greedy(&mut self) -> &mut Self {
77 let code = self.code();
78 let optimizer = GreedyMethod::new(0.0, 0.0);
79 self.optimized = optimize_code(&code, &self.size_dict, &optimizer);
80 self
81 }
82
83 pub fn optimize_treesa(&mut self) -> &mut Self {
87 let code = self.code();
88 let optimizer = TreeSA::default();
89 self.optimized = optimize_code(&code, &self.size_dict, &optimizer);
90 self
91 }
92
93 pub fn is_optimized(&self) -> bool {
95 self.optimized.is_some()
96 }
97
98 pub fn contraction_tree(&self) -> Option<&NestedEinsum<L>> {
100 self.optimized.as_ref()
101 }
102}
103
104impl Einsum<usize> {
105 pub fn execute<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
113 where
114 A: Algebra<Scalar = T, Index = u32>,
115 T: Scalar + BackendScalar<B>,
116 B: Backend + Default,
117 {
118 assert_eq!(
119 tensors.len(),
120 self.ixs.len(),
121 "Number of tensors {} doesn't match number of index specs {}",
122 tensors.len(),
123 self.ixs.len()
124 );
125
126 match &self.optimized {
127 Some(tree) => {
128 if let NestedEinsum::Leaf { tensor_index } = tree {
130 execute_unary_naive::<A, T, B>(
131 tensors[*tensor_index],
132 &self.ixs[*tensor_index],
133 &self.iy,
134 &self.size_dict,
135 )
136 } else {
137 self.execute_tree::<A, T, B>(tree, tensors)
138 }
139 }
140 None => self.execute_pairwise::<A, T, B>(tensors),
141 }
142 }
143
144 pub fn execute_with_argmax<A, T, B>(
149 &self,
150 tensors: &[&Tensor<T, B>],
151 ) -> (Tensor<T, B>, Vec<Tensor<u32, B>>)
152 where
153 A: Algebra<Scalar = T, Index = u32>,
154 T: Scalar + BackendScalar<B>,
155 B: Backend + Default,
156 {
157 assert_eq!(
158 tensors.len(),
159 self.ixs.len(),
160 "Number of tensors {} doesn't match number of index specs {}",
161 tensors.len(),
162 self.ixs.len()
163 );
164
165 let mut argmax_cache = Vec::new();
166
167 let result = match &self.optimized {
168 Some(tree) => {
169 if let NestedEinsum::Leaf { tensor_index } = tree {
171 if A::needs_argmax() {
172 let (result, argmax) = execute_unary_with_argmax::<A, T, B>(
173 tensors[*tensor_index],
174 &self.ixs[*tensor_index],
175 &self.iy,
176 &self.size_dict,
177 );
178 argmax_cache.push(argmax);
179 result
180 } else {
181 execute_unary_naive::<A, T, B>(
182 tensors[*tensor_index],
183 &self.ixs[*tensor_index],
184 &self.iy,
185 &self.size_dict,
186 )
187 }
188 } else {
189 self.execute_tree_with_argmax::<A, T, B>(tree, tensors, &mut argmax_cache)
190 }
191 }
192 None => self.execute_pairwise_with_argmax::<A, T, B>(tensors, &mut argmax_cache),
193 };
194
195 (result, argmax_cache)
196 }
197
198 #[allow(clippy::only_used_in_recursion)]
200 fn execute_tree_with_argmax<A, T, B>(
201 &self,
202 tree: &NestedEinsum<usize>,
203 tensors: &[&Tensor<T, B>],
204 argmax_cache: &mut Vec<Tensor<u32, B>>,
205 ) -> Tensor<T, B>
206 where
207 A: Algebra<Scalar = T, Index = u32>,
208 T: Scalar + BackendScalar<B>,
209 B: Backend + Default,
210 {
211 match tree {
212 NestedEinsum::Leaf { tensor_index } => tensors[*tensor_index].clone(),
213 NestedEinsum::Node { args, eins } => {
214 assert_eq!(args.len(), 2, "Expected binary contraction tree");
215
216 let left =
217 self.execute_tree_with_argmax::<A, T, B>(&args[0], tensors, argmax_cache);
218 let right =
219 self.execute_tree_with_argmax::<A, T, B>(&args[1], tensors, argmax_cache);
220
221 let ia = &eins.ixs[0];
222 let ib = &eins.ixs[1];
223 let iy = &eins.iy;
224
225 if A::needs_argmax() {
226 let (result, argmax) =
227 left.contract_binary_with_argmax::<A>(&right, ia, ib, iy);
228 argmax_cache.push(argmax);
229 result
230 } else {
231 left.contract_binary::<A>(&right, ia, ib, iy)
232 }
233 }
234 }
235 }
236
237 fn execute_pairwise_with_argmax<A, T, B>(
239 &self,
240 tensors: &[&Tensor<T, B>],
241 argmax_cache: &mut Vec<Tensor<u32, B>>,
242 ) -> Tensor<T, B>
243 where
244 A: Algebra<Scalar = T, Index = u32>,
245 T: Scalar + BackendScalar<B>,
246 B: Backend + Default,
247 {
248 if tensors.is_empty() {
249 panic!("Cannot execute einsum with no tensors");
250 }
251
252 if tensors.len() == 1 {
253 if A::needs_argmax() {
254 let (result, argmax) = execute_unary_with_argmax::<A, T, B>(
255 tensors[0],
256 &self.ixs[0],
257 &self.iy,
258 &self.size_dict,
259 );
260 argmax_cache.push(argmax);
261 return result;
262 } else {
263 return execute_unary_naive::<A, T, B>(
264 tensors[0],
265 &self.ixs[0],
266 &self.iy,
267 &self.size_dict,
268 );
269 }
270 }
271
272 let mut result = tensors[0].clone();
274 let mut current_indices = self.ixs[0].clone();
275
276 for i in 1..tensors.len() {
277 let other = tensors[i];
278 let other_indices = &self.ixs[i];
279
280 let intermediate_output = if i == tensors.len() - 1 {
281 self.iy.clone()
282 } else {
283 compute_intermediate_output(¤t_indices, other_indices, &self.iy)
284 };
285
286 if A::needs_argmax() {
287 let (new_result, argmax) = result.contract_binary_with_argmax::<A>(
288 other,
289 ¤t_indices,
290 other_indices,
291 &intermediate_output,
292 );
293 argmax_cache.push(argmax);
294 result = new_result;
295 } else {
296 result = result.contract_binary::<A>(
297 other,
298 ¤t_indices,
299 other_indices,
300 &intermediate_output,
301 );
302 }
303 current_indices = intermediate_output;
304 }
305
306 result
307 }
308
309 #[allow(clippy::only_used_in_recursion)]
311 fn execute_tree<A, T, B>(
312 &self,
313 tree: &NestedEinsum<usize>,
314 tensors: &[&Tensor<T, B>],
315 ) -> Tensor<T, B>
316 where
317 A: Algebra<Scalar = T, Index = u32>,
318 T: Scalar + BackendScalar<B>,
319 B: Backend + Default,
320 {
321 match tree {
322 NestedEinsum::Leaf { tensor_index } => tensors[*tensor_index].clone(),
323 NestedEinsum::Node { args, eins } => {
324 assert_eq!(args.len(), 2, "Expected binary contraction tree");
325
326 let left = self.execute_tree::<A, T, B>(&args[0], tensors);
327 let right = self.execute_tree::<A, T, B>(&args[1], tensors);
328
329 let ia = &eins.ixs[0];
330 let ib = &eins.ixs[1];
331 let iy = &eins.iy;
332
333 left.contract_binary::<A>(&right, ia, ib, iy)
334 }
335 }
336 }
337
338 fn execute_pairwise<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
340 where
341 A: Algebra<Scalar = T, Index = u32>,
342 T: Scalar + BackendScalar<B>,
343 B: Backend + Default,
344 {
345 if tensors.is_empty() {
346 panic!("Cannot execute einsum with no tensors");
347 }
348
349 if tensors.len() == 1 {
350 return execute_unary_naive::<A, T, B>(
352 tensors[0],
353 &self.ixs[0],
354 &self.iy,
355 &self.size_dict,
356 );
357 }
358
359 let mut result = tensors[0].clone();
361 let mut current_indices = self.ixs[0].clone();
362
363 for i in 1..tensors.len() {
364 let other = tensors[i];
365 let other_indices = &self.ixs[i];
366
367 let intermediate_output = if i == tensors.len() - 1 {
369 self.iy.clone()
371 } else {
372 compute_intermediate_output(¤t_indices, other_indices, &self.iy)
374 };
375
376 result = result.contract_binary::<A>(
377 other,
378 ¤t_indices,
379 other_indices,
380 &intermediate_output,
381 );
382 current_indices = intermediate_output;
383 }
384
385 result
386 }
387}
388
389fn compute_intermediate_output(ia: &[usize], ib: &[usize], final_output: &[usize]) -> Vec<usize> {
391 let final_set: std::collections::HashSet<_> = final_output.iter().copied().collect();
392 let ia_set: std::collections::HashSet<_> = ia.iter().copied().collect();
393 let ib_set: std::collections::HashSet<_> = ib.iter().copied().collect();
394
395 let mut output = Vec::new();
397
398 for &i in ia {
399 if (final_set.contains(&i) || !ib_set.contains(&i)) && !output.contains(&i) {
400 output.push(i);
401 }
402 }
403
404 for &i in ib {
405 if (final_set.contains(&i) || !ia_set.contains(&i)) && !output.contains(&i) {
406 output.push(i);
407 }
408 }
409
410 output
411}
412
413fn linear_to_multi(mut linear: usize, shape: &[usize]) -> Vec<usize> {
427 if shape.is_empty() {
428 return vec![];
429 }
430 let mut multi = vec![0; shape.len()];
431 for i in 0..shape.len() {
432 multi[i] = linear % shape[i];
433 linear /= shape[i];
434 }
435 multi
436}
437
438fn compute_input_position(
453 ix: &[usize],
454 idx_values: &HashMap<usize, usize>,
455 shape: &[usize],
456) -> usize {
457 let mut pos = 0;
458 let mut stride = 1;
459 for (dim, &idx) in ix.iter().enumerate() {
460 pos += idx_values[&idx] * stride;
461 stride *= shape[dim];
462 }
463 pos
464}
465
466#[allow(clippy::needless_range_loop)]
489pub(crate) fn execute_unary_naive<A, T, B>(
490 tensor: &Tensor<T, B>,
491 ix: &[usize],
492 iy: &[usize],
493 size_dict: &HashMap<usize, usize>,
494) -> Tensor<T, B>
495where
496 A: Algebra<Scalar = T>,
497 T: Scalar,
498 B: Backend + Default,
499{
500 let outer: &[usize] = iy;
504 let outer_set: HashSet<usize> = outer.iter().copied().collect();
505 let mut inner_vec: Vec<usize> = Vec::new();
507 let mut seen: HashSet<usize> = HashSet::new();
508 for i in ix.iter().copied().filter(|i| !outer_set.contains(i)) {
509 if seen.insert(i) {
510 inner_vec.push(i);
511 }
512 }
513
514 let out_shape: Vec<usize> = outer.iter().map(|&idx| size_dict[&idx]).collect();
516 let out_size = out_shape.iter().product::<usize>().max(1);
517
518 let inner_ranges: Vec<usize> = inner_vec.iter().map(|&idx| size_dict[&idx]).collect();
520 let inner_size = inner_ranges.iter().product::<usize>().max(1);
521
522 let mut out_data = vec![A::zero().to_scalar(); out_size];
524
525 for out_linear in 0..out_size {
527 let out_multi = linear_to_multi(out_linear, &out_shape);
528
529 let mut idx_values: HashMap<usize, usize> = HashMap::new();
532 let mut skip_position = false;
533
534 for (&idx, &val) in outer.iter().zip(out_multi.iter()) {
535 if let Some(&existing) = idx_values.get(&idx) {
536 if existing != val {
538 skip_position = true;
539 break;
540 }
541 } else {
542 idx_values.insert(idx, val);
543 }
544 }
545
546 if skip_position {
548 continue;
550 }
551
552 let mut acc = A::zero();
554 for inner_linear in 0..inner_size {
555 let inner_multi = linear_to_multi(inner_linear, &inner_ranges);
556 for (&idx, &val) in inner_vec.iter().zip(inner_multi.iter()) {
557 idx_values.insert(idx, val);
558 }
559
560 let in_pos = compute_input_position(ix, &idx_values, tensor.shape());
562 acc = acc.add(A::from_scalar(tensor.get(in_pos)));
563 }
564
565 out_data[out_linear] = acc.to_scalar();
566 }
567
568 if out_shape.is_empty() {
569 Tensor::from_data(&out_data, &[])
570 } else {
571 Tensor::from_data(&out_data, &out_shape)
572 }
573}
574
575pub(crate) fn execute_unary_with_argmax<A, T, B>(
583 tensor: &Tensor<T, B>,
584 ix: &[usize],
585 iy: &[usize],
586 size_dict: &HashMap<usize, usize>,
587) -> (Tensor<T, B>, Tensor<u32, B>)
588where
589 A: Algebra<Scalar = T, Index = u32>,
590 T: Scalar,
591 B: Backend + Default,
592{
593 let outer: &[usize] = iy;
595 let outer_set: HashSet<usize> = outer.iter().copied().collect();
596 let mut inner_vec: Vec<usize> = Vec::new();
597 let mut seen: HashSet<usize> = HashSet::new();
598 for i in ix.iter().copied().filter(|i| !outer_set.contains(i)) {
599 if seen.insert(i) {
600 inner_vec.push(i);
601 }
602 }
603
604 let out_shape: Vec<usize> = outer.iter().map(|&idx| size_dict[&idx]).collect();
606 let out_size = out_shape.iter().product::<usize>().max(1);
607
608 let inner_ranges: Vec<usize> = inner_vec.iter().map(|&idx| size_dict[&idx]).collect();
610 let inner_size = inner_ranges.iter().product::<usize>().max(1);
611
612 let mut out_data = vec![A::zero().to_scalar(); out_size];
614 let mut argmax_data = vec![0u32; out_size];
615
616 for out_linear in 0..out_size {
618 let out_multi = linear_to_multi(out_linear, &out_shape);
619
620 let mut idx_values: HashMap<usize, usize> = HashMap::new();
622 let mut skip_position = false;
623
624 for (&idx, &val) in outer.iter().zip(out_multi.iter()) {
625 if let Some(&existing) = idx_values.get(&idx) {
626 if existing != val {
627 skip_position = true;
628 break;
629 }
630 } else {
631 idx_values.insert(idx, val);
632 }
633 }
634
635 if skip_position {
636 continue;
637 }
638
639 let mut best_val = A::zero();
641 let mut best_in_pos = 0usize;
642
643 for inner_linear in 0..inner_size {
644 let inner_multi = linear_to_multi(inner_linear, &inner_ranges);
645 for (&idx, &val) in inner_vec.iter().zip(inner_multi.iter()) {
646 idx_values.insert(idx, val);
647 }
648
649 let in_pos = compute_input_position(ix, &idx_values, tensor.shape());
650 let val = A::from_scalar(tensor.get(in_pos));
651
652 if inner_linear == 0 || A::is_better(&val, &best_val) {
654 best_val = val;
655 best_in_pos = in_pos;
656 }
657 }
658
659 out_data[out_linear] = best_val.to_scalar();
660 argmax_data[out_linear] = best_in_pos as u32;
661 }
662
663 let result = if out_shape.is_empty() {
664 Tensor::from_data(&out_data, &[])
665 } else {
666 Tensor::from_data(&out_data, &out_shape)
667 };
668
669 let argmax = if out_shape.is_empty() {
670 Tensor::from_data(&argmax_data, &[])
671 } else {
672 Tensor::from_data(&argmax_data, &out_shape)
673 };
674
675 (result, argmax)
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use crate::algebra::Standard;
682 use crate::backend::Cpu;
683
684 #[cfg(feature = "tropical")]
685 use crate::algebra::MaxPlus;
686
687 #[test]
688 fn test_einsum_matmul() {
689 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
690 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
691
692 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
693 let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
694
695 let c1 = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
697 assert_eq!(c1.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
698
699 ein.optimize_greedy();
701 let c2 = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
702 assert_eq!(c2.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
703 }
704
705 #[cfg(feature = "tropical")]
706 #[test]
707 fn test_einsum_tropical() {
708 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
709 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
710
711 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
712 let ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
713
714 let c = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
715 assert_eq!(c.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
716 }
717
718 #[test]
719 fn test_einsum_chain() {
720 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
722 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
723 let c = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
724
725 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2), (3, 2)].into();
726 let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2], vec![2, 3]], vec![0, 3], sizes);
727
728 ein.optimize_greedy();
729 let d = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b, &c]);
730
731 assert_eq!(d.shape(), &[2, 2]);
732 }
733
734 #[test]
735 fn test_einsum_trace() {
736 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
739
740 let sizes: HashMap<usize, usize> = [(0, 2)].into();
741 let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
742
743 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
744 assert_eq!(result.to_vec()[0], 5.0);
746 }
747
748 #[test]
749 fn test_einsum_diagonal() {
750 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
753
754 let sizes: HashMap<usize, usize> = [(0, 2)].into();
755 let ein = Einsum::new(vec![vec![0, 0]], vec![0], sizes);
756
757 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
758 assert_eq!(result.to_vec(), vec![1.0, 4.0]);
760 }
761
762 #[test]
763 fn test_einsum_sum_axis() {
764 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
769
770 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
771 let ein = Einsum::new(vec![vec![0, 1]], vec![0], sizes);
772
773 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
774 assert_eq!(result.to_vec(), vec![4.0, 6.0]);
776 }
777
778 #[test]
779 fn test_einsum_sum_all() {
780 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
782
783 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
784 let ein = Einsum::new(vec![vec![0, 1]], vec![], sizes);
785
786 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
787 assert_eq!(result.to_vec()[0], 10.0);
789 }
790
791 #[cfg(feature = "tropical")]
792 #[test]
793 fn test_einsum_trace_tropical() {
794 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
797
798 let sizes: HashMap<usize, usize> = [(0, 2)].into();
799 let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
800
801 let result = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a]);
802 assert_eq!(result.to_vec()[0], 4.0);
804 }
805
806 #[test]
809 fn test_linear_to_multi_empty_shape() {
810 let result = linear_to_multi(0, &[]);
812 assert_eq!(result, vec![]);
813 }
814
815 #[test]
816 fn test_linear_to_multi_1d() {
817 assert_eq!(linear_to_multi(0, &[5]), vec![0]);
819 assert_eq!(linear_to_multi(3, &[5]), vec![3]);
820 assert_eq!(linear_to_multi(4, &[5]), vec![4]);
821 }
822
823 #[test]
824 fn test_linear_to_multi_2d() {
825 assert_eq!(linear_to_multi(0, &[2, 3]), vec![0, 0]);
833 assert_eq!(linear_to_multi(1, &[2, 3]), vec![1, 0]);
834 assert_eq!(linear_to_multi(2, &[2, 3]), vec![0, 1]);
835 assert_eq!(linear_to_multi(3, &[2, 3]), vec![1, 1]);
836 assert_eq!(linear_to_multi(4, &[2, 3]), vec![0, 2]);
837 assert_eq!(linear_to_multi(5, &[2, 3]), vec![1, 2]);
838 }
839
840 #[test]
841 fn test_linear_to_multi_3d() {
842 assert_eq!(linear_to_multi(0, &[2, 3, 4]), vec![0, 0, 0]);
850 assert_eq!(linear_to_multi(1, &[2, 3, 4]), vec![1, 0, 0]);
851 assert_eq!(linear_to_multi(2, &[2, 3, 4]), vec![0, 1, 0]);
852 assert_eq!(linear_to_multi(6, &[2, 3, 4]), vec![0, 0, 1]);
853 assert_eq!(linear_to_multi(7, &[2, 3, 4]), vec![1, 0, 1]);
854 assert_eq!(linear_to_multi(23, &[2, 3, 4]), vec![1, 2, 3]);
856 }
857
858 #[test]
859 fn test_compute_input_position_1d() {
860 let ix = vec![0];
862 let shape = vec![5];
863
864 let mut idx_values = HashMap::new();
865 idx_values.insert(0, 0);
866 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
867
868 idx_values.insert(0, 3);
869 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 3);
870 }
871
872 #[test]
873 fn test_compute_input_position_2d() {
874 let ix = vec![0, 1];
877 let shape = vec![2, 3];
878
879 let mut idx_values = HashMap::new();
880
881 idx_values.insert(0, 0);
883 idx_values.insert(1, 0);
884 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
885
886 idx_values.insert(0, 1);
888 idx_values.insert(1, 0);
889 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 1);
890
891 idx_values.insert(0, 0);
893 idx_values.insert(1, 1);
894 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 2);
895
896 idx_values.insert(0, 1);
898 idx_values.insert(1, 2);
899 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 5);
900 }
901
902 #[test]
903 fn test_compute_input_position_3d() {
904 let ix = vec![0, 1, 2];
907 let shape = vec![2, 3, 4];
908
909 let mut idx_values = HashMap::new();
910
911 idx_values.insert(0, 0);
913 idx_values.insert(1, 0);
914 idx_values.insert(2, 0);
915 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 0);
916
917 idx_values.insert(0, 1);
919 idx_values.insert(1, 0);
920 idx_values.insert(2, 0);
921 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 1);
922
923 idx_values.insert(0, 0);
925 idx_values.insert(1, 1);
926 idx_values.insert(2, 0);
927 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 2);
928
929 idx_values.insert(0, 0);
931 idx_values.insert(1, 0);
932 idx_values.insert(2, 1);
933 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 6);
934
935 idx_values.insert(0, 1);
937 idx_values.insert(1, 2);
938 idx_values.insert(2, 3);
939 assert_eq!(compute_input_position(&ix, &idx_values, &shape), 23);
940 }
941
942 #[test]
943 fn test_linear_to_multi_roundtrip() {
944 let shape = vec![2, 3, 4];
946 let ix: Vec<usize> = (0..shape.len()).collect();
947 let total_size: usize = shape.iter().product();
948
949 for linear in 0..total_size {
950 let multi = linear_to_multi(linear, &shape);
951
952 let mut idx_values = HashMap::new();
954 for (dim, &val) in multi.iter().enumerate() {
955 idx_values.insert(dim, val);
956 }
957
958 let computed_pos = compute_input_position(&ix, &idx_values, &shape);
959 assert_eq!(
960 computed_pos, linear,
961 "Roundtrip failed for linear={}, multi={:?}",
962 linear, multi
963 );
964 }
965 }
966
967 #[test]
972 fn test_unary_naive_transpose() {
973 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
977
978 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
979 let ix = vec![0, 1]; let iy = vec![1, 0]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
983
984 assert_eq!(result.shape(), &[2, 2]);
986 assert_eq!(result.to_vec(), vec![1.0, 3.0, 2.0, 4.0]);
987 }
988
989 #[test]
990 fn test_unary_naive_trace() {
991 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
995
996 let size_dict: HashMap<usize, usize> = [(0, 2)].into();
997 let ix = vec![0, 0]; let iy = vec![]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1001
1002 assert_eq!(result.shape(), &[]);
1004 assert_eq!(result.to_vec()[0], 5.0);
1005 }
1006
1007 #[test]
1008 fn test_unary_naive_diagonal() {
1009 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1012
1013 let size_dict: HashMap<usize, usize> = [(0, 2)].into();
1014 let ix = vec![0, 0]; let iy = vec![0]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1018
1019 assert_eq!(result.shape(), &[2]);
1021 assert_eq!(result.to_vec(), vec![1.0, 4.0]);
1022 }
1023
1024 #[test]
1025 fn test_unary_naive_sum_axis() {
1026 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1029
1030 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1031 let ix = vec![0, 1]; let iy = vec![0]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1035
1036 assert_eq!(result.shape(), &[2]);
1040 assert_eq!(result.to_vec(), vec![4.0, 6.0]);
1041 }
1042
1043 #[test]
1044 fn test_unary_naive_sum_all() {
1045 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1047
1048 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1049 let ix = vec![0, 1]; let iy = vec![]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1053
1054 assert_eq!(result.shape(), &[]);
1056 assert_eq!(result.to_vec()[0], 10.0);
1057 }
1058
1059 #[test]
1060 fn test_unary_naive_partial_trace() {
1061 let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
1065 let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 2]);
1066
1067 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1068 let ix = vec![0, 1, 0]; let iy = vec![1]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1072
1073 assert_eq!(result.shape(), &[3]);
1079 assert_eq!(result.to_vec(), vec![9.0, 13.0, 17.0]);
1080 }
1081
1082 #[test]
1083 fn test_unary_naive_3d_transpose() {
1084 let data: Vec<f32> = (1..=8).map(|x| x as f32).collect();
1086 let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 2, 2]);
1087
1088 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1089 let ix = vec![0, 1, 2]; let iy = vec![2, 0, 1]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1093
1094 assert_eq!(result.shape(), &[2, 2, 2]);
1095
1096 let mut expected = vec![0.0f32; 8];
1100 for i in 0..2 {
1101 for j in 0..2 {
1102 for k in 0..2 {
1103 let a_pos = i + j * 2 + k * 4;
1105 let b_pos = k + i * 2 + j * 4;
1107 expected[b_pos] = data[a_pos];
1108 }
1109 }
1110 }
1111 assert_eq!(result.to_vec(), expected);
1112 }
1113
1114 #[test]
1115 fn test_unary_naive_identity() {
1116 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1118
1119 let size_dict: HashMap<usize, usize> = [(0, 2), (1, 2)].into();
1120 let ix = vec![0, 1]; let iy = vec![0, 1]; let result = execute_unary_naive::<Standard<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1124
1125 assert_eq!(result.shape(), &[2, 2]);
1126 assert_eq!(result.to_vec(), a.to_vec());
1127 }
1128
1129 #[cfg(feature = "tropical")]
1130 #[test]
1131 fn test_unary_naive_trace_tropical() {
1132 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1135
1136 let size_dict: HashMap<usize, usize> = [(0, 2)].into();
1137 let ix = vec![0, 0]; let iy = vec![]; let result = execute_unary_naive::<MaxPlus<f32>, f32, Cpu>(&a, &ix, &iy, &size_dict);
1141
1142 assert_eq!(result.shape(), &[]);
1144 assert_eq!(result.to_vec()[0], 4.0);
1145 }
1146
1147 #[test]
1148 fn test_einsum_trace_optimized() {
1149 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1152
1153 let sizes: HashMap<usize, usize> = [(0, 2)].into();
1154 let mut ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1155
1156 ein.optimize_greedy();
1158 assert!(ein.is_optimized());
1159
1160 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1161
1162 assert_eq!(result.to_vec()[0], 5.0);
1164 }
1165
1166 #[test]
1167 fn test_einsum_unary_with_argmax_optimized() {
1168 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1170
1171 let sizes: HashMap<usize, usize> = [(0, 2)].into();
1172 let mut ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1173
1174 ein.optimize_greedy();
1175 let (result, argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a]);
1176
1177 assert_eq!(result.to_vec()[0], 5.0);
1179 assert!(argmax_cache.is_empty());
1181 }
1182
1183 #[test]
1184 fn test_einsum_unary_pairwise_path() {
1185 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1187
1188 let sizes: HashMap<usize, usize> = [(0, 2)].into();
1189 let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1190
1191 assert!(!ein.is_optimized());
1193
1194 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1195
1196 assert_eq!(result.to_vec()[0], 5.0);
1198 }
1199
1200 #[test]
1201 fn test_einsum_unary_with_argmax_pairwise() {
1202 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1204
1205 let sizes: HashMap<usize, usize> = [(0, 2)].into();
1206 let ein = Einsum::new(vec![vec![0, 0]], vec![], sizes);
1207
1208 let (result, argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a]);
1210
1211 assert_eq!(result.to_vec()[0], 5.0);
1213 assert!(argmax_cache.is_empty());
1215 }
1216
1217 #[cfg(feature = "tropical")]
1218 #[test]
1219 fn test_einsum_with_argmax_tropical() {
1220 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1222 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1223
1224 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1225 let mut ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
1226
1227 ein.optimize_greedy();
1228 let (result, argmax_cache) = ein.execute_with_argmax::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
1229
1230 assert_eq!(result.shape(), &[2, 2]);
1232 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
1233
1234 assert!(!argmax_cache.is_empty());
1236 }
1237
1238 #[cfg(feature = "tropical")]
1239 #[test]
1240 fn test_einsum_with_argmax_tropical_pairwise() {
1241 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1243 let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1244
1245 let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
1246 let ein = Einsum::new(vec![vec![0, 1], vec![1, 2]], vec![0, 2], sizes);
1247
1248 let (result, argmax_cache) = ein.execute_with_argmax::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
1250
1251 assert_eq!(result.shape(), &[2, 2]);
1252 assert_eq!(result.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
1253
1254 assert!(!argmax_cache.is_empty());
1256 }
1257
1258 #[test]
1259 fn test_einsum_transpose_optimized() {
1260 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1262
1263 let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1264 let mut ein = Einsum::new(vec![vec![0, 1]], vec![1, 0], sizes);
1265
1266 ein.optimize_greedy();
1267 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a]);
1268
1269 assert_eq!(result.shape(), &[3, 2]);
1270 assert_eq!(result.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
1274 }
1275
1276 #[test]
1277 fn test_intermediate_output_computation() {
1278 let output = compute_intermediate_output(&[0, 1], &[1, 2], &[0, 2]);
1281 assert!(output.contains(&0));
1282 assert!(output.contains(&2));
1283 assert!(!output.contains(&1));
1284 }
1285
1286 #[test]
1287 fn test_outer_product_pairwise() {
1288 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1290 let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1291
1292 let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1293 let ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1294
1295 assert!(!ein.is_optimized());
1297 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1298
1299 assert_eq!(result.shape(), &[2, 3]);
1300 assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1304 }
1305
1306 #[test]
1307 fn test_outer_product_optimized() {
1308 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1312 let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1313
1314 let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1315 let mut ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1316
1317 ein.optimize_greedy();
1318 let result = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1319
1320 assert_eq!(result.shape(), &[2, 3]);
1321 assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1322 }
1323
1324 #[test]
1325 fn test_outer_product_with_argmax() {
1326 let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0], &[2]);
1328 let b = Tensor::<f32, Cpu>::from_data(&[3.0, 4.0, 5.0], &[3]);
1329
1330 let sizes: HashMap<usize, usize> = [(0, 2), (1, 3)].into();
1331 let mut ein = Einsum::new(vec![vec![0], vec![1]], vec![0, 1], sizes);
1332
1333 ein.optimize_greedy();
1334 let (result, _argmax_cache) = ein.execute_with_argmax::<Standard<f32>, f32, Cpu>(&[&a, &b]);
1335
1336 assert_eq!(result.shape(), &[2, 3]);
1337 assert_eq!(result.to_vec(), vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0]);
1338 }
1339}