Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Introduction

omeinsum-rs is a Rust library for efficient tensor network contractions supporting both standard and tropical (semiring) algebras. It provides a unified interface for einsum operations with automatic contraction order optimization.

What is Einsum?

Einstein summation (einsum) is a compact notation for expressing tensor operations. Instead of writing explicit loops, you specify index labels:

C[i,k] = Σ_j A[i,j] × B[j,k]    # Matrix multiplication

In einsum notation: ij,jk->ik

What are Tropical Algebras?

Tropical algebras replace standard arithmetic with alternative operations:

AlgebraAddition (⊕)Multiplication (⊗)Use Case
Standard+×Normal arithmetic
MaxPlusmax+Longest path, Viterbi
MinPlusmin+Shortest path
MaxMulmax×Max probability

Key Features

  • Multiple Algebras: Standard arithmetic, MaxPlus, MinPlus, MaxMul
  • Contraction Optimization: Uses omeco for optimal contraction order
  • Backpropagation Support: Argmax tracking for tropical gradient computation
  • Flexible Tensors: Stride-based views with zero-copy permute/reshape
  • Backend Abstraction: CPU now, GPU planned

Example

#![allow(unused)]
fn main() {
use omeinsum::{einsum, Tensor, Cpu};
use omeinsum::algebra::MaxPlus;

// Create tensors
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Tropical matrix multiplication: C[i,k] = max_j (A[i,j] + B[j,k])
let c = einsum::<MaxPlus<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
}

Relationship to OMEinsum.jl

This library is inspired by OMEinsum.jl, bringing its powerful tensor contraction capabilities to Rust with support for tropical algebras from tropical-gemm.

Getting Started

This chapter covers installation and basic usage of omeinsum-rs.

Prerequisites

  • Rust 1.70 or later
  • Cargo package manager

Quick Example

#![allow(unused)]
fn main() {
use omeinsum::{einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;

// Matrix multiplication: C = A × B
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

let c = einsum::<Standard<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],  // A[i,j], B[j,k]
    &[0, 2],               // -> C[i,k]
);

assert_eq!(c.to_vec(), vec![7.0, 10.0, 15.0, 22.0]);
}

Continue to Installation for detailed setup instructions.

Installation

From crates.io

Add to your Cargo.toml:

[dependencies]
omeinsum = "0.1"

From Git

For the latest development version:

[dependencies]
omeinsum = { git = "https://github.com/TensorBFS/omeinsum-rs" }

Features

FeatureDefaultDescription
parallelYesEnable parallel execution with rayon
tropical-kernelsYesUse optimized tropical-gemm kernels
cudaNoEnable CUDA GPU support

Minimal Build

For a minimal build without optional dependencies:

[dependencies]
omeinsum = { version = "0.1", default-features = false }

With CUDA

[dependencies]
omeinsum = { version = "0.1", features = ["cuda"] }

Verification

Verify the installation:

use omeinsum::{Tensor, Cpu};

fn main() {
    let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0], &[3]);
    println!("omeinsum installed successfully!");
    println!("Tensor shape: {:?}", t.shape());
}

Quick Start

Basic Tensor Operations

Creating Tensors

#![allow(unused)]
fn main() {
use omeinsum::{Tensor, Cpu};

// From data with shape
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Check properties
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a.ndim(), 2);
assert_eq!(a.numel(), 6);
}

Tensor Views

#![allow(unused)]
fn main() {
// Zero-copy transpose
let a_t = a.permute(&[1, 0]);
assert_eq!(a_t.shape(), &[3, 2]);

// Reshape (zero-copy when contiguous)
let a_flat = a.reshape(&[6]);
assert_eq!(a_flat.shape(), &[6]);
}

Einsum Operations

Matrix Multiplication

#![allow(unused)]
fn main() {
use omeinsum::{einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// C[i,k] = Σ_j A[i,j] × B[j,k]
let c = einsum::<Standard<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],
    &[0, 2],
);
}

Tropical Operations

#![allow(unused)]
fn main() {
use omeinsum::algebra::MaxPlus;

// C[i,k] = max_j (A[i,j] + B[j,k])
let c = einsum::<MaxPlus<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],
    &[0, 2],
);
}

Using the Einsum Builder

For more control over contraction:

#![allow(unused)]
fn main() {
use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;
use std::collections::HashMap;

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();

let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2]],  // ij, jk
    vec![0, 2],                     // -> ik
    sizes,
);

// Optimize contraction order
ein.optimize_greedy();

// Execute
let c = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b]);
}

Next Steps

Algebra Types

omeinsum-rs supports multiple algebraic structures (semirings) for tensor operations.

Semiring Abstraction

A semiring has two operations:

  • Addition (⊕): Associative, commutative, with identity (zero)
  • Multiplication (⊗): Associative, with identity (one)

The Algebra trait extends Semiring with backpropagation support.

Available Algebras

Standard Arithmetic

#![allow(unused)]
fn main() {
use omeinsum::algebra::Standard;

// Standard: ⊕ = +, ⊗ = ×
// C[i,j] = Σ_k A[i,k] × B[k,j]
let c = einsum::<Standard<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
}

