1use crate::types::TropicalWithArgmax;
2
3#[derive(Debug, Clone)]
9pub struct GemmWithArgmax<T: TropicalWithArgmax<Index = u32>> {
10 pub values: Vec<T>,
12 pub argmax: Vec<u32>,
14 pub m: usize,
16 pub n: usize,
18 pub ld: usize,
20}
21
22impl<T: TropicalWithArgmax<Index = u32>> GemmWithArgmax<T> {
23 pub fn new(m: usize, n: usize) -> Self {
25 let size = m * n;
26 Self {
27 values: vec![T::tropical_zero(); size],
28 argmax: vec![0u32; size],
29 m,
30 n,
31 ld: n,
32 }
33 }
34
35 pub fn with_ld(m: usize, n: usize, ld: usize) -> Self {
37 assert!(ld >= n, "Leading dimension must be >= n");
38 let size = m * ld;
39 Self {
40 values: vec![T::tropical_zero(); size],
41 argmax: vec![0u32; size],
42 m,
43 n,
44 ld,
45 }
46 }
47
48 #[inline]
50 pub fn get(&self, i: usize, j: usize) -> T {
51 debug_assert!(i < self.m && j < self.n);
52 self.values[i * self.ld + j]
53 }
54
55 #[inline]
57 pub fn get_argmax(&self, i: usize, j: usize) -> u32 {
58 debug_assert!(i < self.m && j < self.n);
59 self.argmax[i * self.ld + j]
60 }
61
62 #[inline]
64 pub fn get_mut(&mut self, i: usize, j: usize) -> &mut T {
65 debug_assert!(i < self.m && j < self.n);
66 &mut self.values[i * self.ld + j]
67 }
68
69 #[inline]
71 pub fn get_argmax_mut(&mut self, i: usize, j: usize) -> &mut u32 {
72 debug_assert!(i < self.m && j < self.n);
73 &mut self.argmax[i * self.ld + j]
74 }
75
76 #[inline]
78 pub fn as_mut_ptrs(&mut self) -> (*mut T, *mut u32) {
79 (self.values.as_mut_ptr(), self.argmax.as_mut_ptr())
80 }
81
82 #[inline]
86 pub fn argmax_slice(&self) -> &[u32] {
87 &self.argmax
88 }
89
90 #[inline]
92 pub fn values_slice(&self) -> &[T] {
93 &self.values
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::types::TropicalMaxPlus;
101
102 #[test]
103 fn test_gemm_with_argmax_new() {
104 let result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(3, 4);
105
106 assert_eq!(result.m, 3);
107 assert_eq!(result.n, 4);
108 assert_eq!(result.ld, 4);
109 assert_eq!(result.values.len(), 12);
110 assert_eq!(result.argmax.len(), 12);
111
112 for i in 0..3 {
114 for j in 0..4 {
115 assert_eq!(result.get(i, j).0, f64::NEG_INFINITY);
116 assert_eq!(result.get_argmax(i, j), 0);
117 }
118 }
119 }
120
121 #[test]
122 fn test_gemm_with_argmax_with_ld() {
123 let result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::with_ld(3, 4, 8);
124
125 assert_eq!(result.m, 3);
126 assert_eq!(result.n, 4);
127 assert_eq!(result.ld, 8);
128 assert_eq!(result.values.len(), 24);
130 assert_eq!(result.argmax.len(), 24);
131 }
132
133 #[test]
134 #[should_panic(expected = "Leading dimension must be >= n")]
135 fn test_gemm_with_argmax_with_ld_invalid() {
136 let _: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::with_ld(3, 4, 2);
137 }
138
139 #[test]
140 fn test_gemm_with_argmax_get_mut() {
141 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 2);
142
143 *result.get_mut(0, 1) = TropicalMaxPlus(5.0);
145 *result.get_mut(1, 0) = TropicalMaxPlus(3.0);
146
147 assert_eq!(result.get(0, 1).0, 5.0);
148 assert_eq!(result.get(1, 0).0, 3.0);
149 assert_eq!(result.get(0, 0).0, f64::NEG_INFINITY);
150 }
151
152 #[test]
153 fn test_gemm_with_argmax_get_argmax_mut() {
154 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 2);
155
156 *result.get_argmax_mut(0, 1) = 42;
158 *result.get_argmax_mut(1, 0) = 7;
159
160 assert_eq!(result.get_argmax(0, 1), 42);
161 assert_eq!(result.get_argmax(1, 0), 7);
162 assert_eq!(result.get_argmax(0, 0), 0);
163 }
164
165 #[test]
166 fn test_gemm_with_argmax_as_mut_ptrs() {
167 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 3);
168 let (values_ptr, argmax_ptr) = result.as_mut_ptrs();
169
170 unsafe {
172 *values_ptr.add(0) = TropicalMaxPlus(1.0);
173 *values_ptr.add(5) = TropicalMaxPlus(6.0);
174 *argmax_ptr.add(0) = 10;
175 *argmax_ptr.add(5) = 20;
176 }
177
178 assert_eq!(result.get(0, 0).0, 1.0);
179 assert_eq!(result.get(1, 2).0, 6.0);
180 assert_eq!(result.get_argmax(0, 0), 10);
181 assert_eq!(result.get_argmax(1, 2), 20);
182 }
183
184 #[test]
185 fn test_gemm_with_argmax_slices() {
186 let mut result: GemmWithArgmax<TropicalMaxPlus<f64>> = GemmWithArgmax::new(2, 3);
187
188 *result.get_mut(0, 0) = TropicalMaxPlus(1.0);
190 *result.get_mut(1, 2) = TropicalMaxPlus(6.0);
191 *result.get_argmax_mut(0, 0) = 10;
192 *result.get_argmax_mut(1, 2) = 20;
193
194 let values = result.values_slice();
196 assert_eq!(values.len(), 6); assert_eq!(values[0].0, 1.0);
198 assert_eq!(values[5].0, 6.0);
199
200 let argmax = result.argmax_slice();
202 assert_eq!(argmax.len(), 6);
203 assert_eq!(argmax[0], 10);
204 assert_eq!(argmax[5], 20);
205 }
206}