1mod mut_;
30mod ops;
31mod owned;
32mod ref_;
33
34pub use mut_::MatMut;
35pub use owned::Mat;
36pub use ref_::MatRef;
37
38pub struct MatWithArgmax<S: crate::TropicalWithArgmax> {
40 pub values: Mat<S>,
42 pub argmax: Vec<u32>,
44}
45
46impl<S: crate::TropicalWithArgmax<Index = u32>> MatWithArgmax<S> {
47 pub fn get(&self, i: usize, j: usize) -> S {
49 self.values[(i, j)]
50 }
51
52 #[inline]
57 pub fn get_value(&self, i: usize, j: usize) -> S::Scalar {
58 self.values[(i, j)].value()
59 }
60
61 pub fn get_argmax(&self, i: usize, j: usize) -> u32 {
63 self.argmax[j * self.values.nrows() + i]
65 }
66
67 pub fn nrows(&self) -> usize {
69 self.values.nrows()
70 }
71
72 pub fn ncols(&self) -> usize {
74 self.values.ncols()
75 }
76
77 #[inline]
81 pub fn argmax_slice(&self) -> &[u32] {
82 &self.argmax
83 }
84
85 pub fn backward_a<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
121 where
122 G: crate::TropicalSemiring,
123 G::Scalar: Copy + Default + std::ops::AddAssign,
124 {
125 let m = self.nrows();
126 let n = self.ncols();
127 assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
128 assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
129
130 let mut grad_a_data = vec![G::Scalar::default(); m * k];
132
133 for j in 0..n {
134 for i in 0..m {
135 let idx = self.argmax[j * m + i] as usize;
137 if idx < k {
138 grad_a_data[idx * m + i] += grad_c[(i, j)].value();
140 }
141 }
142 }
143
144 Mat::from_col_major(&grad_a_data, m, k)
145 }
146
147 pub fn backward_b<G>(&self, grad_c: &Mat<G>, k: usize) -> Mat<G>
183 where
184 G: crate::TropicalSemiring,
185 G::Scalar: Copy + Default + std::ops::AddAssign,
186 {
187 let m = self.nrows();
188 let n = self.ncols();
189 assert_eq!(grad_c.nrows(), m, "grad_c rows mismatch");
190 assert_eq!(grad_c.ncols(), n, "grad_c cols mismatch");
191
192 let mut grad_b_data = vec![G::Scalar::default(); k * n];
194
195 for j in 0..n {
196 for i in 0..m {
197 let idx = self.argmax[j * m + i] as usize;
199 if idx < k {
200 grad_b_data[j * k + idx] += grad_c[(i, j)].value();
202 }
203 }
204 }
205
206 Mat::from_col_major(&grad_b_data, k, n)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::TropicalMaxPlus;
214
215 #[test]
216 fn test_mat_zeros() {
217 let m = Mat::<TropicalMaxPlus<f64>>::zeros(3, 4);
218 assert_eq!(m.nrows(), 3);
219 assert_eq!(m.ncols(), 4);
220 assert_eq!(m[(0, 0)].0, f64::NEG_INFINITY);
221 }
222
223 #[test]
224 fn test_mat_identity() {
225 let m = Mat::<TropicalMaxPlus<f64>>::identity(3);
226 assert_eq!(m.nrows(), 3);
227 assert_eq!(m.ncols(), 3);
228 assert_eq!(m[(0, 0)].0, 0.0); assert_eq!(m[(0, 1)].0, f64::NEG_INFINITY); assert_eq!(m[(1, 1)].0, 0.0);
231 assert_eq!(m[(2, 2)].0, 0.0);
232 }
233
234 #[test]
235 fn test_mat_from_fn() {
236 let m =
237 Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
238 assert_eq!(m[(0, 0)].0, 0.0);
239 assert_eq!(m[(0, 2)].0, 2.0);
240 assert_eq!(m[(1, 0)].0, 3.0);
241 assert_eq!(m[(1, 2)].0, 5.0);
242 }
243
244 #[test]
245 fn test_matref_from_slice() {
246 let data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
248 let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 3);
249 assert_eq!(m.nrows(), 2);
250 assert_eq!(m.ncols(), 3);
251 assert_eq!(m.get(0, 0), 1.0);
252 assert_eq!(m.get(1, 2), 6.0);
253 }
254
255 #[test]
256 fn test_matmul() {
257 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
261 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
262
263 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
264 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
265
266 let c = a.matmul(&b);
267
268 assert_eq!(c[(0, 0)].0, 8.0);
270 assert_eq!(c[(0, 1)].0, 9.0);
272 assert_eq!(c[(1, 0)].0, 11.0);
274 assert_eq!(c[(1, 1)].0, 12.0);
276 }
277
278 #[test]
279 fn test_matmul_operator() {
280 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
282 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
283
284 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
285 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
286
287 let c = &a * &b;
288
289 assert_eq!(c[(0, 0)].0, 8.0);
290 assert_eq!(c[(1, 1)].0, 12.0);
291 }
292
293 #[test]
294 fn test_matmul_argmax() {
295 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
297 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
298
299 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 3);
300 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
301
302 let result = a.matmul_argmax(&b);
303
304 assert_eq!(result.get(0, 0).0, 8.0);
305 assert_eq!(result.get_argmax(0, 0), 2); }
307
308 #[test]
309 fn test_minplus_matmul() {
310 use crate::TropicalMinPlus;
311
312 let a_data = [1.0f64, 4.0, 2.0, 5.0, 3.0, 6.0];
314 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
315
316 let a = MatRef::<TropicalMinPlus<f64>>::from_slice(&a_data, 2, 3);
317 let b = MatRef::<TropicalMinPlus<f64>>::from_slice(&b_data, 3, 2);
318
319 let c = a.matmul(&b);
320
321 assert_eq!(c[(0, 0)].0, 2.0);
323 assert_eq!(c[(1, 1)].0, 6.0);
325 }
326
327 #[test]
328 fn test_mat_as_ref() {
329 let m =
330 Mat::<TropicalMaxPlus<f64>>::from_fn(2, 3, |i, j| TropicalMaxPlus((i * 3 + j) as f64));
331
332 let r = m.as_ref();
333 assert_eq!(r.nrows(), 2);
334 assert_eq!(r.ncols(), 3);
335 assert_eq!(r.get(0, 0), 0.0);
336 assert_eq!(r.get(1, 2), 5.0);
337 }
338
339 #[test]
340 fn test_mat_matmul_direct() {
341 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
343 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
344
345 let c = a.matmul(&b);
346
347 assert_eq!(c[(0, 0)].0, 8.0);
349 assert_eq!(c[(1, 1)].0, 12.0);
351 }
352
353 #[test]
354 fn test_mat_matmul_argmax_direct() {
355 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
357 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
358
359 let result = a.matmul_argmax(&b);
360
361 assert_eq!(result.get(0, 0).0, 8.0);
362 assert_eq!(result.get_argmax(0, 0), 2); }
364
365 #[test]
366 fn test_mat_get_value() {
367 let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
369
370 assert_eq!(m.get_value(0, 0), 1.0);
371 assert_eq!(m.get_value(0, 1), 2.0);
372 assert_eq!(m.get_value(1, 0), 3.0);
373 assert_eq!(m.get_value(1, 1), 4.0);
374 }
375
376 #[test]
377 fn test_minplus_mat_matmul_direct() {
378 use crate::TropicalMinPlus;
379
380 let a = Mat::<TropicalMinPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
381 let b = Mat::<TropicalMinPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
382
383 let c = a.matmul(&b);
384
385 assert_eq!(c[(0, 0)].0, 2.0);
387 assert_eq!(c[(1, 1)].0, 6.0);
389 }
390
391 #[test]
392 fn test_mat_from_vec() {
393 let data = vec![
394 TropicalMaxPlus(1.0f64),
395 TropicalMaxPlus(2.0),
396 TropicalMaxPlus(3.0),
397 TropicalMaxPlus(4.0),
398 ];
399 let m = Mat::from_vec(data, 2, 2);
400 assert_eq!(m.nrows(), 2);
401 assert_eq!(m.ncols(), 2);
402 assert_eq!(m[(0, 0)].0, 1.0);
403 assert_eq!(m[(1, 1)].0, 4.0);
404 }
405
406 #[test]
407 fn test_mat_as_slice() {
408 let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
409 let slice = m.as_slice();
410 assert_eq!(slice.len(), 4);
411 assert_eq!(slice[0].0, 1.0);
412 assert_eq!(slice[3].0, 4.0);
413 }
414
415 #[test]
416 fn test_mat_as_mut_slice() {
417 let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
418 let slice = m.as_mut_slice();
419 slice[0] = TropicalMaxPlus(100.0);
420 assert_eq!(m[(0, 0)].0, 100.0);
421 }
422
423 #[test]
424 fn test_mat_as_mut_ptr() {
425 let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
426 let ptr = m.as_mut_ptr();
427 assert!(!ptr.is_null());
428 }
429
430 #[test]
431 fn test_mat_index_mut() {
432 let mut m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
433 m[(0, 0)] = TropicalMaxPlus(10.0);
434 m[(1, 1)] = TropicalMaxPlus(40.0);
435 assert_eq!(m[(0, 0)].0, 10.0);
436 assert_eq!(m[(1, 1)].0, 40.0);
437 }
438
439 #[test]
440 fn test_mat_matmul_ref() {
441 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
442 let b_data = [1.0f64, 3.0, 5.0, 2.0, 4.0, 6.0];
444 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
445
446 let c = a.matmul_ref(&b);
447
448 assert_eq!(c[(0, 0)].0, 8.0);
450 assert_eq!(c[(1, 1)].0, 12.0);
452 }
453
454 #[test]
455 fn test_matref_copy_clone() {
456 let data = [1.0f64, 2.0, 3.0, 4.0];
457 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
458 let b = a; let c = a.clone(); assert_eq!(a.get(0, 0), b.get(0, 0));
461 assert_eq!(a.get(0, 0), c.get(0, 0));
462 }
463
464 #[test]
465 fn test_matref_to_owned() {
466 let data = [1.0f64, 2.0, 3.0, 4.0];
467 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
468 let owned = a.to_owned();
469 assert_eq!(owned.nrows(), 2);
470 assert_eq!(owned.ncols(), 2);
471 assert_eq!(owned[(0, 0)].0, 1.0);
472 }
473
474 #[test]
475 fn test_matref_debug() {
476 let data = [1.0f64, 2.0];
477 let m = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 1, 2);
478 let debug_str = format!("{:?}", m);
479 assert!(debug_str.contains("MatRef"));
480 }
481
482 #[test]
483 fn test_mat_clone() {
484 let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
485 let m2 = m.clone();
486 assert_eq!(m2[(0, 0)].0, 1.0);
487 assert_eq!(m2[(1, 1)].0, 4.0);
488 }
489
490 #[test]
491 fn test_mat_debug() {
492 let m = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0], 1, 2);
493 let debug_str = format!("{:?}", m);
494 assert!(debug_str.contains("Mat"));
495 }
496
497 #[test]
498 fn test_matwithargmax_get_value() {
499 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
500 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
501
502 let result = a.matmul_argmax(&b);
503
504 assert_eq!(result.get_value(0, 0), 8.0);
506 assert_eq!(result.get_value(1, 1), 12.0);
507 }
508
509 #[test]
510 fn test_matwithargmax_nrows_ncols() {
511 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
512 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
513
514 let result = a.matmul_argmax(&b);
515
516 assert_eq!(result.nrows(), 2);
517 assert_eq!(result.ncols(), 2);
518 }
519
520 #[test]
521 #[should_panic(expected = "data length")]
522 fn test_mat_from_row_major_size_mismatch() {
523 let _ = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0], 2, 2);
524 }
525
526 #[test]
527 #[should_panic(expected = "data length")]
528 fn test_mat_from_vec_size_mismatch() {
529 let data = vec![TropicalMaxPlus(1.0f64), TropicalMaxPlus(2.0)];
530 let _ = Mat::from_vec(data, 2, 2);
531 }
532
533 #[test]
534 #[should_panic(expected = "data length")]
535 fn test_matref_from_slice_size_mismatch() {
536 let data = [1.0f64, 2.0];
537 let _ = MatRef::<TropicalMaxPlus<f64>>::from_slice(&data, 2, 2);
538 }
539
540 #[test]
541 #[should_panic(expected = "dimension mismatch")]
542 fn test_matmul_dimension_mismatch() {
543 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
544 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
545 let _ = a.matmul(&b); }
547
548 #[test]
549 #[should_panic(expected = "dimension mismatch")]
550 fn test_matref_matmul_dimension_mismatch() {
551 let a_data = [1.0f64, 2.0, 3.0, 4.0];
552 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
553 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
554 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
555 let _ = a.matmul(&b); }
557
558 #[test]
559 #[should_panic(expected = "dimension mismatch")]
560 fn test_matmul_argmax_dimension_mismatch() {
561 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
562 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
563 let _ = a.matmul_argmax(&b); }
565
566 #[test]
567 #[should_panic(expected = "dimension mismatch")]
568 fn test_matref_matmul_argmax_dimension_mismatch() {
569 let a_data = [1.0f64, 2.0, 3.0, 4.0];
570 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
571 let a = MatRef::<TropicalMaxPlus<f64>>::from_slice(&a_data, 2, 2);
572 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
573 let _ = a.matmul_argmax(&b); }
575
576 #[test]
577 #[should_panic(expected = "dimension mismatch")]
578 fn test_mat_matmul_ref_dimension_mismatch() {
579 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
580 let b_data = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
581 let b = MatRef::<TropicalMaxPlus<f64>>::from_slice(&b_data, 3, 2);
582 let _ = a.matmul_ref(&b); }
584
585 #[test]
590 fn test_mat_matmul_batched() {
591 let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592 let a2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0], 2, 2);
593 let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
594 let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
595
596 let results = Mat::matmul_batched(&[a1, a2], &[b1, b2]);
597 assert_eq!(results.len(), 2);
598
599 assert!((results[0][(0, 0)].0 - 2.0).abs() < 1e-5);
602
603 assert!((results[1][(0, 0)].0 - 9.0).abs() < 1e-5);
606 }
607
608 #[test]
609 fn test_mat_matmul_batched_empty() {
610 let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
611 let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
612
613 let results = Mat::matmul_batched(&a_batch, &b_batch);
614 assert!(results.is_empty());
615 }
616
617 #[test]
618 #[should_panic(expected = "batch sizes must match")]
619 fn test_mat_matmul_batched_size_mismatch() {
620 let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
621 let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
622 let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
623
624 let _ = Mat::matmul_batched(&[a1], &[b1, b2]); }
626
627 #[test]
628 #[should_panic(expected = "has dimensions")]
629 fn test_mat_matmul_batched_dimension_mismatch() {
630 let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
631 let a2 =
632 Mat::<TropicalMaxPlus<f32>>::from_row_major(&[5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 2, 3); let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 0.0, 0.0, 1.0], 2, 2);
634 let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0], 2, 2);
635
636 let _ = Mat::matmul_batched(&[a1, a2], &[b1, b2]); }
638
639 #[test]
640 fn test_mat_matmul_batched_with_argmax() {
641 let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
642 let a2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[6.0, 5.0, 4.0, 3.0, 2.0, 1.0], 2, 3);
643 let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
644 let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
645
646 let results = Mat::matmul_batched_with_argmax(&[a1, a2], &[b1, b2]);
647 assert_eq!(results.len(), 2);
648
649 assert!((results[0].get(0, 0).0 - 8.0).abs() < 1e-5);
651 assert_eq!(results[0].get_argmax(0, 0), 2);
652 }
653
654 #[test]
655 fn test_mat_matmul_batched_with_argmax_empty() {
656 let a_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
657 let b_batch: Vec<Mat<TropicalMaxPlus<f32>>> = vec![];
658
659 let results = Mat::matmul_batched_with_argmax(&a_batch, &b_batch);
660 assert!(results.is_empty());
661 }
662
663 #[test]
664 #[should_panic(expected = "batch sizes must match")]
665 fn test_mat_matmul_batched_with_argmax_size_mismatch() {
666 let a1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
667 let b1 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
668 let b2 = Mat::<TropicalMaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
669
670 let _ = Mat::matmul_batched_with_argmax(&[a1], &[b1, b2]); }
672
673 #[test]
678 fn test_matwithargmax_backward_a() {
679 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
680 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
681
682 let result = a.matmul_argmax(&b);
684
685 assert_eq!(result.get_argmax(0, 0), 2);
687 assert_eq!(result.get_argmax(0, 1), 2);
688 assert_eq!(result.get_argmax(1, 0), 2);
689 assert_eq!(result.get_argmax(1, 1), 2);
690
691 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
693 let grad_a = result.backward_a(&grad_c, 3);
694
695 assert_eq!(grad_a.nrows(), 2);
697 assert_eq!(grad_a.ncols(), 3);
698 assert_eq!(grad_a[(0, 0)].0, 0.0); assert_eq!(grad_a[(0, 1)].0, 0.0); assert_eq!(grad_a[(0, 2)].0, 2.0); assert_eq!(grad_a[(1, 0)].0, 0.0); assert_eq!(grad_a[(1, 1)].0, 0.0); assert_eq!(grad_a[(1, 2)].0, 2.0); }
705
706 #[test]
707 fn test_matwithargmax_backward_b() {
708 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
709 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
710
711 let result = a.matmul_argmax(&b);
713
714 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
716 let grad_b = result.backward_b(&grad_c, 3);
717
718 assert_eq!(grad_b.nrows(), 3);
720 assert_eq!(grad_b.ncols(), 2);
721 assert_eq!(grad_b[(0, 0)].0, 0.0); assert_eq!(grad_b[(0, 1)].0, 0.0); assert_eq!(grad_b[(1, 0)].0, 0.0); assert_eq!(grad_b[(1, 1)].0, 0.0); assert_eq!(grad_b[(2, 0)].0, 2.0); assert_eq!(grad_b[(2, 1)].0, 2.0); }
728
729 #[test]
730 fn test_matwithargmax_backward_varied_argmax() {
731 let a =
733 Mat::<TropicalMaxPlus<f64>>::from_row_major(&[10.0, 1.0, 1.0, 1.0, 10.0, 1.0], 2, 3);
734 let b =
735 Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 1.0, 1.0, 1.0, 10.0, 10.0], 3, 2);
736
737 let result = a.matmul_argmax(&b);
738
739 assert_eq!(result.get_argmax(0, 0), 0);
743 assert_eq!(result.get_argmax(1, 0), 1);
744
745 let grad_c = Mat::<TropicalMaxPlus<f64>>::from_fn(2, 2, |_, _| TropicalMaxPlus(1.0));
746 let grad_a = result.backward_a(&grad_c, 3);
747
748 assert!(grad_a[(0, 0)].0 > 0.0); assert!(grad_a[(1, 1)].0 > 0.0); }
753
754 #[test]
755 fn test_matwithargmax_argmax_slice() {
756 let a = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
757 let b = Mat::<TropicalMaxPlus<f64>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
758
759 let result = a.matmul_argmax(&b);
760 let argmax_slice = result.argmax_slice();
761
762 assert_eq!(argmax_slice.len(), 4); assert_eq!(argmax_slice[0], result.get_argmax(0, 0));
764 assert_eq!(argmax_slice[1], result.get_argmax(0, 1));
765 assert_eq!(argmax_slice[2], result.get_argmax(1, 0));
766 assert_eq!(argmax_slice[3], result.get_argmax(1, 1));
767 }
768}