omeinsum/tensor/
ops.rs

1//! Tensor operations for contraction.
2
3use super::Tensor;
4use crate::algebra::{Algebra, Scalar};
5use crate::backend::{Backend, BackendScalar};
6
7/// Compute output shape from input shapes and modes.
8fn compute_output_shape(
9    shape_a: &[usize],
10    modes_a: &[i32],
11    shape_b: &[usize],
12    modes_b: &[i32],
13    modes_c: &[i32],
14) -> Vec<usize> {
15    let mut shape_map = std::collections::HashMap::new();
16    for (idx, &m) in modes_a.iter().enumerate() {
17        shape_map.insert(m, shape_a[idx]);
18    }
19    for (idx, &m) in modes_b.iter().enumerate() {
20        shape_map.insert(m, shape_b[idx]);
21    }
22    modes_c.iter().map(|m| shape_map[m]).collect()
23}
24
25impl<T: Scalar, B: Backend> Tensor<T, B> {
26    /// Binary tensor contraction using reshape-to-GEMM strategy.
27    ///
28    /// # Arguments
29    ///
30    /// * `other` - The other tensor to contract with
31    /// * `ia` - Index labels for self
32    /// * `ib` - Index labels for other
33    /// * `iy` - Output index labels
34    ///
35    /// # Example
36    ///
37    /// ```rust
38    /// use omeinsum::{Tensor, Cpu};
39    /// use omeinsum::algebra::MaxPlus;
40    ///
41    /// // A[i,j,k] × B[j,k,l] → C[i,l]
42    /// let a = Tensor::<f32, Cpu>::from_data(&(0..24).map(|x| x as f32).collect::<Vec<_>>(), &[2, 3, 4]);
43    /// let b = Tensor::<f32, Cpu>::from_data(&(0..60).map(|x| x as f32).collect::<Vec<_>>(), &[3, 4, 5]);
44    /// let c = a.contract_binary::<MaxPlus<f32>>(&b, &[0, 1, 2], &[1, 2, 3], &[0, 3]);
45    /// assert_eq!(c.shape(), &[2, 5]);
46    /// ```
47    pub fn contract_binary<A: Algebra<Scalar = T, Index = u32>>(
48        &self,
49        other: &Self,
50        ia: &[usize],
51        ib: &[usize],
52        iy: &[usize],
53    ) -> Self
54    where
55        T: BackendScalar<B>,
56    {
57        let (result, _) = self.contract_binary_impl::<A>(other, ia, ib, iy, false);
58        result
59    }
60
61    /// Binary contraction with argmax tracking.
62    pub fn contract_binary_with_argmax<A: Algebra<Scalar = T, Index = u32>>(
63        &self,
64        other: &Self,
65        ia: &[usize],
66        ib: &[usize],
67        iy: &[usize],
68    ) -> (Self, Tensor<u32, B>)
69    where
70        T: BackendScalar<B>,
71    {
72        let (result, argmax) = self.contract_binary_impl::<A>(other, ia, ib, iy, true);
73        (result, argmax.expect("argmax requested but not returned"))
74    }
75
76    fn contract_binary_impl<A: Algebra<Scalar = T, Index = u32>>(
77        &self,
78        other: &Self,
79        ia: &[usize],
80        ib: &[usize],
81        iy: &[usize],
82        track_argmax: bool,
83    ) -> (Self, Option<Tensor<u32, B>>)
84    where
85        T: BackendScalar<B>,
86    {
87        assert_eq!(ia.len(), self.ndim(), "ia length must match self.ndim()");
88        assert_eq!(ib.len(), other.ndim(), "ib length must match other.ndim()");
89
90        // Convert usize indices to i32 modes
91        let modes_a: Vec<i32> = ia.iter().map(|&i| i as i32).collect();
92        let modes_b: Vec<i32> = ib.iter().map(|&i| i as i32).collect();
93        let modes_c: Vec<i32> = iy.iter().map(|&i| i as i32).collect();
94
95        // Compute output shape
96        let shape_c = compute_output_shape(
97            self.shape(), &modes_a,
98            other.shape(), &modes_b,
99            &modes_c,
100        );
101
102        if track_argmax {
103            let (c_storage, argmax_storage) = self.backend.contract_with_argmax::<A>(
104                self.storage.as_ref(),
105                self.shape(),
106                self.strides(),
107                &modes_a,
108                other.storage.as_ref(),
109                other.shape(),
110                other.strides(),
111                &modes_b,
112                &shape_c,
113                &modes_c,
114            );
115
116            let c = Self::from_storage(c_storage, &shape_c, self.backend.clone());
117            let argmax = Tensor::<u32, B>::from_storage(
118                argmax_storage,
119                &shape_c,
120                self.backend.clone(),
121            );
122            (c, Some(argmax))
123        } else {
124            let c_storage = self.backend.contract::<A>(
125                self.storage.as_ref(),
126                self.shape(),
127                self.strides(),
128                &modes_a,
129                other.storage.as_ref(),
130                other.shape(),
131                other.strides(),
132                &modes_b,
133                &shape_c,
134                &modes_c,
135            );
136
137            let c = Self::from_storage(c_storage, &shape_c, self.backend.clone());
138            (c, None)
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::algebra::Standard;
147    use crate::backend::Cpu;
148
149    #[cfg(feature = "tropical")]
150    use crate::algebra::MaxPlus;
151
152    #[test]
153    fn test_contract_binary_matmul_standard() {
154        // A[i,j] × B[j,k] → C[i,k] (matrix multiplication)
155        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
156        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
157
158        let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
159
160        assert_eq!(c.shape(), &[2, 2]);
161        assert_eq!(c.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
162    }
163
164    #[cfg(feature = "tropical")]
165    #[test]
166    fn test_contract_binary_matmul_maxplus() {
167        // A[i,j] × B[j,k] → C[i,k] (tropical matrix multiplication)
168        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
169        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
170
171        let c = a.contract_binary::<MaxPlus<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
172
173        assert_eq!(c.shape(), &[2, 2]);
174        assert_eq!(c.to_vec(), vec![5.0, 6.0, 7.0, 8.0]);
175    }
176
177    #[test]
178    fn test_contract_binary() {
179        // A[i,j] × B[j,k] → C[i,k]
180        let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
181        let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
182
183        let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1], &[1, 2], &[0, 2]);
184
185        assert_eq!(c.shape(), &[2, 2]);
186        assert_eq!(c.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
187    }
188
189    #[test]
190    fn test_contract_binary_batched() {
191        // A[b,i,j] × B[b,j,k] → C[b,i,k]
192        // 2 batches, 2x2 matrices
193        // Column-major layout: A[b,i,j] at position b + 2*i + 4*j
194        let a =
195            Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 2, 2]);
196        let b =
197            Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0], &[2, 2, 2]);
198
199        let c = a.contract_binary::<Standard<f32>>(&b, &[0, 1, 2], &[0, 2, 3], &[0, 1, 3]);
200
201        assert_eq!(c.shape(), &[2, 2, 2]);
202        // In column-major [2,2,2]:
203        // Batch 0 of A: [[1,5],[3,7]], Batch 1 of A: [[2,6],[4,8]]
204        // Batch 0 of B: [[1,1],[3,0]], Batch 1 of B: [[2,0],[4,1]]
205        // Batch 0 result: [[16,1],[24,3]], Batch 1 result: [[28,6],[40,8]]
206        // Column-major output: [16, 28, 24, 40, 1, 6, 3, 8]
207        assert_eq!(c.to_vec(), vec![16.0, 28.0, 24.0, 40.0, 1.0, 6.0, 3.0, 8.0]);
208    }
209}