Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

tropical-gemm

High-performance tropical matrix multiplication in Rust with SIMD and CUDA backends.

What is Tropical Algebra?

Tropical algebra (also called max-plus or min-plus algebra) replaces standard arithmetic operations with alternative ones:

StandardTropical (MaxPlus)Tropical (MinPlus)
a + bmax(a, b)min(a, b)
a × ba + ba + b
0-∞+∞
100

Applications

Tropical matrix multiplication appears in many algorithms:

  • Shortest/Longest Path: Computing all-pairs shortest paths via matrix powers
  • Viterbi Algorithm: Finding most likely sequences in HMMs
  • Dynamic Programming: Optimizing over sequence alignments
  • Neural Networks: Tropical neural networks with piecewise-linear activations
  • Combinatorics: Counting optimal solutions

Features

  • Multiple Semirings: MaxPlus, MinPlus, MaxMul
  • SIMD Acceleration: AVX-512, AVX2, SSE4.1, NEON auto-detection
  • CUDA Backend: GPU-accelerated kernels via runtime compilation
  • Argmax Tracking: For backpropagation in differentiable programs
  • Batched Operations: Efficient batch processing
  • Python Bindings: PyTorch integration via PyO3

Feature Matrix

Supported Operations by Semiring and Scalar Type

SemiringScalarCPU GEMMCPU BatchedCPU ArgmaxCPU BackwardGPU GEMMGPU BatchedGPU ArgmaxGPU Backward
MaxPlusf32SIMD
MaxPlusf64SIMD
MaxPlusi32N/A
MaxPlusi64N/A
MinPlusf32SIMD
MinPlusf64
MinPlusi32N/A
MinPlusi64N/A
MaxMulf32SIMD
MaxMulf64
MaxMuli32N/A
MaxMuli64N/A

Legend:

  • SIMD: Optimized with AVX2/AVX-512/NEON vectorization
  • : Supported with portable implementation
  • N/A: Not applicable (integers don’t have gradients)

Quick Example

#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};

// Create 2x3 and 3x2 matrices
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

// C[i,j] = max_k(A[i,k] + B[k,j])
let c = &a * &b;
assert_eq!(c.get_value(0, 0), 8.0); // max(1+1, 2+3, 3+5) = 8
}

Getting Started

This section covers how to install and start using tropical-gemm.

Overview

tropical-gemm is organized as a Cargo workspace with three crates:

CrateDescription
tropical-gemmCore library with CPU implementation
tropical-gemm-cudaOptional GPU acceleration via CUDA
tropical-gemm-pythonPython bindings for NumPy/PyTorch

System Requirements

CPU

  • Rust 1.70 or later
  • x86-64 (AVX2/AVX-512) or ARM64 (NEON) for best performance

GPU (optional)

  • NVIDIA GPU with compute capability 3.5+
  • CUDA Toolkit 11.0 or later
  • nvcc in PATH

Python (optional)

  • Python 3.8+
  • NumPy 1.20+
  • PyTorch 2.0+ (for autograd integration)

Next Steps

Installation

Rust Crate

Add to your Cargo.toml:

[dependencies]
tropical-gemm = "0.1"

# For GPU acceleration (optional):
tropical-gemm-cuda = "0.1"

Python Package

# Basic installation
pip install tropical-gemm

# With PyTorch support for automatic differentiation
pip install tropical-gemm[torch]

# For development
pip install tropical-gemm[dev]

Optional Dependencies

The Python package has optional extras:

ExtraCommandDescription
torchpip install tropical-gemm[torch]PyTorch integration with autograd support
devpip install tropical-gemm[dev]Development dependencies (pytest, torch)

From Source

# Clone the repository
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm/crates/tropical-gemm-python

# Create virtual environment
python -m venv .venv
source .venv/bin/activate  # Linux/Mac
# .venv\Scripts\activate   # Windows

# Install maturin and build
pip install maturin
maturin develop --release

# With CUDA support
maturin develop --release --features cuda

Verify Installation

import tropical_gemm
import numpy as np

a = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
b = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

c = tropical_gemm.maxplus_matmul(a, b)
print(c)  # [[5. 6.] [7. 8.]]

# Check GPU availability
print(f"CUDA available: {tropical_gemm.cuda_available()}")

Verify PyTorch Integration

import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul, GPU_AVAILABLE

print(f"GPU available: {GPU_AVAILABLE}")

a = torch.randn(3, 4, requires_grad=True)
b = torch.randn(4, 5, requires_grad=True)

c = tropical_maxplus_matmul(a, b)
c.sum().backward()

print(f"grad_a: {a.grad.shape}")  # (3, 4)
print(f"grad_b: {b.grad.shape}")  # (4, 5)

CUDA Setup

For GPU acceleration, ensure CUDA is properly installed:

# Check CUDA installation
nvcc --version

# If not found, install CUDA toolkit
# Ubuntu:
sudo apt install nvidia-cuda-toolkit

# Or download from NVIDIA:
# https://developer.nvidia.com/cuda-downloads

The CUDA kernels are compiled at runtime using NVRTC, so you don’t need to compile the library with a specific CUDA version.

Building Python Package with CUDA

cd crates/tropical-gemm-python

# Build with CUDA feature
maturin develop --features cuda

# Or for release
maturin build --release --features cuda

Building from Source

# Clone
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm

# Build all crates
cargo build --release --workspace

# Run tests
cargo test --workspace

# Build documentation
cargo doc --workspace --no-deps --open

Using the Makefile

A Makefile is provided for common tasks:

make help          # Show all targets
make setup         # Setup development environment
make build         # Build in release mode
make test          # Run all tests
make docs          # Build documentation
make bench         # Run benchmarks

Quick Start

This guide shows the basics of tropical matrix multiplication.

Basic Matrix Multiplication

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MatRef, MaxPlus};

// Create matrices from raw data (row-major order)
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 matrix
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2 matrix

// Create matrix views
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

// Multiply using operator
let c = &a * &b;

// Or using method
let c = a.matmul(&b);

// Access result
println!("C[0,0] = {}", c.get_value(0, 0)); // 8.0 = max(1+1, 2+3, 3+5)
}

Understanding the Result

For MaxPlus semiring, the multiplication computes:

C[i,j] = max_k(A[i,k] + B[k,j])

For the example above:

  • C[0,0] = max(1+1, 2+3, 3+5) = max(2, 5, 8) = 8
  • C[0,1] = max(1+2, 2+4, 3+6) = max(3, 6, 9) = 9
  • C[1,0] = max(4+1, 5+3, 6+5) = max(5, 8, 11) = 11
  • C[1,1] = max(4+2, 5+4, 6+6) = max(6, 9, 12) = 12

Using Different Semirings

#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus, MinPlus, MaxMul};

let a_data = [1.0f32, 2.0, 3.0, 4.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0];

// MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])
let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 2, 2);
let c_maxplus = &a * &b;

// MinPlus: C[i,j] = min_k(A[i,k] + B[k,j])
let a = MatRef::<MinPlus<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MinPlus<f32>>::from_slice(&b_data, 2, 2);
let c_minplus = &a * &b;

// MaxMul: C[i,j] = max_k(A[i,k] * B[k,j])
let a = MatRef::<MaxMul<f32>>::from_slice(&a_data, 2, 2);
let b = MatRef::<MaxMul<f32>>::from_slice(&b_data, 2, 2);
let c_maxmul = &a * &b;
}

Factory Methods

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};

// Create a zero matrix (all -∞ for MaxPlus)
let zeros = Mat::<MaxPlus<f32>>::zeros(3, 3);

// Create an identity matrix (0 on diagonal, -∞ elsewhere for MaxPlus)
let identity = Mat::<MaxPlus<f32>>::identity(3);