MaxPlus (Tropical)

#![allow(unused)]
fn main() {
use omeinsum::algebra::MaxPlus;

// MaxPlus: ⊕ = max, ⊗ = +
// C[i,j] = max_k (A[i,k] + B[k,j])
// Use case: Longest path, Viterbi algorithm
let c = einsum::<MaxPlus<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
}

MinPlus

#![allow(unused)]
fn main() {
use omeinsum::algebra::MinPlus;

// MinPlus: ⊕ = min, ⊗ = +
// C[i,j] = min_k (A[i,k] + B[k,j])
// Use case: Shortest path (Dijkstra, Floyd-Warshall)
let c = einsum::<MinPlus<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
}

MaxMul

#![allow(unused)]
fn main() {
use omeinsum::algebra::MaxMul;

// MaxMul: ⊕ = max, ⊗ = ×
// C[i,j] = max_k (A[i,k] × B[k,j])
// Use case: Maximum probability paths
let c = einsum::<MaxMul<f32>, _, _>(&[&a, &b], &[&[0, 1], &[1, 2]], &[0, 2]);
}

Summary Table

AlgebraZeroOneUse Case
Standard<T>+×01Normal arithmetic
MaxPlus<T>max+-∞0Longest path
MinPlus<T>min++∞0Shortest path
MaxMul<T>max×01Max probability

Implementing Custom Algebras

You can implement the Semiring and Algebra traits for custom algebras:

#![allow(unused)]
fn main() {
use omeinsum::algebra::{Semiring, Algebra, Scalar};

#[derive(Copy, Clone)]
pub struct MyAlgebra<T>(T);

impl<T: Scalar> Semiring for MyAlgebra<T> {
    type Scalar = T;

    fn zero() -> Self { /* ... */ }
    fn one() -> Self { /* ... */ }
    fn add(self, rhs: Self) -> Self { /* ... */ }
    fn mul(self, rhs: Self) -> Self { /* ... */ }
    fn from_scalar(s: T) -> Self { /* ... */ }
    fn to_scalar(self) -> T { /* ... */ }
    fn is_zero(&self) -> bool { /* ... */ }
}
}

Tensor API

The Tensor<T, B> type provides a flexible, stride-based tensor implementation.

Creating Tensors

From Data

#![allow(unused)]
fn main() {
use omeinsum::{Tensor, Cpu};

// Create from slice with shape
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Row-major layout:
// [[1, 2, 3],
//  [4, 5, 6]]
}

Zeros and Ones

#![allow(unused)]
fn main() {
let zeros = Tensor::<f32, Cpu>::zeros(&[3, 4]);
let ones = Tensor::<f32, Cpu>::ones(&[3, 4]);
}

Properties

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

t.shape()    // &[2, 3]
t.strides()  // &[3, 1] for row-major
t.ndim()     // 2
t.numel()    // 6
}

Views and Transformations

Permute (Transpose)

Zero-copy axis reordering:

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Transpose: swap axes 0 and 1
let t_t = t.permute(&[1, 0]);
assert_eq!(t_t.shape(), &[3, 2]);

// 3D example: (batch, height, width) -> (batch, width, height)
let img = Tensor::<f32, Cpu>::zeros(&[10, 28, 28]);
let img_t = img.permute(&[0, 2, 1]);
assert_eq!(img_t.shape(), &[10, 28, 28]);
}

Reshape

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Flatten
let flat = t.reshape(&[6]);

// Reshape to different dimensions
let reshaped = t.reshape(&[3, 2]);
}

Contiguous

Convert non-contiguous views to contiguous storage:

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let t_t = t.permute(&[1, 0]);  // Non-contiguous after transpose

let t_contig = t_t.contiguous();  // Copy to contiguous memory
assert!(t_contig.is_contiguous());
}

Matrix Operations

GEMM (General Matrix Multiplication)

#![allow(unused)]
fn main() {
use omeinsum::algebra::{Standard, MaxPlus};

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Standard matrix multiplication
let c = a.gemm::<Standard<f32>>(&b);

// Tropical matrix multiplication
let c_trop = a.gemm::<MaxPlus<f32>>(&b);
}

Binary Contraction

General tensor contraction:

#![allow(unused)]
fn main() {
// A[i,j,k] × B[j,k,l] -> C[i,l]
let a = Tensor::<f32, Cpu>::zeros(&[2, 3, 4]);
let b = Tensor::<f32, Cpu>::zeros(&[3, 4, 5]);

let c = a.contract_binary::<Standard<f32>>(
    &b,
    &[0, 1, 2],  // A's indices: i, j, k
    &[1, 2, 3],  // B's indices: j, k, l
    &[0, 3],     // Output: i, l
);

assert_eq!(c.shape(), &[2, 5]);
}

Data Access

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Convert to Vec
let data = t.to_vec();  // [1.0, 2.0, 3.0, 4.0]
}

Einsum API

The einsum API provides a high-level interface for tensor network contractions.

Quick Einsum

The simplest way to perform einsum:

#![allow(unused)]
fn main() {
use omeinsum::{einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Matrix multiplication: ij,jk->ik
let c = einsum::<Standard<f32>, _, _>(
    &[&a, &b],           // Input tensors
    &[&[0, 1], &[1, 2]], // Index labels: A[0,1], B[1,2]
    &[0, 2],             // Output labels: C[0,2]
);
}

Index Labels

Indices are represented as usize values. Matching indices indicate contraction:

OperationInputsLabelsOutput
Matrix multiplyA[m,k], B[k,n][[0,1], [1,2]][0,2] → C[m,n]
Batch matmulA[b,m,k], B[b,k,n][[0,1,2], [0,2,3]][0,1,3] → C[b,m,n]
Outer productA[m], B[n][[0], [1]][0,1] → C[m,n]
TraceA[n,n][[0,0]][] → scalar
SumA[m,n][[0,1]][] → scalar

Einsum Struct

For more control, use the Einsum struct directly:

#![allow(unused)]
fn main() {
use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;
use std::collections::HashMap;

// Define size dictionary
let sizes: HashMap<usize, usize> = [
    (0, 10),  // i: 10
    (1, 20),  // j: 20
    (2, 30),  // k: 30
].into();

// Create einsum specification
let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2]],  // A[i,j], B[j,k]
    vec![0, 2],                     // -> C[i,k]
    sizes,
);

// Check the einsum code
let code = ein.code();
println!("Einsum: {:?}", code);
}

Contraction Optimization

Greedy Algorithm

Fast O(n²) algorithm, good for most cases:

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(/* ... */);
ein.optimize_greedy();

assert!(ein.is_optimized());
}

Simulated Annealing

Slower but finds better orderings for complex networks:

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(/* ... */);
ein.optimize_treesa();
}

Inspect Contraction Tree

#![allow(unused)]
fn main() {
if let Some(tree) = ein.contraction_tree() {
    println!("Contraction tree: {:?}", tree);
}
}

Chain Contraction Example

Contracting a chain of matrices:

#![allow(unused)]
fn main() {
use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;
use std::collections::HashMap;

// A[i,j] × B[j,k] × C[k,l] → D[i,l]
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

let sizes: HashMap<usize, usize> = [
    (0, 2), (1, 2), (2, 2), (3, 2)
].into();

let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2], vec![2, 3]],
    vec![0, 3],
    sizes,
);

ein.optimize_greedy();
let d = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b, &c]);

assert_eq!(d.shape(), &[2, 2]);
}

Einsum with Gradients

For backpropagation support:

#![allow(unused)]
fn main() {
use omeinsum::einsum_with_grad;
use omeinsum::algebra::MaxPlus;

let (result, gradient) = einsum_with_grad::<MaxPlus<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],
    &[0, 2],
);

// gradient can be used for backpropagation
// (full backward pass implementation in progress)
}

Contraction Optimization

Finding the optimal contraction order is critical for tensor network performance.

The Problem

Consider contracting tensors A, B, C, D with different contraction orders:

((A × B) × C) × D  vs  (A × B) × (C × D)  vs  A × ((B × C) × D)

Different orders have vastly different computational costs. For large networks, the difference can be exponential.

Optimization Algorithms

omeinsum-rs uses omeco for contraction order optimization.

Greedy Method

The greedy algorithm iteratively contracts the pair with minimum cost:

#![allow(unused)]
fn main() {
use omeinsum::Einsum;

let mut ein = Einsum::new(ixs, iy, sizes);
ein.optimize_greedy();
}
  • Complexity: O(n²) where n is number of tensors
  • Quality: Good for most practical cases
  • Speed: Fast

Tree Simulated Annealing (TreeSA)

TreeSA uses simulated annealing to search for better contraction trees:

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(ixs, iy, sizes);
ein.optimize_treesa();
}
  • Complexity: O(iterations × n)
  • Quality: Often finds optimal or near-optimal solutions
  • Speed: Slower, but worthwhile for large networks

When to Optimize

Network SizeRecommendation
2-3 tensorsNo optimization needed
4-10 tensorsGreedy is usually sufficient
10+ tensorsConsider TreeSA
Performance-criticalAlways optimize, benchmark both

Inspecting Results

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(ixs, iy, sizes);
ein.optimize_greedy();

// Check if optimized
if ein.is_optimized() {
    // Get the contraction tree
    if let Some(tree) = ein.contraction_tree() {
        println!("Optimized tree: {:?}", tree);
    }
}
}

Cost Model

The optimization minimizes total FLOP count, considering:

  • Tensor dimensions from size dictionary
  • Intermediate tensor sizes
  • Number of operations per contraction

No Optimization

For simple cases, you can skip optimization:

#![allow(unused)]
fn main() {
// Without optimization: contracts left-to-right
let ein = Einsum::new(ixs, iy, sizes);
let result = ein.execute::<Standard<f32>, f32, Cpu>(&tensors);
}

This uses simple pairwise contraction from left to right, which may be suboptimal for complex networks.

Further Reading

Showcase Examples

This chapter demonstrates practical applications of einsum with gradients across three different algebras: real numbers, complex numbers, and tropical numbers.

Each example shows a real-world use case where differentiation through tensor networks provides meaningful results.

