tropical_gemm/mat/ref_.rs
1//! Immutable matrix reference type.
2
3use std::marker::PhantomData;
4
5use crate::core::Transpose;
6use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
7use crate::types::{TropicalSemiring, TropicalWithArgmax};
8
9use super::{Mat, MatWithArgmax};
10
11/// Immutable view over scalar data interpreted as a tropical matrix.
12///
13/// This is a lightweight view type that can be copied freely.
14/// It references scalar data and interprets operations using the
15/// specified semiring type.
16///
17/// ```
18/// use tropical_gemm::{MatRef, MaxPlus};
19///
20/// let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
21/// let a = MatRef::<MaxPlus<f32>>::from_slice(&data, 2, 3);
22///
23/// assert_eq!(a.nrows(), 2);
24/// assert_eq!(a.ncols(), 3);
25/// assert_eq!(a.get(0, 0), 1.0);
26/// ```
27#[derive(Debug)]
28pub struct MatRef<'a, S: TropicalSemiring> {
29 data: &'a [S::Scalar],
30 nrows: usize,
31 ncols: usize,
32 _phantom: PhantomData<S>,
33}
34
35impl<'a, S: TropicalSemiring> Copy for MatRef<'a, S> {}
36
37impl<'a, S: TropicalSemiring> Clone for MatRef<'a, S> {
38 fn clone(&self) -> Self {
39 *self
40 }
41}
42
43impl<'a, S: TropicalSemiring> MatRef<'a, S> {
44 /// Create a matrix reference from a slice of scalars.
45 ///
46 /// The data must be in column-major order with length `nrows * ncols`.
47 pub fn from_slice(data: &'a [S::Scalar], nrows: usize, ncols: usize) -> Self {
48 assert_eq!(
49 data.len(),
50 nrows * ncols,
51 "data length {} != nrows {} * ncols {}",
52 data.len(),
53 nrows,
54 ncols
55 );
56 Self {
57 data,
58 nrows,
59 ncols,
60 _phantom: PhantomData,
61 }
62 }
63
64 /// Create a matrix reference from an owned Mat.
65 ///
66 /// This extracts the scalar values from the semiring wrapper.
67 pub(crate) fn from_mat(mat: &'a Mat<S>) -> Self
68 where
69 S::Scalar: Copy,
70 {
71 // We need to get scalars from the Mat<S> which stores S values
72 // Since S wraps Scalar, we can use value() to extract
73 // But MatRef needs &[Scalar], not &[S]
74 // This is a design tension - for now we'll use unsafe transmute
75 // since S is repr(transparent) over Scalar
76 //
77 // Safety: TropicalMaxPlus<T>, TropicalMinPlus<T>, etc. are all
78 // repr(transparent) newtype wrappers over T
79 let scalar_slice = unsafe {
80 std::slice::from_raw_parts(mat.data.as_ptr() as *const S::Scalar, mat.data.len())
81 };
82 Self {
83 data: scalar_slice,
84 nrows: mat.nrows,
85 ncols: mat.ncols,
86 _phantom: PhantomData,
87 }
88 }
89
90 /// Number of rows.
91 #[inline]
92 pub fn nrows(&self) -> usize {
93 self.nrows
94 }
95
96 /// Number of columns.
97 #[inline]
98 pub fn ncols(&self) -> usize {
99 self.ncols
100 }
101
102 /// Get the underlying scalar data.
103 #[inline]
104 pub fn as_slice(&self) -> &[S::Scalar] {
105 self.data
106 }
107
108 /// Get the scalar value at position (i, j).
109 #[inline]
110 pub fn get(&self, i: usize, j: usize) -> S::Scalar
111 where
112 S::Scalar: Copy,
113 {
114 debug_assert!(
115 i < self.nrows,
116 "row index {} out of bounds {}",
117 i,
118 self.nrows
119 );
120 debug_assert!(
121 j < self.ncols,
122 "col index {} out of bounds {}",
123 j,
124 self.ncols
125 );
126 // Column-major indexing
127 self.data[j * self.nrows + i]
128 }
129
130 /// Convert to an owned matrix.
131 pub fn to_owned(&self) -> Mat<S>
132 where
133 S::Scalar: Copy,
134 {
135 Mat::from_col_major(self.data, self.nrows, self.ncols)
136 }
137}
138
139// Matrix multiplication methods
140impl<'a, S: TropicalSemiring + KernelDispatch> MatRef<'a, S> {
141 /// Perform tropical matrix multiplication: C = A ⊗ B.
142 ///
143 /// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
144 ///
145 /// # Panics
146 ///
147 /// Panics if dimensions don't match (self.ncols != b.nrows).
148 pub fn matmul(&self, b: &MatRef<S>) -> Mat<S> {
149 assert_eq!(
150 self.ncols, b.nrows,
151 "dimension mismatch: A is {}x{}, B is {}x{}",
152 self.nrows, self.ncols, b.nrows, b.ncols
153 );
154
155 let m = self.nrows;
156 let n = b.ncols;
157 let k = self.ncols;
158
159 let mut c = Mat::<S>::zeros(m, n);
160
161 // Transpose trick for column-major: C = A * B becomes C^T = B^T * A^T
162 unsafe {
163 tropical_gemm_dispatch::<S>(
164 n,
165 m,
166 k,
167 b.data.as_ptr(),
168 k,
169 Transpose::NoTrans,
170 self.data.as_ptr(),
171 m,
172 Transpose::NoTrans,
173 c.data.as_mut_ptr(),
174 m,
175 );
176 }
177
178 c
179 }
180}
181
182// Argmax methods (separate impl block for different trait bounds)
183impl<'a, S> MatRef<'a, S>
184where
185 S: TropicalWithArgmax<Index = u32> + KernelDispatch,
186{
187 /// Perform tropical matrix multiplication with argmax tracking.
188 ///
189 /// Returns both the result matrix and the argmax indices indicating
190 /// which k-index produced each optimal value.
191 ///
192 /// # Panics
193 ///
194 /// Panics if dimensions don't match (self.ncols != b.nrows).
195 pub fn matmul_argmax(&self, b: &MatRef<S>) -> MatWithArgmax<S> {
196 assert_eq!(
197 self.ncols, b.nrows,
198 "dimension mismatch: A is {}x{}, B is {}x{}",
199 self.nrows, self.ncols, b.nrows, b.ncols
200 );
201
202 let m = self.nrows;
203 let n = b.ncols;
204 let k = self.ncols;
205
206 // Transpose trick for column-major: output (n×m) row-major = (m×n) col-major
207 let mut result = crate::core::GemmWithArgmax::<S>::new(n, m);
208
209 unsafe {
210 crate::core::tropical_gemm_with_argmax_portable::<S>(
211 n,
212 m,
213 k,
214 b.data.as_ptr(),
215 k,
216 Transpose::NoTrans,
217 self.data.as_ptr(),
218 m,
219 Transpose::NoTrans,
220 &mut result,
221 );
222 }
223
224 MatWithArgmax {
225 values: Mat {
226 data: result.values,
227 nrows: m,
228 ncols: n,
229 },
230 argmax: result.argmax,
231 }
232 }
233}