tropical_gemm/mat/
mod.rs

1//! Matrix types for tropical algebra.
2//!
3//! This module provides faer-inspired matrix types:
4//! - [`Mat<S>`]: Owned matrix storing semiring values
5//! - [`MatRef<'a, S>`]: Immutable view over scalar data
6//! - [`MatMut<'a, S>`]: Mutable view over semiring data
7//!
8//! # Example
9//!
10//! ```
11//! use tropical_gemm::{Mat, MatRef, MaxPlus};
12//!
13//! // Create a view from raw data
14//! let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
15//! let a = MatRef::<MaxPlus<f32>>::from_slice(&data, 2, 3);
16//! let b = MatRef::<MaxPlus<f32>>::from_slice(&data, 3, 2);
17//!
18//! // Matrix multiplication using method
19//! let c = a.matmul(&b);
20//!
21//! // Or using operator syntax
22//! let c = &a * &b;
23//!
24//! // Factory methods
25//! let zeros = Mat::<MaxPlus<f32>>::zeros(3, 3);
26//! let identity = Mat::<MaxPlus<f32>>::identity(3);
27//! ```
28
29mod mut_;
30mod ops;
31mod owned;
32mod ref_;
33
34pub use mut_::MatMut;
35pub use owned::Mat;
36pub use ref_::MatRef;
37
38/// Result of matrix multiplication with argmax tracking.
39pub struct MatWithArgmax<S: crate::TropicalWithArgmax> {
40    /// The result matrix values.
41    pub values: Mat<S>,
42    /// The argmax indices (which k produced each C[i,j]).
43    pub argmax: Vec<u32>,
44}
45
46impl<S: crate::TropicalWithArgmax<Index = u32>> MatWithArgmax<S> {
47    /// Get the value at position (i, j).
48    pub fn get(&self, i: usize, j: usize) -> S {
49        self.values[(i, j)]
50    }
51
52    /// Get the scalar value at position (i, j).
53    ///
54    /// This is a convenience method that extracts the underlying scalar
55    /// without requiring a trait import.
56    #[inline]
57    pub fn get_value(&self, i: usize, j: usize) -> S::Scalar {
58        self.values[(i, j)].value()
59    }
60
61    /// Get the argmax index at position (i, j).
62    pub fn get_argmax(&self, i: usize, j: usize) -> u32 {
63        // Column-major indexing
64        self.argmax[j * self.values.nrows() + i]
65    }
66
67    /// Number of rows.
68    pub fn nrows(&self) -> usize {
69        self.values.nrows()
70    }
71
72    /// Number of columns.
73    pub fn ncols(&self) -> usize {
74        self.values.ncols()
75    }
76
77    /// Get the argmax indices as a slice.
78    ///
79    /// This is useful for backward pass computation.
80    #[inline]
81    pub fn argmax_slice(&self) -> &[u32] {
82        &self.argmax
83    }
84
85    /// Compute gradient with respect to matrix A.
86    ///
87    /// Given the upstream gradient dL/dC, computes dL/dA using the argmax
88    /// indices from the forward pass.
89    ///
90    /// For C = A ⊗ B where C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j]):
91    /// dL/dA[i,k] = Σ_j { dL/dC[i,j] if argmax[i,j] == k }
92    ///
93    /// # Arguments
94    ///
95    /// * `grad_c` - Gradient of the loss with respect to C, dimensions m×n
96    /// * `k` - Number of columns in A (the inner dimension)
97    ///
98    /// # Returns
99    ///
100    /// Gradient of the loss with respect to A, dimensions m×k
101    ///
102    /// # Example
103    ///
104    /// ```
105    /// use tropical_gemm::{Mat, MaxPlus, TropicalMaxPlus};
106    ///
107    /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
108    /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
109    ///
110    /// // Forward pass with argmax
111    /// let result = a.matmul_argmax(&b);
112    ///
113    /// // Backward pass: grad_c is upstream gradient (e.g., all ones)
114    /// let grad_c = Mat::<MaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
115    /// let grad_a = result.backward_a(&grad_c, 3); // k=3 (columns in A)
116    ///
117    /// assert_eq!(grad_a.nrows(), 2);
118    /// assert_eq!(grad_a.ncols(), 3);
119    /// ```
120    pub fn backward_a<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
121    where
122        G: crate::TropicalSemiring,
123        G::Scalar: Copy + Default + std::ops::AddAssign,
124    {
125        let m = self.nrows();
126        let n = self.ncols();
127        assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
128        assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
129
130        // Output is m×k in column-major
131        let mut grad_a_data = vec![G::Scalar::default(); m * k];
132
133        for j in 0..n {
134            for i in 0..m {
135                // Column-major indexing for argmax
136                let idx = self.argmax[j * m + i] as usize;
137                if idx < k {
138                    // Column-major indexing for grad_a: element (i, idx) at idx * m + i
139                    grad_a_data[idx * m + i] += grad_c[(i, j)].value();
140                }
141            }
142        }
143
144        Mat::from_col_major(&grad_a_data, m, k)
145    }
146
147    /// Compute gradient with respect to matrix B.
148    ///
149    /// Given the upstream gradient dL/dC, computes dL/dB using the argmax
150    /// indices from the forward pass.
151    ///
152    /// For C = A ⊗ B where C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j]):
153    /// dL/dB[k,j] = Σ_i { dL/dC[i,j] if argmax[i,j] == k }
154    ///
155    /// # Arguments
156    ///
157    /// * `grad_c` - Gradient of the loss with respect to C, dimensions m×n
158    /// * `k` - Number of rows in B (the inner dimension)
159    ///
160    /// # Returns
161    ///
162    /// Gradient of the loss with respect to B, dimensions k×n
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// use tropical_gemm::{Mat, MaxPlus, TropicalMaxPlus};
168    ///
169    /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
170    /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
171    ///
172    /// // Forward pass with argmax
173    /// let result = a.matmul_argmax(&b);
174    ///
175    /// // Backward pass: grad_c is upstream gradient
176    /// let grad_c = Mat::<MaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
177    /// let grad_b = result.backward_b(&grad_c, 3); // k=3 (rows in B)
178    ///
179    /// assert_eq!(grad_b.nrows(), 3);
180    /// assert_eq!(grad_b.ncols(), 2);
181    /// ```
182    pub fn backward_b<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
183    where
184        G: crate::TropicalSemiring,
185        G::Scalar: Copy + Default + std::ops::AddAssign,
186    {
187        let m = self.nrows();
188        let n = self.ncols();
189        assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
190        assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
191
192        // Output is k×n in column-major
193        let mut grad_b_data = vec![G::Scalar::default(); k * n];
194
195        for j in 0..n {
196            for i in 0..m {
197                // Column-major indexing for argmax
198                let idx = self.argmax[j * m + i] as usize;
199                if idx < k {
200                    // Column-major indexing for grad_b: element (idx, j) at j * k + idx
201                    grad_b_data[j * k + idx] += grad_c[(i, j)].value();
202                }
203            }
204        }
205
206        Mat::from_col_major(&grad_b_data, k, n)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::TropicalMaxPlus;
214
215    #[test]
216    fn test_mat_zeros() {
217        let m = Mat::<TropicalMaxPlus<f64>>::zeros(3, 4);
218        assert_eq!(m.nrows(), 3);
219        assert_eq!(m.ncols(), 4);
220        assert_eq!(m[(0, 0)].0, f64::NEG_INFINITY);
221    }
222
223    #[test]
224    fn test_mat_identity() {
225        let m = Mat::<TropicalMaxPlus<f64>>::identity(3);
226        assert_eq!(m.nrows(), 3);
227        assert_eq!(m.ncols(), 3);
228        assert_eq!(m[(0, 0)].0, 0.0); // tropical one
229        assert_eq!(m[(0, 1)].0, f64::NEG_INFINITY); // tropical zero
230        assert_eq!(m[(1, 1)].0, 0.0);
231        assert_eq!(m[(2, 2)].0, 0.0);
232    }
233
234    #[test]
235    fn test_mat_from_fn() {
236        let m =
237            Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
238        assert_eq!(m[(0, 0)].0, 0.0);
239        assert_eq!(m[(0, 2)].0, 2.0);
240        assert_eq!(m[(1, 0)].0, 3.0);
241        assert_eq!(m[(1, 2)].0, 5.0);
242    }
243
244    #[test]
245    fn test_matref_from_slice() {
246        // Column-major data: 2×3 matrix [[1,2,3],[4,5,6]] stored as [1,4,2,5,3,6]
247        let data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
248        let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 3);
249        assert_eq!(m.nrows(), 2);
250        assert_eq!(m.ncols(), 3);
251        assert_eq!(m.get(0, 0), 1.0);
252        assert_eq!(m.get(1, 2), 6.0);
253    }
254
255    #[test]
256    fn test_matmul() {
257        // Column-major data:
258        // A: 2×3 matrix [[1,2,3],[4,5,6]] stored as [1,4,2,5,3,6]
259        // B: 3×2 matrix [[1,2],[3,4],[5,6]] stored as [1,3,5,2,4,6]
260        let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
261        let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
262
263        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
264        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
265
266        let c = a.matmul(&b);
267
268        // C[0,0] = max(1+1, 2+3, 3+5) = 8
269        assert_eq!(c[(0, 0)].0, 8.0);
270        // C[0,1] = max(1+2, 2+4, 3+6) = 9
271        assert_eq!(c[(0, 1)].0, 9.0);
272        // C[1,0] = max(4+1, 5+3, 6+5) = 11
273        assert_eq!(c[(1, 0)].0, 11.0);
274        // C[1,1] = max(4+2, 5+4, 6+6) = 12
275        assert_eq!(c[(1, 1)].0, 12.0);
276    }
277
278    #[test]
279    fn test_matmul_operator() {
280        // Column-major data
281        let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
282        let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
283
284        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
285        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
286
287        let c = &a * &b;
288
289        assert_eq!(c[(0, 0)].0, 8.0);
290        assert_eq!(c[(1, 1)].0, 12.0);
291    }
292
293    #[test]
294    fn test_matmul_argmax() {
295        // Column-major data
296        let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
297        let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
298
299        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
300        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
301
302        let result = a.matmul_argmax(&b);
303
304        assert_eq!(result.get(0, 0).0, 8.0);
305        assert_eq!(result.get_argmax(0, 0), 2); // k=2 gave max
306    }
307
308    #[test]
309    fn test_minplus_matmul() {
310        use crate::TropicalMinPlus;
311
312        // Column-major data
313        let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
314        let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
315
316        let a = MatRef::<TropicalMinPlus<f64>>::from_slice(&a_data, 2, 3);
317        let b = MatRef::<TropicalMinPlus<f64>>::from_slice(&b_data, 3, 2);
318
319        let c = a.matmul(&b);
320
321        // C[0,0] = min(1+1, 2+3, 3+5) = 2
322        assert_eq!(c[(0, 0)].0, 2.0);
323        // C[1,1] = min(4+2, 5+4, 6+6) = 6
324        assert_eq!(c[(1, 1)].0, 6.0);
325    }
326
327    #[test]
328    fn test_mat_as_ref() {
329        let m =
330            Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
331
332        let r = m.as_ref();
333        assert_eq!(r.nrows(), 2);
334        assert_eq!(r.ncols(), 3);
335        assert_eq!(r.get(0, 0), 0.0);
336        assert_eq!(r.get(1, 2), 5.0);
337    }
338
339    #[test]
340    fn test_mat_matmul_direct() {
341        // Test Mat::matmul directly (no as_ref needed)
342        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
343        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
344
345        let c = a.matmul(&b);
346
347        // C[0,0] = max(1+1, 2+3, 3+5) = 8
348        assert_eq!(c[(0, 0)].0, 8.0);
349        // C[1,1] = max(4+2, 5+4, 6+6) = 12
350        assert_eq!(c[(1, 1)].0, 12.0);
351    }
352
353    #[test]
354    fn test_mat_matmul_argmax_direct() {
355        // Test Mat::matmul_argmax directly
356        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
357        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
358
359        let result = a.matmul_argmax(&b);
360
361        assert_eq!(result.get(0, 0).0, 8.0);
362        assert_eq!(result.get_argmax(0, 0), 2); // k=2 gave max
363    }
364
365    #[test]
366    fn test_mat_get_value() {
367        // Test get_value method - no trait import needed
368        let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
369
370        assert_eq!(m.get_value(0, 0), 1.0);
371        assert_eq!(m.get_value(0, 1), 2.0);
372        assert_eq!(m.get_value(1, 0), 3.0);
373        assert_eq!(m.get_value(1, 1), 4.0);
374    }
375
376    #[test]
377    fn test_minplus_mat_matmul_direct() {
378        use crate::TropicalMinPlus;
379
380        let a = Mat::<TropicalMinPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
381        let b = Mat::<TropicalMinPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
382
383        let c = a.matmul(&b);
384
385        // C[0,0] = min(1+1, 2+3, 3+5) = 2
386        assert_eq!(c[(0, 0)].0, 2.0);
387        // C[1,1] = min(4+2, 5+4, 6+6) = 6
388        assert_eq!(c[(1, 1)].0, 6.0);
389    }
390
391    #[test]
392    fn test_mat_from_vec() {
393        let data = vec![
394            TropicalMaxPlus(1.0f64),
395            TropicalMaxPlus(2.0),
396            TropicalMaxPlus(3.0),
397            TropicalMaxPlus(4.0),
398        ];
399        let m = Mat::from_vec(data, 2, 2);
400        assert_eq!(m.nrows(), 2);
401        assert_eq!(m.ncols(), 2);
402        assert_eq!(m[(0, 0)].0, 1.0);
403        assert_eq!(m[(1, 1)].0, 4.0);
404    }
405
406    #[test]
407    fn test_mat_as_slice() {
408        let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
409        let slice = m.as_slice();
410        assert_eq!(slice.len(), 4);
411        assert_eq!(slice[0].0, 1.0);
412        assert_eq!(slice[3].0, 4.0);
413    }
414
415    #[test]
416    fn test_mat_as_mut_slice() {
417        let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
418        let slice = m.as_mut_slice();
419        slice[0] = TropicalMaxPlus(100.0);
420        assert_eq!(m[(0, 0)].0, 100.0);
421    }
422
423    #[test]
424    fn test_mat_as_mut_ptr() {
425        let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
426        let ptr = m.as_mut_ptr();
427        assert!(!ptr.is_null());
428    }
429
430    #[test]
431    fn test_mat_index_mut() {
432        let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
433        m[(0, 0)] = TropicalMaxPlus(10.0);
434        m[(1, 1)] = TropicalMaxPlus(40.0);
435        assert_eq!(m[(0, 0)].0, 10.0);
436        assert_eq!(m[(1, 1)].0, 40.0);
437    }
438
439    #[test]
440    fn test_mat_matmul_ref() {
441        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
442        // Column-major data for B: 3×2 matrix [[1,2],[3,4],[5,6]] stored as [1,3,5,2,4,6]
443        let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
444        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
445
446        let c = a.matmul_ref(&b);
447
448        // C[0,0] = max(1+1, 2+3, 3+5) = 8
449        assert_eq!(c[(0, 0)].0, 8.0);
450        // C[1,1] = max(4+2, 5+4, 6+6) = 12
451        assert_eq!(c[(1, 1)].0, 12.0);
452    }
453
454    #[test]
455    fn test_matref_copy_clone() {
456        let data = [1.0f64, 2.0, 3.0, 4.0];
457        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
458        let b = a; // Copy
459        let c = a.clone(); // Clone
460        assert_eq!(a.get(0, 0), b.get(0, 0));
461        assert_eq!(a.get(0, 0), c.get(0, 0));
462    }
463
464    #[test]
465    fn test_matref_to_owned() {
466        let data = [1.0f64, 2.0, 3.0, 4.0];
467        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
468        let owned = a.to_owned();
469        assert_eq!(owned.nrows(), 2);
470        assert_eq!(owned.ncols(), 2);
471        assert_eq!(owned[(0, 0)].0, 1.0);
472    }
473
474    #[test]
475    fn test_matref_debug() {
476        let data = [1.0f64, 2.0];
477        let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 1, 2);
478        let debug_str = format!("{:?}", m);
479        assert!(debug_str.contains("MatRef"));
480    }
481
482    #[test]
483    fn test_mat_clone() {
484        let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
485        let m2 = m.clone();
486        assert_eq!(m2[(0, 0)].0, 1.0);
487        assert_eq!(m2[(1, 1)].0, 4.0);
488    }
489
490    #[test]
491    fn test_mat_debug() {
492        let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0], 1, 2);
493        let debug_str = format!("{:?}", m);
494        assert!(debug_str.contains("Mat"));
495    }
496
497    #[test]
498    fn test_matwithargmax_get_value() {
499        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
500        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
501
502        let result = a.matmul_argmax(&b);
503
504        // Test get_value (scalar extraction without trait import)
505        assert_eq!(result.get_value(0, 0), 8.0);
506        assert_eq!(result.get_value(1, 1), 12.0);
507    }
508
509    #[test]
510    fn test_matwithargmax_nrows_ncols() {
511        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
512        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
513
514        let result = a.matmul_argmax(&b);
515
516        assert_eq!(result.nrows(), 2);
517        assert_eq!(result.ncols(), 2);
518    }
519
520    #[test]
521    #[should_panic(expected = "data length")]
522    fn test_mat_from_row_major_size_mismatch() {
523        let _ = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0], 2, 2);
524    }
525
526    #[test]
527    #[should_panic(expected = "data length")]
528    fn test_mat_from_vec_size_mismatch() {
529        let data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
530        let _ = Mat::from_vec(data, 2, 2);
531    }
532
533    #[test]
534    #[should_panic(expected = "data length")]
535    fn test_matref_from_slice_size_mismatch() {
536        let data = [1.0f64, 2.0];
537        let _ = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
538    }
539
540    #[test]
541    #[should_panic(expected = "dimension mismatch")]
542    fn test_matmul_dimension_mismatch() {
543        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
544        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
545        let _ = a.matmul(&b); // Should panic: A is 2x2, B is 3x2
546    }
547
548    #[test]
549    #[should_panic(expected = "dimension mismatch")]
550    fn test_matref_matmul_dimension_mismatch() {
551        let a_data = [1.0f64, 2.0, 3.0, 4.0];
552        let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
553        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
554        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
555        let _ = a.matmul(&b); // Should panic
556    }
557
558    #[test]
559    #[should_panic(expected = "dimension mismatch")]
560    fn test_matmul_argmax_dimension_mismatch() {
561        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
562        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
563        let _ = a.matmul_argmax(&b); // Should panic
564    }
565
566    #[test]
567    #[should_panic(expected = "dimension mismatch")]
568    fn test_matref_matmul_argmax_dimension_mismatch() {
569        let a_data = [1.0f64, 2.0, 3.0, 4.0];
570        let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
571        let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
572        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
573        let _ = a.matmul_argmax(&b); // Should panic
574    }
575
576    #[test]
577    #[should_panic(expected = "dimension mismatch")]
578    fn test_mat_matmul_ref_dimension_mismatch() {
579        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
580        let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
581        let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
582        let _ = a.matmul_ref(&b); // Should panic
583    }
584
585    // ========================================================================
586    // Batched operation tests
587    // ========================================================================
588
589    #[test]
590    fn test_mat_matmul_batched() {
591        let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592        let a2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
593        let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
594        let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
595
596        let results = Mat::matmul_batched(&[a1, a2], &[b1, b2]);
597        assert_eq!(results.len(), 2);
598
599        // C[0] = A[0] * B[0] (MaxPlus)
600        // C[0,0] = max(1+1, 2+0) = 2
601        assert!((results[0][(0, 0)].0 - 2.0).abs() < 1e-5);
602
603        // C[1] = A[1] * B[1] (MaxPlus)
604        // C[0,0] = max(5+1, 6+3) = 9
605        assert!((results[1][(0, 0)].0 - 9.0).abs() < 1e-5);
606    }
607
608    #[test]
609    fn test_mat_matmul_batched_empty() {
610        let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
611        let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
612
613        let results = Mat::matmul_batched(&a_batch, &b_batch);
614        assert!(results.is_empty());
615    }
616
617    #[test]
618    #[should_panic(expected = "batch sizes must match")]
619    fn test_mat_matmul_batched_size_mismatch() {
620        let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
621        let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
622        let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
623
624        let _ = Mat::matmul_batched(&[a1], &[b1, b2]); // Should panic
625    }
626
627    #[test]
628    #[should_panic(expected = "has dimensions")]
629    fn test_mat_matmul_batched_dimension_mismatch() {
630        let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
631        let a2 =
632            Mat::<TropicalMaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 2, 3); // Different size
633        let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
634        let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
635
636        let _ = Mat::matmul_batched(&[a1, a2], &[b1, b2]); // Should panic
637    }
638
639    #[test]
640    fn test_mat_matmul_batched_with_argmax() {
641        let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
642        let a2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[6.0, 5.0, 4.0, 3.0, 2.0, 1.0], 2, 3);
643        let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
644        let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
645
646        let results = Mat::matmul_batched_with_argmax(&[a1, a2], &[b1, b2]);
647        assert_eq!(results.len(), 2);
648
649        // C[0,0] = max(1+1, 2+3, 3+5) = 8, argmax=2
650        assert!((results[0].get(0, 0).0 - 8.0).abs() < 1e-5);
651        assert_eq!(results[0].get_argmax(0, 0), 2);
652    }
653
654    #[test]
655    fn test_mat_matmul_batched_with_argmax_empty() {
656        let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
657        let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
658
659        let results = Mat::matmul_batched_with_argmax(&a_batch, &b_batch);
660        assert!(results.is_empty());
661    }
662
663    #[test]
664    #[should_panic(expected = "batch sizes must match")]
665    fn test_mat_matmul_batched_with_argmax_size_mismatch() {
666        let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
667        let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
668        let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
669
670        let _ = Mat::matmul_batched_with_argmax(&[a1], &[b1, b2]); // Should panic
671    }
672
673    // ========================================================================
674    // Backward pass tests
675    // ========================================================================
676
677    #[test]
678    fn test_matwithargmax_backward_a() {
679        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
680        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
681
682        // Forward pass
683        let result = a.matmul_argmax(&b);
684
685        // All argmax should be 2 (k=2 wins for all)
686        assert_eq!(result.get_argmax(0, 0), 2);
687        assert_eq!(result.get_argmax(0, 1), 2);
688        assert_eq!(result.get_argmax(1, 0), 2);
689        assert_eq!(result.get_argmax(1, 1), 2);
690
691        // Backward pass with unit gradients
692        let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
693        let grad_a = result.backward_a(&grad_c, 3);
694
695        // Only column 2 should have gradients
696        assert_eq!(grad_a.nrows(), 2);
697        assert_eq!(grad_a.ncols(), 3);
698        assert_eq!(grad_a[(0, 0)].0, 0.0); // Not selected
699        assert_eq!(grad_a[(0, 1)].0, 0.0); // Not selected
700        assert_eq!(grad_a[(0, 2)].0, 2.0); // Selected for C[0,0] and C[0,1]
701        assert_eq!(grad_a[(1, 0)].0, 0.0); // Not selected
702        assert_eq!(grad_a[(1, 1)].0, 0.0); // Not selected
703        assert_eq!(grad_a[(1, 2)].0, 2.0); // Selected for C[1,0] and C[1,1]
704    }
705
706    #[test]
707    fn test_matwithargmax_backward_b() {
708        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
709        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
710
711        // Forward pass
712        let result = a.matmul_argmax(&b);
713
714        // Backward pass with unit gradients
715        let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
716        let grad_b = result.backward_b(&grad_c, 3);
717
718        // Only row 2 should have gradients
719        assert_eq!(grad_b.nrows(), 3);
720        assert_eq!(grad_b.ncols(), 2);
721        assert_eq!(grad_b[(0, 0)].0, 0.0); // Not selected
722        assert_eq!(grad_b[(0, 1)].0, 0.0); // Not selected
723        assert_eq!(grad_b[(1, 0)].0, 0.0); // Not selected
724        assert_eq!(grad_b[(1, 1)].0, 0.0); // Not selected
725        assert_eq!(grad_b[(2, 0)].0, 2.0); // Selected for C[0,0] and C[1,0]
726        assert_eq!(grad_b[(2, 1)].0, 2.0); // Selected for C[0,1] and C[1,1]
727    }
728
729    #[test]
730    fn test_matwithargmax_backward_varied_argmax() {
731        // Design matrices where different k-indices win
732        let a =
733            Mat::<TropicalMaxPlus<f64>>::from_row_major(&[10.0, 1.0, 1.0, 1.0, 10.0, 1.0], 2, 3);
734        let b =
735            Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 1.0, 1.0, 1.0, 10.0, 10.0], 3, 2);
736
737        let result = a.matmul_argmax(&b);
738
739        // Check argmax patterns
740        // C[0,0] = max(10+1=11, 1+1=2, 1+10=11), first wins -> k=0
741        // C[1,0] = max(1+1=2, 10+1=11, 1+10=11), second wins -> k=1
742        assert_eq!(result.get_argmax(0, 0), 0);
743        assert_eq!(result.get_argmax(1, 0), 1);
744
745        let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
746        let grad_a = result.backward_a(&grad_c, 3);
747
748        // grad_a[0,0] should get contributions from C[0,*] where argmax == 0
749        // grad_a[1,1] should get contributions from C[1,*] where argmax == 1
750        assert!(grad_a[(0, 0)].0 > 0.0); // k=0 selected for C[0,0] and C[0,1]
751        assert!(grad_a[(1, 1)].0 > 0.0); // k=1 selected for C[1,0] and C[1,1]
752    }
753
754    #[test]
755    fn test_matwithargmax_argmax_slice() {
756        let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
757        let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
758
759        let result = a.matmul_argmax(&b);
760        let argmax_slice = result.argmax_slice();
761
762        assert_eq!(argmax_slice.len(), 4); // 2x2 output
763        assert_eq!(argmax_slice[0], result.get_argmax(0, 0));
764        assert_eq!(argmax_slice[1], result.get_argmax(0, 1));
765        assert_eq!(argmax_slice[2], result.get_argmax(1, 0));
766        assert_eq!(argmax_slice[3], result.get_argmax(1, 1));
767    }
768}