einsum_with_grad

Function einsum_with_grad 

Source
pub fn einsum_with_grad<A, T, B>(
    tensors: &[&Tensor<T, B>],
    ixs: &[&[usize]],
    iy: &[usize],
) -> (Tensor<T, B>, EinsumGradient<T, B>)
where A: Algebra<Scalar = T, Index = u32>, T: Scalar + BackendScalar<B>, B: Backend + Default,
Expand description

Einsum with gradient computation.

Returns (result, gradient_fn) where gradient_fn can be called with the output gradient to compute input gradients.

For Standard algebra, gradients are computed via einsum (no argmax tracking needed). For tropical algebras, argmax is tracked during forward pass for gradient routing.