ExampleAlgebraApplicationGradient Meaning
Bayesian NetworkStandard<f64>Probabilistic inferenceMarginal probability
Tensor TrainStandard<Complex64>Quantum simulationEnergy optimization direction
MPS Ground StateStandard<f64>Quantum many-bodyVariational optimization
Independent SetMaxPlus<f64>Combinatorial optimizationOptimal vertex selection

Bayesian Network Marginals

Key insight: Differentiation = Marginalization

Problem

Given a chain-structured Bayesian network with 3 binary variables X₀ - X₁ - X₂, compute:

  1. The partition function Z (sum over all configurations)
  2. The marginal probability P(X₁ = 1)

Mathematical Setup

Vertex potentials (unnormalized probabilities):

φ₀ = [1, 2]   → P(X₀=1) ∝ 2
φ₁ = [1, 3]   → P(X₁=1) ∝ 3
φ₂ = [1, 1]   → uniform

Edge potentials (encourage agreement):

ψ = [[2, 1],
     [1, 2]]

Partition function as einsum:

Z = Σ_{x₀,x₁,x₂} φ₀(x₀) × ψ₀₁(x₀,x₁) × φ₁(x₁) × ψ₁₂(x₁,x₂) × φ₂(x₂)
  = einsum("i,ij,j,jk,k->", φ₀, ψ₀₁, φ₁, ψ₁₂, φ₂)

The Gradient-Marginal Connection

The beautiful insight from probabilistic graphical models:

∂Z/∂θᵥ = Σ_{configurations where xᵥ=1} (product of all other factors)
       = Z × P(xᵥ = 1)

Therefore:

P(xᵥ = 1) = (1/Z) × ∂Z/∂θᵥ = ∂log(Z)/∂θᵥ

Differentiation through the tensor network gives marginal probabilities!

Manual Verification

All 8 configurations:

X₀X₁X₂φ₀ψ₀₁φ₁ψ₁₂φ₂Product
000121214
001121112
010113113
011113216
100211214
101211112
1102231112
1112232124

Results:

  • Z = 57
  • P(X₁=1) = (3+6+12+24)/57 = 45/57 ≈ 0.789

Code

#![allow(unused)]
fn main() {
use omeinsum::{einsum, einsum_with_grad, Standard, Tensor, Cpu};

// Vertex potentials
let phi0 = Tensor::<f64, Cpu>::from_data(&[1.0, 2.0], &[2]);
let phi1 = Tensor::<f64, Cpu>::from_data(&[1.0, 3.0], &[2]);
let phi2 = Tensor::<f64, Cpu>::from_data(&[1.0, 1.0], &[2]);

// Edge potentials (column-major)
let psi = Tensor::<f64, Cpu>::from_data(&[2.0, 1.0, 1.0, 2.0], &[2, 2]);

// Contract step by step
let t1 = einsum::<Standard<f64>, _, _>(&[&phi0, &psi], &[&[0], &[0, 1]], &[1]);
// ... continue contracting to compute Z

// Use einsum_with_grad to get gradients
let (result, grad_fn) = einsum_with_grad::<Standard<f64>, _, _>(
    &[&phi0, &psi], &[&[0], &[0, 1]], &[1]
);
let grads = grad_fn.backward::<Standard<f64>>(&grad_output, &[&phi0, &psi]);
}

Tensor Train (Quantum States)

Key insight: Gradients enable variational optimization of quantum states

Problem

Represent a quantum state using a Matrix Product State (MPS) and compute contractions with complex numbers. Gradients ∂E/∂A give the optimization direction for finding ground states.

Mathematical Setup

An MPS represents a quantum state as:

|ψ⟩ = Σ_{s₁,s₂,...} A¹[s₁] · A²[s₂] · ... |s₁s₂...⟩

Where each Aⁱ[sᵢ] is a complex matrix.

Example: Two-Site Contraction

A1 = [[1+i,  0  ],      A2 = [[2,  i ],
      [0,   1-i]]            [-i, 3 ]]

Contraction: result[s1,s2] = Σ_b A1[s1,b] × A2[b,s2]

Manual calculation:

result[0,0] = (1+i)×2 + 0×(-i) = 2+2i
result[0,1] = (1+i)×i + 0×3 = -1+i
result[1,0] = 0×2 + (1-i)×(-i) = -1-i
result[1,1] = 0×i + (1-i)×3 = 3-3i

Norm: ⟨ψ|ψ⟩ = |2+2i|² + |-1+i|² + |-1-i|² + |3-3i|² = 8+2+2+18 = 30

Code

#![allow(unused)]
fn main() {
use num_complex::Complex64 as C64;
use omeinsum::{einsum, einsum_with_grad, Standard, Tensor, Cpu};

let a1 = Tensor::<C64, Cpu>::from_data(&[
    C64::new(1.0, 1.0),   // 1+i
    C64::new(0.0, 0.0),   // 0
    C64::new(0.0, 0.0),   // 0
    C64::new(1.0, -1.0),  // 1-i
], &[2, 2]);

let a2 = Tensor::<C64, Cpu>::from_data(&[
    C64::new(2.0, 0.0),   // 2
    C64::new(0.0, -1.0),  // -i
    C64::new(0.0, 1.0),   // i
    C64::new(3.0, 0.0),   // 3
], &[2, 2]);

// Contract: result[s1,s2] = Σ_b A1[s1,b] × A2[b,s2]
let result = einsum::<Standard<C64>, _, _>(
    &[&a1, &a2],
    &[&[0, 1], &[1, 2]],  // contract over index 1
    &[0, 2]
);

// Compute gradients for optimization
let (result, grad_fn) = einsum_with_grad::<Standard<C64>, _, _>(
    &[&a1, &a2], &[&[0, 1], &[1, 2]], &[0, 2]
);
}

