Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Tensor API

The Tensor<T, B> type provides a flexible, stride-based tensor implementation.

Creating Tensors

From Data

#![allow(unused)]
fn main() {
use omeinsum::{Tensor, Cpu};

// Create from slice with shape
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Row-major layout:
// [[1, 2, 3],
//  [4, 5, 6]]
}

Zeros and Ones

#![allow(unused)]
fn main() {
let zeros = Tensor::<f32, Cpu>::zeros(&[3, 4]);
let ones = Tensor::<f32, Cpu>::ones(&[3, 4]);
}

Properties

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

t.shape()    // &[2, 3]
t.strides()  // &[3, 1] for row-major
t.ndim()     // 2
t.numel()    // 6
}

Views and Transformations

Permute (Transpose)

Zero-copy axis reordering:

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Transpose: swap axes 0 and 1
let t_t = t.permute(&[1, 0]);
assert_eq!(t_t.shape(), &[3, 2]);

// 3D example: (batch, height, width) -> (batch, width, height)
let img = Tensor::<f32, Cpu>::zeros(&[10, 28, 28]);
let img_t = img.permute(&[0, 2, 1]);
assert_eq!(img_t.shape(), &[10, 28, 28]);
}

Reshape

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);

// Flatten
let flat = t.reshape(&[6]);

// Reshape to different dimensions
let reshaped = t.reshape(&[3, 2]);
}

Contiguous

Convert non-contiguous views to contiguous storage:

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let t_t = t.permute(&[1, 0]);  // Non-contiguous after transpose

let t_contig = t_t.contiguous();  // Copy to contiguous memory
assert!(t_contig.is_contiguous());
}

Matrix Operations

GEMM (General Matrix Multiplication)

#![allow(unused)]
fn main() {
use omeinsum::algebra::{Standard, MaxPlus};

let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Standard matrix multiplication
let c = a.gemm::<Standard<f32>>(&b);

// Tropical matrix multiplication
let c_trop = a.gemm::<MaxPlus<f32>>(&b);
}

Binary Contraction

General tensor contraction:

#![allow(unused)]
fn main() {
// A[i,j,k] × B[j,k,l] -> C[i,l]
let a = Tensor::<f32, Cpu>::zeros(&[2, 3, 4]);
let b = Tensor::<f32, Cpu>::zeros(&[3, 4, 5]);

let c = a.contract_binary::<Standard<f32>>(
    &b,
    &[0, 1, 2],  // A's indices: i, j, k
    &[1, 2, 3],  // B's indices: j, k, l
    &[0, 3],     // Output: i, l
);

assert_eq!(c.shape(), &[2, 5]);
}

Data Access

#![allow(unused)]
fn main() {
let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);

// Convert to Vec
let data = t.to_vec();  // [1.0, 2.0, 3.0, 4.0]
}