tropical-gemm
High-performance tropical matrix multiplication in Rust with SIMD and CUDA backends.
What is Tropical Algebra?
Tropical algebra (also called max-plus or min-plus algebra) replaces standard arithmetic operations with alternative ones:
| Standard | Tropical (MaxPlus) | Tropical (MinPlus) |
|---|---|---|
| a + b | max(a, b) | min(a, b) |
| a × b | a + b | a + b |
| 0 | -∞ | +∞ |
| 1 | 0 | 0 |
Applications
Tropical matrix multiplication appears in many algorithms:
- Shortest/Longest Path: Computing all-pairs shortest paths via matrix powers
- Viterbi Algorithm: Finding most likely sequences in HMMs
- Dynamic Programming: Optimizing over sequence alignments
- Neural Networks: Tropical neural networks with piecewise-linear activations
- Combinatorics: Counting optimal solutions
Features
- Multiple Semirings: MaxPlus, MinPlus, MaxMul
- SIMD Acceleration: AVX-512, AVX2, SSE4.1, NEON auto-detection
- CUDA Backend: GPU-accelerated kernels via runtime compilation
- Argmax Tracking: For backpropagation in differentiable programs
- Batched Operations: Efficient batch processing
- Python Bindings: PyTorch integration via PyO3
Feature Matrix
Supported Operations by Semiring and Scalar Type
| Semiring | Scalar | CPU GEMM | CPU Batched | CPU Argmax | CPU Backward | GPU GEMM | GPU Batched | GPU Argmax | GPU Backward |
|---|---|---|---|---|---|---|---|---|---|
| MaxPlus | f32 | SIMD | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MaxPlus | f64 | SIMD | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MaxPlus | i32 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
| MaxPlus | i64 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
| MinPlus | f32 | SIMD | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MinPlus | f64 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MinPlus | i32 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
| MinPlus | i64 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
| MaxMul | f32 | SIMD | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MaxMul | f64 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| MaxMul | i32 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
| MaxMul | i64 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | N/A |
Legend:
- SIMD: Optimized with AVX2/AVX-512/NEON vectorization
- ✅: Supported with portable implementation
- N/A: Not applicable (integers don’t have gradients)
Quick Example
#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};
// Create 2x3 and 3x2 matrices
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);
// C[i,j] = max_k(A[i,k] + B[k,j])
let c = &a * &b;
assert_eq!(c.get_value(0, 0), 8.0); // max(1+1, 2+3, 3+5) = 8
}
Getting Started
This section covers how to install and start using tropical-gemm.
Overview
tropical-gemm is organized as a Cargo workspace with three crates:
| Crate | Description |
|---|---|
tropical-gemm | Core library with CPU implementation |
tropical-gemm-cuda | Optional GPU acceleration via CUDA |
tropical-gemm-python | Python bindings for NumPy/PyTorch |
System Requirements
CPU
- Rust 1.70 or later
- x86-64 (AVX2/AVX-512) or ARM64 (NEON) for best performance
GPU (optional)
- NVIDIA GPU with compute capability 3.5+
- CUDA Toolkit 11.0 or later
nvccin PATH
Python (optional)
- Python 3.8+
- NumPy 1.20+
- PyTorch 2.0+ (for autograd integration)
Next Steps
- Installation - Detailed installation instructions
- Quick Start - Your first tropical matrix multiplication
Installation
Rust Crate
Add to your Cargo.toml:
[dependencies]
tropical-gemm = "0.1"
# For GPU acceleration (optional):
tropical-gemm-cuda = "0.1"
Python Package
From PyPI (Recommended)
# Basic installation
pip install tropical-gemm
# With PyTorch support for automatic differentiation
pip install tropical-gemm[torch]
# For development
pip install tropical-gemm[dev]
Optional Dependencies
The Python package has optional extras:
| Extra | Command | Description |
|---|---|---|
torch | pip install tropical-gemm[torch] | PyTorch integration with autograd support |
dev | pip install tropical-gemm[dev] | Development dependencies (pytest, torch) |
From Source
# Clone the repository
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm/crates/tropical-gemm-python
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # Linux/Mac
# .venv\Scripts\activate # Windows
# Install maturin and build
pip install maturin
maturin develop --release
# With CUDA support
maturin develop --release --features cuda
Verify Installation
import tropical_gemm
import numpy as np
a = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
b = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
c = tropical_gemm.maxplus_matmul(a, b)
print(c) # [[5. 6.] [7. 8.]]
# Check GPU availability
print(f"CUDA available: {tropical_gemm.cuda_available()}")
Verify PyTorch Integration
import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul, GPU_AVAILABLE
print(f"GPU available: {GPU_AVAILABLE}")
a = torch.randn(3, 4, requires_grad=True)
b = torch.randn(4, 5, requires_grad=True)
c = tropical_maxplus_matmul(a, b)
c.sum().backward()
print(f"grad_a: {a.grad.shape}") # (3, 4)
print(f"grad_b: {b.grad.shape}") # (4, 5)
CUDA Setup
For GPU acceleration, ensure CUDA is properly installed:
# Check CUDA installation
nvcc --version
# If not found, install CUDA toolkit
# Ubuntu:
sudo apt install nvidia-cuda-toolkit
# Or download from NVIDIA:
# https://developer.nvidia.com/cuda-downloads
The CUDA kernels are compiled at runtime using NVRTC, so you don’t need to compile the library with a specific CUDA version.
Building Python Package with CUDA
cd crates/tropical-gemm-python
# Build with CUDA feature
maturin develop --features cuda
# Or for release
maturin build --release --features cuda
Building from Source
# Clone
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm
# Build all crates
cargo build --release --workspace
# Run tests
cargo test --workspace
# Build documentation
cargo doc --workspace --no-deps --open
Using the Makefile
A Makefile is provided for common tasks:
make help # Show all targets
make setup # Setup development environment
make build # Build in release mode
make test # Run all tests
make docs # Build documentation
make bench # Run benchmarks
Quick Start
This guide shows the basics of tropical matrix multiplication.
Basic Matrix Multiplication
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MatRef, MaxPlus};
// Create matrices from raw data (row-major order)
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 matrix
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 matrix
// Create matrix views
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);
// Multiply using operator
let c = &a * &b;
// Or using method
let c = a.matmul(&b);
// Access result
println!("C[0,0] = {}", c.get_value(0, 0)); // 8.0 = max(1+1, 2+3, 3+5)
}
Understanding the Result
For MaxPlus semiring, the multiplication computes:
C[i,j] = max_k(A[i,k] + B[k,j])
For the example above:
- C[0,0] = max(1+1, 2+3, 3+5) = max(2, 5, 8) = 8
- C[0,1] = max(1+2, 2+4, 3+6) = max(3, 6, 9) = 9
- C[1,0] = max(4+1, 5+3, 6+5) = max(5, 8, 11) = 11
- C[1,1] = max(4+2, 5+4, 6+6) = max(6, 9, 12) = 12
Using Different Semirings
#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus, MinPlus, MaxMul};
let a_data = [1.0f32, 2.0, 3.0, 4.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0];
// MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 2, 2);
let c_maxplus = &a * &b;
// MinPlus: C[i,j] = min_k(A[i,k] + B[k,j])
let a = MatRef::<MinPlus<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MinPlus<f32>>::from_slice(&b_data, 2, 2);
let c_minplus = &a * &b;
// MaxMul: C[i,j] = max_k(A[i,k] * B[k,j])
let a = MatRef::<MaxMul<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MaxMul<f32>>::from_slice(&b_data, 2, 2);
let c_maxmul = &a * &b;
}
Factory Methods
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
// Create a zero matrix (all -∞ for MaxPlus)
let zeros = Mat::<MaxPlus<f32>>::zeros(3, 3);
// Create an identity matrix (0 on diagonal, -∞ elsewhere for MaxPlus)
let identity = Mat::<MaxPlus<f32>>::identity(3);
// Create from function
let mat = Mat::<MaxPlus<f32>>::from_fn(3, 3, |i, j| {
MaxPlus::from_scalar((i + j) as f32)
});
}
Next Steps
- Semiring Types - Learn about different tropical semirings
- Matrix API - Full matrix API reference
- GPU Acceleration - Using CUDA for large matrices
Semiring Types
A semiring is an algebraic structure with two operations that generalize addition and multiplication. Tropical semirings replace standard operations with max/min and addition.
Available Semirings
| Type | ⊕ (add) | ⊗ (mul) | Zero | One | Use Case |
|---|---|---|---|---|---|
MaxPlus<T> | max | + | -∞ | 0 | Longest path, Viterbi |
MinPlus<T> | min | + | +∞ | 0 | Shortest path, Dijkstra |
MaxMul<T> | max | × | 0 | 1 | Maximum probability |
AndOr | OR | AND | false | true | Graph reachability |
MaxPlus Semiring
The MaxPlus (or max-plus) semiring uses:
- Addition:
a ⊕ b = max(a, b) - Multiplication:
a ⊗ b = a + b
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, TropicalSemiring};
let a = MaxPlus::from_scalar(3.0f32);
let b = MaxPlus::from_scalar(5.0f32);
// Tropical add: max(3, 5) = 5
let sum = MaxPlus::tropical_add(a, b);
assert_eq!(sum.value(), 5.0);
// Tropical mul: 3 + 5 = 8
let product = MaxPlus::tropical_mul(a, b);
assert_eq!(product.value(), 8.0);
}
Applications:
- Longest path in graphs (Bellman-Ford with negated weights)
- Viterbi algorithm for HMM decoding
- Log-probability computations
MinPlus Semiring
The MinPlus (or min-plus) semiring uses:
- Addition:
a ⊕ b = min(a, b) - Multiplication:
a ⊗ b = a + b
#![allow(unused)]
fn main() {
use tropical_gemm::{MinPlus, TropicalSemiring};
let a = MinPlus::from_scalar(3.0f32);
let b = MinPlus::from_scalar(5.0f32);
// Tropical add: min(3, 5) = 3
let sum = MinPlus::tropical_add(a, b);
assert_eq!(sum.value(), 3.0);
// Tropical mul: 3 + 5 = 8
let product = MinPlus::tropical_mul(a, b);
assert_eq!(product.value(), 8.0);
}
Applications:
- Shortest path (Floyd-Warshall, Dijkstra)
- Edit distance computation
- Resource allocation
MaxMul Semiring
The MaxMul semiring uses:
- Addition:
a ⊕ b = max(a, b) - Multiplication:
a ⊗ b = a × b
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxMul, TropicalSemiring};
let a = MaxMul::from_scalar(3.0f32);
let b = MaxMul::from_scalar(5.0f32);
// Tropical add: max(3, 5) = 5
let sum = MaxMul::tropical_add(a, b);
assert_eq!(sum.value(), 5.0);
// Tropical mul: 3 × 5 = 15
let product = MaxMul::tropical_mul(a, b);
assert_eq!(product.value(), 15.0);
}
Applications:
- Maximum probability paths (non-log space)
- Fuzzy set operations
- Reliability analysis
Supported Scalar Types
Each semiring supports multiple scalar types:
| Scalar | MaxPlus | MinPlus | MaxMul | Notes |
|---|---|---|---|---|
f32 | ✅ SIMD | ✅ SIMD | ✅ SIMD | Best performance |
f64 | ✅ SIMD | ✅ | ✅ | Higher precision |
i32 | ✅ | ✅ | ✅ | Integer operations |
i64 | ✅ | ✅ | ✅ | Large integers |
Type Aliases
For convenience, shorter type aliases are provided:
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, MinPlus, MaxMul, AndOr};
// These are equivalent:
type A = tropical_gemm::TropicalMaxPlus<f32>;
type B = MaxPlus<f32>; // Preferred
}
Matrix API
tropical-gemm provides a matrix API inspired by faer.
Matrix Types
| Type | Description |
|---|---|
Mat<S> | Owned matrix with heap-allocated storage |
MatRef<'a, S> | Immutable view into matrix data |
MatMut<'a, S> | Mutable view into matrix data |
MatWithArgmax<S> | Matrix with argmax indices for backpropagation |
Creating Matrices
From Raw Data
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MatRef, MaxPlus};
// Create a view from a slice (no allocation)
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = MatRef::<MaxPlus<f32>>::from_slice(&data, 2, 3);
// Create an owned matrix from a slice (allocates)
let b = Mat::<MaxPlus<f32>>::from_row_major(&data, 2, 3);
}
Factory Methods
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
// Zero matrix (all elements = tropical zero)
let zeros = Mat::<MaxPlus<f32>>::zeros(3, 4);
// Identity matrix (diagonal = tropical one, off-diagonal = tropical zero)
let identity = Mat::<MaxPlus<f32>>::identity(3);
// From function
let mat = Mat::<MaxPlus<f32>>::from_fn(3, 3, |i, j| {
MaxPlus::from_scalar((i * 3 + j) as f32)
});
}
Matrix Multiplication
Operator Syntax
#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);
// Multiply using operators
let c = &a * &b; // Returns Mat<S>
}
Method Syntax
#![allow(unused)]
fn main() {
let c = a.matmul(&b);
}
Accessing Elements
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
let data = [1.0f32, 2.0, 3.0, 4.0];
let mat = Mat::<MaxPlus<f32>>::from_row_major(&data, 2, 2);
// Get the underlying scalar value
let value = mat.get_value(0, 1); // 2.0
// Get the tropical element
let elem = mat[(0, 1)]; // MaxPlus(2.0)
// Dimensions
let (rows, cols) = (mat.nrows(), mat.ncols());
}
Argmax Tracking
For backpropagation, track which k produced each output:
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
let a = Mat::<MaxPlus<f64>>::from_row_major(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3
);
let b = Mat::<MaxPlus<f64>>::from_row_major(
&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2
);
let result = a.matmul_argmax(&b);
// Get value and argmax
let value = result.get_value(0, 0); // 8.0
let k_idx = result.get_argmax(0, 0); // 2
// Compute gradients
let grad_c = vec![1.0f64; 4]; // upstream gradient (m × n)
let grad_a = result.backward_a(&grad_c);
let grad_b = result.backward_b(&grad_c);
}
Batched Operations
Process multiple matrices in parallel:
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
let a_batch = vec![
Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2),
];
let b_batch = vec![
Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
];
// Batched matmul (parallel by default)
let c_batch = Mat::matmul_batched(&a_batch, &b_batch);
// With argmax
let results = Mat::matmul_batched_with_argmax(&a_batch, &b_batch);
}
GPU Acceleration
tropical-gemm-cuda provides NVIDIA GPU acceleration via CUDA.
Requirements
- NVIDIA GPU (compute capability 3.5+)
- CUDA Toolkit 11.0 or later
nvccin PATH
Basic Usage
use tropical_gemm::{MatRef, MaxPlus};
use tropical_gemm_cuda::{CudaContext, GpuMat};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create CUDA context (compiles kernels on first use)
let ctx = CudaContext::new()?;
// Prepare CPU data
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);
// Upload to GPU
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
let b_gpu = GpuMat::from_matref(&ctx, &b)?;
// Compute on GPU
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
// Download result
let c = c_gpu.to_mat(&ctx)?;
println!("C[0,0] = {}", c.get_value(0, 0));
Ok(())
}
Context Reuse
The CudaContext compiles CUDA kernels on first use. Always reuse contexts
to avoid repeated compilation:
#![allow(unused)]
fn main() {
// GOOD: Reuse context
let ctx = CudaContext::new()?;
for _ in 0..100 {
let c = a_gpu.matmul(&ctx, &b_gpu)?;
}
// BAD: Creates new context each iteration
for _ in 0..100 {
let ctx = CudaContext::new()?; // Slow!
let c = a_gpu.matmul(&ctx, &b_gpu)?;
}
}
GPU Argmax
For backpropagation with GPU computation:
#![allow(unused)]
fn main() {
let ctx = CudaContext::new()?;
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
let b_gpu = GpuMat::from_matref(&ctx, &b)?;
// Forward pass with argmax tracking
let result = a_gpu.matmul_argmax(&ctx, &b_gpu)?;
// Download values and argmax
let result_cpu = result.to_mat_with_argmax(&ctx)?;
let value = result_cpu.get_value(0, 0);
let k_idx = result_cpu.get_argmax(0, 0);
// Backward pass on GPU
let grad_c_gpu = GpuMat::from_matref(&ctx, &grad_c)?;
let grad_a_gpu = result.backward_a(&ctx, &grad_c_gpu)?;
let grad_b_gpu = result.backward_b(&ctx, &grad_c_gpu)?;
}
Batched GPU Operations
Process multiple matrices efficiently:
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
use tropical_gemm_cuda::{CudaContext, GpuMat};
let ctx = CudaContext::new()?;
// Upload batch to GPU
let a_batch: Vec<Mat<MaxPlus<f32>>> = /* ... */;
let b_batch: Vec<Mat<MaxPlus<f32>>> = /* ... */;
let a_gpu_batch = GpuMat::from_mats(&ctx, &a_batch)?;
let b_gpu_batch = GpuMat::from_mats(&ctx, &b_batch)?;
// Batched multiply
let c_gpu_batch = GpuMat::matmul_batched(&ctx, &a_gpu_batch, &b_gpu_batch)?;
// Download results
let c_batch = GpuMat::to_mats(&ctx, &c_gpu_batch)?;
}
One-Shot API
For simple cases without context reuse:
#![allow(unused)]
fn main() {
use tropical_gemm::TropicalMaxPlus;
use tropical_gemm_cuda::tropical_matmul_gpu;
let a = vec![1.0f32; 64 * 64];
let b = vec![1.0f32; 64 * 64];
// One-shot GPU multiplication (creates temporary context)
let c = tropical_matmul_gpu::<TropicalMaxPlus<f32>>(&a, 64, 64, &b, 64)?;
}
Performance Comparison
| Size | CPU SIMD | GPU | Speedup |
|---|---|---|---|
| 256 | 4.1 ms | 0.032 ms | 128x |
| 512 | 32.8 ms | 0.086 ms | 381x |
| 1024 | 262.3 ms | 0.358 ms | 733x |
| 2048 | 2091.6 ms | 2.510 ms | 833x |
GPU becomes advantageous for matrices larger than ~256×256.
PyTorch Integration
tropical-gemm provides Python bindings with full PyTorch autograd support.
Installation
# From PyPI
pip install tropical-gemm
# With PyTorch support (recommended)
pip install tropical-gemm[torch]
# For GPU support (requires CUDA toolkit)
pip install maturin
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm/crates/tropical-gemm-python
maturin develop --features cuda
Basic NumPy Usage
import numpy as np
import tropical_gemm
a = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
b = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
# MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])
c = tropical_gemm.maxplus_matmul(a, b)
# MinPlus: C[i,j] = min_k(A[i,k] + B[k,j])
c = tropical_gemm.minplus_matmul(a, b)
# MaxMul: C[i,j] = max_k(A[i,k] * B[k,j])
c = tropical_gemm.maxmul_matmul(a, b)
# With argmax tracking for backpropagation
c, argmax = tropical_gemm.maxplus_matmul_with_argmax(a, b)
PyTorch Module (Recommended)
The tropical_gemm.pytorch module provides pre-built autograd functions:
import torch
from tropical_gemm.pytorch import (
# CPU operations
tropical_maxplus_matmul,
tropical_minplus_matmul,
tropical_maxmul_matmul,
# GPU operations (requires CUDA)
tropical_maxplus_matmul_gpu,
tropical_minplus_matmul_gpu,
tropical_maxmul_matmul_gpu,
# Check GPU availability
GPU_AVAILABLE,
)
# Create tensors with gradient tracking
a = torch.randn(100, 50, requires_grad=True)
b = torch.randn(50, 80, requires_grad=True)
# Forward pass - compute tropical matmul
c = tropical_maxplus_matmul(a, b)
# Backward pass - gradients computed automatically
loss = c.sum()
loss.backward()
print(f"grad_a shape: {a.grad.shape}") # (100, 50)
print(f"grad_b shape: {b.grad.shape}") # (50, 80)
GPU Acceleration
For larger matrices, use GPU-accelerated functions:
if GPU_AVAILABLE:
a = torch.randn(1024, 512, requires_grad=True)
b = torch.randn(512, 1024, requires_grad=True)
c = tropical_maxplus_matmul_gpu(a, b)
loss = c.sum()
loss.backward() # Gradients still work!
Available Functions
| CPU Function | GPU Function | Operation |
|---|---|---|
tropical_maxplus_matmul | tropical_maxplus_matmul_gpu | max_k(A[i,k] + B[k,j]) |
tropical_minplus_matmul | tropical_minplus_matmul_gpu | min_k(A[i,k] + B[k,j]) |
tropical_maxmul_matmul | tropical_maxmul_matmul_gpu | max_k(A[i,k] * B[k,j]) |
Training Example
import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul
# Learnable parameters
a = torch.randn(64, 128, requires_grad=True)
b = torch.randn(128, 32, requires_grad=True)
target = torch.randn(64, 32)
optimizer = torch.optim.Adam([a, b], lr=0.1)
for step in range(100):
optimizer.zero_grad()
# Forward - tropical matmul
c = tropical_maxplus_matmul(a, b)
# Loss
loss = ((c - target) ** 2).mean()
# Backward - gradients flow through tropical operation
loss.backward()
# Update parameters
optimizer.step()
if step % 20 == 0:
print(f"Step {step}: loss = {loss.item():.4f}")
Gradient Semantics
The gradient computation depends on the semiring type:
MaxPlus / MinPlus (Additive Rule)
For C[i,j] = max_k(A[i,k] + B[k,j]), let k* = argmax_k(A[i,k] + B[k,j]):
grad_A[i,k*] += grad_C[i,j]grad_B[k*,j] += grad_C[i,j]
The gradient is sparse - only the winning index contributes.
MaxMul (Multiplicative Rule)
For C[i,j] = max_k(A[i,k] * B[k,j]), let k* = argmax_k(A[i,k] * B[k,j]):
grad_A[i,k*] += grad_C[i,j] * B[k*,j]grad_B[k*,j] += grad_C[i,j] * A[i,k*]
| Semiring | Forward | Backward Rule |
|---|---|---|
| MaxPlus | max_k(A + B) | ∂C/∂A = 1 at argmax |
| MinPlus | min_k(A + B) | ∂C/∂A = 1 at argmin |
| MaxMul | max_k(A × B) | ∂C/∂A = B at argmax |
Graph Algorithms
Shortest Path (MinPlus)
import torch
from tropical_gemm.pytorch import tropical_minplus_matmul
# Adjacency matrix (inf = no edge)
inf = float("inf")
adj = torch.tensor([
[0.0, 1.0, inf, 4.0],
[inf, 0.0, 2.0, inf],
[inf, inf, 0.0, 1.0],
[inf, inf, inf, 0.0],
])
# 2-hop shortest paths
two_hop = tropical_minplus_matmul(adj, adj)
# 3-hop shortest paths
three_hop = tropical_minplus_matmul(two_hop, adj)
Longest Path (MaxPlus)
import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul
# Edge weights for critical path analysis
neg_inf = float("-inf")
adj = torch.tensor([
[0.0, 3.0, 2.0, neg_inf],
[neg_inf, 0.0, neg_inf, 4.0],
[neg_inf, neg_inf, 0.0, 5.0],
[neg_inf, neg_inf, neg_inf, 0.0],
])
# 2-hop longest paths
two_hop = tropical_maxplus_matmul(adj, adj)
Custom Autograd Function (Advanced)
If you need custom behavior, you can still define your own autograd function:
import torch
import numpy as np
import tropical_gemm
class TropicalMaxPlusMatmul(torch.autograd.Function):
"""Custom differentiable MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])"""
@staticmethod
def forward(ctx, a, b):
m, k = a.shape
n = b.shape[1]
# Convert to NumPy
a_np = a.detach().cpu().numpy().astype(np.float32)
b_np = b.detach().cpu().numpy().astype(np.float32)
if not a_np.flags["C_CONTIGUOUS"]:
a_np = np.ascontiguousarray(a_np)
if not b_np.flags["C_CONTIGUOUS"]:
b_np = np.ascontiguousarray(b_np)
# Forward pass with argmax tracking
c_flat, argmax_flat = tropical_gemm.maxplus_matmul_with_argmax(a_np, b_np)
c_np = np.asarray(c_flat).reshape(m, n) # zero-copy when possible
argmax_np = np.asarray(argmax_flat).reshape(m, n)
# Save for backward
ctx.save_for_backward(torch.from_numpy(argmax_np))
ctx.k, ctx.m, ctx.n = k, m, n
return torch.from_numpy(c_np).to(a.device)
@staticmethod
def backward(ctx, grad_c):
argmax, = ctx.saved_tensors
k, m, n = ctx.k, ctx.m, ctx.n
grad_c_np = grad_c.cpu().numpy().astype(np.float32)
argmax_np = argmax.numpy().astype(np.int32)
if not grad_c_np.flags["C_CONTIGUOUS"]:
grad_c_np = np.ascontiguousarray(grad_c_np)
# Backward pass
grad_a_flat = tropical_gemm.backward_a(grad_c_np, argmax_np, k)
grad_b_flat = tropical_gemm.backward_b(grad_c_np, argmax_np, k)
grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k)).to(grad_c.device)
grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k, n)).to(grad_c.device)
return grad_a, grad_b
Complete Example
See crates/tropical-gemm-python/examples/pytorch_tropical.py for:
- Gradient verification tests
- Shortest/longest path examples
- Optimization demos
- GPU benchmarks
Architecture
This section describes the internal architecture of tropical-gemm.
Overview
tropical-gemm achieves high performance through:
- BLIS-style blocking for cache efficiency
- SIMD microkernels for vectorization
- Runtime dispatch for optimal kernel selection
- CUDA kernels for GPU acceleration
Crate Structure
tropical-gemm/
├── src/
│ ├── lib.rs # Public API
│ ├── api.rs # Function-based API
│ ├── types/ # Semiring definitions
│ │ ├── traits.rs # TropicalSemiring trait
│ │ ├── max_plus.rs
│ │ ├── min_plus.rs
│ │ └── max_mul.rs
│ ├── core/ # BLIS algorithm
│ │ ├── gemm.rs # 5-loop blocking
│ │ ├── kernel.rs # Microkernel trait
│ │ ├── packing.rs # Matrix packing
│ │ └── tiling.rs # Cache parameters
│ ├── simd/ # SIMD kernels
│ │ ├── dispatch.rs # Runtime selection
│ │ ├── detect.rs # CPU detection
│ │ └── kernels/ # Per-architecture
│ └── mat/ # Matrix types
tropical-gemm-cuda/
├── src/
│ ├── lib.rs # Public API
│ ├── context.rs # CUDA context
│ ├── kernels.rs # Kernel management
│ └── gpu_mat.rs # GPU matrix type
└── kernels/
└── tropical_gemm.cu # CUDA source
Performance Layers
┌─────────────────────────────────────────────────────────────┐
│ User API (Mat, MatRef) │
├─────────────────────────────────────────────────────────────┤
│ Function API (tropical_matmul) │
├─────────────────────────────────────────────────────────────┤
│ SIMD Dispatch (KernelDispatch) │
├─────────────────────────────────────────────────────────────┤
│ BLIS 5-Loop Blocking (tropical_gemm_inner) │
├─────────────────────────────────────────────────────────────┤
│ SIMD Microkernel │
│ (AVX2 / AVX-512 / NEON / Portable) │
└─────────────────────────────────────────────────────────────┘
Key Design Decisions
1. Semiring as Type Parameter
Operations are generic over the semiring type, enabling compile-time specialization:
#![allow(unused)]
fn main() {
pub fn tropical_matmul<S: TropicalSemiring>(
a: &[S::Scalar], m: usize, k: usize,
b: &[S::Scalar], n: usize
) -> Vec<S>
}
2. Scalar vs Semiring Types
- Input: Raw scalar data (
&[f32],&[f64]) - Output: Semiring-wrapped values (
Vec<MaxPlus<f32>>)
This avoids unnecessary wrapping in hot paths.
3. Runtime SIMD Dispatch
CPU features are detected at runtime, not compile time:
#![allow(unused)]
fn main() {
match simd_level() {
SimdLevel::Avx512 => avx512_kernel(...),
SimdLevel::Avx2 => avx2_kernel(...),
_ => portable_kernel(...),
}
}
4. CUDA Runtime Compilation
Kernels are compiled from CUDA C source at runtime via NVRTC:
- No compile-time CUDA dependency
- Portability across CUDA versions
- Template-like specialization via macros
BLIS Algorithm
The CPU implementation uses BLIS-style cache blocking for optimal performance.
5-Loop Blocking
Matrix multiplication is blocked into tiles that fit in cache:
┌──────────────────────────────────────────────────────────────────────────┐
│ Loop 5: for jc in 0..N step NC (L3 cache - columns of B) │
│ Loop 4: for pc in 0..K step KC (L2 cache - depth) │
│ Pack B[pc:KC, jc:NC] → B̃ (contiguous in L3) │
│ Loop 3: for ic in 0..M step MC (L1 cache - rows of A) │
│ Pack A[ic:MC, pc:KC] → Ã (contiguous in L2) │
│ Loop 2: for jr in 0..NC step NR (register blocking) │
│ Loop 1: for ir in 0..MC step MR (microkernel) │
│ microkernel(Ã[ir], B̃[jr], C[ic+ir, jc+jr]) │
└──────────────────────────────────────────────────────────────────────────┘
Cache Tiling Parameters
| Parameter | Description | f32 AVX2 | f64 AVX2 | Portable |
|---|---|---|---|---|
| MC | Rows per L2 block | 256 | 128 | 64 |
| NC | Columns per L3 block | 256 | 128 | 64 |
| KC | Depth per block | 512 | 256 | 256 |
| MR | Microkernel rows | 8 | 4 | 4 |
| NR | Microkernel columns | 8 | 4 | 4 |
Parameters are tuned to fit in cache:
MC × KCfits in L2 cacheKC × NCfits in L3 cacheMR × NRfits in registers
Packing
Before computation, matrices are packed into contiguous buffers:
Pack A (MC × KC block)
Original layout (row-major):
A[0,0] A[0,1] A[0,2] ...
A[1,0] A[1,1] A[1,2] ...
...
Packed layout (MR-contiguous panels):
A[0,0] A[1,0] ... A[MR-1,0] // First column of first panel
A[0,1] A[1,1] ... A[MR-1,1] // Second column of first panel
...
A[MR,0] A[MR+1,0] ... // First column of second panel
Pack B (KC × NC block)
Packed into NR-wide panels for broadcasting:
B[0,0] B[0,1] ... B[0,NR-1] // First row of first panel
B[1,0] B[1,1] ... B[1,NR-1] // Second row of first panel
...
Benefits
- Sequential access: Packed data is accessed linearly
- Cache reuse: Each block is loaded once, used many times
- TLB efficiency: Fewer page table lookups
- SIMD friendly: Contiguous data enables vectorization
Code Location
core/gemm.rs: Main blocking loopscore/packing.rs: Pack functionscore/tiling.rs: TilingParams structcore/kernel.rs: Microkernel trait
SIMD Kernels
The microkernel is vectorized using SIMD instructions for maximum throughput.
Supported Architectures
| Architecture | Instruction Set | Vector Width | f32 MR×NR | f64 MR×NR |
|---|---|---|---|---|
| x86_64 | AVX-512 | 512-bit | 16×16 | 8×8 |
| x86_64 | AVX2 | 256-bit | 8×8 | 4×4 |
| x86_64 | SSE4.1 | 128-bit | 4×4 | 2×2 |
| aarch64 | NEON | 128-bit | 4×4 | 2×2 |
| Any | Portable | Scalar | 4×4 | 4×4 |
Runtime Detection
CPU features are detected at runtime:
#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};
match simd_level() {
SimdLevel::Avx512 => println!("Using AVX-512"),
SimdLevel::Avx2 => println!("Using AVX2"),
SimdLevel::Sse41 => println!("Using SSE4.1"),
SimdLevel::Neon => println!("Using NEON"),
SimdLevel::None => println!("Using portable"),
}
}
Microkernel Design
For MaxPlus f32 with AVX2 (8-wide vectors):
#![allow(unused)]
fn main() {
// Pseudocode for 8×8 microkernel
for k in 0..KC {
// Load 8 elements from packed A
let a_vec = _mm256_loadu_ps(a_ptr);
// For each column in the 8-column output tile
for j in 0..8 {
// Broadcast scalar from packed B
let b_scalar = _mm256_broadcast_ss(b_ptr + j);
// Tropical multiply: a + b (element-wise)
let prod = _mm256_add_ps(a_vec, b_scalar);
// Tropical accumulate: max(c, prod)
c_vec[j] = _mm256_max_ps(c_vec[j], prod);
}
a_ptr += 8; // Next column in packed A
b_ptr += 8; // Next row in packed B
}
}
Semiring-Specific Operations
| Semiring | Tropical Mul | Tropical Add |
|---|---|---|
| MaxPlus | _mm256_add_ps | _mm256_max_ps |
| MinPlus | _mm256_add_ps | _mm256_min_ps |
| MaxMul | _mm256_mul_ps | _mm256_max_ps |
Dispatch Mechanism
The KernelDispatch trait routes to the appropriate implementation:
#![allow(unused)]
fn main() {
impl KernelDispatch for TropicalMaxPlus<f32> {
unsafe fn dispatch_gemm(...) {
match simd_level() {
SimdLevel::Avx2 | SimdLevel::Avx512 => {
tropical_gemm_inner::<Self, Avx2MaxPlusF32>(...);
}
_ => {
tropical_gemm_inner::<Self, PortableMicrokernel>(...);
}
}
}
}
}
Code Location
simd/detect.rs: CPU feature detectionsimd/dispatch.rs: Runtime dispatch traitsimd/kernels/avx2.rs: AVX2 implementationssimd/kernels/neon.rs: NEON implementationssimd/kernels/portable.rs: Fallback implementation
CUDA Implementation
The GPU backend uses CUDA with runtime kernel compilation.
Architecture
┌─────────────────────────────────────────────────────────────┐
│ User API │
│ (GpuMat::matmul, tropical_matmul_gpu) │
├─────────────────────────────────────────────────────────────┤
│ CudaContext │
│ (kernel compilation, device management) │
├─────────────────────────────────────────────────────────────┤
│ NVRTC │
│ (runtime kernel compilation) │
├─────────────────────────────────────────────────────────────┤
│ CUDA Kernels │
│ (tropical_gemm.cu, specialized per semiring) │
└─────────────────────────────────────────────────────────────┘
Runtime Compilation
Kernels are compiled from CUDA C source at runtime using NVRTC:
#![allow(unused)]
fn main() {
// On first CudaContext::new()
let ctx = CudaContext::new()?; // Compiles kernels (~1-2 seconds)
// Subsequent operations are fast
let c = a_gpu.matmul(&ctx, &b_gpu)?; // Just kernel launch
}
Benefits:
- No build-time CUDA dependency: Users don’t need nvcc at build time
- Portability: Works across CUDA versions
- Specialization: Kernels optimized for specific semirings
Kernel Design
Thread Block Organization
Block size: 16×16 threads (256 threads per block)
Grid: ceil(M/16) × ceil(N/16) blocks
Each thread computes one output element C[i,j]
Memory Access Pattern
__global__ void tropical_maxplus_gemm(
const float* A, const float* B, float* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float max_val = -INFINITY;
for (int k = 0; k < K; k++) {
float sum = A[row * K + k] + B[k * N + col];
max_val = fmaxf(max_val, sum);
}
C[row * N + col] = max_val;
}
}
Shared Memory Tiling
For larger matrices, shared memory is used:
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
// Load tiles cooperatively
As[ty][tx] = A[row * K + (tile * TILE_SIZE + tx)];
Bs[ty][tx] = B[(tile * TILE_SIZE + ty) * N + col];
__syncthreads();
// Compute partial result from tile
for (int k = 0; k < TILE_SIZE; k++) {
max_val = fmaxf(max_val, As[ty][k] + Bs[k][tx]);
}
Argmax Kernels
For backpropagation, kernels track which k index achieved the max:
__global__ void tropical_maxplus_gemm_argmax(
const float* A, const float* B,
float* C, int* argmax,
int M, int N, int K
) {
// ... setup ...
float max_val = -INFINITY;
int max_k = 0;
for (int k = 0; k < K; k++) {
float sum = A[row * K + k] + B[k * N + col];
if (sum > max_val) {
max_val = sum;
max_k = k;
}
}
C[row * N + col] = max_val;
argmax[row * N + col] = max_k;
}
Batched Kernels
For processing multiple matrices:
// Strided batched: matrices stored contiguously
__global__ void tropical_maxplus_gemm_batched(
const float* A, const float* B, float* C,
int M, int N, int K, int batch_count,
int stride_a, int stride_b, int stride_c
) {
int batch = blockIdx.z;
// ... standard GEMM with offset by batch * stride ...
}
Memory Management
Device Memory Allocation
#![allow(unused)]
fn main() {
// Allocate GPU memory
let d_ptr = cuda_malloc(size_bytes)?;
// Copy host → device
cuda_memcpy_h2d(d_ptr, h_data, size_bytes)?;
// Copy device → host
cuda_memcpy_d2h(h_data, d_ptr, size_bytes)?;
// Free
cuda_free(d_ptr)?;
}
Pinned Memory (for faster transfers)
#![allow(unused)]
fn main() {
// For frequent CPU↔GPU transfers, use pinned memory
let pinned = cuda_malloc_host(size_bytes)?;
// ... 2-3x faster transfers ...
cuda_free_host(pinned)?;
}
Error Handling
CUDA errors are wrapped in Rust Result types:
#![allow(unused)]
fn main() {
match CudaContext::new() {
Ok(ctx) => { /* use context */ }
Err(CudaError::NoDevice) => {
println!("No CUDA device found, using CPU");
}
Err(CudaError::CompilationFailed(msg)) => {
eprintln!("Kernel compilation failed: {}", msg);
}
Err(e) => return Err(e.into()),
}
}
Code Location
tropical-gemm-cuda/src/context.rs: CUDA context and compilationtropical-gemm-cuda/src/gpu_mat.rs: GPU matrix typetropical-gemm-cuda/src/kernels.rs: Kernel managementtropical-gemm-cuda/kernels/tropical_gemm.cu: CUDA kernel source
Performance Guide
This guide helps you get the best performance from tropical-gemm.
CPU vs GPU Selection
| Matrix Size | Recommendation | Reason |
|---|---|---|
| < 128×128 | CPU | GPU transfer overhead dominates |
| 128-256 | CPU or GPU | Similar performance |
| > 256×256 | GPU | GPU computation advantage |
| > 1024×1024 | GPU (strongly) | 100-800x speedup |
Benchmark Results (MaxPlus f32)
Tested on NVIDIA RTX A4500 (Ampere) with AMD Ryzen 9 5900X.
| Size | CPU AVX2 | GPU | GPU Speedup |
|---|---|---|---|
| 64 | 0.05 ms | 0.02 ms | 2.5x |
| 128 | 0.4 ms | 0.02 ms | 20x |
| 256 | 4.1 ms | 0.03 ms | 137x |
| 512 | 32.8 ms | 0.09 ms | 364x |
| 1024 | 262 ms | 0.36 ms | 728x |
| 2048 | 2092 ms | 2.5 ms | 837x |
Rust CUDA vs C Reference
Comparison with TropicalGemm_Cuda:
| Size | C Library (ms) | Rust CUDA (ms) | Ratio |
|---|---|---|---|
| 256 | 0.028 | 0.032 | 1.14x |
| 512 | 0.074 | 0.086 | 1.16x |
| 1024 | 0.315 | 0.358 | 1.14x |
| 2048 | 2.224 | 2.509 | 1.13x |
The C library is ~13-16% faster due to pre-compiled PTX vs runtime compilation.
GPU Backward Pass Performance
| Size | Forward (ms) | Backward A (ms) | Backward B (ms) |
|---|---|---|---|
| 256 | 0.032 | 0.018 | 0.018 |
| 512 | 0.086 | 0.052 | 0.052 |
| 1024 | 0.358 | 0.183 | 0.184 |
| 2048 | 2.510 | 1.312 | 1.315 |
CPU Optimization
SIMD Detection
Ensure optimal SIMD is being used:
#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};
match simd_level() {
SimdLevel::Avx512 => println!("Best: AVX-512"),
SimdLevel::Avx2 => println!("Good: AVX2"),
SimdLevel::Sse41 => println!("Okay: SSE4.1"),
SimdLevel::Neon => println!("ARM: NEON"),
SimdLevel::None => println!("Slow: Portable fallback"),
}
}
Memory Layout
Row-major contiguous data is fastest:
#![allow(unused)]
fn main() {
// GOOD: Contiguous row-major
let a = Mat::<MaxPlus<f32>>::from_fn(m, k, |i, j| data[i * k + j]);
// SLOWER: Non-contiguous requires packing overhead
let a_ref = MatRef::from_slice_strided(&data, m, k, stride);
}
Cache Efficiency
For best cache utilization:
- Square matrices: Optimal blocking
- Tall-skinny (M >> K): Good cache reuse for A
- Short-wide (K >> M): May have cache pressure
GPU Optimization
Context Reuse
Critical: Reuse CudaContext to avoid repeated kernel compilation:
#![allow(unused)]
fn main() {
// GOOD: Create once, reuse many times
let ctx = CudaContext::new()?; // ~1-2 seconds
for batch in batches {
let c = a.matmul(&ctx, &b)?; // Fast
}
// BAD: Creates new context each time
for batch in batches {
let ctx = CudaContext::new()?; // Slow!
let c = a.matmul(&ctx, &b)?;
}
}
Batched Operations
For multiple matrix multiplications, use batched API:
#![allow(unused)]
fn main() {
// GOOD: Single kernel launch for all matrices
let c_batch = GpuMat::matmul_batched(&ctx, &a_batch, &b_batch)?;
// SLOWER: Sequential kernel launches
let c_batch: Vec<_> = a_batch.iter()
.zip(&b_batch)
.map(|(a, b)| a.matmul(&ctx, b))
.collect();
}
Memory Transfer
Minimize CPU↔GPU transfers:
#![allow(unused)]
fn main() {
// GOOD: Keep data on GPU between operations
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
let b_gpu = GpuMat::from_matref(&ctx, &b)?;
// Multiple operations without transfer
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
let d_gpu = c_gpu.matmul(&ctx, &b_gpu)?;
let e_gpu = d_gpu.matmul(&ctx, &b_gpu)?;
// Only transfer final result
let e = e_gpu.to_mat(&ctx)?;
// BAD: Transfer for each operation
for i in 0..3 {
let a_gpu = GpuMat::from_matref(&ctx, &a)?; // Upload
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
let c = c_gpu.to_mat(&ctx)?; // Download
a = c; // Use result for next iteration
}
}
PyTorch Training
Keep Context Alive
# Create context once at module initialization
class TropicalLayer(nn.Module):
def __init__(self):
super().__init__()
# Context created once
self.ctx = tropical_gemm.CudaContext()
def forward(self, a, b):
# Reuse context
return tropical_matmul_gpu(self.ctx, a, b)
Batch Your Data
# GOOD: Large batch, single kernel
output = tropical_matmul(large_batch_a, large_batch_b)
# SLOWER: Many small operations
outputs = [tropical_matmul(a, b) for a, b in zip(small_as, small_bs)]
Python Threading
GIL Release During Compute
All CPU functions release Python’s GIL during heavy computation, allowing other Python threads to run concurrently:
import threading
import tropical_gemm
import numpy as np
def background_task():
# This can run while tropical_gemm computes
print("Background task running")
a = np.random.randn(1000, 1000).astype(np.float32)
b = np.random.randn(1000, 1000).astype(np.float32)
# Start background thread
t = threading.Thread(target=background_task)
t.start()
# GIL is released during compute - background thread can run
c = tropical_gemm.maxplus_matmul(a, b)
t.join()
This is particularly useful in:
- Web servers (Flask, FastAPI) handling concurrent requests
- GUI applications that need to remain responsive
- Async applications using concurrent.futures
Zero-Copy with 2D Functions
The *_matmul_2d functions return properly shaped 2D arrays without reshaping overhead:
# Recommended: Use 2D functions for cleaner code
c = tropical_gemm.maxplus_matmul_2d(a, b) # shape: (m, n)
# Older pattern requiring reshape
c_flat = tropical_gemm.maxplus_matmul(a, b) # shape: (m*n,)
c = c_flat.reshape(m, n)
Memory Considerations
Argmax Memory
With argmax tracking, memory usage increases:
| Operation | Memory per element |
|---|---|
| Standard GEMM | 4 bytes (f32) |
| With argmax | 8 bytes (f32 + i32) |
For large matrices, this can be significant:
- 4096×4096 standard: 64 MB
- 4096×4096 with argmax: 128 MB
GPU Memory
Check available GPU memory:
#![allow(unused)]
fn main() {
let (free, total) = cuda_mem_info()?;
println!("GPU memory: {} MB free / {} MB total",
free / 1024 / 1024,
total / 1024 / 1024);
}
Profiling
CPU Profiling
# Linux perf
perf record --call-graph dwarf ./target/release/benchmark
perf report
# Flamegraph
cargo install flamegraph
cargo flamegraph --bin benchmark
GPU Profiling
# NVIDIA Nsight
nsys profile ./target/release/gpu_benchmark
nsys-ui report.nsys-rep
# nvprof (older)
nvprof ./target/release/gpu_benchmark
Troubleshooting Performance
Unexpectedly Slow CPU
- Check SIMD level (should be AVX2 or better on modern x86)
- Ensure data is contiguous (avoid strided access)
- Check for memory pressure (matrix too large for cache)
Unexpectedly Slow GPU
- Verify context reuse (compilation is slow)
- Check transfer overhead (small matrices dominated by transfer)
- Ensure sufficient GPU memory (avoid swapping)
- Use batched API for multiple matrices
Running Benchmarks
# CPU benchmark
cargo run --release --example bench_rust -p tropical-gemm
# CUDA vs CPU benchmark
cargo run --release --example bench_cuda_vs_cpu -p tropical-gemm-cuda
# GPU backward pass benchmark
cargo run --release --example bench_backward -p tropical-gemm-cuda
Or use the Makefile:
make bench # Run all benchmarks
make bench-cpu # CPU only
make bench-cuda # CUDA only
Troubleshooting
Common issues and solutions for tropical-gemm.
Installation Issues
Rust Compilation Errors
Error: “missing SIMD intrinsics”
error[E0433]: failed to resolve: use of undeclared crate or module `core_arch`
Solution: Update Rust to latest stable:
rustup update stable
Error: “target feature avx2 is not enabled”
This is expected on non-x86 platforms. The portable fallback will be used automatically.
CUDA Issues
Error: “CUDA driver not found”
CudaError: CUDA driver version is insufficient
Solution:
- Install/update NVIDIA drivers
- Verify with
nvidia-smi - Install CUDA Toolkit
Error: “nvcc not found”
CudaError: Failed to compile kernel: nvcc not found
Solution:
# Add CUDA to PATH
export PATH=/usr/local/cuda/bin:$PATH
# Verify
nvcc --version
Error: “Kernel compilation failed”
CudaError: CompilationFailed: ...
Solution:
- Check CUDA version compatibility (requires 11.0+)
- Ensure CUDA headers are installed
- Try reinstalling CUDA Toolkit
Python Binding Issues
Error: “module ‘tropical_gemm’ not found”
>>> import tropical_gemm
ModuleNotFoundError: No module named 'tropical_gemm'
Solution:
cd crates/tropical-gemm-python
pip install maturin
maturin develop --release
Error: “symbol not found in flat namespace” (macOS)
ImportError: dlopen(...): symbol not found in flat namespace
Solution: Rebuild with correct Python version:
# Ensure using correct Python
which python
python --version
# Rebuild
maturin develop --release
Error: “dtype mismatch”
TypeError: Expected float32 array, got float64
Solution: Explicitly cast to float32:
import numpy as np
a = a.astype(np.float32)
b = b.astype(np.float32)
c = tropical_gemm.maxplus_matmul(a, b)
Runtime Issues
Incorrect Results
Symptom: All outputs are -inf or inf
This typically means input contains NaN or inf values:
#![allow(unused)]
fn main() {
// Check for invalid values
for &x in data.iter() {
if x.is_nan() || x.is_infinite() {
panic!("Invalid input value: {}", x);
}
}
}
Symptom: Results differ between CPU and GPU
Small numerical differences are expected due to floating-point associativity. For MaxPlus/MinPlus, results should be identical (only comparisons).
For MaxMul, small differences may occur:
#![allow(unused)]
fn main() {
// Allow small tolerance
let diff = (cpu_result - gpu_result).abs();
assert!(diff < 1e-5, "Results differ by {}", diff);
}
Performance Issues
Symptom: GPU slower than CPU
For small matrices, transfer overhead dominates:
#![allow(unused)]
fn main() {
// Rule of thumb: GPU beneficial for N > 256
if n < 256 {
// Use CPU
tropical_matmul::<MaxPlus<f32>>(&a, m, k, &b, n)
} else {
// Use GPU
tropical_matmul_gpu::<MaxPlus<f32>>(&a, m, k, &b, n)?
}
}
Symptom: CPU slower than expected
Check SIMD detection:
#![allow(unused)]
fn main() {
use tropical_gemm::simd_level;
println!("SIMD level: {:?}", simd_level());
// Should be Avx2 or Avx512 on modern x86
}
Memory Issues
Error: “out of memory” (GPU)
CudaError: Out of memory
Solution:
- Use smaller batch sizes
- Process matrices sequentially
- Free unused GPU memory
#![allow(unused)]
fn main() {
// Process in chunks
for chunk in matrices.chunks(batch_size) {
let result = process_batch(&ctx, chunk)?;
// Results are downloaded, GPU memory freed
}
}
Error: “allocation failed” (CPU)
Large matrices may exceed available RAM:
#![allow(unused)]
fn main() {
// Estimate memory needed
let bytes = m * n * std::mem::size_of::<f32>();
println!("Matrix requires {} MB", bytes / 1024 / 1024);
}
PyTorch Issues
Gradient Issues
Symptom: Gradients are all zeros
Check that tensors require gradients:
a = torch.randn(4, 5, requires_grad=True) # Must be True
b = torch.randn(5, 3, requires_grad=True)
c = TropicalMaxPlusMatmul.apply(a, b)
loss = c.sum()
loss.backward()
print(a.grad) # Should not be None
Symptom: “RuntimeError: element 0 of tensors does not require grad”
Ensure input tensors have requires_grad=True:
a = torch.tensor([[1.0, 2.0]], requires_grad=True)
# Not: a = torch.tensor([[1.0, 2.0]]) # No gradients!
Device Mismatch
Error: “Expected all tensors on same device”
# Ensure both inputs on same device
a = a.to('cuda')
b = b.to('cuda')
c = TropicalMaxPlusMatmul.apply(a, b)
Getting Help
If you encounter issues not covered here:
- Check GitHub issues: https://github.com/TensorBFS/tropical-gemm/issues
- Open a new issue with:
- Error message
- Rust/Python version
- OS and hardware
- Minimal reproduction code
Diagnostic Information
Include this in bug reports:
# Rust version
rustc --version
cargo --version
# CUDA (if applicable)
nvcc --version
nvidia-smi
# Python (if applicable)
python --version
pip show tropical_gemm
API Reference
This page provides quick reference to the main APIs.
For complete documentation, see the Rust API docs.
Crate Overview
| Crate | Purpose |
|---|---|
tropical-gemm | CPU implementation with SIMD |
tropical-gemm-cuda | GPU implementation with CUDA |
tropical-gemm-python | Python bindings |
Semiring Types
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, MinPlus, MaxMul};
use tropical_gemm::types::{TropicalMaxPlus, TropicalMinPlus, TropicalMaxMul};
// Wrapper types (for storage)
let a: MaxPlus<f32> = MaxPlus::new(3.0);
let b: MinPlus<f64> = MinPlus::new(5.0);
// Marker types (for generic functions)
type S = TropicalMaxPlus<f32>;
}
Matrix Types
Mat (Owned)
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
// Create from function
let a = Mat::<MaxPlus<f32>>::from_fn(m, k, |i, j| value);
// Create from scalar slice
let a = Mat::<MaxPlus<f32>>::from_scalar_slice(&data, m, k);
// Access
let val = a.get_value(i, j); // Returns f32
let dim = a.dim(); // Returns (rows, cols)
}
MatRef (Borrowed)
#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};
// From slice
let a = MatRef::<MaxPlus<f32>>::from_slice(&data, m, k);
// From Mat
let a_ref = a.as_ref();
}
MatMut (Mutable)
#![allow(unused)]
fn main() {
use tropical_gemm::MatMut;
let mut c = Mat::zeros(m, n);
let c_mut = c.as_mut();
}
Matrix Operations
High-Level API (Mat)
#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
let a = Mat::<MaxPlus<f32>>::from_scalar_slice(&a_data, m, k);
let b = Mat::<MaxPlus<f32>>::from_scalar_slice(&b_data, k, n);
// Standard multiply
let c = a.matmul(&b);
// With argmax tracking
let result = a.matmul_with_argmax(&b);
let value = result.get_value(i, j);
let argmax = result.get_argmax(i, j);
}
Low-Level API (Functions)
#![allow(unused)]
fn main() {
use tropical_gemm::{tropical_matmul, tropical_matmul_with_argmax, TropicalMaxPlus};
// Standard multiply
let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, m, k, &b, n);
// With argmax
let (values, argmax) = tropical_matmul_with_argmax::<TropicalMaxPlus<f32>>(&a, m, k, &b, n);
}
GPU API
CudaContext
#![allow(unused)]
fn main() {
use tropical_gemm_cuda::CudaContext;
let ctx = CudaContext::new()?; // Compiles kernels
}
GpuMat
#![allow(unused)]
fn main() {
use tropical_gemm_cuda::GpuMat;
use tropical_gemm::{MatRef, MaxPlus};
// Upload
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
// Compute
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
// With argmax
let result = a_gpu.matmul_argmax(&ctx, &b_gpu)?;
// Download
let c = c_gpu.to_mat(&ctx)?;
}
Batched Operations
#![allow(unused)]
fn main() {
use tropical_gemm_cuda::GpuMat;
// Upload batch
let a_batch = GpuMat::from_mats(&ctx, &a_mats)?;
let b_batch = GpuMat::from_mats(&ctx, &b_mats)?;
// Batched multiply
let c_batch = GpuMat::matmul_batched(&ctx, &a_batch, &b_batch)?;
// Download batch
let c_mats = GpuMat::to_mats(&ctx, &c_batch)?;
}
Python API
NumPy Functions
import tropical_gemm
import numpy as np
a = np.array([[1, 2], [3, 4]], dtype=np.float32)
b = np.array([[5, 6], [7, 8]], dtype=np.float32)
# Basic operations (returns flattened 1D array)
c_flat = tropical_gemm.maxplus_matmul(a, b)
c = c_flat.reshape(a.shape[0], b.shape[1])
# 2D output (returns proper 2D array directly)
c = tropical_gemm.maxplus_matmul_2d(a, b) # shape: (m, n)
c = tropical_gemm.minplus_matmul_2d(a, b)
c = tropical_gemm.maxmul_matmul_2d(a, b)
# With argmax
values, argmax = tropical_gemm.maxplus_matmul_with_argmax(a, b)
2D Output Functions
The *_matmul_2d variants return properly shaped 2D NumPy arrays without manual reshaping:
| Type | MaxPlus | MinPlus | MaxMul |
|---|---|---|---|
| f32 | maxplus_matmul_2d | minplus_matmul_2d | maxmul_matmul_2d |
| f64 | maxplus_matmul_2d_f64 | minplus_matmul_2d_f64 | maxmul_matmul_2d_f64 |
| i32 | maxplus_matmul_2d_i32 | minplus_matmul_2d_i32 | maxmul_matmul_2d_i32 |
| i64 | maxplus_matmul_2d_i64 | minplus_matmul_2d_i64 | maxmul_matmul_2d_i64 |
# f64 example
a = np.array([[1, 2], [3, 4]], dtype=np.float64)
b = np.array([[5, 6], [7, 8]], dtype=np.float64)
c = tropical_gemm.maxplus_matmul_2d_f64(a, b) # shape: (2, 2)
# i32 example
a = np.array([[1, 2], [3, 4]], dtype=np.int32)
b = np.array([[5, 6], [7, 8]], dtype=np.int32)
c = tropical_gemm.maxplus_matmul_2d_i32(a, b) # shape: (2, 2)
Backward Pass
# Gradient computation
grad_a = tropical_gemm.backward_a(grad_c, argmax, k)
grad_b = tropical_gemm.backward_b(grad_c, argmax, k)
Utility Functions
SIMD Detection
#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};
match simd_level() {
SimdLevel::Avx512 => { /* ... */ }
SimdLevel::Avx2 => { /* ... */ }
SimdLevel::Sse41 => { /* ... */ }
SimdLevel::Neon => { /* ... */ }
SimdLevel::None => { /* ... */ }
}
}
Type Aliases
For convenience:
#![allow(unused)]
fn main() {
// These are equivalent:
use tropical_gemm::MaxPlus;
use tropical_gemm::types::max_plus::MaxPlus;
// Marker types for generics:
use tropical_gemm::TropicalMaxPlus; // = TropicalSemiringImpl<MaxPlusTag, T>
use tropical_gemm::TropicalMinPlus;
use tropical_gemm::TropicalMaxMul;
}
Changelog
All notable changes to tropical-gemm.
[0.2.0]
Added
- 2D output functions: New
*_matmul_2dvariants that return properly shaped 2D arrays instead of flattened 1D output. Available for all semirings (maxplus, minplus, maxmul) and data types (f32, f64, i32, i64):maxplus_matmul_2d,minplus_matmul_2d,maxmul_matmul_2d(f32)maxplus_matmul_2d_f64,minplus_matmul_2d_f64,maxmul_matmul_2d_f64maxplus_matmul_2d_i32,minplus_matmul_2d_i32,maxmul_matmul_2d_i32maxplus_matmul_2d_i64,minplus_matmul_2d_i64,maxmul_matmul_2d_i64
- mdBook documentation
- Comprehensive architecture documentation
- Performance tuning guide
- Troubleshooting guide
Changed
- GIL release during compute: All CPU functions now release Python’s GIL during heavy computation, allowing other Python threads to run concurrently. This improves performance in multi-threaded Python applications.
Fixed
- Batched CPU path copies: Fixed unnecessary memory copies in batched PyTorch operations by using
np.asarray()instead ofnp.array()for zero-copy array creation when possible.
[0.1.0] - Initial Release
Features
- High-performance tropical matrix multiplication
- Support for three semirings: MaxPlus, MinPlus, MaxMul
- SIMD acceleration (AVX-512, AVX2, SSE4.1, NEON)
- CUDA GPU acceleration
- Argmax tracking for backpropagation
- Python bindings with NumPy support
- PyTorch autograd integration
Crates
tropical-gemm: Core CPU implementationtropical-gemm-cuda: CUDA GPU backendtropical-gemm-python: Python bindings
Performance
- BLIS-style 5-loop cache blocking
- Runtime SIMD dispatch
- GPU speedup up to 800x for large matrices
Version History
| Version | Date | Highlights |
|---|---|---|
| 0.1.0 | 2024 | Initial release |
Migration Guides
From NumPy Implementation
If migrating from a pure NumPy tropical matrix multiplication:
# Before (NumPy)
def maxplus_matmul_numpy(a, b):
m, k = a.shape
n = b.shape[1]
c = np.full((m, n), -np.inf)
for i in range(m):
for j in range(n):
for kk in range(k):
c[i, j] = max(c[i, j], a[i, kk] + b[kk, j])
return c
# After (tropical-gemm)
import tropical_gemm
c = tropical_gemm.maxplus_matmul(a, b)
API Changes
No breaking changes yet (this is the first release).