// Create from function
let mat = Mat::<MaxPlus<f32>>::from_fn(3, 3, |i, j| {
    MaxPlus::from_scalar((i + j) as f32)
});
}

Next Steps

Semiring Types

A semiring is an algebraic structure with two operations that generalize addition and multiplication. Tropical semirings replace standard operations with max/min and addition.

Available Semirings

Type⊕ (add)⊗ (mul)ZeroOneUse Case
MaxPlus<T>max+-∞0Longest path, Viterbi
MinPlus<T>min++∞0Shortest path, Dijkstra
MaxMul<T>max×01Maximum probability
AndOrORANDfalsetrueGraph reachability

MaxPlus Semiring

The MaxPlus (or max-plus) semiring uses:

  • Addition: a ⊕ b = max(a, b)
  • Multiplication: a ⊗ b = a + b
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, TropicalSemiring};

let a = MaxPlus::from_scalar(3.0f32);
let b = MaxPlus::from_scalar(5.0f32);

// Tropical add: max(3, 5) = 5
let sum = MaxPlus::tropical_add(a, b);
assert_eq!(sum.value(), 5.0);

// Tropical mul: 3 + 5 = 8
let product = MaxPlus::tropical_mul(a, b);
assert_eq!(product.value(), 8.0);
}

Applications:

  • Longest path in graphs (Bellman-Ford with negated weights)
  • Viterbi algorithm for HMM decoding
  • Log-probability computations

MinPlus Semiring

The MinPlus (or min-plus) semiring uses:

  • Addition: a ⊕ b = min(a, b)
  • Multiplication: a ⊗ b = a + b
#![allow(unused)]
fn main() {
use tropical_gemm::{MinPlus, TropicalSemiring};

let a = MinPlus::from_scalar(3.0f32);
let b = MinPlus::from_scalar(5.0f32);

// Tropical add: min(3, 5) = 3
let sum = MinPlus::tropical_add(a, b);
assert_eq!(sum.value(), 3.0);

// Tropical mul: 3 + 5 = 8
let product = MinPlus::tropical_mul(a, b);
assert_eq!(product.value(), 8.0);
}

Applications:

  • Shortest path (Floyd-Warshall, Dijkstra)
  • Edit distance computation
  • Resource allocation

MaxMul Semiring

The MaxMul semiring uses:

  • Addition: a ⊕ b = max(a, b)
  • Multiplication: a ⊗ b = a × b
#![allow(unused)]
fn main() {
use tropical_gemm::{MaxMul, TropicalSemiring};

let a = MaxMul::from_scalar(3.0f32);
let b = MaxMul::from_scalar(5.0f32);

// Tropical add: max(3, 5) = 5
let sum = MaxMul::tropical_add(a, b);
assert_eq!(sum.value(), 5.0);

// Tropical mul: 3 × 5 = 15
let product = MaxMul::tropical_mul(a, b);
assert_eq!(product.value(), 15.0);
}

Applications:

  • Maximum probability paths (non-log space)
  • Fuzzy set operations
  • Reliability analysis

Supported Scalar Types

Each semiring supports multiple scalar types:

ScalarMaxPlusMinPlusMaxMulNotes
f32✅ SIMD✅ SIMD✅ SIMDBest performance
f64✅ SIMDHigher precision
i32Integer operations
i64Large integers

Type Aliases

For convenience, shorter type aliases are provided:

#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, MinPlus, MaxMul, AndOr};

// These are equivalent:
type A = tropical_gemm::TropicalMaxPlus<f32>;
type B = MaxPlus<f32>;  // Preferred
}

Matrix API

tropical-gemm provides a matrix API inspired by faer.

Matrix Types

TypeDescription
Mat<S>Owned matrix with heap-allocated storage
MatRef<'a, S>Immutable view into matrix data
MatMut<'a, S>Mutable view into matrix data
MatWithArgmax<S>Matrix with argmax indices for backpropagation

Creating Matrices

From Raw Data

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MatRef, MaxPlus};

// Create a view from a slice (no allocation)
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = MatRef::<MaxPlus<f32>>::from_slice(&data, 2, 3);

// Create an owned matrix from a slice (allocates)
let b = Mat::<MaxPlus<f32>>::from_row_major(&data, 2, 3);
}

Factory Methods

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};

// Zero matrix (all elements = tropical zero)
let zeros = Mat::<MaxPlus<f32>>::zeros(3, 4);

// Identity matrix (diagonal = tropical one, off-diagonal = tropical zero)
let identity = Mat::<MaxPlus<f32>>::identity(3);

// From function
let mat = Mat::<MaxPlus<f32>>::from_fn(3, 3, |i, j| {
    MaxPlus::from_scalar((i * 3 + j) as f32)
});
}

Matrix Multiplication

Operator Syntax

#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};

let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

// Multiply using operators
let c = &a * &b;  // Returns Mat<S>
}

Method Syntax

#![allow(unused)]
fn main() {
let c = a.matmul(&b);
}

Accessing Elements

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};

let data = [1.0f32, 2.0, 3.0, 4.0];
let mat = Mat::<MaxPlus<f32>>::from_row_major(&data, 2, 2);

// Get the underlying scalar value
let value = mat.get_value(0, 1);  // 2.0

// Get the tropical element
let elem = mat[(0, 1)];  // MaxPlus(2.0)

// Dimensions
let (rows, cols) = (mat.nrows(), mat.ncols());
}

Argmax Tracking

For backpropagation, track which k produced each output:

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};

let a = Mat::<MaxPlus<f64>>::from_row_major(
    &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3
);
let b = Mat::<MaxPlus<f64>>::from_row_major(
    &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2
);

let result = a.matmul_argmax(&b);

// Get value and argmax
let value = result.get_value(0, 0);    // 8.0
let k_idx = result.get_argmax(0, 0);   // 2

// Compute gradients
let grad_c = vec![1.0f64; 4];  // upstream gradient (m × n)
let grad_a = result.backward_a(&grad_c);
let grad_b = result.backward_b(&grad_c);
}

Batched Operations

Process multiple matrices in parallel:

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};

let a_batch = vec![
    Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
    Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2),
];
let b_batch = vec![
    Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
    Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2),
];

// Batched matmul (parallel by default)
let c_batch = Mat::matmul_batched(&a_batch, &b_batch);

// With argmax
let results = Mat::matmul_batched_with_argmax(&a_batch, &b_batch);
}

GPU Acceleration

tropical-gemm-cuda provides NVIDIA GPU acceleration via CUDA.

Requirements

  • NVIDIA GPU (compute capability 3.5+)
  • CUDA Toolkit 11.0 or later
  • nvcc in PATH

Basic Usage

use tropical_gemm::{MatRef, MaxPlus};
use tropical_gemm_cuda::{CudaContext, GpuMat};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Create CUDA context (compiles kernels on first use)
    let ctx = CudaContext::new()?;

    // Prepare CPU data
    let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
    let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

    let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
    let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

    // Upload to GPU
    let a_gpu = GpuMat::from_matref(&ctx, &a)?;
    let b_gpu = GpuMat::from_matref(&ctx, &b)?;

    // Compute on GPU
    let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;

    // Download result
    let c = c_gpu.to_mat(&ctx)?;

    println!("C[0,0] = {}", c.get_value(0, 0));
    Ok(())
}

Context Reuse

The CudaContext compiles CUDA kernels on first use. Always reuse contexts to avoid repeated compilation:

