Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

tropical-gemm

High-performance tropical matrix multiplication in Rust with SIMD and CUDA backends.

What is Tropical Algebra?

Tropical algebra (also called max-plus or min-plus algebra) replaces standard arithmetic operations with alternative ones:

StandardTropical (MaxPlus)Tropical (MinPlus)
a + bmax(a, b)min(a, b)
a × ba + ba + b
0-∞+∞
100

Applications

Tropical matrix multiplication appears in many algorithms:

  • Shortest/Longest Path: Computing all-pairs shortest paths via matrix powers
  • Viterbi Algorithm: Finding most likely sequences in HMMs
  • Dynamic Programming: Optimizing over sequence alignments
  • Neural Networks: Tropical neural networks with piecewise-linear activations
  • Combinatorics: Counting optimal solutions

Features

  • Multiple Semirings: MaxPlus, MinPlus, MaxMul
  • SIMD Acceleration: AVX-512, AVX2, SSE4.1, NEON auto-detection
  • CUDA Backend: GPU-accelerated kernels via runtime compilation
  • Argmax Tracking: For backpropagation in differentiable programs
  • Batched Operations: Efficient batch processing
  • Python Bindings: PyTorch integration via PyO3

Feature Matrix

Supported Operations by Semiring and Scalar Type

SemiringScalarCPU GEMMCPU BatchedCPU ArgmaxCPU BackwardGPU GEMMGPU BatchedGPU ArgmaxGPU Backward
MaxPlusf32SIMD
MaxPlusf64SIMD
MaxPlusi32N/A
MaxPlusi64N/A
MinPlusf32SIMD
MinPlusf64
MinPlusi32N/A
MinPlusi64N/A
MaxMulf32SIMD
MaxMulf64
MaxMuli32N/A
MaxMuli64N/A

Legend:

  • SIMD: Optimized with AVX2/AVX-512/NEON vectorization
  • : Supported with portable implementation
  • N/A: Not applicable (integers don’t have gradients)

Quick Example

#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};

// Create 2x3 and 3x2 matrices
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

// C[i,j] = max_k(A[i,k] + B[k,j])
let c = &a * &b;
assert_eq!(c.get_value(0, 0), 8.0); // max(1+1, 2+3, 3+5) = 8
}