tropical_gemm/mat/
owned.rs

1//! Owned matrix type.
2
3use std::ops::{Index, IndexMut};
4
5use crate::core::Transpose;
6use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
7use crate::types::{TropicalSemiring, TropicalWithArgmax};
8
9use super::{MatRef, MatWithArgmax};
10
11/// Owned matrix storing semiring values.
12///
13/// The matrix stores values in column-major order (Fortran/BLAS convention).
14/// Use factory methods to create matrices:
15///
16/// ```
17/// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
18///
19/// let zeros = Mat::<MaxPlus<f32>>::zeros(3, 4);
20/// let identity = Mat::<MaxPlus<f32>>::identity(3);
21/// let custom = Mat::<MaxPlus<f32>>::from_fn(2, 2, |i, j| {
22///     MaxPlus::<f32>::from_scalar((i + j) as f32)
23/// });
24/// ```
25#[derive(Debug, Clone)]
26pub struct Mat<S: TropicalSemiring> {
27    pub(crate) data: Vec<S>,
28    pub(crate) nrows: usize,
29    pub(crate) ncols: usize,
30}
31
32impl<S: TropicalSemiring> Mat<S> {
33    /// Create a matrix filled with tropical zeros.
34    ///
35    /// For MaxPlus, this fills with -∞.
36    /// For MinPlus, this fills with +∞.
37    pub fn zeros(nrows: usize, ncols: usize) -> Self {
38        Self {
39            data: vec![S::tropical_zero(); nrows * ncols],
40            nrows,
41            ncols,
42        }
43    }
44
45    /// Create a tropical identity matrix.
46    ///
47    /// Diagonal elements are tropical one (0 for MaxPlus/MinPlus).
48    /// Off-diagonal elements are tropical zero (-∞ for MaxPlus, +∞ for MinPlus).
49    pub fn identity(n: usize) -> Self {
50        let mut mat = Self::zeros(n, n);
51        for i in 0..n {
52            // Column-major: diagonal element (i, i) at index i + i * n
53            mat.data[i + i * n] = S::tropical_one();
54        }
55        mat
56    }
57
58    /// Create a matrix from a function.
59    ///
60    /// The function is called with (row, col) indices.
61    /// Data is stored in column-major order internally.
62    pub fn from_fn<F>(nrows: usize, ncols: usize, mut f: F) -> Self
63    where
64        F: FnMut(usize, usize) -> S,
65    {
66        // Column-major: iterate column by column
67        let data = (0..nrows * ncols)
68            .map(|idx| f(idx % nrows, idx / nrows))
69            .collect();
70        Self { data, nrows, ncols }
71    }
72
73    /// Create a matrix from column-major scalar data.
74    ///
75    /// Each scalar is wrapped in the semiring type.
76    /// Data should be in column-major order: first column, then second column, etc.
77    pub fn from_col_major(data: &[S::Scalar], nrows: usize, ncols: usize) -> Self
78    where
79        S::Scalar: Copy,
80    {
81        assert_eq!(
82            data.len(),
83            nrows * ncols,
84            "data length {} != nrows {} * ncols {}",
85            data.len(),
86            nrows,
87            ncols
88        );
89        let data = data.iter().map(|&s| S::from_scalar(s)).collect();
90        Self { data, nrows, ncols }
91    }
92
93    /// Create a matrix from row-major scalar data.
94    ///
95    /// This is a convenience method that converts row-major input to column-major storage.
96    ///
97    /// # Performance Warning
98    ///
99    /// This method performs an O(m×n) transpose operation. For performance-critical code,
100    /// provide data in column-major order and use [`from_col_major`] instead.
101    #[deprecated(since = "0.4.0", note = "use from_col_major instead for direct column-major input; this method has O(m×n) transpose overhead")]
102    pub fn from_row_major(data: &[S::Scalar], nrows: usize, ncols: usize) -> Self
103    where
104        S::Scalar: Copy,
105    {
106        assert_eq!(
107            data.len(),
108            nrows * ncols,
109            "data length {} != nrows {} * ncols {}",
110            data.len(),
111            nrows,
112            ncols
113        );
114        // Convert row-major to column-major
115        let col_major: Vec<S> = (0..nrows * ncols)
116            .map(|idx| {
117                let i = idx % nrows;
118                let j = idx / nrows;
119                S::from_scalar(data[i * ncols + j])
120            })
121            .collect();
122        Self { data: col_major, nrows, ncols }
123    }
124
125    /// Create a matrix from a vector of semiring values.
126    pub fn from_vec(data: Vec<S>, nrows: usize, ncols: usize) -> Self {
127        assert_eq!(
128            data.len(),
129            nrows * ncols,
130            "data length {} != nrows {} * ncols {}",
131            data.len(),
132            nrows,
133            ncols
134        );
135        Self { data, nrows, ncols }
136    }
137
138    /// Number of rows.
139    #[inline]
140    pub fn nrows(&self) -> usize {
141        self.nrows
142    }
143
144    /// Number of columns.
145    #[inline]
146    pub fn ncols(&self) -> usize {
147        self.ncols
148    }
149
150    /// Get the underlying data as a slice.
151    #[inline]
152    pub fn as_slice(&self) -> &[S] {
153        &self.data
154    }
155
156    /// Get the underlying data as a mutable slice.
157    #[inline]
158    pub fn as_mut_slice(&mut self) -> &mut [S] {
159        &mut self.data
160    }
161
162    /// Get the scalar value at position (i, j).
163    ///
164    /// This is a convenience method that extracts the underlying scalar
165    /// without requiring a trait import.
166    ///
167    /// # Example
168    ///
169    /// ```
170    /// use tropical_gemm::{Mat, MaxPlus};
171    ///
172    /// let m = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
173    /// assert_eq!(m.get_value(0, 0), 1.0);
174    /// assert_eq!(m.get_value(1, 1), 4.0);
175    /// ```
176    #[inline]
177    pub fn get_value(&self, i: usize, j: usize) -> S::Scalar {
178        self[(i, j)].value()
179    }
180
181    /// Convert to an immutable matrix reference.
182    ///
183    /// The returned reference views the scalar values.
184    pub fn as_ref(&self) -> MatRef<'_, S>
185    where
186        S::Scalar: Copy,
187    {
188        // Extract scalars from semiring values
189        // This requires that the data is laid out such that we can get scalars
190        // For now, we create a view that extracts values on-the-fly
191        // This is a limitation - ideally we'd have a separate scalar buffer
192        MatRef::from_mat(self)
193    }
194
195    /// Get a mutable pointer to the data.
196    #[inline]
197    pub fn as_mut_ptr(&mut self) -> *mut S {
198        self.data.as_mut_ptr()
199    }
200}
201
202impl<S: TropicalSemiring> Index<(usize, usize)> for Mat<S> {
203    type Output = S;
204
205    #[inline]
206    fn index(&self, (i, j): (usize, usize)) -> &S {
207        debug_assert!(
208            i < self.nrows,
209            "row index {} out of bounds {}",
210            i,
211            self.nrows
212        );
213        debug_assert!(
214            j < self.ncols,
215            "col index {} out of bounds {}",
216            j,
217            self.ncols
218        );
219        // Column-major indexing
220        &self.data[j * self.nrows + i]
221    }
222}
223
224impl<S: TropicalSemiring> IndexMut<(usize, usize)> for Mat<S> {
225    #[inline]
226    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut S {
227        debug_assert!(
228            i < self.nrows,
229            "row index {} out of bounds {}",
230            i,
231            self.nrows
232        );
233        debug_assert!(
234            j < self.ncols,
235            "col index {} out of bounds {}",
236            j,
237            self.ncols
238        );
239        // Column-major indexing
240        &mut self.data[j * self.nrows + i]
241    }
242}
243
244// Matrix multiplication methods directly on Mat
245impl<S> Mat<S>
246where
247    S: TropicalSemiring + KernelDispatch,
248    S::Scalar: Copy,
249{
250    /// Perform tropical matrix multiplication: C = A ⊗ B.
251    ///
252    /// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
253    ///
254    /// # Panics
255    ///
256    /// Panics if dimensions don't match (self.ncols != b.nrows).
257    ///
258    /// # Example
259    ///
260    /// ```
261    /// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
262    ///
263    /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
264    /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
265    ///
266    /// let c = a.matmul(&b);
267    ///
268    /// // C[0,0] = max(1+1, 2+3, 3+5) = 8
269    /// assert_eq!(c[(0, 0)].value(), 8.0);
270    /// ```
271    pub fn matmul(&self, b: &Mat<S>) -> Mat<S> {
272        assert_eq!(
273            self.ncols, b.nrows,
274            "dimension mismatch: A is {}x{}, B is {}x{}",
275            self.nrows, self.ncols, b.nrows, b.ncols
276        );
277
278        let a_ref = self.as_ref();
279        let b_ref = b.as_ref();
280
281        let m = self.nrows;
282        let n = b.ncols;
283        let k = self.ncols;
284
285        let mut c = Mat::<S>::zeros(m, n);
286
287        // The kernel uses row-major convention. For column-major data,
288        // we use the transpose trick: C = A * B becomes C^T = B^T * A^T.
289        // Column-major A (m×k) viewed as row-major is A^T (k×m) with ld=m.
290        // So we swap A and B, swap m and n, and the result is written
291        // in the correct column-major layout.
292        unsafe {
293            tropical_gemm_dispatch::<S>(
294                n,                           // rows of C^T = cols of C
295                m,                           // cols of C^T = rows of C
296                k,
297                b_ref.as_slice().as_ptr(),   // B becomes first operand (B^T)
298                k,                           // lda = nrows of B in col-major
299                Transpose::NoTrans,
300                a_ref.as_slice().as_ptr(),   // A becomes second operand (A^T)
301                m,                           // ldb = nrows of A in col-major
302                Transpose::NoTrans,
303                c.data.as_mut_ptr(),
304                m,                           // ldc = nrows of C in col-major
305            );
306        }
307
308        c
309    }
310
311    /// Perform tropical matrix multiplication with a MatRef.
312    ///
313    /// This allows mixing owned and reference matrices.
314    pub fn matmul_ref(&self, b: &MatRef<S>) -> Mat<S> {
315        assert_eq!(
316            self.ncols,
317            b.nrows(),
318            "dimension mismatch: A is {}x{}, B is {}x{}",
319            self.nrows,
320            self.ncols,
321            b.nrows(),
322            b.ncols()
323        );
324
325        let a_ref = self.as_ref();
326
327        let m = self.nrows;
328        let n = b.ncols();
329        let k = self.ncols;
330
331        let mut c = Mat::<S>::zeros(m, n);
332
333        // Transpose trick for column-major: C = A * B becomes C^T = B^T * A^T
334        unsafe {
335            tropical_gemm_dispatch::<S>(
336                n,
337                m,
338                k,
339                b.as_slice().as_ptr(),
340                k,
341                Transpose::NoTrans,
342                a_ref.as_slice().as_ptr(),
343                m,
344                Transpose::NoTrans,
345                c.data.as_mut_ptr(),
346                m,
347            );
348        }
349
350        c
351    }
352}
353
354// Argmax methods on Mat
355impl<S> Mat<S>
356where
357    S: TropicalWithArgmax<Index = u32> + KernelDispatch,
358    S::Scalar: Copy,
359{
360    /// Perform tropical matrix multiplication with argmax tracking.
361    ///
362    /// Returns both the result matrix and the argmax indices indicating
363    /// which k-index produced each optimal value.
364    ///
365    /// # Example
366    ///
367    /// ```
368    /// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
369    ///
370    /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
371    /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
372    ///
373    /// let result = a.matmul_argmax(&b);
374    ///
375    /// assert_eq!(result.get(0, 0).value(), 8.0);
376    /// assert_eq!(result.get_argmax(0, 0), 2); // k=2 gave max
377    /// ```
378    pub fn matmul_argmax(&self, b: &Mat<S>) -> MatWithArgmax<S> {
379        assert_eq!(
380            self.ncols, b.nrows,
381            "dimension mismatch: A is {}x{}, B is {}x{}",
382            self.nrows, self.ncols, b.nrows, b.ncols
383        );
384
385        let a_ref = self.as_ref();
386        let b_ref = b.as_ref();
387
388        let m = self.nrows;
389        let n = b.ncols;
390        let k = self.ncols;
391
392        // The kernel outputs row-major. We use the transpose trick:
393        // C = A * B becomes C^T = B^T * A^T.
394        // Create result with swapped dimensions (n×m) which the kernel fills
395        // in row-major, then we interpret as (m×n) column-major.
396        let mut result = crate::core::GemmWithArgmax::<S>::new(n, m);
397
398        unsafe {
399            crate::core::tropical_gemm_with_argmax_portable::<S>(
400                n,
401                m,
402                k,
403                b_ref.as_slice().as_ptr(),
404                k,
405                Transpose::NoTrans,
406                a_ref.as_slice().as_ptr(),
407                m,
408                Transpose::NoTrans,
409                &mut result,
410            );
411        }
412
413        // The result is stored as (n×m) row-major = (m×n) column-major
414        MatWithArgmax {
415            values: Mat {
416                data: result.values,
417                nrows: m,
418                ncols: n,
419            },
420            argmax: result.argmax,
421        }
422    }
423
424    /// Batched tropical matrix multiplication with argmax tracking.
425    ///
426    /// Computes C[i] = A[i] ⊗ B[i] for each pair of matrices in the batch,
427    /// tracking which k-index produced each optimal value.
428    ///
429    /// All matrices in `a_batch` must have the same dimensions, and all
430    /// matrices in `b_batch` must have the same dimensions.
431    ///
432    /// # Panics
433    ///
434    /// Panics if:
435    /// - `a_batch` and `b_batch` have different lengths
436    /// - Matrices in `a_batch` have different dimensions
437    /// - Matrices in `b_batch` have different dimensions
438    /// - Inner dimensions don't match (A.ncols != B.nrows)
439    ///
440    /// # Example
441    ///
442    /// ```
443    /// use tropical_gemm::{Mat, MaxPlus};
444    ///
445    /// let a1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
446    /// let a2 = Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
447    /// let b1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
448    /// let b2 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
449    ///
450    /// let results = Mat::matmul_batched_with_argmax(&[a1, a2], &[b1, b2]);
451    /// assert_eq!(results.len(), 2);
452    /// ```
453    pub fn matmul_batched_with_argmax(
454        a_batch: &[Mat<S>],
455        b_batch: &[Mat<S>],
456    ) -> Vec<MatWithArgmax<S>> {
457        assert_eq!(
458            a_batch.len(),
459            b_batch.len(),
460            "batch sizes must match: {} != {}",
461            a_batch.len(),
462            b_batch.len()
463        );
464
465        if a_batch.is_empty() {
466            return Vec::new();
467        }
468
469        // Validate dimensions
470        let (m, k) = (a_batch[0].nrows, a_batch[0].ncols);
471        let n = b_batch[0].ncols;
472
473        for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
474            assert_eq!(
475                (a.nrows, a.ncols),
476                (m, k),
477                "A[{}] has dimensions {}x{}, expected {}x{}",
478                i,
479                a.nrows,
480                a.ncols,
481                m,
482                k
483            );
484            assert_eq!(
485                (b.nrows, b.ncols),
486                (k, n),
487                "B[{}] has dimensions {}x{}, expected {}x{}",
488                i,
489                b.nrows,
490                b.ncols,
491                k,
492                n
493            );
494        }
495
496        a_batch
497            .iter()
498            .zip(b_batch.iter())
499            .map(|(a, b)| a.matmul_argmax(b))
500            .collect()
501    }
502}
503
504// Batched operations on Mat
505impl<S> Mat<S>
506where
507    S: TropicalSemiring + KernelDispatch,
508    S::Scalar: Copy,
509{
510    /// Batched tropical matrix multiplication.
511    ///
512    /// Computes C[i] = A[i] ⊗ B[i] for each pair of matrices in the batch.
513    /// All matrices in `a_batch` must have the same dimensions, and all
514    /// matrices in `b_batch` must have the same dimensions.
515    ///
516    /// # Panics
517    ///
518    /// Panics if:
519    /// - `a_batch` and `b_batch` have different lengths
520    /// - Matrices in `a_batch` have different dimensions
521    /// - Matrices in `b_batch` have different dimensions
522    /// - Inner dimensions don't match (A.ncols != B.nrows)
523    ///
524    /// # Example
525    ///
526    /// ```
527    /// use tropical_gemm::{Mat, MaxPlus};
528    ///
529    /// let a1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
530    /// let a2 = Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
531    /// let b1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
532    /// let b2 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
533    ///
534    /// let results = Mat::matmul_batched(&[a1, a2], &[b1, b2]);
535    /// assert_eq!(results.len(), 2);
536    /// ```
537    pub fn matmul_batched(a_batch: &[Mat<S>], b_batch: &[Mat<S>]) -> Vec<Mat<S>> {
538        assert_eq!(
539            a_batch.len(),
540            b_batch.len(),
541            "batch sizes must match: {} != {}",
542            a_batch.len(),
543            b_batch.len()
544        );
545
546        if a_batch.is_empty() {
547            return Vec::new();
548        }
549
550        // Validate dimensions
551        let (m, k) = (a_batch[0].nrows, a_batch[0].ncols);
552        let n = b_batch[0].ncols;
553
554        for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
555            assert_eq!(
556                (a.nrows, a.ncols),
557                (m, k),
558                "A[{}] has dimensions {}x{}, expected {}x{}",
559                i,
560                a.nrows,
561                a.ncols,
562                m,
563                k
564            );
565            assert_eq!(
566                (b.nrows, b.ncols),
567                (k, n),
568                "B[{}] has dimensions {}x{}, expected {}x{}",
569                i,
570                b.nrows,
571                b.ncols,
572                k,
573                n
574            );
575        }
576
577        a_batch
578            .iter()
579            .zip(b_batch.iter())
580            .map(|(a, b)| a.matmul(b))
581            .collect()
582    }
583}