#![allow(unused)]
fn main() {
// GOOD: Reuse context
let ctx = CudaContext::new()?;
for _ in 0..100 {
    let c = a_gpu.matmul(&ctx, &b_gpu)?;
}

// BAD: Creates new context each iteration
for _ in 0..100 {
    let ctx = CudaContext::new()?;  // Slow!
    let c = a_gpu.matmul(&ctx, &b_gpu)?;
}
}

GPU Argmax

For backpropagation with GPU computation:

#![allow(unused)]
fn main() {
let ctx = CudaContext::new()?;
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
let b_gpu = GpuMat::from_matref(&ctx, &b)?;

// Forward pass with argmax tracking
let result = a_gpu.matmul_argmax(&ctx, &b_gpu)?;

// Download values and argmax
let result_cpu = result.to_mat_with_argmax(&ctx)?;
let value = result_cpu.get_value(0, 0);
let k_idx = result_cpu.get_argmax(0, 0);

// Backward pass on GPU
let grad_c_gpu = GpuMat::from_matref(&ctx, &grad_c)?;
let grad_a_gpu = result.backward_a(&ctx, &grad_c_gpu)?;
let grad_b_gpu = result.backward_b(&ctx, &grad_c_gpu)?;
}

Batched GPU Operations

Process multiple matrices efficiently:

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};
use tropical_gemm_cuda::{CudaContext, GpuMat};

let ctx = CudaContext::new()?;

// Upload batch to GPU
let a_batch: Vec<Mat<MaxPlus<f32>>> = /* ... */;
let b_batch: Vec<Mat<MaxPlus<f32>>> = /* ... */;

let a_gpu_batch = GpuMat::from_mats(&ctx, &a_batch)?;
let b_gpu_batch = GpuMat::from_mats(&ctx, &b_batch)?;

// Batched multiply
let c_gpu_batch = GpuMat::matmul_batched(&ctx, &a_gpu_batch, &b_gpu_batch)?;

// Download results
let c_batch = GpuMat::to_mats(&ctx, &c_gpu_batch)?;
}

One-Shot API

For simple cases without context reuse:

#![allow(unused)]
fn main() {
use tropical_gemm::TropicalMaxPlus;
use tropical_gemm_cuda::tropical_matmul_gpu;

let a = vec![1.0f32; 64 * 64];
let b = vec![1.0f32; 64 * 64];

// One-shot GPU multiplication (creates temporary context)
let c = tropical_matmul_gpu::<TropicalMaxPlus<f32>>(&a, 64, 64, &b, 64)?;
}

Performance Comparison

SizeCPU SIMDGPUSpeedup
2564.1 ms0.032 ms128x
51232.8 ms0.086 ms381x
1024262.3 ms0.358 ms733x
20482091.6 ms2.510 ms833x

GPU becomes advantageous for matrices larger than ~256×256.

PyTorch Integration

tropical-gemm provides Python bindings with full PyTorch autograd support.

Installation

# From PyPI
pip install tropical-gemm

# With PyTorch support (recommended)
pip install tropical-gemm[torch]

# For GPU support (requires CUDA toolkit)
pip install maturin
git clone https://github.com/TensorBFS/tropical-gemm
cd tropical-gemm/crates/tropical-gemm-python
maturin develop --features cuda

Basic NumPy Usage

import numpy as np
import tropical_gemm

a = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
b = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)

# MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])
c = tropical_gemm.maxplus_matmul(a, b)

# MinPlus: C[i,j] = min_k(A[i,k] + B[k,j])
c = tropical_gemm.minplus_matmul(a, b)

# MaxMul: C[i,j] = max_k(A[i,k] * B[k,j])
c = tropical_gemm.maxmul_matmul(a, b)

# With argmax tracking for backpropagation
c, argmax = tropical_gemm.maxplus_matmul_with_argmax(a, b)

The tropical_gemm.pytorch module provides pre-built autograd functions:

import torch
from tropical_gemm.pytorch import (
    # CPU operations
    tropical_maxplus_matmul,
    tropical_minplus_matmul,
    tropical_maxmul_matmul,
    # GPU operations (requires CUDA)
    tropical_maxplus_matmul_gpu,
    tropical_minplus_matmul_gpu,
    tropical_maxmul_matmul_gpu,
    # Check GPU availability
    GPU_AVAILABLE,
)

# Create tensors with gradient tracking
a = torch.randn(100, 50, requires_grad=True)
b = torch.randn(50, 80, requires_grad=True)

# Forward pass - compute tropical matmul
c = tropical_maxplus_matmul(a, b)

# Backward pass - gradients computed automatically
loss = c.sum()
loss.backward()

print(f"grad_a shape: {a.grad.shape}")  # (100, 50)
print(f"grad_b shape: {b.grad.shape}")  # (50, 80)

GPU Acceleration

For larger matrices, use GPU-accelerated functions:

if GPU_AVAILABLE:
    a = torch.randn(1024, 512, requires_grad=True)
    b = torch.randn(512, 1024, requires_grad=True)

    c = tropical_maxplus_matmul_gpu(a, b)
    loss = c.sum()
    loss.backward()  # Gradients still work!

Available Functions

CPU FunctionGPU FunctionOperation
tropical_maxplus_matmultropical_maxplus_matmul_gpumax_k(A[i,k] + B[k,j])
tropical_minplus_matmultropical_minplus_matmul_gpumin_k(A[i,k] + B[k,j])
tropical_maxmul_matmultropical_maxmul_matmul_gpumax_k(A[i,k] * B[k,j])

Training Example

import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul

# Learnable parameters
a = torch.randn(64, 128, requires_grad=True)
b = torch.randn(128, 32, requires_grad=True)
target = torch.randn(64, 32)

optimizer = torch.optim.Adam([a, b], lr=0.1)

for step in range(100):
    optimizer.zero_grad()

    # Forward - tropical matmul
    c = tropical_maxplus_matmul(a, b)

    # Loss
    loss = ((c - target) ** 2).mean()

    # Backward - gradients flow through tropical operation
    loss.backward()

    # Update parameters
    optimizer.step()

    if step % 20 == 0:
        print(f"Step {step}: loss = {loss.item():.4f}")

Gradient Semantics

The gradient computation depends on the semiring type:

MaxPlus / MinPlus (Additive Rule)

For C[i,j] = max_k(A[i,k] + B[k,j]), let k* = argmax_k(A[i,k] + B[k,j]):

  • grad_A[i,k*] += grad_C[i,j]
  • grad_B[k*,j] += grad_C[i,j]

The gradient is sparse - only the winning index contributes.

MaxMul (Multiplicative Rule)

For C[i,j] = max_k(A[i,k] * B[k,j]), let k* = argmax_k(A[i,k] * B[k,j]):

  • grad_A[i,k*] += grad_C[i,j] * B[k*,j]
  • grad_B[k*,j] += grad_C[i,j] * A[i,k*]
SemiringForwardBackward Rule
MaxPlusmax_k(A + B)∂C/∂A = 1 at argmax
MinPlusmin_k(A + B)∂C/∂A = 1 at argmin
MaxMulmax_k(A × B)∂C/∂A = B at argmax

Graph Algorithms

Shortest Path (MinPlus)

import torch
from tropical_gemm.pytorch import tropical_minplus_matmul

# Adjacency matrix (inf = no edge)
inf = float("inf")
adj = torch.tensor([
    [0.0, 1.0, inf, 4.0],
    [inf, 0.0, 2.0, inf],
    [inf, inf, 0.0, 1.0],
    [inf, inf, inf, 0.0],
])

# 2-hop shortest paths
two_hop = tropical_minplus_matmul(adj, adj)

# 3-hop shortest paths
three_hop = tropical_minplus_matmul(two_hop, adj)

Longest Path (MaxPlus)

import torch
from tropical_gemm.pytorch import tropical_maxplus_matmul

