Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Einsum API

The einsum API provides a high-level interface for tensor network contractions.

Quick Einsum

The simplest way to perform einsum:

#![allow(unused)]
fn main() {
use omeinsum::{einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Matrix multiplication: ij,jk->ik
let c = einsum::<Standard<f32>, _, _>(
    &[&a, &b],           // Input tensors
    &[&[0, 1], &[1, 2]], // Index labels: A[0,1], B[1,2]
    &[0, 2],             // Output labels: C[0,2]
);
}

Index Labels

Indices are represented as usize values. Matching indices indicate contraction:

OperationInputsLabelsOutput
Matrix multiplyA[m,k], B[k,n][[0,1], [1,2]][0,2] → C[m,n]
Batch matmulA[b,m,k], B[b,k,n][[0,1,2], [0,2,3]][0,1,3] → C[b,m,n]
Outer productA[m], B[n][[0], [1]][0,1] → C[m,n]
TraceA[n,n][[0,0]][] → scalar
SumA[m,n][[0,1]][] → scalar

Einsum Struct

For more control, use the Einsum struct directly:

#![allow(unused)]
fn main() {
use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;
use std::collections::HashMap;

// Define size dictionary
let sizes: HashMap<usize, usize> = [
    (0, 10),  // i: 10
    (1, 20),  // j: 20
    (2, 30),  // k: 30
].into();

// Create einsum specification
let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2]],  // A[i,j], B[j,k]
    vec![0, 2],                     // -> C[i,k]
    sizes,
);

// Check the einsum code
let code = ein.code();
println!("Einsum: {:?}", code);
}

Contraction Optimization

Greedy Algorithm

Fast O(n²) algorithm, good for most cases:

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(/* ... */);
ein.optimize_greedy();

assert!(ein.is_optimized());
}

Simulated Annealing

Slower but finds better orderings for complex networks:

#![allow(unused)]
fn main() {
let mut ein = Einsum::new(/* ... */);
ein.optimize_treesa();
}

Inspect Contraction Tree

#![allow(unused)]
fn main() {
if let Some(tree) = ein.contraction_tree() {
    println!("Contraction tree: {:?}", tree);
}
}

Chain Contraction Example

Contracting a chain of matrices:

#![allow(unused)]
fn main() {
use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::Standard;
use std::collections::HashMap;

// A[i,j] × B[j,k] × C[k,l] → D[i,l]
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

let sizes: HashMap<usize, usize> = [
    (0, 2), (1, 2), (2, 2), (3, 2)
].into();

let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2], vec![2, 3]],
    vec![0, 3],
    sizes,
);

ein.optimize_greedy();
let d = ein.execute::<Standard<f32>, f32, Cpu>(&[&a, &b, &c]);

assert_eq!(d.shape(), &[2, 2]);
}

Einsum with Gradients

For backpropagation support:

#![allow(unused)]
fn main() {
use omeinsum::einsum_with_grad;
use omeinsum::algebra::MaxPlus;

let (result, gradient) = einsum_with_grad::<MaxPlus<f32>, _, _>(
    &[&a, &b],
    &[&[0, 1], &[1, 2]],
    &[0, 2],
);

// gradient can be used for backpropagation
// (full backward pass implementation in progress)
}