tropical_gemm/mat/
ref_.rs

1//! Immutable matrix reference type.
2
3use std::marker::PhantomData;
4
5use crate::core::Transpose;
6use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
7use crate::types::{TropicalSemiring, TropicalWithArgmax};
8
9use super::{Mat, MatWithArgmax};
10
11/// Immutable view over scalar data interpreted as a tropical matrix.
12///
13/// This is a lightweight view type that can be copied freely.
14/// It references scalar data and interprets operations using the
15/// specified semiring type.
16///
17/// ```
18/// use tropical_gemm::{MatRef, MaxPlus};
19///
20/// let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
21/// let a = MatRef::<MaxPlus<f32>>::from_slice(&data, 2, 3);
22///
23/// assert_eq!(a.nrows(), 2);
24/// assert_eq!(a.ncols(), 3);
25/// assert_eq!(a.get(0, 0), 1.0);
26/// ```
27#[derive(Debug)]
28pub struct MatRef<'a, S: TropicalSemiring> {
29    data: &'a [S::Scalar],
30    nrows: usize,
31    ncols: usize,
32    _phantom: PhantomData<S>,
33}
34
35impl<'a, S: TropicalSemiring> Copy for MatRef<'a, S> {}
36
37impl<'a, S: TropicalSemiring> Clone for MatRef<'a, S> {
38    fn clone(&self) -> Self {
39        *self
40    }
41}
42
43impl<'a, S: TropicalSemiring> MatRef<'a, S> {
44    /// Create a matrix reference from a slice of scalars.
45    ///
46    /// The data must be in column-major order with length `nrows * ncols`.
47    pub fn from_slice(data: &'a [S::Scalar], nrows: usize, ncols: usize) -> Self {
48        assert_eq!(
49            data.len(),
50            nrows * ncols,
51            "data length {} != nrows {} * ncols {}",
52            data.len(),
53            nrows,
54            ncols
55        );
56        Self {
57            data,
58            nrows,
59            ncols,
60            _phantom: PhantomData,
61        }
62    }
63
64    /// Create a matrix reference from an owned Mat.
65    ///
66    /// This extracts the scalar values from the semiring wrapper.
67    pub(crate) fn from_mat(mat: &'a Mat<S>) -> Self
68    where
69        S::Scalar: Copy,
70    {
71        // We need to get scalars from the Mat<S> which stores S values
72        // Since S wraps Scalar, we can use value() to extract
73        // But MatRef needs &[Scalar], not &[S]
74        // This is a design tension - for now we'll use unsafe transmute
75        // since S is repr(transparent) over Scalar
76        //
77        // Safety: TropicalMaxPlus<T>, TropicalMinPlus<T>, etc. are all
78        // repr(transparent) newtype wrappers over T
79        let scalar_slice = unsafe {
80            std::slice::from_raw_parts(mat.data.as_ptr() as *const S::Scalar, mat.data.len())
81        };
82        Self {
83            data: scalar_slice,
84            nrows: mat.nrows,
85            ncols: mat.ncols,
86            _phantom: PhantomData,
87        }
88    }
89
90    /// Number of rows.
91    #[inline]
92    pub fn nrows(&self) -> usize {
93        self.nrows
94    }
95
96    /// Number of columns.
97    #[inline]
98    pub fn ncols(&self) -> usize {
99        self.ncols
100    }
101
102    /// Get the underlying scalar data.
103    #[inline]
104    pub fn as_slice(&self) -> &[S::Scalar] {
105        self.data
106    }
107
108    /// Get the scalar value at position (i, j).
109    #[inline]
110    pub fn get(&self, i: usize, j: usize) -> S::Scalar
111    where
112        S::Scalar: Copy,
113    {
114        debug_assert!(
115            i < self.nrows,
116            "row index {} out of bounds {}",
117            i,
118            self.nrows
119        );
120        debug_assert!(
121            j < self.ncols,
122            "col index {} out of bounds {}",
123            j,
124            self.ncols
125        );
126        // Column-major indexing
127        self.data[j * self.nrows + i]
128    }
129
130    /// Convert to an owned matrix.
131    pub fn to_owned(&self) -> Mat<S>
132    where
133        S::Scalar: Copy,
134    {
135        Mat::from_col_major(self.data, self.nrows, self.ncols)
136    }
137}
138
139// Matrix multiplication methods
140impl<'a, S: TropicalSemiring + KernelDispatch> MatRef<'a, S> {
141    /// Perform tropical matrix multiplication: C = A ⊗ B.
142    ///
143    /// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
144    ///
145    /// # Panics
146    ///
147    /// Panics if dimensions don't match (self.ncols != b.nrows).
148    pub fn matmul(&self, b: &MatRef<S>) -> Mat<S> {
149        assert_eq!(
150            self.ncols, b.nrows,
151            "dimension mismatch: A is {}x{}, B is {}x{}",
152            self.nrows, self.ncols, b.nrows, b.ncols
153        );
154
155        let m = self.nrows;
156        let n = b.ncols;
157        let k = self.ncols;
158
159        let mut c = Mat::<S>::zeros(m, n);
160
161        // Transpose trick for column-major: C = A * B becomes C^T = B^T * A^T
162        unsafe {
163            tropical_gemm_dispatch::<S>(
164                n,
165                m,
166                k,
167                b.data.as_ptr(),
168                k,
169                Transpose::NoTrans,
170                self.data.as_ptr(),
171                m,
172                Transpose::NoTrans,
173                c.data.as_mut_ptr(),
174                m,
175            );
176        }
177
178        c
179    }
180}
181
182// Argmax methods (separate impl block for different trait bounds)
183impl<'a, S> MatRef<'a, S>
184where
185    S: TropicalWithArgmax<Index = u32> + KernelDispatch,
186{
187    /// Perform tropical matrix multiplication with argmax tracking.
188    ///
189    /// Returns both the result matrix and the argmax indices indicating
190    /// which k-index produced each optimal value.
191    ///
192    /// # Panics
193    ///
194    /// Panics if dimensions don't match (self.ncols != b.nrows).
195    pub fn matmul_argmax(&self, b: &MatRef<S>) -> MatWithArgmax<S> {
196        assert_eq!(
197            self.ncols, b.nrows,
198            "dimension mismatch: A is {}x{}, B is {}x{}",
199            self.nrows, self.ncols, b.nrows, b.ncols
200        );
201
202        let m = self.nrows;
203        let n = b.ncols;
204        let k = self.ncols;
205
206        // Transpose trick for column-major: output (n×m) row-major = (m×n) col-major
207        let mut result = crate::core::GemmWithArgmax::<S>::new(n, m);
208
209        unsafe {
210            crate::core::tropical_gemm_with_argmax_portable::<S>(
211                n,
212                m,
213                k,
214                b.data.as_ptr(),
215                k,
216                Transpose::NoTrans,
217                self.data.as_ptr(),
218                m,
219                Transpose::NoTrans,
220                &mut result,
221            );
222        }
223
224        MatWithArgmax {
225            values: Mat {
226                data: result.values,
227                nrows: m,
228                ncols: n,
229            },
230            argmax: result.argmax,
231        }
232    }
233}