Algebra

Trait Algebra 

Source
pub trait Algebra: Semiring {
    type Index: Copy + Clone + Send + Sync + Default + Debug + 'static;

    // Required methods
    fn add_with_argmax(
        self,
        self_idx: Self::Index,
        rhs: Self,
        rhs_idx: Self::Index,
    ) -> (Self, Self::Index);
    fn add_backward(
        self,
        rhs: Self,
        grad_out: Self::Scalar,
        winner_idx: Option<Self::Index>,
    ) -> (Self::Scalar, Self::Scalar);
    fn mul_backward(
        self,
        rhs: Self,
        grad_out: Self::Scalar,
    ) -> (Self::Scalar, Self::Scalar);
    fn is_better(&self, other: &Self) -> bool;

    // Provided method
    fn needs_argmax() -> bool { ... }
}
Expand description

Extended semiring operations for automatic differentiation.

This trait adds argmax tracking needed for tropical backpropagation.

Required Associated Types§

Source

type Index: Copy + Clone + Send + Sync + Default + Debug + 'static

Index type for argmax tracking

Required Methods§

Source

fn add_with_argmax( self, self_idx: Self::Index, rhs: Self, rhs_idx: Self::Index, ) -> (Self, Self::Index)

Addition with argmax tracking.

Returns (result, winner_index) where winner_index indicates which operand “won” the addition (relevant for tropical max/min).

Source

fn add_backward( self, rhs: Self, grad_out: Self::Scalar, winner_idx: Option<Self::Index>, ) -> (Self::Scalar, Self::Scalar)

Backward pass for addition.

Given output gradient grad_out, compute gradients for inputs. For standard arithmetic: both inputs get grad_out. For tropical: only the winner gets grad_out.

Source

fn mul_backward( self, rhs: Self, grad_out: Self::Scalar, ) -> (Self::Scalar, Self::Scalar)

Backward pass for multiplication.

Given output gradient grad_out, compute gradients for inputs. Standard: grad_a = grad_out × b, grad_b = grad_out × a Tropical (add): grad_a = grad_out, grad_b = grad_out

Source

fn is_better(&self, other: &Self) -> bool

Check if self is “better” than other for tropical selection.

For MaxPlus: returns true if self > other For MinPlus: returns true if self < other For Standard: not meaningful (always false)

Provided Methods§

Source

fn needs_argmax() -> bool

Whether this algebra requires argmax tracking for backprop.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<T: Scalar + Zero + One + PartialEq + Add<Output = T> + Mul<Output = T>> Algebra for Standard<T>