# Edge weights for critical path analysis
neg_inf = float("-inf")
adj = torch.tensor([
    [0.0, 3.0, 2.0, neg_inf],
    [neg_inf, 0.0, neg_inf, 4.0],
    [neg_inf, neg_inf, 0.0, 5.0],
    [neg_inf, neg_inf, neg_inf, 0.0],
])

# 2-hop longest paths
two_hop = tropical_maxplus_matmul(adj, adj)

Custom Autograd Function (Advanced)

If you need custom behavior, you can still define your own autograd function:

import torch
import numpy as np
import tropical_gemm

class TropicalMaxPlusMatmul(torch.autograd.Function):
    """Custom differentiable MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])"""

    @staticmethod
    def forward(ctx, a, b):
        m, k = a.shape
        n = b.shape[1]

        # Convert to NumPy
        a_np = a.detach().cpu().numpy().astype(np.float32)
        b_np = b.detach().cpu().numpy().astype(np.float32)

        if not a_np.flags["C_CONTIGUOUS"]:
            a_np = np.ascontiguousarray(a_np)
        if not b_np.flags["C_CONTIGUOUS"]:
            b_np = np.ascontiguousarray(b_np)

        # Forward pass with argmax tracking
        c_flat, argmax_flat = tropical_gemm.maxplus_matmul_with_argmax(a_np, b_np)
        c_np = np.asarray(c_flat).reshape(m, n)  # zero-copy when possible
        argmax_np = np.asarray(argmax_flat).reshape(m, n)

        # Save for backward
        ctx.save_for_backward(torch.from_numpy(argmax_np))
        ctx.k, ctx.m, ctx.n = k, m, n

        return torch.from_numpy(c_np).to(a.device)

    @staticmethod
    def backward(ctx, grad_c):
        argmax, = ctx.saved_tensors
        k, m, n = ctx.k, ctx.m, ctx.n

        grad_c_np = grad_c.cpu().numpy().astype(np.float32)
        argmax_np = argmax.numpy().astype(np.int32)

        if not grad_c_np.flags["C_CONTIGUOUS"]:
            grad_c_np = np.ascontiguousarray(grad_c_np)

        # Backward pass
        grad_a_flat = tropical_gemm.backward_a(grad_c_np, argmax_np, k)
        grad_b_flat = tropical_gemm.backward_b(grad_c_np, argmax_np, k)

        grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k)).to(grad_c.device)
        grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k, n)).to(grad_c.device)

        return grad_a, grad_b

Complete Example

See crates/tropical-gemm-python/examples/pytorch_tropical.py for:

  • Gradient verification tests
  • Shortest/longest path examples
  • Optimization demos
  • GPU benchmarks

Architecture

This section describes the internal architecture of tropical-gemm.

Overview

tropical-gemm achieves high performance through:

  1. BLIS-style blocking for cache efficiency
  2. SIMD microkernels for vectorization
  3. Runtime dispatch for optimal kernel selection
  4. CUDA kernels for GPU acceleration

Crate Structure

tropical-gemm/
├── src/
│   ├── lib.rs          # Public API
│   ├── api.rs          # Function-based API
│   ├── types/          # Semiring definitions
│   │   ├── traits.rs   # TropicalSemiring trait
│   │   ├── max_plus.rs
│   │   ├── min_plus.rs
│   │   └── max_mul.rs
│   ├── core/           # BLIS algorithm
│   │   ├── gemm.rs     # 5-loop blocking
│   │   ├── kernel.rs   # Microkernel trait
│   │   ├── packing.rs  # Matrix packing
│   │   └── tiling.rs   # Cache parameters
│   ├── simd/           # SIMD kernels
│   │   ├── dispatch.rs # Runtime selection
│   │   ├── detect.rs   # CPU detection
│   │   └── kernels/    # Per-architecture
│   └── mat/            # Matrix types

tropical-gemm-cuda/
├── src/
│   ├── lib.rs          # Public API
│   ├── context.rs      # CUDA context
│   ├── kernels.rs      # Kernel management
│   └── gpu_mat.rs      # GPU matrix type
└── kernels/
    └── tropical_gemm.cu  # CUDA source

Performance Layers

┌─────────────────────────────────────────────────────────────┐
│                    User API (Mat, MatRef)                   │
├─────────────────────────────────────────────────────────────┤
│                  Function API (tropical_matmul)             │
├─────────────────────────────────────────────────────────────┤
│               SIMD Dispatch (KernelDispatch)                │
├─────────────────────────────────────────────────────────────┤
│           BLIS 5-Loop Blocking (tropical_gemm_inner)        │
├─────────────────────────────────────────────────────────────┤
│                    SIMD Microkernel                         │
│            (AVX2 / AVX-512 / NEON / Portable)               │
└─────────────────────────────────────────────────────────────┘

Key Design Decisions

1. Semiring as Type Parameter

Operations are generic over the semiring type, enabling compile-time specialization:

#![allow(unused)]
fn main() {
pub fn tropical_matmul<S: TropicalSemiring>(
    a: &[S::Scalar], m: usize, k: usize,
    b: &[S::Scalar], n: usize
) -> Vec<S>
}

2. Scalar vs Semiring Types

  • Input: Raw scalar data (&[f32], &[f64])
  • Output: Semiring-wrapped values (Vec<MaxPlus<f32>>)

This avoids unnecessary wrapping in hot paths.

3. Runtime SIMD Dispatch

CPU features are detected at runtime, not compile time:

#![allow(unused)]
fn main() {
match simd_level() {
    SimdLevel::Avx512 => avx512_kernel(...),
    SimdLevel::Avx2   => avx2_kernel(...),
    _                 => portable_kernel(...),
}
}

4. CUDA Runtime Compilation

Kernels are compiled from CUDA C source at runtime via NVRTC:

  • No compile-time CUDA dependency
  • Portability across CUDA versions
  • Template-like specialization via macros

BLIS Algorithm

The CPU implementation uses BLIS-style cache blocking for optimal performance.

5-Loop Blocking

Matrix multiplication is blocked into tiles that fit in cache:

┌──────────────────────────────────────────────────────────────────────────┐
│ Loop 5: for jc in 0..N step NC    (L3 cache - columns of B)             │
│   Loop 4: for pc in 0..K step KC  (L2 cache - depth)                    │
│     Pack B[pc:KC, jc:NC] → B̃  (contiguous in L3)                        │
│     Loop 3: for ic in 0..M step MC  (L1 cache - rows of A)              │
│       Pack A[ic:MC, pc:KC] → Ã  (contiguous in L2)                      │
│       Loop 2: for jr in 0..NC step NR  (register blocking)              │
│         Loop 1: for ir in 0..MC step MR  (microkernel)                  │
│           microkernel(Ã[ir], B̃[jr], C[ic+ir, jc+jr])                    │
└──────────────────────────────────────────────────────────────────────────┘

Cache Tiling Parameters

ParameterDescriptionf32 AVX2f64 AVX2Portable
MCRows per L2 block25612864
NCColumns per L3 block25612864
KCDepth per block512256256
MRMicrokernel rows844
NRMicrokernel columns844

Parameters are tuned to fit in cache:

  • MC × KC fits in L2 cache
  • KC × NC fits in L3 cache
  • MR × NR fits in registers

Packing

Before computation, matrices are packed into contiguous buffers:

Pack A (MC × KC block)

Original layout (row-major):

A[0,0] A[0,1] A[0,2] ...
A[1,0] A[1,1] A[1,2] ...
...

Packed layout (MR-contiguous panels):