Application: Variational Ground State

For a Heisenberg spin chain, the energy expectation ⟨ψ|H|ψ⟩ can be computed via tensor network contraction. The gradient ∂E/∂Aⁱ tells us how to update each tensor to lower the energy, converging to the ground state.


MPS Heisenberg Ground State

Key insight: Autodiff gradients enable variational optimization of quantum many-body states

Problem

Find the ground state of a 5-site Heisenberg spin chain using the Matrix Product State (MPS) variational ansatz. Compare with exact diagonalization to verify convergence.

The Heisenberg Hamiltonian

The antiferromagnetic Heisenberg model:

H = Σᵢ (Sˣᵢ Sˣᵢ₊₁ + Sʸᵢ Sʸᵢ₊₁ + Sᶻᵢ Sᶻᵢ₊₁)

For 5 sites with open boundary conditions, the exact ground state energy is E₀ ≈ -1.928.

MPS Ansatz

The MPS represents the quantum state as:

|ψ⟩ = Σ_{s₁...s₅} A¹[s₁] A²[s₂] A³[s₃] A⁴[s₄] A⁵[s₅] |s₁...s₅⟩

Where each Aⁱ is a tensor with:

  • Physical index sᵢ ∈ {0, 1} (spin up/down)
  • Bond indices with dimension χ (bond dimension)

With bond dimension χ=4, the MPS can accurately represent the ground state.

Variational Optimization

The energy functional:

E[A] = ⟨ψ|H|ψ⟩ / ⟨ψ|ψ⟩

Gradient descent: Update each tensor A using ∂E/∂A computed via einsum autodiff.

Results

MethodEnergyRelative Error
Exact diagonalization-1.9279
MPS (χ=4, 80 iterations)-1.92050.38%

The MPS optimization converges to within 0.4% of the exact ground state energy, demonstrating that einsum autodiff correctly computes gradients for quantum many-body optimization.

Code Outline

#![allow(unused)]
fn main() {
use omeinsum::{einsum, einsum_with_grad, Standard, Tensor, Cpu};

// Initialize MPS tensors with bond dimension χ=4
let chi = 4;
let mut a1 = init_tensor(1, 2, chi);   // [1, 2, χ]
let mut a2 = init_tensor(chi, 2, chi); // [χ, 2, χ]
// ... a3, a4, a5

// Contract MPS to get state vector |ψ⟩
fn contract_mps(a1, a2, a3, a4, a5) -> Vec<f64> {
    // Contract bond indices: A1·A2·A3·A4·A5
    // Returns 2^5 = 32 dimensional state vector
}

// Compute energy E = ⟨ψ|H|ψ⟩ / ⟨ψ|ψ⟩
fn compute_energy(tensors, hamiltonian) -> f64;

// Gradient descent loop
for iter in 0..80 {
    // Compute gradients via finite differences or autodiff
    let grads = compute_gradients(&tensors, &hamiltonian);

    // Update: A ← A - lr * ∂E/∂A
    for (a, g) in tensors.iter_mut().zip(grads.iter()) {
        *a -= learning_rate * g;
    }

    // Normalize to prevent blow-up
    normalize_mps(&mut tensors);
}
}

Einsum Gradient Verification

The gradients for MPS tensor contractions can be verified against finite differences:

#![allow(unused)]
fn main() {
// Contract two adjacent MPS tensors
let (result, grad_fn) = einsum_with_grad::<Standard<f64>, _, _>(
    &[&a2, &a3],
    &[&[0, 1, 2], &[2, 3, 4]],  // b1,s2,b2 × b2,s3,b3
    &[0, 1, 3, 4],               // → b1,s2,s3,b3
);

let grads = grad_fn.backward::<Standard<f64>>(&grad_output, &[&a2, &a3]);

// Verify: autodiff gradient matches finite difference with error < 1e-10
}

Maximum Weight Independent Set

Key insight: Tropical gradients give optimal vertex selection

Problem

Find the maximum weight independent set on a pentagon graph. An independent set contains no adjacent vertices.

Graph

      0 (w=3)
     / \
    4   1
   (2) (5)
    |   |
    3---2
   (4) (1)

Edges: (0,1), (1,2), (2,3), (3,4), (4,0)

Tropical Tensor Network

Vertex tensor for vertex v with weight wᵥ:

W[s] = [0, wᵥ]  where s ∈ {0, 1}
  • s=0: vertex not selected, contributes 0 (tropical multiplicative identity)
  • s=1: vertex selected, contributes wᵥ

Edge tensor enforcing independence constraint:

