omeinsum/backend/traits.rs
1//! Backend trait definitions.
2
3use crate::algebra::{Algebra, Scalar};
4
5/// Storage trait for tensor data.
6///
7/// Abstracts over different storage backends (CPU memory, GPU memory).
8pub trait Storage<T: Scalar>: Clone + Send + Sync + Sized {
9 /// Number of elements in storage.
10 fn len(&self) -> usize;
11
12 /// Check if storage is empty.
13 fn is_empty(&self) -> bool {
14 self.len() == 0
15 }
16
17 /// Get element at index (may be slow for GPU).
18 fn get(&self, index: usize) -> T;
19
20 /// Set element at index (may be slow for GPU).
21 fn set(&mut self, index: usize, value: T);
22
23 /// Copy all data to a Vec (downloads from GPU if needed).
24 fn to_vec(&self) -> Vec<T>;
25
26 /// Create storage from slice.
27 fn from_slice(data: &[T]) -> Self;
28
29 /// Create zero-initialized storage.
30 fn zeros(len: usize) -> Self;
31}
32
33/// Marker trait for scalar types supported by a specific backend.
34///
35/// This enables compile-time checking that a scalar type is supported
36/// by a particular backend (e.g., CUDA only supports f32/f64/complex).
37pub trait BackendScalar<B: Backend>: Scalar {}
38
39/// Backend trait for tensor execution.
40///
41/// Defines how tensor operations are executed on different hardware.
42pub trait Backend: Clone + Send + Sync + 'static {
43 /// Storage type for this backend.
44 type Storage<T: Scalar>: Storage<T>;
45
46 /// Backend name for debugging.
47 fn name() -> &'static str;
48
49 /// Synchronize all pending operations.
50 fn synchronize(&self);
51
52 /// Allocate storage.
53 fn alloc<T: Scalar>(&self, len: usize) -> Self::Storage<T>;
54
55 /// Create storage from slice.
56 #[allow(clippy::wrong_self_convention)]
57 fn from_slice<T: Scalar>(&self, data: &[T]) -> Self::Storage<T>;
58
59 /// Copy strided data to contiguous storage.
60 ///
61 /// This is the core operation for making non-contiguous tensors contiguous.
62 fn copy_strided<T: Scalar>(
63 &self,
64 src: &Self::Storage<T>,
65 shape: &[usize],
66 strides: &[usize],
67 offset: usize,
68 ) -> Self::Storage<T>;
69
70 /// Binary tensor contraction.
71 ///
72 /// Computes a generalized tensor contraction: `C[modes_c] = Σ A[modes_a] ⊗ B[modes_b]`
73 /// where the sum (using semiring addition) is over indices appearing in both A and B
74 /// but not in the output C.
75 ///
76 /// # Mode Labels
77 ///
78 /// Each mode (dimension) of the input tensors is labeled with a unique integer identifier.
79 /// These labels determine how the contraction is performed:
80 ///
81 /// - **Contracted indices**: Labels appearing in both `modes_a` and `modes_b` but NOT in
82 /// `modes_c`. These dimensions are summed over (reduced).
83 /// - **Free indices from A**: Labels appearing only in `modes_a`. These appear in the output.
84 /// - **Free indices from B**: Labels appearing only in `modes_b`. These appear in the output.
85 /// - **Batch indices**: Labels appearing in `modes_a`, `modes_b`, AND `modes_c`.
86 /// These dimensions are preserved and processed in parallel.
87 ///
88 /// # Arguments
89 ///
90 /// * `a` - Storage for first input tensor
91 /// * `shape_a` - Shape (dimensions) of tensor A
92 /// * `strides_a` - Strides for tensor A (column-major, supports non-contiguous tensors)
93 /// * `modes_a` - Mode labels for tensor A (length must equal `shape_a.len()`)
94 /// * `b` - Storage for second input tensor
95 /// * `shape_b` - Shape of tensor B
96 /// * `strides_b` - Strides for tensor B
97 /// * `modes_b` - Mode labels for tensor B (length must equal `shape_b.len()`)
98 /// * `shape_c` - Shape of output tensor C (must be consistent with `modes_c`)
99 /// * `modes_c` - Mode labels for output tensor C (determines output structure)
100 ///
101 /// # Returns
102 ///
103 /// Contiguous storage containing the result tensor with shape `shape_c`.
104 ///
105 /// # Examples
106 ///
107 /// ## Matrix multiplication: `C[i,k] = Σⱼ A[i,j] ⊗ B[j,k]`
108 ///
109 /// ```ignore
110 /// // A is 2×3, B is 3×4 -> C is 2×4
111 /// let c = backend.contract::<Standard<f32>>(
112 /// &a, &[2, 3], &[1, 2], &[0, 1], // A[i=0, j=1], shape 2×3
113 /// &b, &[3, 4], &[1, 3], &[1, 2], // B[j=1, k=2], shape 3×4
114 /// &[2, 4], &[0, 2], // C[i=0, k=2], shape 2×4
115 /// );
116 /// ```
117 ///
118 /// ## Batched matrix multiplication: `C[b,i,k] = Σⱼ A[b,i,j] ⊗ B[b,j,k]`
119 ///
120 /// ```ignore
121 /// // Batch size 8, A is 2×3, B is 3×4 -> C is 8×2×4
122 /// let c = backend.contract::<Standard<f32>>(
123 /// &a, &[8, 2, 3], &[1, 8, 16], &[0, 1, 2], // A[b=0, i=1, j=2]
124 /// &b, &[8, 3, 4], &[1, 8, 24], &[0, 2, 3], // B[b=0, j=2, k=3]
125 /// &[8, 2, 4], &[0, 1, 3], // C[b=0, i=1, k=3]
126 /// );
127 /// ```
128 ///
129 /// ## Tropical shortest path (with min-plus semiring)
130 ///
131 /// ```ignore
132 /// // Find shortest paths via matrix multiplication in (min,+) semiring
133 /// let distances = backend.contract::<MinPlus<f32>>(
134 /// &graph_a, &[n, n], &[1, n], &[0, 1],
135 /// &graph_b, &[n, n], &[1, n], &[1, 2],
136 /// &[n, n], &[0, 2],
137 /// );
138 /// ```
139 ///
140 /// # Panics
141 ///
142 /// Panics if:
143 /// - Mode labels have inconsistent sizes across tensors (e.g., if mode 1 has size 3
144 /// in A but size 4 in B)
145 /// - The scalar type is not supported by the backend (compile-time check via `BackendScalar`)
146 #[allow(clippy::too_many_arguments)]
147 fn contract<A: Algebra>(
148 &self,
149 a: &Self::Storage<A::Scalar>,
150 shape_a: &[usize],
151 strides_a: &[usize],
152 modes_a: &[i32],
153 b: &Self::Storage<A::Scalar>,
154 shape_b: &[usize],
155 strides_b: &[usize],
156 modes_b: &[i32],
157 shape_c: &[usize],
158 modes_c: &[i32],
159 ) -> Self::Storage<A::Scalar>
160 where
161 A::Scalar: BackendScalar<Self>;
162
163 /// Contraction with argmax tracking for tropical backpropagation.
164 ///
165 /// This is identical to [`Backend::contract`] but additionally returns an argmax
166 /// tensor that tracks which contracted index "won" the reduction at each output
167 /// position. This is essential for tropical algebra backward passes where gradients
168 /// are routed through the winning path only.
169 ///
170 /// # Returns
171 ///
172 /// A tuple of:
173 /// - `result`: The contraction result (same as `contract`)
174 /// - `argmax`: Tensor of `u32` indices indicating which contracted index won
175 /// at each output position
176 ///
177 /// # Use Cases
178 ///
179 /// - Tropical backpropagation (Viterbi, shortest path)
180 /// - Computing attention patterns in max-pooling operations
181 /// - Any semiring where addition is idempotent and gradient routing matters
182 #[allow(clippy::too_many_arguments)]
183 fn contract_with_argmax<A: Algebra<Index = u32>>(
184 &self,
185 a: &Self::Storage<A::Scalar>,
186 shape_a: &[usize],
187 strides_a: &[usize],
188 modes_a: &[i32],
189 b: &Self::Storage<A::Scalar>,
190 shape_b: &[usize],
191 strides_b: &[usize],
192 modes_b: &[i32],
193 shape_c: &[usize],
194 modes_c: &[i32],
195 ) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)
196 where
197 A::Scalar: BackendScalar<Self>;
198
199}
200
201// CPU supports all Scalar types
202impl<T: Scalar> BackendScalar<crate::backend::Cpu> for T {}
203
204// CUDA supports f32, f64, and CudaComplex types
205#[cfg(feature = "cuda")]
206impl BackendScalar<crate::backend::Cuda> for f32 {}
207#[cfg(feature = "cuda")]
208impl BackendScalar<crate::backend::Cuda> for f64 {}
209#[cfg(feature = "cuda")]
210impl BackendScalar<crate::backend::Cuda> for crate::backend::CudaComplex<f32> {}
211#[cfg(feature = "cuda")]
212impl BackendScalar<crate::backend::Cuda> for crate::backend::CudaComplex<f64> {}