A[0,0] A[1,0] ... A[MR-1,0]   // First column of first panel
A[0,1] A[1,1] ... A[MR-1,1]   // Second column of first panel
...
A[MR,0] A[MR+1,0] ...         // First column of second panel

Pack B (KC × NC block)

Packed into NR-wide panels for broadcasting:

B[0,0] B[0,1] ... B[0,NR-1]   // First row of first panel
B[1,0] B[1,1] ... B[1,NR-1]   // Second row of first panel
...

Benefits

  1. Sequential access: Packed data is accessed linearly
  2. Cache reuse: Each block is loaded once, used many times
  3. TLB efficiency: Fewer page table lookups
  4. SIMD friendly: Contiguous data enables vectorization

Code Location

  • core/gemm.rs: Main blocking loops
  • core/packing.rs: Pack functions
  • core/tiling.rs: TilingParams struct
  • core/kernel.rs: Microkernel trait

SIMD Kernels

The microkernel is vectorized using SIMD instructions for maximum throughput.

Supported Architectures

ArchitectureInstruction SetVector Widthf32 MR×NRf64 MR×NR
x86_64AVX-512512-bit16×168×8
x86_64AVX2256-bit8×84×4
x86_64SSE4.1128-bit4×42×2
aarch64NEON128-bit4×42×2
AnyPortableScalar4×44×4

Runtime Detection

CPU features are detected at runtime:

#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};

match simd_level() {
    SimdLevel::Avx512 => println!("Using AVX-512"),
    SimdLevel::Avx2   => println!("Using AVX2"),
    SimdLevel::Sse41  => println!("Using SSE4.1"),
    SimdLevel::Neon   => println!("Using NEON"),
    SimdLevel::None   => println!("Using portable"),
}
}

Microkernel Design

For MaxPlus f32 with AVX2 (8-wide vectors):

#![allow(unused)]
fn main() {
// Pseudocode for 8×8 microkernel
for k in 0..KC {
    // Load 8 elements from packed A
    let a_vec = _mm256_loadu_ps(a_ptr);

    // For each column in the 8-column output tile
    for j in 0..8 {
        // Broadcast scalar from packed B
        let b_scalar = _mm256_broadcast_ss(b_ptr + j);

        // Tropical multiply: a + b (element-wise)
        let prod = _mm256_add_ps(a_vec, b_scalar);

        // Tropical accumulate: max(c, prod)
        c_vec[j] = _mm256_max_ps(c_vec[j], prod);
    }

    a_ptr += 8;  // Next column in packed A
    b_ptr += 8;  // Next row in packed B
}
}

Semiring-Specific Operations

SemiringTropical MulTropical Add
MaxPlus_mm256_add_ps_mm256_max_ps
MinPlus_mm256_add_ps_mm256_min_ps
MaxMul_mm256_mul_ps_mm256_max_ps

Dispatch Mechanism

The KernelDispatch trait routes to the appropriate implementation:

#![allow(unused)]
fn main() {
impl KernelDispatch for TropicalMaxPlus<f32> {
    unsafe fn dispatch_gemm(...) {
        match simd_level() {
            SimdLevel::Avx2 | SimdLevel::Avx512 => {
                tropical_gemm_inner::<Self, Avx2MaxPlusF32>(...);
            }
            _ => {
                tropical_gemm_inner::<Self, PortableMicrokernel>(...);
            }
        }
    }
}
}

Code Location

  • simd/detect.rs: CPU feature detection
  • simd/dispatch.rs: Runtime dispatch trait
  • simd/kernels/avx2.rs: AVX2 implementations
  • simd/kernels/neon.rs: NEON implementations
  • simd/kernels/portable.rs: Fallback implementation

CUDA Implementation

The GPU backend uses CUDA with runtime kernel compilation.

Architecture

┌─────────────────────────────────────────────────────────────┐
│                      User API                               │
│           (GpuMat::matmul, tropical_matmul_gpu)             │
├─────────────────────────────────────────────────────────────┤
│                    CudaContext                              │
│         (kernel compilation, device management)             │
├─────────────────────────────────────────────────────────────┤
│                      NVRTC                                  │
│           (runtime kernel compilation)                      │
├─────────────────────────────────────────────────────────────┤
│                   CUDA Kernels                              │
│       (tropical_gemm.cu, specialized per semiring)          │
└─────────────────────────────────────────────────────────────┘

Runtime Compilation

Kernels are compiled from CUDA C source at runtime using NVRTC:

#![allow(unused)]
fn main() {
// On first CudaContext::new()
let ctx = CudaContext::new()?;  // Compiles kernels (~1-2 seconds)

// Subsequent operations are fast
let c = a_gpu.matmul(&ctx, &b_gpu)?;  // Just kernel launch
}

Benefits:

  • No build-time CUDA dependency: Users don’t need nvcc at build time
  • Portability: Works across CUDA versions
  • Specialization: Kernels optimized for specific semirings

Kernel Design

Thread Block Organization

Block size: 16×16 threads (256 threads per block)
Grid: ceil(M/16) × ceil(N/16) blocks

Each thread computes one output element C[i,j]

Memory Access Pattern

__global__ void tropical_maxplus_gemm(
    const float* A, const float* B, float* C,
    int M, int N, int K
) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < M && col < N) {
        float max_val = -INFINITY;
        for (int k = 0; k < K; k++) {
            float sum = A[row * K + k] + B[k * N + col];
            max_val = fmaxf(max_val, sum);
        }
        C[row * N + col] = max_val;
    }
}

Shared Memory Tiling

For larger matrices, shared memory is used:

__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];

// Load tiles cooperatively
As[ty][tx] = A[row * K + (tile * TILE_SIZE + tx)];
Bs[ty][tx] = B[(tile * TILE_SIZE + ty) * N + col];
__syncthreads();

// Compute partial result from tile
for (int k = 0; k < TILE_SIZE; k++) {
    max_val = fmaxf(max_val, As[ty][k] + Bs[k][tx]);
}

Argmax Kernels

For backpropagation, kernels track which k index achieved the max:

__global__ void tropical_maxplus_gemm_argmax(
    const float* A, const float* B,
    float* C, int* argmax,
    int M, int N, int K
) {
    // ... setup ...

    float max_val = -INFINITY;
    int max_k = 0;

    for (int k = 0; k < K; k++) {
        float sum = A[row * K + k] + B[k * N + col];
        if (sum > max_val) {
            max_val = sum;
            max_k = k;
        }
    }

    C[row * N + col] = max_val;
    argmax[row * N + col] = max_k;
}

Batched Kernels

For processing multiple matrices:

// Strided batched: matrices stored contiguously
__global__ void tropical_maxplus_gemm_batched(
    const float* A, const float* B, float* C,
    int M, int N, int K, int batch_count,
    int stride_a, int stride_b, int stride_c
) {
    int batch = blockIdx.z;
    // ... standard GEMM with offset by batch * stride ...
}

Memory Management

Device Memory Allocation

#![allow(unused)]
fn main() {
// Allocate GPU memory
let d_ptr = cuda_malloc(size_bytes)?;

// Copy host → device
cuda_memcpy_h2d(d_ptr, h_data, size_bytes)?;

// Copy device → host
cuda_memcpy_d2h(h_data, d_ptr, size_bytes)?;

// Free
cuda_free(d_ptr)?;
}

Pinned Memory (for faster transfers)

#![allow(unused)]
fn main() {
// For frequent CPU↔GPU transfers, use pinned memory
let pinned = cuda_malloc_host(size_bytes)?;
// ... 2-3x faster transfers ...
cuda_free_host(pinned)?;
}

Error Handling

CUDA errors are wrapped in Rust Result types:

#![allow(unused)]
fn main() {
match CudaContext::new() {
    Ok(ctx) => { /* use context */ }
    Err(CudaError::NoDevice) => {
        println!("No CUDA device found, using CPU");
    }
    Err(CudaError::CompilationFailed(msg)) => {
        eprintln!("Kernel compilation failed: {}", msg);
    }
    Err(e) => return Err(e.into()),
}
}

Code Location

  • tropical-gemm-cuda/src/context.rs: CUDA context and compilation
  • tropical-gemm-cuda/src/gpu_mat.rs: GPU matrix type
  • tropical-gemm-cuda/src/kernels.rs: Kernel management
  • tropical-gemm-cuda/kernels/tropical_gemm.cu: CUDA kernel source

Performance Guide

This guide helps you get the best performance from tropical-gemm.

CPU vs GPU Selection

Matrix SizeRecommendationReason
< 128×128CPUGPU transfer overhead dominates
128-256CPU or GPUSimilar performance
> 256×256GPUGPU computation advantage
> 1024×1024GPU (strongly)100-800x speedup

Benchmark Results (MaxPlus f32)

Tested on NVIDIA RTX A4500 (Ampere) with AMD Ryzen 9 5900X.

SizeCPU AVX2GPUGPU Speedup
640.05 ms0.02 ms2.5x
1280.4 ms0.02 ms20x
2564.1 ms0.03 ms137x
51232.8 ms0.09 ms364x
1024262 ms0.36 ms728x
20482092 ms2.5 ms837x

Rust CUDA vs C Reference

Comparison with TropicalGemm_Cuda:

SizeC Library (ms)Rust CUDA (ms)Ratio
2560.0280.0321.14x
5120.0740.0861.16x
10240.3150.3581.14x
20482.2242.5091.13x

The C library is ~13-16% faster due to pre-compiled PTX vs runtime compilation.

GPU Backward Pass Performance

SizeForward (ms)Backward A (ms)Backward B (ms)
2560.0320.0180.018
5120.0860.0520.052
10240.3580.1830.184
20482.5101.3121.315

CPU Optimization

SIMD Detection

Ensure optimal SIMD is being used:

#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};

match simd_level() {
    SimdLevel::Avx512 => println!("Best: AVX-512"),
    SimdLevel::Avx2 => println!("Good: AVX2"),
    SimdLevel::Sse41 => println!("Okay: SSE4.1"),
    SimdLevel::Neon => println!("ARM: NEON"),
    SimdLevel::None => println!("Slow: Portable fallback"),
}
}

Memory Layout

Row-major contiguous data is fastest:

#![allow(unused)]
fn main() {
// GOOD: Contiguous row-major
let a = Mat::<MaxPlus<f32>>::from_fn(m, k, |i, j| data[i * k + j]);

// SLOWER: Non-contiguous requires packing overhead
let a_ref = MatRef::from_slice_strided(&data, m, k, stride);
}

Cache Efficiency

For best cache utilization:

  • Square matrices: Optimal blocking
  • Tall-skinny (M >> K): Good cache reuse for A
  • Short-wide (K >> M): May have cache pressure

GPU Optimization

Context Reuse

Critical: Reuse CudaContext to avoid repeated kernel compilation:

#![allow(unused)]
fn main() {
// GOOD: Create once, reuse many times
let ctx = CudaContext::new()?;  // ~1-2 seconds
for batch in batches {
    let c = a.matmul(&ctx, &b)?;  // Fast
}

// BAD: Creates new context each time
for batch in batches {
    let ctx = CudaContext::new()?;  // Slow!
    let c = a.matmul(&ctx, &b)?;
}
}

Batched Operations

For multiple matrix multiplications, use batched API:

#![allow(unused)]
fn main() {
// GOOD: Single kernel launch for all matrices
let c_batch = GpuMat::matmul_batched(&ctx, &a_batch, &b_batch)?;

// SLOWER: Sequential kernel launches
let c_batch: Vec<_> = a_batch.iter()
    .zip(&b_batch)
    .map(|(a, b)| a.matmul(&ctx, b))
    .collect();
}

Memory Transfer

Minimize CPU↔GPU transfers:

#![allow(unused)]
fn main() {
// GOOD: Keep data on GPU between operations
let a_gpu = GpuMat::from_matref(&ctx, &a)?;
let b_gpu = GpuMat::from_matref(&ctx, &b)?;

// Multiple operations without transfer
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
let d_gpu = c_gpu.matmul(&ctx, &b_gpu)?;
let e_gpu = d_gpu.matmul(&ctx, &b_gpu)?;

// Only transfer final result
let e = e_gpu.to_mat(&ctx)?;

// BAD: Transfer for each operation
for i in 0..3 {
    let a_gpu = GpuMat::from_matref(&ctx, &a)?;  // Upload
    let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;
    let c = c_gpu.to_mat(&ctx)?;  // Download
    a = c;  // Use result for next iteration
}
}

PyTorch Training

Keep Context Alive

# Create context once at module initialization
class TropicalLayer(nn.Module):
    def __init__(self):
        super().__init__()
        # Context created once
        self.ctx = tropical_gemm.CudaContext()

    def forward(self, a, b):
        # Reuse context
        return tropical_matmul_gpu(self.ctx, a, b)

Batch Your Data

# GOOD: Large batch, single kernel
output = tropical_matmul(large_batch_a, large_batch_b)

# SLOWER: Many small operations
outputs = [tropical_matmul(a, b) for a, b in zip(small_as, small_bs)]

Python Threading

GIL Release During Compute

All CPU functions release Python’s GIL during heavy computation, allowing other Python threads to run concurrently:

import threading
import tropical_gemm
import numpy as np

def background_task():
    # This can run while tropical_gemm computes
    print("Background task running")

a = np.random.randn(1000, 1000).astype(np.float32)
b = np.random.randn(1000, 1000).astype(np.float32)

# Start background thread
t = threading.Thread(target=background_task)
t.start()

# GIL is released during compute - background thread can run
c = tropical_gemm.maxplus_matmul(a, b)

t.join()

This is particularly useful in:

  • Web servers (Flask, FastAPI) handling concurrent requests
  • GUI applications that need to remain responsive
  • Async applications using concurrent.futures

Zero-Copy with 2D Functions

The *_matmul_2d functions return properly shaped 2D arrays without reshaping overhead:

# Recommended: Use 2D functions for cleaner code
c = tropical_gemm.maxplus_matmul_2d(a, b)  # shape: (m, n)

# Older pattern requiring reshape
c_flat = tropical_gemm.maxplus_matmul(a, b)  # shape: (m*n,)
c = c_flat.reshape(m, n)

Memory Considerations

Argmax Memory

With argmax tracking, memory usage increases:

OperationMemory per element
Standard GEMM4 bytes (f32)
With argmax8 bytes (f32 + i32)

For large matrices, this can be significant:

  • 4096×4096 standard: 64 MB
  • 4096×4096 with argmax: 128 MB

GPU Memory

Check available GPU memory:

#![allow(unused)]
fn main() {
let (free, total) = cuda_mem_info()?;
println!("GPU memory: {} MB free / {} MB total",
    free / 1024 / 1024,
    total / 1024 / 1024);
}

Profiling

CPU Profiling

# Linux perf
perf record --call-graph dwarf ./target/release/benchmark
perf report

# Flamegraph
cargo install flamegraph
cargo flamegraph --bin benchmark

GPU Profiling

# NVIDIA Nsight
nsys profile ./target/release/gpu_benchmark
nsys-ui report.nsys-rep

# nvprof (older)
nvprof ./target/release/gpu_benchmark

Troubleshooting Performance

