tropical_gemm/mat/
mut_.rs1use crate::types::TropicalSemiring;
4
5#[derive(Debug)]
10pub struct MatMut<'a, S: TropicalSemiring> {
11 data: &'a mut [S],
12 nrows: usize,
13 ncols: usize,
14}
15
16impl<'a, S: TropicalSemiring> MatMut<'a, S> {
17 pub fn from_slice(data: &'a mut [S], nrows: usize, ncols: usize) -> Self {
19 assert_eq!(
20 data.len(),
21 nrows * ncols,
22 "data length {} != nrows {} * ncols {}",
23 data.len(),
24 nrows,
25 ncols
26 );
27 Self { data, nrows, ncols }
28 }
29
30 #[inline]
32 pub fn nrows(&self) -> usize {
33 self.nrows
34 }
35
36 #[inline]
38 pub fn ncols(&self) -> usize {
39 self.ncols
40 }
41
42 #[inline]
44 pub fn as_mut_slice(&mut self) -> &mut [S] {
45 self.data
46 }
47
48 #[inline]
50 pub fn as_mut_ptr(&mut self) -> *mut S {
51 self.data.as_mut_ptr()
52 }
53
54 #[inline]
56 pub fn get(&self, i: usize, j: usize) -> &S {
57 debug_assert!(
58 i < self.nrows,
59 "row index {} out of bounds {}",
60 i,
61 self.nrows
62 );
63 debug_assert!(
64 j < self.ncols,
65 "col index {} out of bounds {}",
66 j,
67 self.ncols
68 );
69 &self.data[j * self.nrows + i]
71 }
72
73 #[inline]
75 pub fn get_mut(&mut self, i: usize, j: usize) -> &mut S {
76 debug_assert!(
77 i < self.nrows,
78 "row index {} out of bounds {}",
79 i,
80 self.nrows
81 );
82 debug_assert!(
83 j < self.ncols,
84 "col index {} out of bounds {}",
85 j,
86 self.ncols
87 );
88 &mut self.data[j * self.nrows + i]
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::TropicalMaxPlus;
97
98 #[test]
99 fn test_matmut_from_slice() {
100 let mut data = vec![
101 TropicalMaxPlus(1.0f64),
102 TropicalMaxPlus(2.0),
103 TropicalMaxPlus(3.0),
104 TropicalMaxPlus(4.0),
105 ];
106 let m = MatMut::from_slice(&mut data, 2, 2);
107 assert_eq!(m.nrows(), 2);
108 assert_eq!(m.ncols(), 2);
109 }
110
111 #[test]
112 fn test_matmut_get() {
113 let mut data = vec![
116 TropicalMaxPlus(1.0f64),
117 TropicalMaxPlus(3.0),
118 TropicalMaxPlus(2.0),
119 TropicalMaxPlus(4.0),
120 ];
121 let m = MatMut::from_slice(&mut data, 2, 2);
122 assert_eq!(m.get(0, 0).0, 1.0);
123 assert_eq!(m.get(0, 1).0, 2.0);
124 assert_eq!(m.get(1, 0).0, 3.0);
125 assert_eq!(m.get(1, 1).0, 4.0);
126 }
127
128 #[test]
129 fn test_matmut_get_mut() {
130 let mut data = vec![
131 TropicalMaxPlus(1.0f64),
132 TropicalMaxPlus(2.0),
133 TropicalMaxPlus(3.0),
134 TropicalMaxPlus(4.0),
135 ];
136 let mut m = MatMut::from_slice(&mut data, 2, 2);
137 *m.get_mut(0, 0) = TropicalMaxPlus(10.0);
138 assert_eq!(m.get(0, 0).0, 10.0);
139 }
140
141 #[test]
142 fn test_matmut_as_mut_slice() {
143 let mut data = vec![
144 TropicalMaxPlus(1.0f64),
145 TropicalMaxPlus(2.0),
146 TropicalMaxPlus(3.0),
147 TropicalMaxPlus(4.0),
148 ];
149 let mut m = MatMut::from_slice(&mut data, 2, 2);
150 let slice = m.as_mut_slice();
151 slice[0] = TropicalMaxPlus(100.0);
152 assert_eq!(data[0].0, 100.0);
153 }
154
155 #[test]
156 fn test_matmut_as_mut_ptr() {
157 let mut data = vec![
158 TropicalMaxPlus(1.0f64),
159 TropicalMaxPlus(2.0),
160 TropicalMaxPlus(3.0),
161 TropicalMaxPlus(4.0),
162 ];
163 let mut m = MatMut::from_slice(&mut data, 2, 2);
164 let ptr = m.as_mut_ptr();
165 assert!(!ptr.is_null());
166 }
167
168 #[test]
169 fn test_matmut_debug() {
170 let mut data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
171 let m = MatMut::from_slice(&mut data, 1, 2);
172 let debug_str = format!("{:?}", m);
173 assert!(debug_str.contains("MatMut"));
174 }
175
176 #[test]
177 #[should_panic(expected = "data length")]
178 fn test_matmut_size_mismatch() {
179 let mut data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
180 let _ = MatMut::from_slice(&mut data, 2, 2); }
182}