1mod mut_;
30mod ops;
31mod owned;
32mod ref_;
33
34pub use mut_::MatMut;
35pub use owned::Mat;
36pub use ref_::MatRef;
37
38pub struct MatWithArgmax<S: crate::TropicalWithArgmax> {
40 pub values: Mat<S>,
42 pub argmax: Vec<u32>,
44}
45
46impl<S: crate::TropicalWithArgmax<Index = u32>> MatWithArgmax<S> {
47 pub fn get(&self, i: usize, j: usize) -> S {
49 self.values[(i, j)]
50 }
51
52 #[inline]
57 pub fn get_value(&self, i: usize, j: usize) -> S::Scalar {
58 self.values[(i, j)].value()
59 }
60
61 pub fn get_argmax(&self, i: usize, j: usize) -> u32 {
63 self.argmax[j * self.values.nrows() + i]
65 }
66
67 pub fn nrows(&self) -> usize {
69 self.values.nrows()
70 }
71
72 pub fn ncols(&self) -> usize {
74 self.values.ncols()
75 }
76
77 #[inline]
81 pub fn argmax_slice(&self) -> &[u32] {
82 &self.argmax
83 }
84
85 pub fn backward_a<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
121 where
122 G: crate::TropicalSemiring,
123 G::Scalar: Copy + Default + std::ops::AddAssign,
124 {
125 let m = self.nrows();
126 let n = self.ncols();
127 assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
128 assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
129
130 let mut grad_a_data = vec![G::Scalar::default(); m * k];
132
133 for j in 0..n {
134 for i in 0..m {
135 let idx = self.argmax[j * m + i] as usize;
137 if idx < k {
138 grad_a_data[idx * m + i] += grad_c[(i, j)].value();
140 }
141 }
142 }
143
144 Mat::from_col_major(&grad_a_data, m, k)
145 }
146
147 pub fn backward_b<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
183 where
184 G: crate::TropicalSemiring,
185 G::Scalar: Copy + Default + std::ops::AddAssign,
186 {
187 let m = self.nrows();
188 let n = self.ncols();
189 assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
190 assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
191
192 let mut grad_b_data = vec![G::Scalar::default(); k * n];
194
195 for j in 0..n {
196 for i in 0..m {
197 let idx = self.argmax[j * m + i] as usize;
199 if idx < k {
200 grad_b_data[j * k + idx] += grad_c[(i, j)].value();
202 }
203 }
204 }
205
206 Mat::from_col_major(&grad_b_data, k, n)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::TropicalMaxPlus;
214
215 #[test]
216 fn test_mat_zeros() {
217 let m = Mat::<TropicalMaxPlus<f64>>::zeros(3, 4);
218 assert_eq!(m.nrows(), 3);
219 assert_eq!(m.ncols(), 4);
220 assert_eq!(m[(0, 0)].0, f64::NEG_INFINITY);
221 }
222
223 #[test]
224 fn test_mat_identity() {
225 let m = Mat::<TropicalMaxPlus<f64>>::identity(3);
226 assert_eq!(m.nrows(), 3);
227 assert_eq!(m.ncols(), 3);
228 assert_eq!(m[(0, 0)].0, 0.0); assert_eq!(m[(0, 1)].0, f64::NEG_INFINITY); assert_eq!(m[(1, 1)].0, 0.0);
231 assert_eq!(m[(2, 2)].0, 0.0);
232 }
233
234 #[test]
235 fn test_mat_from_fn() {
236 let m =
237 Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
238 assert_eq!(m[(0, 0)].0, 0.0);
239 assert_eq!(m[(0, 2)].0, 2.0);
240 assert_eq!(m[(1, 0)].0, 3.0);
241 assert_eq!(m[(1, 2)].0, 5.0);
242 }
243
244 #[test]
245 fn test_matref_from_slice() {
246 let data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
248 let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 3);
249 assert_eq!(m.nrows(), 2);
250 assert_eq!(m.ncols(), 3);
251 assert_eq!(m.get(0, 0), 1.0);
252 assert_eq!(m.get(1, 2), 6.0);
253 }
254
255 #[test]
256 fn test_matmul() {
257 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
261 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
262
263 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
264 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
265
266 let c = a.matmul(&b);
267
268 assert_eq!(c[(0, 0)].0, 8.0);
270 assert_eq!(c[(0, 1)].0, 9.0);
272 assert_eq!(c[(1, 0)].0, 11.0);
274 assert_eq!(c[(1, 1)].0, 12.0);
276 }
277
278 #[test]
279 fn test_matmul_operator() {
280 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
282 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
283
284 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
285 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
286
287 let c = &a * &b;
288
289 assert_eq!(c[(0, 0)].0, 8.0);
290 assert_eq!(c[(1, 1)].0, 12.0);
291 }
292
293 #[test]
294 fn test_matmul_argmax() {
295 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
297 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
298
299 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
300 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
301
302 let result = a.matmul_argmax(&b);
303
304 assert_eq!(result.get(0, 0).0, 8.0);
305 assert_eq!(result.get_argmax(0, 0), 2); }
307
308 #[test]
309 fn test_minplus_matmul() {
310 use crate::TropicalMinPlus;
311
312 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
314 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
315
316 let a = MatRef::<TropicalMinPlus<f64>>::from_slice(&a_data, 2, 3);
317 let b = MatRef::<TropicalMinPlus<f64>>::from_slice(&b_data, 3, 2);
318
319 let c = a.matmul(&b);
320
321 assert_eq!(c[(0, 0)].0, 2.0);
323 assert_eq!(c[(1, 1)].0, 6.0);
325 }
326
327 #[test]
328 fn test_mat_as_ref() {
329 let m =
330 Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
331
332 let r = m.as_ref();
333 assert_eq!(r.nrows(), 2);
334 assert_eq!(r.ncols(), 3);
335 assert_eq!(r.get(0, 0), 0.0);
336 assert_eq!(r.get(1, 2), 5.0);
337 }
338
339 #[test]
340 fn test_mat_matmul_direct() {
341 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
343 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
344
345 let c = a.matmul(&b);
346
347 assert_eq!(c[(0, 0)].0, 8.0);
349 assert_eq!(c[(1, 1)].0, 12.0);
351 }
352
353 #[test]
354 fn test_mat_matmul_argmax_direct() {
355 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
357 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
358
359 let result = a.matmul_argmax(&b);
360
361 assert_eq!(result.get(0, 0).0, 8.0);
362 assert_eq!(result.get_argmax(0, 0), 2); }
364
365 #[test]
366 fn test_mat_get_value() {
367 let m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
369
370 assert_eq!(m.get_value(0, 0), 1.0);
371 assert_eq!(m.get_value(0, 1), 2.0);
372 assert_eq!(m.get_value(1, 0), 3.0);
373 assert_eq!(m.get_value(1, 1), 4.0);
374 }
375
376 #[test]
377 fn test_minplus_mat_matmul_direct() {
378 use crate::TropicalMinPlus;
379
380 let a = Mat::<TropicalMinPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
381 let b = Mat::<TropicalMinPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
382
383 let c = a.matmul(&b);
384
385 assert_eq!(c[(0, 0)].0, 2.0);
387 assert_eq!(c[(1, 1)].0, 6.0);
389 }
390
391 #[test]
392 fn test_mat_from_vec() {
393 let data = vec![
394 TropicalMaxPlus(1.0f64),
395 TropicalMaxPlus(2.0),
396 TropicalMaxPlus(3.0),
397 TropicalMaxPlus(4.0),
398 ];
399 let m = Mat::from_vec(data, 2, 2);
400 assert_eq!(m.nrows(), 2);
401 assert_eq!(m.ncols(), 2);
402 assert_eq!(m[(0, 0)].0, 1.0);
403 assert_eq!(m[(1, 1)].0, 4.0);
404 }
405
406 #[test]
407 fn test_mat_as_slice() {
408 let m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
409 let slice = m.as_slice();
410 assert_eq!(slice.len(), 4);
411 assert_eq!(slice[0].0, 1.0);
412 assert_eq!(slice[3].0, 4.0);
413 }
414
415 #[test]
416 fn test_mat_as_mut_slice() {
417 let mut m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
418 let slice = m.as_mut_slice();
419 slice[0] = TropicalMaxPlus(100.0);
420 assert_eq!(m[(0, 0)].0, 100.0);
421 }
422
423 #[test]
424 fn test_mat_as_mut_ptr() {
425 let mut m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
426 let ptr = m.as_mut_ptr();
427 assert!(!ptr.is_null());
428 }
429
430 #[test]
431 fn test_mat_index_mut() {
432 let mut m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
433 m[(0, 0)] = TropicalMaxPlus(10.0);
434 m[(1, 1)] = TropicalMaxPlus(40.0);
435 assert_eq!(m[(0, 0)].0, 10.0);
436 assert_eq!(m[(1, 1)].0, 40.0);
437 }
438
439 #[test]
440 fn test_mat_matmul_ref() {
441 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
442 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
444 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
445
446 let c = a.matmul_ref(&b);
447
448 assert_eq!(c[(0, 0)].0, 8.0);
450 assert_eq!(c[(1, 1)].0, 12.0);
452 }
453
454 #[test]
455 fn test_matref_copy_clone() {
456 let data = [1.0f64, 2.0, 3.0, 4.0];
457 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
458 let b = a; let c = a.clone(); assert_eq!(a.get(0, 0), b.get(0, 0));
461 assert_eq!(a.get(0, 0), c.get(0, 0));
462 }
463
464 #[test]
465 fn test_matref_to_owned() {
466 let data = [1.0f64, 2.0, 3.0, 4.0];
467 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
468 let owned = a.to_owned();
469 assert_eq!(owned.nrows(), 2);
470 assert_eq!(owned.ncols(), 2);
471 assert_eq!(owned[(0, 0)].0, 1.0);
472 }
473
474 #[test]
475 fn test_matref_debug() {
476 let data = [1.0f64, 2.0];
477 let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 1, 2);
478 let debug_str = format!("{:?}", m);
479 assert!(debug_str.contains("MatRef"));
480 }
481
482 #[test]
483 fn test_mat_clone() {
484 let m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
485 let m2 = m.clone();
486 assert_eq!(m2[(0, 0)].0, 1.0);
487 assert_eq!(m2[(1, 1)].0, 4.0);
488 }
489
490 #[test]
491 fn test_mat_debug() {
492 let m = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 2.0], 1, 2);
493 let debug_str = format!("{:?}", m);
494 assert!(debug_str.contains("Mat"));
495 }
496
497 #[test]
498 fn test_matwithargmax_get_value() {
499 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
500 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
501
502 let result = a.matmul_argmax(&b);
503
504 assert_eq!(result.get_value(0, 0), 8.0);
506 assert_eq!(result.get_value(1, 1), 12.0);
507 }
508
509 #[test]
510 fn test_matwithargmax_nrows_ncols() {
511 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
512 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
513
514 let result = a.matmul_argmax(&b);
515
516 assert_eq!(result.nrows(), 2);
517 assert_eq!(result.ncols(), 2);
518 }
519
520 #[test]
523 #[should_panic(expected = "data length")]
524 #[allow(deprecated)]
525 fn test_mat_from_row_major_size_mismatch() {
526 let _ = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0], 2, 2);
527 }
528
529 #[test]
530 #[allow(deprecated)]
531 fn test_from_row_major_equals_col_major_transposed() {
532 let rm = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
536 let cm = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
537 assert_eq!((rm.nrows(), rm.ncols()), (2, 3));
538 for i in 0..2 {
539 for j in 0..3 {
540 assert_eq!(rm.get_value(i, j), cm.get_value(i, j));
541 }
542 }
543 }
544
545 #[test]
546 #[should_panic(expected = "data length")]
547 fn test_mat_from_vec_size_mismatch() {
548 let data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
549 let _ = Mat::from_vec(data, 2, 2);
550 }
551
552 #[test]
553 #[should_panic(expected = "data length")]
554 fn test_matref_from_slice_size_mismatch() {
555 let data = [1.0f64, 2.0];
556 let _ = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
557 }
558
559 #[test]
560 #[should_panic(expected = "dimension mismatch")]
561 fn test_matmul_dimension_mismatch() {
562 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
563 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
564 let _ = a.matmul(&b); }
566
567 #[test]
568 #[should_panic(expected = "dimension mismatch")]
569 fn test_matref_matmul_dimension_mismatch() {
570 let a_data = [1.0f64, 2.0, 3.0, 4.0];
571 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
572 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
573 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
574 let _ = a.matmul(&b); }
576
577 #[test]
578 #[should_panic(expected = "dimension mismatch")]
579 fn test_matmul_argmax_dimension_mismatch() {
580 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
581 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
582 let _ = a.matmul_argmax(&b); }
584
585 #[test]
586 #[should_panic(expected = "dimension mismatch")]
587 fn test_matref_matmul_argmax_dimension_mismatch() {
588 let a_data = [1.0f64, 2.0, 3.0, 4.0];
589 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
590 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
591 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
592 let _ = a.matmul_argmax(&b); }
594
595 #[test]
596 #[should_panic(expected = "dimension mismatch")]
597 fn test_mat_matmul_ref_dimension_mismatch() {
598 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
599 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
600 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
601 let _ = a.matmul_ref(&b); }
603
604 #[test]
609 fn test_mat_matmul_batched() {
610 let a1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
611 let a2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[5.0, 7.0, 6.0, 8.0], 2, 2);
612 let b1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
613 let b2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
614
615 let results = Mat::matmul_batched(&[a1, a2], &[b1, b2]);
616 assert_eq!(results.len(), 2);
617
618 assert!((results[0][(0, 0)].0 - 2.0).abs() < 1e-5);
621
622 assert!((results[1][(0, 0)].0 - 9.0).abs() < 1e-5);
625 }
626
627 #[test]
628 fn test_mat_matmul_batched_empty() {
629 let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
630 let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
631
632 let results = Mat::matmul_batched(&a_batch, &b_batch);
633 assert!(results.is_empty());
634 }
635
636 #[test]
637 #[should_panic(expected = "batch sizes must match")]
638 fn test_mat_matmul_batched_size_mismatch() {
639 let a1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
640 let b1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
641 let b2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
642
643 let _ = Mat::matmul_batched(&[a1], &[b1, b2]); }
645
646 #[test]
647 #[should_panic(expected = "has dimensions")]
648 fn test_mat_matmul_batched_dimension_mismatch() {
649 let a1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
650 let a2 =
651 Mat::<TropicalMaxPlus<f32>>::from_col_major(&[5.0, 8.0, 6.0, 9.0, 7.0, 10.0], 2, 3); let b1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
653 let b2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 2.0, 4.0], 2, 2);
654
655 let _ = Mat::matmul_batched(&[a1, a2], &[b1, b2]); }
657
658 #[test]
659 fn test_mat_matmul_batched_with_argmax() {
660 let a1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
661 let a2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[6.0, 3.0, 5.0, 2.0, 4.0, 1.0], 2, 3);
662 let b1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
663 let b2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
664
665 let results = Mat::matmul_batched_with_argmax(&[a1, a2], &[b1, b2]);
666 assert_eq!(results.len(), 2);
667
668 assert!((results[0].get(0, 0).0 - 8.0).abs() < 1e-5);
670 assert_eq!(results[0].get_argmax(0, 0), 2);
671 }
672
673 #[test]
674 fn test_mat_matmul_batched_with_argmax_empty() {
675 let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
676 let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
677
678 let results = Mat::matmul_batched_with_argmax(&a_batch, &b_batch);
679 assert!(results.is_empty());
680 }
681
682 #[test]
683 #[should_panic(expected = "batch sizes must match")]
684 fn test_mat_matmul_batched_with_argmax_size_mismatch() {
685 let a1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
686 let b1 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
687 let b2 = Mat::<TropicalMaxPlus<f32>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
688
689 let _ = Mat::matmul_batched_with_argmax(&[a1], &[b1, b2]); }
691
692 #[test]
697 fn test_matwithargmax_backward_a() {
698 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
699 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
700
701 let result = a.matmul_argmax(&b);
703
704 assert_eq!(result.get_argmax(0, 0), 2);
706 assert_eq!(result.get_argmax(0, 1), 2);
707 assert_eq!(result.get_argmax(1, 0), 2);
708 assert_eq!(result.get_argmax(1, 1), 2);
709
710 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
712 let grad_a = result.backward_a(&grad_c, 3);
713
714 assert_eq!(grad_a.nrows(), 2);
716 assert_eq!(grad_a.ncols(), 3);
717 assert_eq!(grad_a[(0, 0)].0, 0.0); assert_eq!(grad_a[(0, 1)].0, 0.0); assert_eq!(grad_a[(0, 2)].0, 2.0); assert_eq!(grad_a[(1, 0)].0, 0.0); assert_eq!(grad_a[(1, 1)].0, 0.0); assert_eq!(grad_a[(1, 2)].0, 2.0); }
724
725 #[test]
726 fn test_matwithargmax_backward_b() {
727 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
728 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
729
730 let result = a.matmul_argmax(&b);
732
733 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
735 let grad_b = result.backward_b(&grad_c, 3);
736
737 assert_eq!(grad_b.nrows(), 3);
739 assert_eq!(grad_b.ncols(), 2);
740 assert_eq!(grad_b[(0, 0)].0, 0.0); assert_eq!(grad_b[(0, 1)].0, 0.0); assert_eq!(grad_b[(1, 0)].0, 0.0); assert_eq!(grad_b[(1, 1)].0, 0.0); assert_eq!(grad_b[(2, 0)].0, 2.0); assert_eq!(grad_b[(2, 1)].0, 2.0); }
747
748 #[test]
749 fn test_matwithargmax_backward_varied_argmax() {
750 let a =
752 Mat::<TropicalMaxPlus<f64>>::from_col_major(&[10.0, 1.0, 1.0, 10.0, 1.0, 1.0], 2, 3);
753 let b =
754 Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 1.0, 10.0, 1.0, 1.0, 10.0], 3, 2);
755
756 let result = a.matmul_argmax(&b);
757
758 assert_eq!(result.get_argmax(0, 0), 0);
762 assert_eq!(result.get_argmax(1, 0), 1);
763
764 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
765 let grad_a = result.backward_a(&grad_c, 3);
766
767 assert!(grad_a[(0, 0)].0 > 0.0); assert!(grad_a[(1, 1)].0 > 0.0); }
772
773 #[test]
774 fn test_matwithargmax_argmax_slice() {
775 let a = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3);
776 let b = Mat::<TropicalMaxPlus<f64>>::from_col_major(&[1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 3, 2);
777
778 let result = a.matmul_argmax(&b);
779 let argmax_slice = result.argmax_slice();
780
781 assert_eq!(argmax_slice.len(), 4); assert_eq!(argmax_slice[0], result.get_argmax(0, 0));
783 assert_eq!(argmax_slice[1], result.get_argmax(0, 1));
784 assert_eq!(argmax_slice[2], result.get_argmax(1, 0));
785 assert_eq!(argmax_slice[3], result.get_argmax(1, 1));
786 }
787}