omeinsum/einsum/
builder.rs1use std::collections::HashMap;
4
5use omeco::Label;
6
7use super::Einsum;
8
9pub struct EinBuilder<L: Label = usize> {
26 ixs: Vec<Vec<L>>,
27 iy: Option<Vec<L>>,
28 size_dict: HashMap<L, usize>,
29}
30
31impl<L: Label> Default for EinBuilder<L> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<L: Label> EinBuilder<L> {
38 pub fn new() -> Self {
40 Self {
41 ixs: Vec::new(),
42 iy: None,
43 size_dict: HashMap::new(),
44 }
45 }
46
47 pub fn input(mut self, indices: &[L]) -> Self {
49 self.ixs.push(indices.to_vec());
50 self
51 }
52
53 pub fn output(mut self, indices: &[L]) -> Self {
55 self.iy = Some(indices.to_vec());
56 self
57 }
58
59 pub fn size(mut self, index: L, size: usize) -> Self {
61 self.size_dict.insert(index, size);
62 self
63 }
64
65 pub fn sizes(mut self, sizes: impl IntoIterator<Item = (L, usize)>) -> Self {
67 self.size_dict.extend(sizes);
68 self
69 }
70
71 pub fn build(self) -> Einsum<L> {
77 let iy = self.iy.expect("Output indices not specified");
78
79 for ix in &self.ixs {
81 for i in ix {
82 assert!(
83 self.size_dict.contains_key(i),
84 "Size not specified for index {:?}",
85 i
86 );
87 }
88 }
89 for i in &iy {
90 assert!(
91 self.size_dict.contains_key(i),
92 "Size not specified for output index {:?}",
93 i
94 );
95 }
96
97 Einsum::new(self.ixs, iy, self.size_dict)
98 }
99}
100
101#[macro_export]
116macro_rules! ein {
117 ($([$($ix:expr),*]),+ -> [$($iy:expr),*]; $($label:ident = $size:expr),*) => {{
119 let mut builder = $crate::EinBuilder::new();
120 $(
121 builder = builder.input(&[$($ix),*]);
122 )+
123 builder = builder.output(&[$($iy),*]);
124 $(
125 builder = builder.size($label, $size);
126 )*
127 builder.build()
128 }};
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_builder() {
137 let ein: Einsum<usize> = EinBuilder::new()
138 .input(&[0, 1])
139 .input(&[1, 2])
140 .output(&[0, 2])
141 .size(0, 10)
142 .size(1, 20)
143 .size(2, 30)
144 .build();
145
146 assert_eq!(ein.ixs, vec![vec![0, 1], vec![1, 2]]);
147 assert_eq!(ein.iy, vec![0, 2]);
148 assert_eq!(ein.size_dict.get(&0), Some(&10));
149 assert_eq!(ein.size_dict.get(&1), Some(&20));
150 assert_eq!(ein.size_dict.get(&2), Some(&30));
151 }
152
153 #[test]
154 fn test_builder_with_chars() {
155 let ein: Einsum<char> = EinBuilder::new()
156 .input(&['i', 'j'])
157 .input(&['j', 'k'])
158 .output(&['i', 'k'])
159 .size('i', 10)
160 .size('j', 20)
161 .size('k', 30)
162 .build();
163
164 assert_eq!(ein.ixs, vec![vec!['i', 'j'], vec!['j', 'k']]);
165 assert_eq!(ein.iy, vec!['i', 'k']);
166 }
167}