Einsum

Struct Einsum 

Source
pub struct Einsum<L: Label = usize> {
    pub ixs: Vec<Vec<L>>,
    pub iy: Vec<L>,
    pub size_dict: HashMap<L, usize>,
    /* private fields */
}
Expand description

Einsum specification and execution engine.

Supports contraction order optimization via omeco.

§Example

use omeinsum::{Einsum, Tensor, Cpu};
use omeinsum::algebra::MaxPlus;
use std::collections::HashMap;

// A[i,j] × B[j,k] → C[i,k]
let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::<f32, Cpu>::from_data(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);

let sizes: HashMap<usize, usize> = [(0, 2), (1, 2), (2, 2)].into();
let mut ein = Einsum::new(
    vec![vec![0, 1], vec![1, 2]],
    vec![0, 2],
    sizes,
);

ein.optimize_greedy();
let result = ein.execute::<MaxPlus<f32>, f32, Cpu>(&[&a, &b]);
assert_eq!(result.shape(), &[2, 2]);

Fields§

§ixs: Vec<Vec<L>>

Input index labels for each tensor

§iy: Vec<L>

Output index labels

§size_dict: HashMap<L, usize>

Dimension sizes for each index

Implementations§

Source§

impl<L: Label> Einsum<L>

Source

pub fn new(ixs: Vec<Vec<L>>, iy: Vec<L>, size_dict: HashMap<L, usize>) -> Self

Create a new einsum specification.

§Arguments
  • ixs - Index labels for each input tensor
  • iy - Output index labels
  • size_dict - Mapping from index labels to dimension sizes
Source

pub fn code(&self) -> EinCode<L>

Get the einsum code specification.

Source

pub fn optimize_greedy(&mut self) -> &mut Self

Optimize contraction order using greedy algorithm.

Fast O(n²) algorithm, good for most cases.

Source

pub fn optimize_treesa(&mut self) -> &mut Self

Optimize contraction order using simulated annealing.

Slower but finds better orderings for complex networks.

Source

pub fn is_optimized(&self) -> bool

Check if optimization has been performed.

Source

pub fn contraction_tree(&self) -> Option<&NestedEinsum<L>>

Get the optimized contraction tree.

Source§

impl Einsum<usize>

Source

pub fn execute<A, T, B>(&self, tensors: &[&Tensor<T, B>]) -> Tensor<T, B>
where A: Algebra<Scalar = T, Index = u32>, T: Scalar + BackendScalar<B>, B: Backend + Default,

Execute the einsum contraction.

§Type Parameters
  • A - The algebra to use (e.g., Standard<f32>, MaxPlus<f32>)
  • T - The scalar type
  • B - The backend type
Source

pub fn execute_with_argmax<A, T, B>( &self, tensors: &[&Tensor<T, B>], ) -> (Tensor<T, B>, Vec<Tensor<u32, B>>)
where A: Algebra<Scalar = T, Index = u32>, T: Scalar + BackendScalar<B>, B: Backend + Default,

Execute with argmax tracking for backpropagation.

Returns (result, argmax_cache) where argmax_cache contains argmax tensors for each binary contraction in the execution tree.

Auto Trait Implementations§

§

impl<L> Freeze for Einsum<L>

§

impl<L> RefUnwindSafe for Einsum<L>
where L: RefUnwindSafe,

§

impl<L> Send for Einsum<L>

§

impl<L> Sync for Einsum<L>

§

impl<L> Unpin for Einsum<L>
where L: Unpin,

§

impl<L> UnwindSafe for Einsum<L>
where L: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
§

impl<T> ByRef<T> for T

§

fn by_ref(&self) -> &T

§

impl<T> DistributionExt for T
where T: ?Sized,

§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V