pub trait Backend:
Clone
+ Send
+ Sync
+ 'static {
type Storage<T: Scalar>: Storage<T>;
// Required methods
fn name() -> &'static str;
fn synchronize(&self);
fn alloc<T: Scalar>(&self, len: usize) -> Self::Storage<T>;
fn from_slice<T: Scalar>(&self, data: &[T]) -> Self::Storage<T>;
fn copy_strided<T: Scalar>(
&self,
src: &Self::Storage<T>,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Self::Storage<T>;
fn contract<A: Algebra>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> Self::Storage<A::Scalar>
where A::Scalar: BackendScalar<Self>;
fn contract_with_argmax<A: Algebra<Index = u32>>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)
where A::Scalar: BackendScalar<Self>;
}Expand description
Backend trait for tensor execution.
Defines how tensor operations are executed on different hardware.
Required Associated Types§
Required Methods§
Sourcefn synchronize(&self)
fn synchronize(&self)
Synchronize all pending operations.
Sourcefn from_slice<T: Scalar>(&self, data: &[T]) -> Self::Storage<T>
fn from_slice<T: Scalar>(&self, data: &[T]) -> Self::Storage<T>
Create storage from slice.
Sourcefn copy_strided<T: Scalar>(
&self,
src: &Self::Storage<T>,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Self::Storage<T>
fn copy_strided<T: Scalar>( &self, src: &Self::Storage<T>, shape: &[usize], strides: &[usize], offset: usize, ) -> Self::Storage<T>
Copy strided data to contiguous storage.
This is the core operation for making non-contiguous tensors contiguous.
Sourcefn contract<A: Algebra>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> Self::Storage<A::Scalar>where
A::Scalar: BackendScalar<Self>,
fn contract<A: Algebra>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> Self::Storage<A::Scalar>where
A::Scalar: BackendScalar<Self>,
Binary tensor contraction.
Computes a generalized tensor contraction: C[modes_c] = Σ A[modes_a] ⊗ B[modes_b]
where the sum (using semiring addition) is over indices appearing in both A and B
but not in the output C.
§Mode Labels
Each mode (dimension) of the input tensors is labeled with a unique integer identifier. These labels determine how the contraction is performed:
- Contracted indices: Labels appearing in both
modes_aandmodes_bbut NOT inmodes_c. These dimensions are summed over (reduced). - Free indices from A: Labels appearing only in
modes_a. These appear in the output. - Free indices from B: Labels appearing only in
modes_b. These appear in the output. - Batch indices: Labels appearing in
modes_a,modes_b, ANDmodes_c. These dimensions are preserved and processed in parallel.
§Arguments
a- Storage for first input tensorshape_a- Shape (dimensions) of tensor Astrides_a- Strides for tensor A (column-major, supports non-contiguous tensors)modes_a- Mode labels for tensor A (length must equalshape_a.len())b- Storage for second input tensorshape_b- Shape of tensor Bstrides_b- Strides for tensor Bmodes_b- Mode labels for tensor B (length must equalshape_b.len())shape_c- Shape of output tensor C (must be consistent withmodes_c)modes_c- Mode labels for output tensor C (determines output structure)
§Returns
Contiguous storage containing the result tensor with shape shape_c.
§Examples
§Matrix multiplication: C[i,k] = Σⱼ A[i,j] ⊗ B[j,k]
// A is 2×3, B is 3×4 -> C is 2×4
let c = backend.contract::<Standard<f32>>(
&a, &[2, 3], &[1, 2], &[0, 1], // A[i=0, j=1], shape 2×3
&b, &[3, 4], &[1, 3], &[1, 2], // B[j=1, k=2], shape 3×4
&[2, 4], &[0, 2], // C[i=0, k=2], shape 2×4
);§Batched matrix multiplication: C[b,i,k] = Σⱼ A[b,i,j] ⊗ B[b,j,k]
// Batch size 8, A is 2×3, B is 3×4 -> C is 8×2×4
let c = backend.contract::<Standard<f32>>(
&a, &[8, 2, 3], &[1, 8, 16], &[0, 1, 2], // A[b=0, i=1, j=2]
&b, &[8, 3, 4], &[1, 8, 24], &[0, 2, 3], // B[b=0, j=2, k=3]
&[8, 2, 4], &[0, 1, 3], // C[b=0, i=1, k=3]
);§Tropical shortest path (with min-plus semiring)
// Find shortest paths via matrix multiplication in (min,+) semiring
let distances = backend.contract::<MinPlus<f32>>(
&graph_a, &[n, n], &[1, n], &[0, 1],
&graph_b, &[n, n], &[1, n], &[1, 2],
&[n, n], &[0, 2],
);§Panics
Panics if:
- Mode labels have inconsistent sizes across tensors (e.g., if mode 1 has size 3 in A but size 4 in B)
- The scalar type is not supported by the backend (compile-time check via
BackendScalar)
Sourcefn contract_with_argmax<A: Algebra<Index = u32>>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)where
A::Scalar: BackendScalar<Self>,
fn contract_with_argmax<A: Algebra<Index = u32>>(
&self,
a: &Self::Storage<A::Scalar>,
shape_a: &[usize],
strides_a: &[usize],
modes_a: &[i32],
b: &Self::Storage<A::Scalar>,
shape_b: &[usize],
strides_b: &[usize],
modes_b: &[i32],
shape_c: &[usize],
modes_c: &[i32],
) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)where
A::Scalar: BackendScalar<Self>,
Contraction with argmax tracking for tropical backpropagation.
This is identical to Backend::contract but additionally returns an argmax
tensor that tracks which contracted index “won” the reduction at each output
position. This is essential for tropical algebra backward passes where gradients
are routed through the winning path only.
§Returns
A tuple of:
result: The contraction result (same ascontract)argmax: Tensor ofu32indices indicating which contracted index won at each output position
§Use Cases
- Tropical backpropagation (Viterbi, shortest path)
- Computing attention patterns in max-pooling operations
- Any semiring where addition is idempotent and gradient routing matters
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.