omeinsum/einsum/
builder.rs

1//! Builder pattern for einsum construction.
2
3use std::collections::HashMap;
4
5use omeco::Label;
6
7use super::Einsum;
8
9/// Builder for constructing einsum specifications.
10///
11/// # Example
12///
13/// ```rust
14/// use omeinsum::EinBuilder;
15///
16/// let ein = EinBuilder::new()
17///     .input(&[0, 1])      // A[i,j]
18///     .input(&[1, 2])      // B[j,k]
19///     .output(&[0, 2])     // C[i,k]
20///     .size(0, 10)         // i has size 10
21///     .size(1, 20)         // j has size 20
22///     .size(2, 30)         // k has size 30
23///     .build();
24/// ```
25pub struct EinBuilder<L: Label = usize> {
26    ixs: Vec<Vec<L>>,
27    iy: Option<Vec<L>>,
28    size_dict: HashMap<L, usize>,
29}
30
31impl<L: Label> Default for EinBuilder<L> {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<L: Label> EinBuilder<L> {
38    /// Create a new builder.
39    pub fn new() -> Self {
40        Self {
41            ixs: Vec::new(),
42            iy: None,
43            size_dict: HashMap::new(),
44        }
45    }
46
47    /// Add an input tensor specification.
48    pub fn input(mut self, indices: &[L]) -> Self {
49        self.ixs.push(indices.to_vec());
50        self
51    }
52
53    /// Set the output specification.
54    pub fn output(mut self, indices: &[L]) -> Self {
55        self.iy = Some(indices.to_vec());
56        self
57    }
58
59    /// Set the size for an index.
60    pub fn size(mut self, index: L, size: usize) -> Self {
61        self.size_dict.insert(index, size);
62        self
63    }
64
65    /// Set multiple sizes at once.
66    pub fn sizes(mut self, sizes: impl IntoIterator<Item = (L, usize)>) -> Self {
67        self.size_dict.extend(sizes);
68        self
69    }
70
71    /// Build the einsum specification.
72    ///
73    /// # Panics
74    ///
75    /// Panics if no output is specified or if sizes are missing.
76    pub fn build(self) -> Einsum<L> {
77        let iy = self.iy.expect("Output indices not specified");
78
79        // Validate all indices have sizes
80        for ix in &self.ixs {
81            for i in ix {
82                assert!(
83                    self.size_dict.contains_key(i),
84                    "Size not specified for index {:?}",
85                    i
86                );
87            }
88        }
89        for i in &iy {
90            assert!(
91                self.size_dict.contains_key(i),
92                "Size not specified for output index {:?}",
93                i
94            );
95        }
96
97        Einsum::new(self.ixs, iy, self.size_dict)
98    }
99}
100
101/// Convenience macro for creating einsum specifications.
102///
103/// Note: This macro uses identifier patterns for labels (like `i`, `j`, `k`),
104/// not numeric literals.
105///
106/// # Example
107///
108/// ```rust,no_run
109/// use omeinsum::ein;
110///
111/// // A[i,j] × B[j,k] → C[i,k]
112/// let (i, j, k) = (0, 1, 2);
113/// let ein = ein!([i, j], [j, k] -> [i, k]; i=10, j=20, k=30);
114/// ```
115#[macro_export]
116macro_rules! ein {
117    // Parse: [ix1], [ix2], ... -> [iy]; sizes
118    ($([$($ix:expr),*]),+ -> [$($iy:expr),*]; $($label:ident = $size:expr),*) => {{
119        let mut builder = $crate::EinBuilder::new();
120        $(
121            builder = builder.input(&[$($ix),*]);
122        )+
123        builder = builder.output(&[$($iy),*]);
124        $(
125            builder = builder.size($label, $size);
126        )*
127        builder.build()
128    }};
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_builder() {
137        let ein: Einsum<usize> = EinBuilder::new()
138            .input(&[0, 1])
139            .input(&[1, 2])
140            .output(&[0, 2])
141            .size(0, 10)
142            .size(1, 20)
143            .size(2, 30)
144            .build();
145
146        assert_eq!(ein.ixs, vec![vec![0, 1], vec![1, 2]]);
147        assert_eq!(ein.iy, vec![0, 2]);
148        assert_eq!(ein.size_dict.get(&0), Some(&10));
149        assert_eq!(ein.size_dict.get(&1), Some(&20));
150        assert_eq!(ein.size_dict.get(&2), Some(&30));
151    }
152
153    #[test]
154    fn test_builder_with_chars() {
155        let ein: Einsum<char> = EinBuilder::new()
156            .input(&['i', 'j'])
157            .input(&['j', 'k'])
158            .output(&['i', 'k'])
159            .size('i', 10)
160            .size('j', 20)
161            .size('k', 30)
162            .build();
163
164        assert_eq!(ein.ixs, vec![vec!['i', 'j'], vec!['j', 'k']]);
165        assert_eq!(ein.iy, vec!['i', 'k']);
166    }
167}