omeinsum/backend/
traits.rs

1//! Backend trait definitions.
2
3use crate::algebra::{Algebra, Scalar};
4
5/// Storage trait for tensor data.
6///
7/// Abstracts over different storage backends (CPU memory, GPU memory).
8pub trait Storage<T: Scalar>: Clone + Send + Sync + Sized {
9    /// Number of elements in storage.
10    fn len(&self) -> usize;
11
12    /// Check if storage is empty.
13    fn is_empty(&self) -> bool {
14        self.len() == 0
15    }
16
17    /// Get element at index (may be slow for GPU).
18    fn get(&self, index: usize) -> T;
19
20    /// Set element at index (may be slow for GPU).
21    fn set(&mut self, index: usize, value: T);
22
23    /// Copy all data to a Vec (downloads from GPU if needed).
24    fn to_vec(&self) -> Vec<T>;
25
26    /// Create storage from slice.
27    fn from_slice(data: &[T]) -> Self;
28
29    /// Create zero-initialized storage.
30    fn zeros(len: usize) -> Self;
31}
32
33/// Marker trait for scalar types supported by a specific backend.
34///
35/// This enables compile-time checking that a scalar type is supported
36/// by a particular backend (e.g., CUDA only supports f32/f64/complex).
37pub trait BackendScalar<B: Backend>: Scalar {}
38
39/// Backend trait for tensor execution.
40///
41/// Defines how tensor operations are executed on different hardware.
42pub trait Backend: Clone + Send + Sync + 'static {
43    /// Storage type for this backend.
44    type Storage<T: Scalar>: Storage<T>;
45
46    /// Backend name for debugging.
47    fn name() -> &'static str;
48
49    /// Synchronize all pending operations.
50    fn synchronize(&self);
51
52    /// Allocate storage.
53    fn alloc<T: Scalar>(&self, len: usize) -> Self::Storage<T>;
54
55    /// Create storage from slice.
56    #[allow(clippy::wrong_self_convention)]
57    fn from_slice<T: Scalar>(&self, data: &[T]) -> Self::Storage<T>;
58
59    /// Copy strided data to contiguous storage.
60    ///
61    /// This is the core operation for making non-contiguous tensors contiguous.
62    fn copy_strided<T: Scalar>(
63        &self,
64        src: &Self::Storage<T>,
65        shape: &[usize],
66        strides: &[usize],
67        offset: usize,
68    ) -> Self::Storage<T>;
69
70    /// Binary tensor contraction.
71    ///
72    /// Computes a generalized tensor contraction: `C[modes_c] = Σ A[modes_a] ⊗ B[modes_b]`
73    /// where the sum (using semiring addition) is over indices appearing in both A and B
74    /// but not in the output C.
75    ///
76    /// # Mode Labels
77    ///
78    /// Each mode (dimension) of the input tensors is labeled with a unique integer identifier.
79    /// These labels determine how the contraction is performed:
80    ///
81    /// - **Contracted indices**: Labels appearing in both `modes_a` and `modes_b` but NOT in
82    ///   `modes_c`. These dimensions are summed over (reduced).
83    /// - **Free indices from A**: Labels appearing only in `modes_a`. These appear in the output.
84    /// - **Free indices from B**: Labels appearing only in `modes_b`. These appear in the output.
85    /// - **Batch indices**: Labels appearing in `modes_a`, `modes_b`, AND `modes_c`.
86    ///   These dimensions are preserved and processed in parallel.
87    ///
88    /// # Arguments
89    ///
90    /// * `a` - Storage for first input tensor
91    /// * `shape_a` - Shape (dimensions) of tensor A
92    /// * `strides_a` - Strides for tensor A (column-major, supports non-contiguous tensors)
93    /// * `modes_a` - Mode labels for tensor A (length must equal `shape_a.len()`)
94    /// * `b` - Storage for second input tensor
95    /// * `shape_b` - Shape of tensor B
96    /// * `strides_b` - Strides for tensor B
97    /// * `modes_b` - Mode labels for tensor B (length must equal `shape_b.len()`)
98    /// * `shape_c` - Shape of output tensor C (must be consistent with `modes_c`)
99    /// * `modes_c` - Mode labels for output tensor C (determines output structure)
100    ///
101    /// # Returns
102    ///
103    /// Contiguous storage containing the result tensor with shape `shape_c`.
104    ///
105    /// # Examples
106    ///
107    /// ## Matrix multiplication: `C[i,k] = Σⱼ A[i,j] ⊗ B[j,k]`
108    ///
109    /// ```ignore
110    /// // A is 2×3, B is 3×4 -> C is 2×4
111    /// let c = backend.contract::<Standard<f32>>(
112    ///     &a, &[2, 3], &[1, 2], &[0, 1],  // A[i=0, j=1], shape 2×3
113    ///     &b, &[3, 4], &[1, 3], &[1, 2],  // B[j=1, k=2], shape 3×4
114    ///     &[2, 4], &[0, 2],               // C[i=0, k=2], shape 2×4
115    /// );
116    /// ```
117    ///
118    /// ## Batched matrix multiplication: `C[b,i,k] = Σⱼ A[b,i,j] ⊗ B[b,j,k]`
119    ///
120    /// ```ignore
121    /// // Batch size 8, A is 2×3, B is 3×4 -> C is 8×2×4
122    /// let c = backend.contract::<Standard<f32>>(
123    ///     &a, &[8, 2, 3], &[1, 8, 16], &[0, 1, 2],  // A[b=0, i=1, j=2]
124    ///     &b, &[8, 3, 4], &[1, 8, 24], &[0, 2, 3],  // B[b=0, j=2, k=3]
125    ///     &[8, 2, 4], &[0, 1, 3],                    // C[b=0, i=1, k=3]
126    /// );
127    /// ```
128    ///
129    /// ## Tropical shortest path (with min-plus semiring)
130    ///
131    /// ```ignore
132    /// // Find shortest paths via matrix multiplication in (min,+) semiring
133    /// let distances = backend.contract::<MinPlus<f32>>(
134    ///     &graph_a, &[n, n], &[1, n], &[0, 1],
135    ///     &graph_b, &[n, n], &[1, n], &[1, 2],
136    ///     &[n, n], &[0, 2],
137    /// );
138    /// ```
139    ///
140    /// # Panics
141    ///
142    /// Panics if:
143    /// - Mode labels have inconsistent sizes across tensors (e.g., if mode 1 has size 3
144    ///   in A but size 4 in B)
145    /// - The scalar type is not supported by the backend (compile-time check via `BackendScalar`)
146    #[allow(clippy::too_many_arguments)]
147    fn contract<A: Algebra>(
148        &self,
149        a: &Self::Storage<A::Scalar>,
150        shape_a: &[usize],
151        strides_a: &[usize],
152        modes_a: &[i32],
153        b: &Self::Storage<A::Scalar>,
154        shape_b: &[usize],
155        strides_b: &[usize],
156        modes_b: &[i32],
157        shape_c: &[usize],
158        modes_c: &[i32],
159    ) -> Self::Storage<A::Scalar>
160    where
161        A::Scalar: BackendScalar<Self>;
162
163    /// Contraction with argmax tracking for tropical backpropagation.
164    ///
165    /// This is identical to [`Backend::contract`] but additionally returns an argmax
166    /// tensor that tracks which contracted index "won" the reduction at each output
167    /// position. This is essential for tropical algebra backward passes where gradients
168    /// are routed through the winning path only.
169    ///
170    /// # Returns
171    ///
172    /// A tuple of:
173    /// - `result`: The contraction result (same as `contract`)
174    /// - `argmax`: Tensor of `u32` indices indicating which contracted index won
175    ///   at each output position
176    ///
177    /// # Use Cases
178    ///
179    /// - Tropical backpropagation (Viterbi, shortest path)
180    /// - Computing attention patterns in max-pooling operations
181    /// - Any semiring where addition is idempotent and gradient routing matters
182    #[allow(clippy::too_many_arguments)]
183    fn contract_with_argmax<A: Algebra<Index = u32>>(
184        &self,
185        a: &Self::Storage<A::Scalar>,
186        shape_a: &[usize],
187        strides_a: &[usize],
188        modes_a: &[i32],
189        b: &Self::Storage<A::Scalar>,
190        shape_b: &[usize],
191        strides_b: &[usize],
192        modes_b: &[i32],
193        shape_c: &[usize],
194        modes_c: &[i32],
195    ) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)
196    where
197        A::Scalar: BackendScalar<Self>;
198
199}
200
201// CPU supports all Scalar types
202impl<T: Scalar> BackendScalar<crate::backend::Cpu> for T {}
203
204// CUDA supports f32, f64, and CudaComplex types
205#[cfg(feature = "cuda")]
206impl BackendScalar<crate::backend::Cuda> for f32 {}
207#[cfg(feature = "cuda")]
208impl BackendScalar<crate::backend::Cuda> for f64 {}
209#[cfg(feature = "cuda")]
210impl BackendScalar<crate::backend::Cuda> for crate::backend::CudaComplex<f32> {}
211#[cfg(feature = "cuda")]
212impl BackendScalar<crate::backend::Cuda> for crate::backend::CudaComplex<f64> {}