Unexpectedly Slow CPU

  1. Check SIMD level (should be AVX2 or better on modern x86)
  2. Ensure data is contiguous (avoid strided access)
  3. Check for memory pressure (matrix too large for cache)

Unexpectedly Slow GPU

  1. Verify context reuse (compilation is slow)
  2. Check transfer overhead (small matrices dominated by transfer)
  3. Ensure sufficient GPU memory (avoid swapping)
  4. Use batched API for multiple matrices

Running Benchmarks

# CPU benchmark
cargo run --release --example bench_rust -p tropical-gemm

# CUDA vs CPU benchmark
cargo run --release --example bench_cuda_vs_cpu -p tropical-gemm-cuda

# GPU backward pass benchmark
cargo run --release --example bench_backward -p tropical-gemm-cuda

Or use the Makefile:

make bench          # Run all benchmarks
make bench-cpu      # CPU only
make bench-cuda     # CUDA only

Troubleshooting

Common issues and solutions for tropical-gemm.

Installation Issues

Rust Compilation Errors

Error: “missing SIMD intrinsics”

error[E0433]: failed to resolve: use of undeclared crate or module `core_arch`

Solution: Update Rust to latest stable:

rustup update stable

Error: “target feature avx2 is not enabled”

This is expected on non-x86 platforms. The portable fallback will be used automatically.

CUDA Issues

Error: “CUDA driver not found”

CudaError: CUDA driver version is insufficient

Solution:

  1. Install/update NVIDIA drivers
  2. Verify with nvidia-smi
  3. Install CUDA Toolkit

Error: “nvcc not found”

CudaError: Failed to compile kernel: nvcc not found

Solution:

# Add CUDA to PATH
export PATH=/usr/local/cuda/bin:$PATH

# Verify
nvcc --version

Error: “Kernel compilation failed”

CudaError: CompilationFailed: ...

Solution:

  1. Check CUDA version compatibility (requires 11.0+)
  2. Ensure CUDA headers are installed
  3. Try reinstalling CUDA Toolkit

Python Binding Issues

Error: “module ‘tropical_gemm’ not found”

>>> import tropical_gemm
ModuleNotFoundError: No module named 'tropical_gemm'

Solution:

cd crates/tropical-gemm-python
pip install maturin
maturin develop --release

Error: “symbol not found in flat namespace” (macOS)

ImportError: dlopen(...): symbol not found in flat namespace

Solution: Rebuild with correct Python version:

# Ensure using correct Python
which python
python --version

# Rebuild
maturin develop --release

Error: “dtype mismatch”

TypeError: Expected float32 array, got float64

Solution: Explicitly cast to float32:

import numpy as np
a = a.astype(np.float32)
b = b.astype(np.float32)
c = tropical_gemm.maxplus_matmul(a, b)

Runtime Issues

Incorrect Results

Symptom: All outputs are -inf or inf

This typically means input contains NaN or inf values:

#![allow(unused)]
fn main() {
// Check for invalid values
for &x in data.iter() {
    if x.is_nan() || x.is_infinite() {
        panic!("Invalid input value: {}", x);
    }
}
}

Symptom: Results differ between CPU and GPU

Small numerical differences are expected due to floating-point associativity. For MaxPlus/MinPlus, results should be identical (only comparisons).

For MaxMul, small differences may occur:

#![allow(unused)]
fn main() {
// Allow small tolerance
let diff = (cpu_result - gpu_result).abs();
assert!(diff < 1e-5, "Results differ by {}", diff);
}

Performance Issues

Symptom: GPU slower than CPU

For small matrices, transfer overhead dominates:

#![allow(unused)]
fn main() {
// Rule of thumb: GPU beneficial for N > 256
if n < 256 {
    // Use CPU
    tropical_matmul::<MaxPlus<f32>>(&a, m, k, &b, n)
} else {
    // Use GPU
    tropical_matmul_gpu::<MaxPlus<f32>>(&a, m, k, &b, n)?
}
}

Symptom: CPU slower than expected

Check SIMD detection:

#![allow(unused)]
fn main() {
use tropical_gemm::simd_level;
println!("SIMD level: {:?}", simd_level());
// Should be Avx2 or Avx512 on modern x86
}

Memory Issues

Error: “out of memory” (GPU)

CudaError: Out of memory

Solution:

  1. Use smaller batch sizes
  2. Process matrices sequentially
  3. Free unused GPU memory
#![allow(unused)]
fn main() {
// Process in chunks
for chunk in matrices.chunks(batch_size) {
    let result = process_batch(&ctx, chunk)?;
    // Results are downloaded, GPU memory freed
}
}

Error: “allocation failed” (CPU)

Large matrices may exceed available RAM:

#![allow(unused)]
fn main() {
// Estimate memory needed
let bytes = m * n * std::mem::size_of::<f32>();
println!("Matrix requires {} MB", bytes / 1024 / 1024);
}

PyTorch Issues

Gradient Issues

Symptom: Gradients are all zeros

Check that tensors require gradients:

a = torch.randn(4, 5, requires_grad=True)  # Must be True
b = torch.randn(5, 3, requires_grad=True)

c = TropicalMaxPlusMatmul.apply(a, b)
loss = c.sum()
loss.backward()

print(a.grad)  # Should not be None

Symptom: “RuntimeError: element 0 of tensors does not require grad”

Ensure input tensors have requires_grad=True:

a = torch.tensor([[1.0, 2.0]], requires_grad=True)
# Not: a = torch.tensor([[1.0, 2.0]])  # No gradients!

Device Mismatch

Error: “Expected all tensors on same device”

# Ensure both inputs on same device
a = a.to('cuda')
b = b.to('cuda')
c = TropicalMaxPlusMatmul.apply(a, b)

Getting Help

If you encounter issues not covered here:

  1. Check GitHub issues: https://github.com/TensorBFS/tropical-gemm/issues
  2. Open a new issue with:
    • Error message
    • Rust/Python version
    • OS and hardware
    • Minimal reproduction code

Diagnostic Information

Include this in bug reports:

# Rust version
rustc --version
cargo --version

# CUDA (if applicable)
nvcc --version
nvidia-smi

# Python (if applicable)
python --version
pip show tropical_gemm

API Reference

This page provides quick reference to the main APIs.

For complete documentation, see the Rust API docs.

Crate Overview

CratePurpose
tropical-gemmCPU implementation with SIMD
tropical-gemm-cudaGPU implementation with CUDA
tropical-gemm-pythonPython bindings

Semiring Types

#![allow(unused)]
fn main() {
use tropical_gemm::{MaxPlus, MinPlus, MaxMul};
use tropical_gemm::types::{TropicalMaxPlus, TropicalMinPlus, TropicalMaxMul};

// Wrapper types (for storage)
let a: MaxPlus<f32> = MaxPlus::new(3.0);
let b: MinPlus<f64> = MinPlus::new(5.0);

// Marker types (for generic functions)
type S = TropicalMaxPlus<f32>;
}

Matrix Types

Mat (Owned)

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};

// Create from function
let a = Mat::<MaxPlus<f32>>::from_fn(m, k, |i, j| value);

// Create from scalar slice
let a = Mat::<MaxPlus<f32>>::from_scalar_slice(&data, m, k);

// Access
let val = a.get_value(i, j);  // Returns f32
let dim = a.dim();            // Returns (rows, cols)
}

MatRef (Borrowed)

#![allow(unused)]
fn main() {
use tropical_gemm::{MatRef, MaxPlus};

// From slice
let a = MatRef::<MaxPlus<f32>>::from_slice(&data, m, k);

// From Mat
let a_ref = a.as_ref();
}

MatMut (Mutable)

