tropical_gemm/mat/
mut_.rs

1//! Mutable matrix reference type.
2
3use crate::types::TropicalSemiring;
4
5/// Mutable view over semiring data.
6///
7/// Unlike `MatRef`, this holds mutable references to semiring values,
8/// not scalars. This is used for in-place operations.
9#[derive(Debug)]
10pub struct MatMut<'a, S: TropicalSemiring> {
11    data: &'a mut [S],
12    nrows: usize,
13    ncols: usize,
14}
15
16impl<'a, S: TropicalSemiring> MatMut<'a, S> {
17    /// Create a mutable matrix reference from a slice.
18    pub fn from_slice(data: &'a mut [S], nrows: usize, ncols: usize) -> Self {
19        assert_eq!(
20            data.len(),
21            nrows * ncols,
22            "data length {} != nrows {} * ncols {}",
23            data.len(),
24            nrows,
25            ncols
26        );
27        Self { data, nrows, ncols }
28    }
29
30    /// Number of rows.
31    #[inline]
32    pub fn nrows(&self) -> usize {
33        self.nrows
34    }
35
36    /// Number of columns.
37    #[inline]
38    pub fn ncols(&self) -> usize {
39        self.ncols
40    }
41
42    /// Get the underlying data as a mutable slice.
43    #[inline]
44    pub fn as_mut_slice(&mut self) -> &mut [S] {
45        self.data
46    }
47
48    /// Get a mutable pointer to the data.
49    #[inline]
50    pub fn as_mut_ptr(&mut self) -> *mut S {
51        self.data.as_mut_ptr()
52    }
53
54    /// Get a reference to the value at position (i, j).
55    #[inline]
56    pub fn get(&self, i: usize, j: usize) -> &S {
57        debug_assert!(
58            i < self.nrows,
59            "row index {} out of bounds {}",
60            i,
61            self.nrows
62        );
63        debug_assert!(
64            j < self.ncols,
65            "col index {} out of bounds {}",
66            j,
67            self.ncols
68        );
69        // Column-major indexing
70        &self.data[j * self.nrows + i]
71    }
72
73    /// Get a mutable reference to the value at position (i, j).
74    #[inline]
75    pub fn get_mut(&mut self, i: usize, j: usize) -> &mut S {
76        debug_assert!(
77            i < self.nrows,
78            "row index {} out of bounds {}",
79            i,
80            self.nrows
81        );
82        debug_assert!(
83            j < self.ncols,
84            "col index {} out of bounds {}",
85            j,
86            self.ncols
87        );
88        // Column-major indexing
89        &mut self.data[j * self.nrows + i]
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::TropicalMaxPlus;
97
98    #[test]
99    fn test_matmut_from_slice() {
100        let mut data = vec![
101            TropicalMaxPlus(1.0f64),
102            TropicalMaxPlus(2.0),
103            TropicalMaxPlus(3.0),
104            TropicalMaxPlus(4.0),
105        ];
106        let m = MatMut::from_slice(&mut data, 2, 2);
107        assert_eq!(m.nrows(), 2);
108        assert_eq!(m.ncols(), 2);
109    }
110
111    #[test]
112    fn test_matmut_get() {
113        // Column-major: data stored column-by-column
114        // For 2×2 matrix [[1,2],[3,4]], col-major is [1,3,2,4]
115        let mut data = vec![
116            TropicalMaxPlus(1.0f64),
117            TropicalMaxPlus(3.0),
118            TropicalMaxPlus(2.0),
119            TropicalMaxPlus(4.0),
120        ];
121        let m = MatMut::from_slice(&mut data, 2, 2);
122        assert_eq!(m.get(0, 0).0, 1.0);
123        assert_eq!(m.get(0, 1).0, 2.0);
124        assert_eq!(m.get(1, 0).0, 3.0);
125        assert_eq!(m.get(1, 1).0, 4.0);
126    }
127
128    #[test]
129    fn test_matmut_get_mut() {
130        let mut data = vec![
131            TropicalMaxPlus(1.0f64),
132            TropicalMaxPlus(2.0),
133            TropicalMaxPlus(3.0),
134            TropicalMaxPlus(4.0),
135        ];
136        let mut m = MatMut::from_slice(&mut data, 2, 2);
137        *m.get_mut(0, 0) = TropicalMaxPlus(10.0);
138        assert_eq!(m.get(0, 0).0, 10.0);
139    }
140
141    #[test]
142    fn test_matmut_as_mut_slice() {
143        let mut data = vec![
144            TropicalMaxPlus(1.0f64),
145            TropicalMaxPlus(2.0),
146            TropicalMaxPlus(3.0),
147            TropicalMaxPlus(4.0),
148        ];
149        let mut m = MatMut::from_slice(&mut data, 2, 2);
150        let slice = m.as_mut_slice();
151        slice[0] = TropicalMaxPlus(100.0);
152        assert_eq!(data[0].0, 100.0);
153    }
154
155    #[test]
156    fn test_matmut_as_mut_ptr() {
157        let mut data = vec![
158            TropicalMaxPlus(1.0f64),
159            TropicalMaxPlus(2.0),
160            TropicalMaxPlus(3.0),
161            TropicalMaxPlus(4.0),
162        ];
163        let mut m = MatMut::from_slice(&mut data, 2, 2);
164        let ptr = m.as_mut_ptr();
165        assert!(!ptr.is_null());
166    }
167
168    #[test]
169    fn test_matmut_debug() {
170        let mut data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
171        let m = MatMut::from_slice(&mut data, 1, 2);
172        let debug_str = format!("{:?}", m);
173        assert!(debug_str.contains("MatMut"));
174    }
175
176    #[test]
177    #[should_panic(expected = "data length")]
178    fn test_matmut_size_mismatch() {
179        let mut data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
180        let _ = MatMut::from_slice(&mut data, 2, 2); // Should panic
181    }
182}