B[sᵤ, sᵥ] = [[0,   0 ],
             [0,  -∞ ]]
  • B[1,1] = -∞ forbids selecting both endpoints (tropical zero)

Tropical contraction (MaxPlus: ⊕=max, ⊗=+):

result = max over all valid configurations of Σ(selected weights)

Gradient = Selection Mask

From tropical autodiff theory:

∂(max_weight)/∂(wᵥ) = 1 if vertex v is in optimal set
                    = 0 otherwise

The tropical gradient directly reveals the optimal selection!

Manual Verification

All independent sets of the pentagon:

SetWeight
{0}3
{1}5
{2}1
{3}4
{4}2
{0,2}4
{0,3}7
{1,3}9 ← maximum
{1,4}7
{2,4}3

Optimal: {1, 3} with weight 9

Code

#![allow(unused)]
fn main() {
use omeinsum::{einsum, MaxPlus, Tensor, Cpu};

// Vertex tensors: W[s] = [0, weight]
let w0 = Tensor::<f64, Cpu>::from_data(&[0.0, 3.0], &[2]);
let w1 = Tensor::<f64, Cpu>::from_data(&[0.0, 5.0], &[2]);

// Edge constraint: B[1,1] = -∞ forbids both selected
let neg_inf = f64::NEG_INFINITY;
let edge = Tensor::<f64, Cpu>::from_data(&[0.0, 0.0, 0.0, neg_inf], &[2, 2]);

// Contract two vertices with edge constraint
// max_{s0,s1} (W0[s0] + B[s0,s1] + W1[s1])
let t0e = einsum::<MaxPlus<f64>, _, _>(&[&w0, &edge], &[&[0], &[0, 1]], &[1]);
let result = einsum::<MaxPlus<f64>, _, _>(&[&t0e, &w1], &[&[0], &[0]], &[]);

// Result: 5.0 (select vertex 1 only, since selecting both gives -∞)
}

References


Summary

AlgebraOperationGradient Meaning
Standard (real)Σ (sum), × (multiply)Sensitivity / marginal probability
Standard (complex)Σ, × with complex arithmeticOptimization direction
MaxPlus (tropical)max, +Binary selection mask (argmax routing)
MinPlus (tropical)min, +Binary selection mask (argmin routing)

These examples demonstrate that einsum with automatic differentiation is a powerful tool for:

  • Probabilistic inference (belief propagation)
  • Quantum simulation (variational ground state methods)
  • Combinatorial optimization (finding optimal configurations)

Architecture

This chapter describes the internal architecture of omeinsum-rs.

Module Structure

omeinsum/
├── algebra/          # Semiring and algebra traits
│   ├── mod.rs
│   ├── semiring.rs   # Semiring, Algebra, Scalar traits
│   ├── standard.rs   # Standard<T> implementation
│   └── tropical.rs   # MaxPlus, MinPlus, MaxMul
├── backend/          # Execution backends
│   ├── mod.rs
│   ├── traits.rs     # Backend, Storage traits
│   └── cpu.rs        # CPU backend implementation
├── tensor/           # Tensor type
│   ├── mod.rs        # Tensor<T, B> definition
│   ├── view.rs       # View operations (permute, reshape)
│   └── ops.rs        # GEMM, contract_binary
├── einsum/           # Einsum engine
│   ├── mod.rs        # einsum(), einsum_with_grad()
│   ├── engine.rs     # Einsum struct, optimization
│   └── builder.rs    # EinBuilder (planned)
└── lib.rs            # Public API exports

Core Abstractions

Scalar Trait

Base trait for numeric types:

#![allow(unused)]
fn main() {
pub trait Scalar: Copy + Clone + Default + Send + Sync + 'static {
    fn neg_infinity() -> Self;
    fn infinity() -> Self;
}
}

Semiring Trait

Defines the algebraic structure:

#![allow(unused)]
fn main() {
pub trait Semiring: Copy + Clone + Send + Sync + 'static {
    type Scalar: Scalar;
    fn zero() -> Self;
    fn one() -> Self;
    fn add(self, rhs: Self) -> Self;
    fn mul(self, rhs: Self) -> Self;
    fn from_scalar(s: Self::Scalar) -> Self;
    fn to_scalar(self) -> Self::Scalar;
}
}

Algebra Trait

Extends Semiring with backpropagation support:

#![allow(unused)]
fn main() {
pub trait Algebra: Semiring {
    type Index: Copy + Clone + Default + Send + Sync;

    fn add_with_argmax(self, self_idx: Self::Index, rhs: Self, rhs_idx: Self::Index)
        -> (Self, Self::Index);

    fn add_backward(self, rhs: Self, grad_out: Self::Scalar, winner_idx: Option<Self::Index>)
        -> (Self::Scalar, Self::Scalar);

    fn mul_backward(self, rhs: Self, grad_out: Self::Scalar)
        -> (Self::Scalar, Self::Scalar);

    fn needs_argmax() -> bool;
}
}

Backend Trait

Abstracts execution hardware:

