tropical_gemm/core/
argmax.rs

1use crate::types::TropicalWithArgmax;
2
3/// Result of GEMM with argmax tracking.
4///
5/// For each element C[i,j], tracks which k index produced the optimal value:
6/// C[i,j] = ⊕_{k} A[i,k] ⊗ B[k,j]
7/// argmax[i,j] = argmax_k (A[i,k] ⊗ B[k,j])
8#[derive(Debug, Clone)]
9pub struct GemmWithArgmax<T: TropicalWithArgmax<Index = u32>> {
10    /// The result matrix values.
11    pub values: Vec<T>,
12    /// The argmax indices for each element.
13    pub argmax: Vec<u32>,
14    /// Number of rows in the result.
15    pub m: usize,
16    /// Number of columns in the result.
17    pub n: usize,
18    /// Leading dimension (stride between rows).
19    pub ld: usize,
20}
21
22impl<T: TropicalWithArgmax<Index = u32>> GemmWithArgmax<T> {
23    /// Create a new result container with tropical zeros.
24    pub fn new(m: usize, n: usize) -> Self {
25        let size = m * n;
26        Self {
27            values: vec![T::tropical_zero(); size],
28            argmax: vec![0u32; size],
29            m,
30            n,
31            ld: n,
32        }
33    }
34
35    /// Create a new result container with specified leading dimension.
36    pub fn with_ld(m: usize, n: usize, ld: usize) -> Self {
37        assert!(ld >= n, "Leading dimension must be >= n");
38        let size = m * ld;
39        Self {
40            values: vec![T::tropical_zero(); size],
41            argmax: vec![0u32; size],
42            m,
43            n,
44            ld,
45        }
46    }
47
48    /// Get value at (i, j).
49    #[inline]
50    pub fn get(&self, i: usize, j: usize) -> T {
51        debug_assert!(i < self.m && j < self.n);
52        self.values[i * self.ld + j]
53    }
54
55    /// Get argmax at (i, j).
56    #[inline]
57    pub fn get_argmax(&self, i: usize, j: usize) -> u32 {
58        debug_assert!(i < self.m && j < self.n);
59        self.argmax[i * self.ld + j]
60    }
61
62    /// Get mutable reference to value at (i, j).
63    #[inline]
64    pub fn get_mut(&mut self, i: usize, j: usize) -> &mut T {
65        debug_assert!(i < self.m && j < self.n);
66        &mut self.values[i * self.ld + j]
67    }
68
69    /// Get mutable reference to argmax at (i, j).
70    #[inline]
71    pub fn get_argmax_mut(&mut self, i: usize, j: usize) -> &mut u32 {
72        debug_assert!(i < self.m && j < self.n);
73        &mut self.argmax[i * self.ld + j]
74    }
75
76    /// Get raw pointers to the data.
77    #[inline]
78    pub fn as_mut_ptrs(&mut self) -> (*mut T, *mut u32) {
79        (self.values.as_mut_ptr(), self.argmax.as_mut_ptr())
80    }
81
82    /// Get the argmax array as a slice.
83    ///
84    /// This is useful for backward pass computation.
85    #[inline]
86    pub fn argmax_slice(&self) -> &[u32] {
87        &self.argmax
88    }
89
90    /// Get the values array as a slice.
91    #[inline]
92    pub fn values_slice(&self) -> &[T] {
93        &self.values
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::types::TropicalMaxPlus;
101
102    #[test]
103    fn test_gemm_with_argmax_new() {
104        let result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(3, 4);
105
106        assert_eq!(result.m, 3);
107        assert_eq!(result.n, 4);
108        assert_eq!(result.ld, 4);
109        assert_eq!(result.values.len(), 12);
110        assert_eq!(result.argmax.len(), 12);
111
112        // All values should be tropical zero (-inf)
113        for i in 0..3 {
114            for j in 0..4 {
115                assert_eq!(result.get(i, j).0, f64::NEG_INFINITY);
116                assert_eq!(result.get_argmax(i, j), 0);
117            }
118        }
119    }
120
121    #[test]
122    fn test_gemm_with_argmax_with_ld() {
123        let result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::with_ld(3, 4, 8);
124
125        assert_eq!(result.m, 3);
126        assert_eq!(result.n, 4);
127        assert_eq!(result.ld, 8);
128        // Size is m * ld = 3 * 8 = 24
129        assert_eq!(result.values.len(), 24);
130        assert_eq!(result.argmax.len(), 24);
131    }
132
133    #[test]
134    #[should_panic(expected = "Leading dimension must be >= n")]
135    fn test_gemm_with_argmax_with_ld_invalid() {
136        let _: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::with_ld(3, 4, 2);
137    }
138
139    #[test]
140    fn test_gemm_with_argmax_get_mut() {
141        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 2);
142
143        // Modify value using get_mut
144        *result.get_mut(0, 1) = TropicalMaxPlus(5.0);
145        *result.get_mut(1, 0) = TropicalMaxPlus(3.0);
146
147        assert_eq!(result.get(0, 1).0, 5.0);
148        assert_eq!(result.get(1, 0).0, 3.0);
149        assert_eq!(result.get(0, 0).0, f64::NEG_INFINITY);
150    }
151
152    #[test]
153    fn test_gemm_with_argmax_get_argmax_mut() {
154        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 2);
155
156        // Modify argmax using get_argmax_mut
157        *result.get_argmax_mut(0, 1) = 42;
158        *result.get_argmax_mut(1, 0) = 7;
159
160        assert_eq!(result.get_argmax(0, 1), 42);
161        assert_eq!(result.get_argmax(1, 0), 7);
162        assert_eq!(result.get_argmax(0, 0), 0);
163    }
164
165    #[test]
166    fn test_gemm_with_argmax_as_mut_ptrs() {
167        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 3);
168        let (values_ptr, argmax_ptr) = result.as_mut_ptrs();
169
170        // Write through raw pointers
171        unsafe {
172            *values_ptr.add(0) = TropicalMaxPlus(1.0);
173            *values_ptr.add(5) = TropicalMaxPlus(6.0);
174            *argmax_ptr.add(0) = 10;
175            *argmax_ptr.add(5) = 20;
176        }
177
178        assert_eq!(result.get(0, 0).0, 1.0);
179        assert_eq!(result.get(1, 2).0, 6.0);
180        assert_eq!(result.get_argmax(0, 0), 10);
181        assert_eq!(result.get_argmax(1, 2), 20);
182    }
183
184    #[test]
185    fn test_gemm_with_argmax_slices() {
186        let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 3);
187
188        // Modify some values
189        *result.get_mut(0, 0) = TropicalMaxPlus(1.0);
190        *result.get_mut(1, 2) = TropicalMaxPlus(6.0);
191        *result.get_argmax_mut(0, 0) = 10;
192        *result.get_argmax_mut(1, 2) = 20;
193
194        // Test values_slice
195        let values = result.values_slice();
196        assert_eq!(values.len(), 6); // 2 * 3
197        assert_eq!(values[0].0, 1.0);
198        assert_eq!(values[5].0, 6.0);
199
200        // Test argmax_slice
201        let argmax = result.argmax_slice();
202        assert_eq!(argmax.len(), 6);
203        assert_eq!(argmax[0], 10);
204        assert_eq!(argmax[5], 20);
205    }
206}