#![allow(unused)]
fn main() {
use tropical_gemm::MatMut;

let mut c = Mat::zeros(m, n);
let c_mut = c.as_mut();
}

Matrix Operations

High-Level API (Mat)

#![allow(unused)]
fn main() {
use tropical_gemm::{Mat, MaxPlus};

let a = Mat::<MaxPlus<f32>>::from_scalar_slice(&a_data, m, k);
let b = Mat::<MaxPlus<f32>>::from_scalar_slice(&b_data, k, n);

// Standard multiply
let c = a.matmul(&b);

// With argmax tracking
let result = a.matmul_with_argmax(&b);
let value = result.get_value(i, j);
let argmax = result.get_argmax(i, j);
}

Low-Level API (Functions)

#![allow(unused)]
fn main() {
use tropical_gemm::{tropical_matmul, tropical_matmul_with_argmax, TropicalMaxPlus};

// Standard multiply
let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, m, k, &b, n);

// With argmax
let (values, argmax) = tropical_matmul_with_argmax::<TropicalMaxPlus<f32>>(&a, m, k, &b, n);
}

GPU API

CudaContext

#![allow(unused)]
fn main() {
use tropical_gemm_cuda::CudaContext;

let ctx = CudaContext::new()?;  // Compiles kernels
}

GpuMat

#![allow(unused)]
fn main() {
use tropical_gemm_cuda::GpuMat;
use tropical_gemm::{MatRef, MaxPlus};

// Upload
let a_gpu = GpuMat::from_matref(&ctx, &a)?;

// Compute
let c_gpu = a_gpu.matmul(&ctx, &b_gpu)?;

// With argmax
let result = a_gpu.matmul_argmax(&ctx, &b_gpu)?;

// Download
let c = c_gpu.to_mat(&ctx)?;
}

Batched Operations

#![allow(unused)]
fn main() {
use tropical_gemm_cuda::GpuMat;

// Upload batch
let a_batch = GpuMat::from_mats(&ctx, &a_mats)?;
let b_batch = GpuMat::from_mats(&ctx, &b_mats)?;

// Batched multiply
let c_batch = GpuMat::matmul_batched(&ctx, &a_batch, &b_batch)?;

// Download batch
let c_mats = GpuMat::to_mats(&ctx, &c_batch)?;
}

Python API

NumPy Functions

import tropical_gemm
import numpy as np

a = np.array([[1, 2], [3, 4]], dtype=np.float32)
b = np.array([[5, 6], [7, 8]], dtype=np.float32)

# Basic operations (returns flattened 1D array)
c_flat = tropical_gemm.maxplus_matmul(a, b)
c = c_flat.reshape(a.shape[0], b.shape[1])

# 2D output (returns proper 2D array directly)
c = tropical_gemm.maxplus_matmul_2d(a, b)    # shape: (m, n)
c = tropical_gemm.minplus_matmul_2d(a, b)
c = tropical_gemm.maxmul_matmul_2d(a, b)

# With argmax
values, argmax = tropical_gemm.maxplus_matmul_with_argmax(a, b)

2D Output Functions

The *_matmul_2d variants return properly shaped 2D NumPy arrays without manual reshaping:

TypeMaxPlusMinPlusMaxMul
f32maxplus_matmul_2dminplus_matmul_2dmaxmul_matmul_2d
f64maxplus_matmul_2d_f64minplus_matmul_2d_f64maxmul_matmul_2d_f64
i32maxplus_matmul_2d_i32minplus_matmul_2d_i32maxmul_matmul_2d_i32
i64maxplus_matmul_2d_i64minplus_matmul_2d_i64maxmul_matmul_2d_i64
# f64 example
a = np.array([[1, 2], [3, 4]], dtype=np.float64)
b = np.array([[5, 6], [7, 8]], dtype=np.float64)
c = tropical_gemm.maxplus_matmul_2d_f64(a, b)  # shape: (2, 2)

# i32 example
a = np.array([[1, 2], [3, 4]], dtype=np.int32)
b = np.array([[5, 6], [7, 8]], dtype=np.int32)
c = tropical_gemm.maxplus_matmul_2d_i32(a, b)  # shape: (2, 2)

Backward Pass

# Gradient computation
grad_a = tropical_gemm.backward_a(grad_c, argmax, k)
grad_b = tropical_gemm.backward_b(grad_c, argmax, k)

Utility Functions

SIMD Detection

#![allow(unused)]
fn main() {
use tropical_gemm::{simd_level, SimdLevel};

match simd_level() {
    SimdLevel::Avx512 => { /* ... */ }
    SimdLevel::Avx2 => { /* ... */ }
    SimdLevel::Sse41 => { /* ... */ }
    SimdLevel::Neon => { /* ... */ }
    SimdLevel::None => { /* ... */ }
}
}

Type Aliases

For convenience:

#![allow(unused)]
fn main() {
// These are equivalent:
use tropical_gemm::MaxPlus;
use tropical_gemm::types::max_plus::MaxPlus;

// Marker types for generics:
use tropical_gemm::TropicalMaxPlus;  // = TropicalSemiringImpl<MaxPlusTag, T>
use tropical_gemm::TropicalMinPlus;
use tropical_gemm::TropicalMaxMul;
}

Changelog

All notable changes to tropical-gemm.

[0.2.0]

Added

  • 2D output functions: New *_matmul_2d variants that return properly shaped 2D arrays instead of flattened 1D output. Available for all semirings (maxplus, minplus, maxmul) and data types (f32, f64, i32, i64):
    • maxplus_matmul_2d, minplus_matmul_2d, maxmul_matmul_2d (f32)
    • maxplus_matmul_2d_f64, minplus_matmul_2d_f64, maxmul_matmul_2d_f64
    • maxplus_matmul_2d_i32, minplus_matmul_2d_i32, maxmul_matmul_2d_i32
    • maxplus_matmul_2d_i64, minplus_matmul_2d_i64, maxmul_matmul_2d_i64
  • mdBook documentation
  • Comprehensive architecture documentation
  • Performance tuning guide
  • Troubleshooting guide

Changed

  • GIL release during compute: All CPU functions now release Python’s GIL during heavy computation, allowing other Python threads to run concurrently. This improves performance in multi-threaded Python applications.

Fixed

  • Batched CPU path copies: Fixed unnecessary memory copies in batched PyTorch operations by using np.asarray() instead of np.array() for zero-copy array creation when possible.

[0.1.0] - Initial Release

Features

  • High-performance tropical matrix multiplication
  • Support for three semirings: MaxPlus, MinPlus, MaxMul
  • SIMD acceleration (AVX-512, AVX2, SSE4.1, NEON)
  • CUDA GPU acceleration
  • Argmax tracking for backpropagation
  • Python bindings with NumPy support
  • PyTorch autograd integration

Crates

  • tropical-gemm: Core CPU implementation
  • tropical-gemm-cuda: CUDA GPU backend
  • tropical-gemm-python: Python bindings

Performance

  • BLIS-style 5-loop cache blocking
  • Runtime SIMD dispatch
  • GPU speedup up to 800x for large matrices

Version History

VersionDateHighlights
0.1.02024Initial release

Migration Guides

From NumPy Implementation

If migrating from a pure NumPy tropical matrix multiplication:

# Before (NumPy)
def maxplus_matmul_numpy(a, b):
    m, k = a.shape
    n = b.shape[1]
    c = np.full((m, n), -np.inf)
    for i in range(m):
        for j in range(n):
            for kk in range(k):
                c[i, j] = max(c[i, j], a[i, kk] + b[kk, j])
    return c

# After (tropical-gemm)
import tropical_gemm
c = tropical_gemm.maxplus_matmul(a, b)

API Changes

No breaking changes yet (this is the first release).