#![allow(unused)]
fn main() {
pub trait Backend: Clone + Send + Sync + 'static {
    type Storage<T: Scalar>: Storage<T>;

    fn gemm<A: Algebra>(&self, a: &Self::Storage<A::Scalar>, m: usize, k: usize,
                        b: &Self::Storage<A::Scalar>, n: usize) -> Self::Storage<A::Scalar>;

    fn gemm_with_argmax<A: Algebra<Index = u32>>(&self, ...)
        -> (Self::Storage<A::Scalar>, Self::Storage<u32>);
}
}

Tensor Implementation

Stride-Based Storage

Tensors use stride-based views for efficient transformations:

#![allow(unused)]
fn main() {
pub struct Tensor<T: Scalar, B: Backend> {
    storage: Arc<B::Storage<T>>,  // Shared storage
    shape: Vec<usize>,
    strides: Vec<usize>,
    offset: usize,
    backend: B,
}
}

Zero-Copy Operations

  • Permute: Reorders strides, no data copy
  • Reshape: Updates shape if contiguous, otherwise copies
  • Contiguous: Copies if non-contiguous

Contraction Strategy

Binary contraction uses reshape-to-GEMM:

  1. Classify indices: batch, left-only, right-only, contracted
  2. Permute tensors to [batch, left/right, contracted]
  3. Reshape to 2D matrices
  4. Execute GEMM
  5. Reshape and permute to output

This leverages optimized GEMM implementations for all algebras.

Optimization Integration

Uses omeco for contraction order:

#![allow(unused)]
fn main() {
use omeco::{EinCode, GreedyMethod, TreeSA, optimize_code};

let code = EinCode::new(ixs, iy);
let tree = optimize_code(&code, &size_dict, &GreedyMethod::new(0.0, 0.0));
}

The resulting NestedEinsum tree is executed recursively.

Backpropagation

omeinsum-rs supports gradient computation for both standard and tropical algebras.

Standard Backpropagation

For standard arithmetic, gradients follow the chain rule:

C = A × B
∂L/∂A = ∂L/∂C × B^T
∂L/∂B = A^T × ∂L/∂C

Tropical Backpropagation

Tropical algebras use argmax tracking for gradient routing.

The Challenge

In tropical algebra:

C[i,j] = max_k (A[i,k] + B[k,j])

The gradient only flows through the winning path (the k that achieved the max).

Argmax Tracking

During forward pass, we track which index “won”:

#![allow(unused)]
fn main() {
let (c, argmax) = a.gemm_with_argmax::<MaxPlus<f32>>(&b);
// argmax[i,j] = the k that maximized A[i,k] + B[k,j]
}

Backward Pass

Gradients are routed using the argmax:

#![allow(unused)]
fn main() {
// For each output element [i,j]:
// k* = argmax[i,j]
// ∂L/∂A[i,k*] += ∂L/∂C[i,j]
// ∂L/∂B[k*,j] += ∂L/∂C[i,j]
}

API Usage

With Argmax Tracking

#![allow(unused)]
fn main() {
use omeinsum::algebra::MaxPlus;

// GEMM with argmax
let (c, argmax) = a.gemm_with_argmax::<MaxPlus<f32>>(&b);

// Contract with argmax
let (c, argmax) = a.contract_binary_with_argmax::<MaxPlus<f32>>(
    &b, &[0, 1], &[1, 2], &[0, 2]
);
}

Einsum with Gradients

#![allow(unused)]
fn main() {
use omeinsum::einsum_with_grad;

let (result, gradient) = einsum_with_grad::<MaxPlus<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],
    &[0, 2],
);

// Use gradient.backward() for gradient computation
// (Implementation in progress)
}

Implementation Status

FeatureStatus
Forward passComplete
Argmax trackingComplete
GEMM backwardImplemented
Full einsum backwardIn progress

Tie-Breaking

When multiple indices achieve the same maximum, the implementation uses a deterministic tie-breaking rule (first winning index). This ensures reproducible gradients.

References

  • Zhang et al., “Tropical Geometry of Deep Neural Networks” (2018)
  • tropical-gemm gradient implementation

Performance Guide

Tips for getting the best performance from omeinsum-rs.

Contraction Order

The most important optimization is contraction order:

#![allow(unused)]
fn main() {
// Always optimize for networks with 3+ tensors
let mut ein = Einsum::new(ixs, iy, sizes);
ein.optimize_greedy();  // or optimize_treesa() for large networks
}

Bad contraction order can be exponentially slower.

Memory Layout

Keep Tensors Contiguous

Non-contiguous tensors require copies before GEMM:

#![allow(unused)]
fn main() {
// After permute, tensor may be non-contiguous
let t_permuted = t.permute(&[1, 0]);

// Make contiguous if you'll use it multiple times
let t_contig = t_permuted.contiguous();
}

Avoid Unnecessary Copies

#![allow(unused)]
fn main() {
// Good: zero-copy view
let view = t.permute(&[1, 0]);

// Avoid: unnecessary explicit copy
let bad = t.permute(&[1, 0]).contiguous();  // Only if needed
}

Parallelization

Enable the parallel feature (default):

[dependencies]
omeinsum = "0.1"  # parallel enabled by default

Disable for single-threaded workloads:

[dependencies]
omeinsum = { version = "0.1", default-features = false }

Data Types

Use f32 When Possible

f32 is typically faster than f64 due to:

  • Smaller memory bandwidth
  • Better SIMD utilization
#![allow(unused)]
fn main() {
// Prefer f32
let t = Tensor::<f32, Cpu>::from_data(&data, &shape);

// Use f64 only when precision is critical
let t = Tensor::<f64, Cpu>::from_data(&data, &shape);
}

Benchmarking

Use release mode for benchmarks:

cargo run --release --example basic_einsum

Profile with:

cargo build --release
perf record ./target/release/examples/basic_einsum
perf report

Common Pitfalls

1. Forgetting to Optimize

#![allow(unused)]
fn main() {
// Bad: no optimization
let ein = Einsum::new(ixs, iy, sizes);
let result = ein.execute::<A, T, B>(&tensors);

// Good: with optimization
let mut ein = Einsum::new(ixs, iy, sizes);
ein.optimize_greedy();
let result = ein.execute::<A, T, B>(&tensors);
}

2. Redundant Contiguous Calls

#![allow(unused)]
fn main() {
// Bad: unnecessary copy
let c = a.contiguous().gemm::<Standard<f32>>(&b.contiguous());

// Good: gemm handles this internally
let c = a.gemm::<Standard<f32>>(&b);
}

3. Debug Mode

Debug builds are ~10-50x slower:

# Bad: debug mode
cargo run --example benchmark

# Good: release mode
cargo run --release --example benchmark

Future Optimizations

Planned performance improvements:

  • CUDA backend for GPU acceleration
  • Optimized tropical-gemm kernel integration
  • Batched GEMM support
  • Cache-aware blocking

API Reference

Full API documentation is available at:

https://tensorbfs.github.io/omeinsum-rs/api/omeinsum/

Quick Reference

Main Types

TypeDescription
Tensor<T, B>N-dimensional tensor with backend B
Einsum<L>Einsum specification and executor
CpuCPU backend

Algebra Types

TypeAdditionMultiplication
Standard<T>+×
MaxPlus<T>max+
MinPlus<T>min+
MaxMul<T>max×

Key Functions

#![allow(unused)]
fn main() {
// Quick einsum
fn einsum<A, T, B>(tensors: &[&Tensor<T, B>], ixs: &[&[usize]], iy: &[usize]) -> Tensor<T, B>

// Einsum with gradient support
fn einsum_with_grad<A, T, B>(...) -> (Tensor<T, B>, EinsumGradient<T, B>)
}

Tensor Methods

#![allow(unused)]
fn main() {
impl<T, B> Tensor<T, B> {
    // Creation
    fn from_data(data: &[T], shape: &[usize]) -> Self
    fn zeros(shape: &[usize]) -> Self
    fn ones(shape: &[usize]) -> Self

    // Properties
    fn shape(&self) -> &[usize]
    fn strides(&self) -> &[usize]
    fn ndim(&self) -> usize
    fn numel(&self) -> usize
    fn is_contiguous(&self) -> bool

    // Transformations
    fn permute(&self, order: &[usize]) -> Self
    fn reshape(&self, new_shape: &[usize]) -> Self
    fn contiguous(&self) -> Self

    // Operations
    fn gemm<A: Algebra>(&self, other: &Self) -> Self
    fn contract_binary<A>(&self, other: &Self, ia: &[usize], ib: &[usize], iy: &[usize]) -> Self

    // Data
    fn to_vec(&self) -> Vec<T>
}
}

Einsum Methods

#![allow(unused)]
fn main() {
impl<L> Einsum<L> {
    fn new(ixs: Vec<Vec<L>>, iy: Vec<L>, size_dict: HashMap<L, usize>) -> Self
    fn code(&self) -> EinCode<L>
    fn optimize_greedy(&mut self) -> &mut Self
    fn optimize_treesa(&mut self) -> &mut Self
    fn is_optimized(&self) -> bool
    fn contraction_tree(&self) -> Option<&NestedEinsum<L>>
}

impl Einsum<usize> {
    fn execute<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
    fn execute_with_argmax<A, T, B>(&self, tensors: &[&Tensor<T, B>])
        -> (Tensor<T, B>, Vec<Tensor<u32, B>>)
}
}

Building Documentation Locally

make docs-build   # Rust API docs
make docs-serve   # Serve at localhost:8000

Changelog

All notable changes to this project will be documented in this file.

The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.

[Unreleased]

Added

  • Initial release
  • Tensor<T, B> type with stride-based views
  • Algebra traits: Semiring, Algebra, Scalar
  • Algebra implementations: Standard, MaxPlus, MinPlus, MaxMul
  • Backend trait with Cpu implementation
  • Einsum struct with omeco optimization integration
  • einsum() and einsum_with_grad() functions
  • Greedy and TreeSA contraction order optimization
  • Argmax tracking for tropical backpropagation
  • mdBook documentation
  • CI/CD with GitHub Actions

Changed

  • N/A

Deprecated

  • N/A

Removed

  • N/A

Fixed

  • N/A

Security

  • N/A

[0.1.0] - TBD

Initial public release.