tropical_gemm/mat/owned.rs
1//! Owned matrix type.
2
3use std::ops::{Index, IndexMut};
4
5use crate::core::Transpose;
6use crate::simd::{tropical_gemm_dispatch, KernelDispatch};
7use crate::types::{TropicalSemiring, TropicalWithArgmax};
8
9use super::{MatRef, MatWithArgmax};
10
11/// Owned matrix storing semiring values.
12///
13/// The matrix stores values in column-major order (Fortran/BLAS convention).
14/// Use factory methods to create matrices:
15///
16/// ```
17/// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
18///
19/// let zeros = Mat::<MaxPlus<f32>>::zeros(3, 4);
20/// let identity = Mat::<MaxPlus<f32>>::identity(3);
21/// let custom = Mat::<MaxPlus<f32>>::from_fn(2, 2, |i, j| {
22/// MaxPlus::<f32>::from_scalar((i + j) as f32)
23/// });
24/// ```
25#[derive(Debug, Clone)]
26pub struct Mat<S: TropicalSemiring> {
27 pub(crate) data: Vec<S>,
28 pub(crate) nrows: usize,
29 pub(crate) ncols: usize,
30}
31
32impl<S: TropicalSemiring> Mat<S> {
33 /// Create a matrix filled with tropical zeros.
34 ///
35 /// For MaxPlus, this fills with -∞.
36 /// For MinPlus, this fills with +∞.
37 pub fn zeros(nrows: usize, ncols: usize) -> Self {
38 Self {
39 data: vec![S::tropical_zero(); nrows * ncols],
40 nrows,
41 ncols,
42 }
43 }
44
45 /// Create a tropical identity matrix.
46 ///
47 /// Diagonal elements are tropical one (0 for MaxPlus/MinPlus).
48 /// Off-diagonal elements are tropical zero (-∞ for MaxPlus, +∞ for MinPlus).
49 pub fn identity(n: usize) -> Self {
50 let mut mat = Self::zeros(n, n);
51 for i in 0..n {
52 // Column-major: diagonal element (i, i) at index i + i * n
53 mat.data[i + i * n] = S::tropical_one();
54 }
55 mat
56 }
57
58 /// Create a matrix from a function.
59 ///
60 /// The function is called with (row, col) indices.
61 /// Data is stored in column-major order internally.
62 pub fn from_fn<F>(nrows: usize, ncols: usize, mut f: F) -> Self
63 where
64 F: FnMut(usize, usize) -> S,
65 {
66 // Column-major: iterate column by column
67 let data = (0..nrows * ncols)
68 .map(|idx| f(idx % nrows, idx / nrows))
69 .collect();
70 Self { data, nrows, ncols }
71 }
72
73 /// Create a matrix from column-major scalar data.
74 ///
75 /// Each scalar is wrapped in the semiring type.
76 /// Data should be in column-major order: first column, then second column, etc.
77 pub fn from_col_major(data: &[S::Scalar], nrows: usize, ncols: usize) -> Self
78 where
79 S::Scalar: Copy,
80 {
81 assert_eq!(
82 data.len(),
83 nrows * ncols,
84 "data length {} != nrows {} * ncols {}",
85 data.len(),
86 nrows,
87 ncols
88 );
89 let data = data.iter().map(|&s| S::from_scalar(s)).collect();
90 Self { data, nrows, ncols }
91 }
92
93 /// Create a matrix from row-major scalar data.
94 ///
95 /// This is a convenience method that converts row-major input to column-major storage.
96 ///
97 /// # Performance Warning
98 ///
99 /// This method performs an O(m×n) transpose operation. For performance-critical code,
100 /// provide data in column-major order and use [`from_col_major`] instead.
101 #[deprecated(since = "0.4.0", note = "use from_col_major instead for direct column-major input; this method has O(m×n) transpose overhead")]
102 pub fn from_row_major(data: &[S::Scalar], nrows: usize, ncols: usize) -> Self
103 where
104 S::Scalar: Copy,
105 {
106 assert_eq!(
107 data.len(),
108 nrows * ncols,
109 "data length {} != nrows {} * ncols {}",
110 data.len(),
111 nrows,
112 ncols
113 );
114 // Convert row-major to column-major
115 let col_major: Vec<S> = (0..nrows * ncols)
116 .map(|idx| {
117 let i = idx % nrows;
118 let j = idx / nrows;
119 S::from_scalar(data[i * ncols + j])
120 })
121 .collect();
122 Self { data: col_major, nrows, ncols }
123 }
124
125 /// Create a matrix from a vector of semiring values.
126 pub fn from_vec(data: Vec<S>, nrows: usize, ncols: usize) -> Self {
127 assert_eq!(
128 data.len(),
129 nrows * ncols,
130 "data length {} != nrows {} * ncols {}",
131 data.len(),
132 nrows,
133 ncols
134 );
135 Self { data, nrows, ncols }
136 }
137
138 /// Number of rows.
139 #[inline]
140 pub fn nrows(&self) -> usize {
141 self.nrows
142 }
143
144 /// Number of columns.
145 #[inline]
146 pub fn ncols(&self) -> usize {
147 self.ncols
148 }
149
150 /// Get the underlying data as a slice.
151 #[inline]
152 pub fn as_slice(&self) -> &[S] {
153 &self.data
154 }
155
156 /// Get the underlying data as a mutable slice.
157 #[inline]
158 pub fn as_mut_slice(&mut self) -> &mut [S] {
159 &mut self.data
160 }
161
162 /// Get the scalar value at position (i, j).
163 ///
164 /// This is a convenience method that extracts the underlying scalar
165 /// without requiring a trait import.
166 ///
167 /// # Example
168 ///
169 /// ```
170 /// use tropical_gemm::{Mat, MaxPlus};
171 ///
172 /// let m = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
173 /// assert_eq!(m.get_value(0, 0), 1.0);
174 /// assert_eq!(m.get_value(1, 1), 4.0);
175 /// ```
176 #[inline]
177 pub fn get_value(&self, i: usize, j: usize) -> S::Scalar {
178 self[(i, j)].value()
179 }
180
181 /// Convert to an immutable matrix reference.
182 ///
183 /// The returned reference views the scalar values.
184 pub fn as_ref(&self) -> MatRef<'_, S>
185 where
186 S::Scalar: Copy,
187 {
188 // Extract scalars from semiring values
189 // This requires that the data is laid out such that we can get scalars
190 // For now, we create a view that extracts values on-the-fly
191 // This is a limitation - ideally we'd have a separate scalar buffer
192 MatRef::from_mat(self)
193 }
194
195 /// Get a mutable pointer to the data.
196 #[inline]
197 pub fn as_mut_ptr(&mut self) -> *mut S {
198 self.data.as_mut_ptr()
199 }
200}
201
202impl<S: TropicalSemiring> Index<(usize, usize)> for Mat<S> {
203 type Output = S;
204
205 #[inline]
206 fn index(&self, (i, j): (usize, usize)) -> &S {
207 debug_assert!(
208 i < self.nrows,
209 "row index {} out of bounds {}",
210 i,
211 self.nrows
212 );
213 debug_assert!(
214 j < self.ncols,
215 "col index {} out of bounds {}",
216 j,
217 self.ncols
218 );
219 // Column-major indexing
220 &self.data[j * self.nrows + i]
221 }
222}
223
224impl<S: TropicalSemiring> IndexMut<(usize, usize)> for Mat<S> {
225 #[inline]
226 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut S {
227 debug_assert!(
228 i < self.nrows,
229 "row index {} out of bounds {}",
230 i,
231 self.nrows
232 );
233 debug_assert!(
234 j < self.ncols,
235 "col index {} out of bounds {}",
236 j,
237 self.ncols
238 );
239 // Column-major indexing
240 &mut self.data[j * self.nrows + i]
241 }
242}
243
244// Matrix multiplication methods directly on Mat
245impl<S> Mat<S>
246where
247 S: TropicalSemiring + KernelDispatch,
248 S::Scalar: Copy,
249{
250 /// Perform tropical matrix multiplication: C = A ⊗ B.
251 ///
252 /// Computes C[i,j] = ⊕_k (A[i,k] ⊗ B[k,j])
253 ///
254 /// # Panics
255 ///
256 /// Panics if dimensions don't match (self.ncols != b.nrows).
257 ///
258 /// # Example
259 ///
260 /// ```
261 /// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
262 ///
263 /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
264 /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
265 ///
266 /// let c = a.matmul(&b);
267 ///
268 /// // C[0,0] = max(1+1, 2+3, 3+5) = 8
269 /// assert_eq!(c[(0, 0)].value(), 8.0);
270 /// ```
271 pub fn matmul(&self, b: &Mat<S>) -> Mat<S> {
272 assert_eq!(
273 self.ncols, b.nrows,
274 "dimension mismatch: A is {}x{}, B is {}x{}",
275 self.nrows, self.ncols, b.nrows, b.ncols
276 );
277
278 let a_ref = self.as_ref();
279 let b_ref = b.as_ref();
280
281 let m = self.nrows;
282 let n = b.ncols;
283 let k = self.ncols;
284
285 let mut c = Mat::<S>::zeros(m, n);
286
287 // The kernel uses row-major convention. For column-major data,
288 // we use the transpose trick: C = A * B becomes C^T = B^T * A^T.
289 // Column-major A (m×k) viewed as row-major is A^T (k×m) with ld=m.
290 // So we swap A and B, swap m and n, and the result is written
291 // in the correct column-major layout.
292 unsafe {
293 tropical_gemm_dispatch::<S>(
294 n, // rows of C^T = cols of C
295 m, // cols of C^T = rows of C
296 k,
297 b_ref.as_slice().as_ptr(), // B becomes first operand (B^T)
298 k, // lda = nrows of B in col-major
299 Transpose::NoTrans,
300 a_ref.as_slice().as_ptr(), // A becomes second operand (A^T)
301 m, // ldb = nrows of A in col-major
302 Transpose::NoTrans,
303 c.data.as_mut_ptr(),
304 m, // ldc = nrows of C in col-major
305 );
306 }
307
308 c
309 }
310
311 /// Perform tropical matrix multiplication with a MatRef.
312 ///
313 /// This allows mixing owned and reference matrices.
314 pub fn matmul_ref(&self, b: &MatRef<S>) -> Mat<S> {
315 assert_eq!(
316 self.ncols,
317 b.nrows(),
318 "dimension mismatch: A is {}x{}, B is {}x{}",
319 self.nrows,
320 self.ncols,
321 b.nrows(),
322 b.ncols()
323 );
324
325 let a_ref = self.as_ref();
326
327 let m = self.nrows;
328 let n = b.ncols();
329 let k = self.ncols;
330
331 let mut c = Mat::<S>::zeros(m, n);
332
333 // Transpose trick for column-major: C = A * B becomes C^T = B^T * A^T
334 unsafe {
335 tropical_gemm_dispatch::<S>(
336 n,
337 m,
338 k,
339 b.as_slice().as_ptr(),
340 k,
341 Transpose::NoTrans,
342 a_ref.as_slice().as_ptr(),
343 m,
344 Transpose::NoTrans,
345 c.data.as_mut_ptr(),
346 m,
347 );
348 }
349
350 c
351 }
352}
353
354// Argmax methods on Mat
355impl<S> Mat<S>
356where
357 S: TropicalWithArgmax<Index = u32> + KernelDispatch,
358 S::Scalar: Copy,
359{
360 /// Perform tropical matrix multiplication with argmax tracking.
361 ///
362 /// Returns both the result matrix and the argmax indices indicating
363 /// which k-index produced each optimal value.
364 ///
365 /// # Example
366 ///
367 /// ```
368 /// use tropical_gemm::{Mat, MaxPlus, TropicalSemiring};
369 ///
370 /// let a = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
371 /// let b = Mat::<MaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
372 ///
373 /// let result = a.matmul_argmax(&b);
374 ///
375 /// assert_eq!(result.get(0, 0).value(), 8.0);
376 /// assert_eq!(result.get_argmax(0, 0), 2); // k=2 gave max
377 /// ```
378 pub fn matmul_argmax(&self, b: &Mat<S>) -> MatWithArgmax<S> {
379 assert_eq!(
380 self.ncols, b.nrows,
381 "dimension mismatch: A is {}x{}, B is {}x{}",
382 self.nrows, self.ncols, b.nrows, b.ncols
383 );
384
385 let a_ref = self.as_ref();
386 let b_ref = b.as_ref();
387
388 let m = self.nrows;
389 let n = b.ncols;
390 let k = self.ncols;
391
392 // The kernel outputs row-major. We use the transpose trick:
393 // C = A * B becomes C^T = B^T * A^T.
394 // Create result with swapped dimensions (n×m) which the kernel fills
395 // in row-major, then we interpret as (m×n) column-major.
396 let mut result = crate::core::GemmWithArgmax::<S>::new(n, m);
397
398 unsafe {
399 crate::core::tropical_gemm_with_argmax_portable::<S>(
400 n,
401 m,
402 k,
403 b_ref.as_slice().as_ptr(),
404 k,
405 Transpose::NoTrans,
406 a_ref.as_slice().as_ptr(),
407 m,
408 Transpose::NoTrans,
409 &mut result,
410 );
411 }
412
413 // The result is stored as (n×m) row-major = (m×n) column-major
414 MatWithArgmax {
415 values: Mat {
416 data: result.values,
417 nrows: m,
418 ncols: n,
419 },
420 argmax: result.argmax,
421 }
422 }
423
424 /// Batched tropical matrix multiplication with argmax tracking.
425 ///
426 /// Computes C[i] = A[i] ⊗ B[i] for each pair of matrices in the batch,
427 /// tracking which k-index produced each optimal value.
428 ///
429 /// All matrices in `a_batch` must have the same dimensions, and all
430 /// matrices in `b_batch` must have the same dimensions.
431 ///
432 /// # Panics
433 ///
434 /// Panics if:
435 /// - `a_batch` and `b_batch` have different lengths
436 /// - Matrices in `a_batch` have different dimensions
437 /// - Matrices in `b_batch` have different dimensions
438 /// - Inner dimensions don't match (A.ncols != B.nrows)
439 ///
440 /// # Example
441 ///
442 /// ```
443 /// use tropical_gemm::{Mat, MaxPlus};
444 ///
445 /// let a1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
446 /// let a2 = Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
447 /// let b1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
448 /// let b2 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
449 ///
450 /// let results = Mat::matmul_batched_with_argmax(&[a1, a2], &[b1, b2]);
451 /// assert_eq!(results.len(), 2);
452 /// ```
453 pub fn matmul_batched_with_argmax(
454 a_batch: &[Mat<S>],
455 b_batch: &[Mat<S>],
456 ) -> Vec<MatWithArgmax<S>> {
457 assert_eq!(
458 a_batch.len(),
459 b_batch.len(),
460 "batch sizes must match: {} != {}",
461 a_batch.len(),
462 b_batch.len()
463 );
464
465 if a_batch.is_empty() {
466 return Vec::new();
467 }
468
469 // Validate dimensions
470 let (m, k) = (a_batch[0].nrows, a_batch[0].ncols);
471 let n = b_batch[0].ncols;
472
473 for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
474 assert_eq!(
475 (a.nrows, a.ncols),
476 (m, k),
477 "A[{}] has dimensions {}x{}, expected {}x{}",
478 i,
479 a.nrows,
480 a.ncols,
481 m,
482 k
483 );
484 assert_eq!(
485 (b.nrows, b.ncols),
486 (k, n),
487 "B[{}] has dimensions {}x{}, expected {}x{}",
488 i,
489 b.nrows,
490 b.ncols,
491 k,
492 n
493 );
494 }
495
496 a_batch
497 .iter()
498 .zip(b_batch.iter())
499 .map(|(a, b)| a.matmul_argmax(b))
500 .collect()
501 }
502}
503
504// Batched operations on Mat
505impl<S> Mat<S>
506where
507 S: TropicalSemiring + KernelDispatch,
508 S::Scalar: Copy,
509{
510 /// Batched tropical matrix multiplication.
511 ///
512 /// Computes C[i] = A[i] ⊗ B[i] for each pair of matrices in the batch.
513 /// All matrices in `a_batch` must have the same dimensions, and all
514 /// matrices in `b_batch` must have the same dimensions.
515 ///
516 /// # Panics
517 ///
518 /// Panics if:
519 /// - `a_batch` and `b_batch` have different lengths
520 /// - Matrices in `a_batch` have different dimensions
521 /// - Matrices in `b_batch` have different dimensions
522 /// - Inner dimensions don't match (A.ncols != B.nrows)
523 ///
524 /// # Example
525 ///
526 /// ```
527 /// use tropical_gemm::{Mat, MaxPlus};
528 ///
529 /// let a1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
530 /// let a2 = Mat::<MaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
531 /// let b1 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
532 /// let b2 = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
533 ///
534 /// let results = Mat::matmul_batched(&[a1, a2], &[b1, b2]);
535 /// assert_eq!(results.len(), 2);
536 /// ```
537 pub fn matmul_batched(a_batch: &[Mat<S>], b_batch: &[Mat<S>]) -> Vec<Mat<S>> {
538 assert_eq!(
539 a_batch.len(),
540 b_batch.len(),
541 "batch sizes must match: {} != {}",
542 a_batch.len(),
543 b_batch.len()
544 );
545
546 if a_batch.is_empty() {
547 return Vec::new();
548 }
549
550 // Validate dimensions
551 let (m, k) = (a_batch[0].nrows, a_batch[0].ncols);
552 let n = b_batch[0].ncols;
553
554 for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
555 assert_eq!(
556 (a.nrows, a.ncols),
557 (m, k),
558 "A[{}] has dimensions {}x{}, expected {}x{}",
559 i,
560 a.nrows,
561 a.ncols,
562 m,
563 k
564 );
565 assert_eq!(
566 (b.nrows, b.ncols),
567 (k, n),
568 "B[{}] has dimensions {}x{}, expected {}x{}",
569 i,
570 b.nrows,
571 b.ncols,
572 k,
573 n
574 );
575 }
576
577 a_batch
578 .iter()
579 .zip(b_batch.iter())
580 .map(|(a, b)| a.matmul(b))
581 .collect()
582 }
583}