Crate tropical_gemm

Crate tropical_gemm 

Source
Expand description

High-performance tropical matrix multiplication.

This library provides BLAS-level performance for tropical matrix multiplication across multiple semiring types.

§GPU Acceleration

For GPU-accelerated operations, add the tropical-gemm-cuda crate:

[dependencies]
tropical-gemm = "0.1"
tropical-gemm-cuda = "0.1"

Then use the GPU API:

use tropical_gemm::TropicalMaxPlus;
use tropical_gemm_cuda::{tropical_matmul_gpu, CudaContext};

let c = tropical_matmul_gpu::<TropicalMaxPlus<f32>>(&a, m, k, &b, n)?;

§Tropical Semirings

Tropical algebra replaces standard arithmetic operations:

  • Standard addition → tropical addition (typically max or min)
  • Standard multiplication → tropical multiplication (typically + or ×)
Type⊕ (add)⊗ (mul)ZeroOneUse Case
TropicalMaxPlus<T>max+-∞0Viterbi, longest path
TropicalMinPlus<T>min++∞0Shortest path
TropicalMaxMul<T>max×01Probability (non-log)
TropicalAndOrORANDfalsetrueGraph reachability
CountingTropical<T,C>max+count+,×(-∞,0)(0,1)Path counting

§Quick Start

§Function-based API

use tropical_gemm::{tropical_matmul, TropicalMaxPlus, TropicalSemiring};

// Create 2x3 and 3x2 matrices
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

// Compute C = A ⊗ B using TropicalMaxPlus semiring
let c = tropical_matmul::<TropicalMaxPlus<f32>>(&a, 2, 3, &b, 2);

// C[i,j] = max_k(A[i,k] + B[k,j])
assert_eq!(c[0].value(), 8.0); // max(1+1, 2+3, 3+5) = 8

§Matrix-based API (faer-style)

use tropical_gemm::{Mat, MatRef, MaxPlus, TropicalSemiring};

// Create matrix views from raw data
let a_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];

let a = MatRef::<MaxPlus<f32>>::from_slice(&a_data, 2, 3);
let b = MatRef::<MaxPlus<f32>>::from_slice(&b_data, 3, 2);

// Matrix multiplication using operators
let c = &a * &b;
assert_eq!(c[(0, 0)].value(), 8.0);

// Or using methods
let c = a.matmul(&b);

// Factory methods
let zeros = Mat::<MaxPlus<f32>>::zeros(3, 3);
let identity = Mat::<MaxPlus<f32>>::identity(3);

§Argmax Tracking (Backpropagation)

For gradient routing in neural networks, you can track which k index produced each optimal value:

use tropical_gemm::{tropical_matmul_with_argmax, TropicalMaxPlus, TropicalSemiring};

let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];

let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(&a, 2, 3, &b, 2);

// Get the optimal value and which k produced it
let value = result.get(0, 0).value(); // 8.0
let k_idx = result.get_argmax(0, 0);  // 2 (k=2 gave max)

§Performance

The library uses:

  • BLIS-style cache blocking for memory efficiency
  • Runtime CPU feature detection for optimal SIMD kernels
  • AVX2/AVX-512 on x86-64, NEON on ARM
use tropical_gemm::Backend;

println!("Using: {}", Backend::description());

§BLAS-style API

For fine-grained control:

use tropical_gemm::{TropicalGemm, TropicalMaxPlus, TropicalSemiring};

let a = vec![1.0f32; 64 * 64];
let b = vec![1.0f32; 64 * 64];
let mut c = vec![TropicalMaxPlus::tropical_zero(); 64 * 64];

TropicalGemm::<TropicalMaxPlus<f32>>::new(64, 64, 64)
    .execute(&a, 64, &b, 64, &mut c, 64);

Re-exports§

pub use core::GemmWithArgmax;
pub use core::Layout;
pub use core::Transpose;
pub use mat::Mat;
pub use mat::MatMut;
pub use mat::MatRef;
pub use mat::MatWithArgmax;
pub use simd::simd_level;
pub use simd::KernelDispatch;
pub use simd::SimdLevel;
pub use types::CountingTropical;
pub use types::SimdTropical;
pub use types::TropicalAndOr;
pub use types::TropicalMaxMul;
pub use types::TropicalMaxPlus;
pub use types::TropicalMinPlus;
pub use types::TropicalScalar;
pub use types::TropicalSemiring;
pub use types::TropicalWithArgmax;

Modules§

core
Core tropical GEMM algorithms using BLIS-style blocking.
mat
Matrix types for tropical algebra.
prelude
Prelude module for convenient imports.
simd
SIMD-optimized microkernels for tropical GEMM.
types
Tropical semiring type definitions.

Structs§

TropicalGemm
Builder for configuring tropical GEMM operations.

Enums§

Backend
Available backends for tropical GEMM.

Functions§

tropical_backward_a
Compute gradient with respect to matrix A in tropical matmul.
tropical_backward_a_batched
Batched backward pass for gradient with respect to A.
tropical_backward_b
Compute gradient with respect to matrix B in tropical matmul.
tropical_backward_b_batched
Batched backward pass for gradient with respect to B.
tropical_gemm
BLAS-style GEMM interface.
tropical_matmul
Simple tropical matrix multiplication: C = A ⊗ B
tropical_matmul_batched
Batched tropical matrix multiplication: C[i] = A[i] ⊗ B[i] for i = 0..batch_size
tropical_matmul_batched_with_argmax
Batched tropical matrix multiplication with argmax tracking.
tropical_matmul_strided_batched
Strided batched GEMM: computes C[i] = A[i] ⊗ B[i] from contiguous memory.
tropical_matmul_with_argmax
Tropical matrix multiplication with argmax tracking.
version_info
Get information about the library configuration.

Type Aliases§

AndOr
Alias for TropicalAndOr.
MaxMul
Alias for TropicalMaxMul.
MaxPlus
Alias for TropicalMaxPlus.
MinPlus
Alias for TropicalMinPlus.