omeinsum/backend/cpu/
mod.rs

1//! CPU backend implementation.
2
3mod contract;
4
5use super::traits::{Backend, BackendScalar, Storage};
6use crate::algebra::{Algebra, Scalar, Standard};
7use std::any::TypeId;
8
9/// CPU backend using Vec storage.
10#[derive(Clone, Debug, Default)]
11pub struct Cpu;
12
13impl Cpu {
14    /// General matrix multiplication (internal implementation).
15    ///
16    /// Computes C = A ⊗ B where ⊗ is the semiring multiplication
17    /// and the reduction uses semiring addition.
18    ///
19    /// This is an internal implementation detail used by the contract method.
20    /// Users should use `einsum()` or `contract_binary()` instead.
21    pub(crate) fn gemm_internal<A: Algebra>(
22        &self,
23        a: &[A::Scalar],
24        m: usize,
25        k: usize,
26        b: &[A::Scalar],
27        n: usize,
28    ) -> Vec<A::Scalar> {
29        // Fast path: faer for Standard f32/f64
30        if TypeId::of::<A>() == TypeId::of::<Standard<f32>>() {
31            // SAFETY: A::Scalar is f32 when A is Standard<f32>
32            let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
33            let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
34            let result = faer_gemm_f32(a_f32, m, k, b_f32, n);
35            return unsafe { std::mem::transmute::<Vec<f32>, Vec<A::Scalar>>(result) };
36        }
37        if TypeId::of::<A>() == TypeId::of::<Standard<f64>>() {
38            let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
39            let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
40            let result = faer_gemm_f64(a_f64, m, k, b_f64, n);
41            return unsafe { std::mem::transmute::<Vec<f64>, Vec<A::Scalar>>(result) };
42        }
43
44        // Try to use optimized tropical-gemm if available
45        #[cfg(feature = "tropical-kernels")]
46        {
47            if let Some(result) = try_tropical_gemm::<A>(a, m, k, b, n) {
48                return result;
49            }
50        }
51
52        // Fallback to generic loop implementation
53        generic_gemm::<A>(a, m, k, b, n)
54    }
55
56    /// GEMM with argmax tracking (internal implementation).
57    ///
58    /// Returns (result, argmax) where argmax[i, j] is the k index
59    /// that "won" the reduction for element [i, j].
60    pub(crate) fn gemm_with_argmax_internal<A: Algebra<Index = u32>>(
61        &self,
62        a: &[A::Scalar],
63        m: usize,
64        k: usize,
65        b: &[A::Scalar],
66        n: usize,
67    ) -> (Vec<A::Scalar>, Vec<u32>) {
68        // Try to use optimized tropical-gemm if available
69        #[cfg(feature = "tropical-kernels")]
70        {
71            if let Some(result) = try_tropical_gemm_with_argmax::<A>(a, m, k, b, n) {
72                return result;
73            }
74        }
75
76        // Fallback to generic loop implementation
77        generic_gemm_with_argmax::<A>(a, m, k, b, n)
78    }
79
80    /// Backward pass for GEMM w.r.t. A (internal implementation).
81    /// Used primarily for testing CPU-specific backward implementations.
82    #[allow(dead_code)]
83    pub(crate) fn gemm_backward_a_internal<A: Algebra>(
84        &self,
85        grad_c: &[A::Scalar],
86        argmax: &[u32],
87        _b: &[A::Scalar],
88        m: usize,
89        k: usize,
90        n: usize,
91    ) -> Vec<A::Scalar> {
92        let mut grad_a = vec![A::Scalar::default(); m * k];
93
94        // For tropical: grad_a[i, argmax[i,j]] += grad_c[i,j]
95        // For standard: grad_a = grad_c @ b.T
96        // Column-major: element (i, j) is at index j * nrows + i
97        if A::needs_argmax() {
98            for j in 0..n {
99                for i in 0..m {
100                    let idx = argmax[j * m + i] as usize; // argmax[i, j] in column-major
101                                                          // grad_a[i, idx] += grad_c[i, j]
102                    grad_a[idx * m + i] += grad_c[j * m + i];
103                }
104            }
105        }
106
107        grad_a
108    }
109
110    /// Backward pass for GEMM w.r.t. B (internal implementation).
111    /// Used primarily for testing CPU-specific backward implementations.
112    #[allow(dead_code)]
113    pub(crate) fn gemm_backward_b_internal<A: Algebra>(
114        &self,
115        grad_c: &[A::Scalar],
116        argmax: &[u32],
117        _a: &[A::Scalar],
118        m: usize,
119        k: usize,
120        n: usize,
121    ) -> Vec<A::Scalar> {
122        let mut grad_b = vec![A::Scalar::default(); k * n];
123
124        // Column-major: element (i, j) is at index j * nrows + i
125        if A::needs_argmax() {
126            for j in 0..n {
127                for i in 0..m {
128                    let idx = argmax[j * m + i] as usize; // argmax[i, j] in column-major
129                                                          // grad_b[idx, j] += grad_c[i, j]
130                    grad_b[j * k + idx] += grad_c[j * m + i];
131                }
132            }
133        }
134
135        grad_b
136    }
137
138    /// Batched GEMM (internal implementation).
139    pub(crate) fn gemm_batched_internal<A: Algebra>(
140        &self,
141        a: &[A::Scalar],
142        batch_size: usize,
143        m: usize,
144        k: usize,
145        b: &[A::Scalar],
146        n: usize,
147    ) -> Vec<A::Scalar> {
148        let a_batch_stride = m * k;
149        let b_batch_stride = k * n;
150        let c_batch_stride = m * n;
151
152        let mut c = vec![A::zero().to_scalar(); batch_size * m * n];
153
154        for batch in 0..batch_size {
155            let a_offset = batch * a_batch_stride;
156            let b_offset = batch * b_batch_stride;
157            let c_offset = batch * c_batch_stride;
158
159            let a_slice = &a[a_offset..a_offset + a_batch_stride];
160            let b_slice = &b[b_offset..b_offset + b_batch_stride];
161
162            let c_batch = generic_gemm::<A>(a_slice, m, k, b_slice, n);
163            c[c_offset..c_offset + c_batch_stride].copy_from_slice(&c_batch);
164        }
165
166        c
167    }
168
169    /// Batched GEMM with argmax tracking (internal implementation).
170    pub(crate) fn gemm_batched_with_argmax_internal<A: Algebra<Index = u32>>(
171        &self,
172        a: &[A::Scalar],
173        batch_size: usize,
174        m: usize,
175        k: usize,
176        b: &[A::Scalar],
177        n: usize,
178    ) -> (Vec<A::Scalar>, Vec<u32>) {
179        let a_batch_stride = m * k;
180        let b_batch_stride = k * n;
181        let c_batch_stride = m * n;
182
183        let mut c = vec![A::zero().to_scalar(); batch_size * m * n];
184        let mut argmax = vec![0u32; batch_size * m * n];
185
186        for batch in 0..batch_size {
187            let a_offset = batch * a_batch_stride;
188            let b_offset = batch * b_batch_stride;
189            let c_offset = batch * c_batch_stride;
190
191            let a_slice = &a[a_offset..a_offset + a_batch_stride];
192            let b_slice = &b[b_offset..b_offset + b_batch_stride];
193
194            let (c_batch, argmax_batch) = generic_gemm_with_argmax::<A>(a_slice, m, k, b_slice, n);
195            c[c_offset..c_offset + c_batch_stride].copy_from_slice(&c_batch);
196            argmax[c_offset..c_offset + c_batch_stride].copy_from_slice(&argmax_batch);
197        }
198
199        (c, argmax)
200    }
201}
202
203impl<T: Scalar> Storage<T> for Vec<T> {
204    #[inline]
205    fn len(&self) -> usize {
206        Vec::len(self)
207    }
208
209    #[inline]
210    fn get(&self, index: usize) -> T {
211        self[index]
212    }
213
214    #[inline]
215    fn set(&mut self, index: usize, value: T) {
216        self[index] = value;
217    }
218
219    #[inline]
220    fn to_vec(&self) -> Vec<T> {
221        self.clone()
222    }
223
224    #[inline]
225    fn from_slice(data: &[T]) -> Self {
226        data.to_vec()
227    }
228
229    #[inline]
230    fn zeros(len: usize) -> Self {
231        vec![T::default(); len]
232    }
233}
234
235impl Backend for Cpu {
236    type Storage<T: Scalar> = Vec<T>;
237
238    fn name() -> &'static str {
239        "cpu"
240    }
241
242    fn synchronize(&self) {
243        // No-op for CPU
244    }
245
246    fn alloc<T: Scalar>(&self, len: usize) -> Vec<T> {
247        vec![T::default(); len]
248    }
249
250    fn from_slice<T: Scalar>(&self, data: &[T]) -> Vec<T> {
251        data.to_vec()
252    }
253
254    fn contract<A: Algebra>(
255        &self,
256        a: &Self::Storage<A::Scalar>,
257        shape_a: &[usize],
258        strides_a: &[usize],
259        modes_a: &[i32],
260        b: &Self::Storage<A::Scalar>,
261        shape_b: &[usize],
262        strides_b: &[usize],
263        modes_b: &[i32],
264        shape_c: &[usize],
265        modes_c: &[i32],
266    ) -> Self::Storage<A::Scalar>
267    where
268        A::Scalar: BackendScalar<Self>,
269    {
270        contract::contract::<A>(
271            self, a, shape_a, strides_a, modes_a,
272            b, shape_b, strides_b, modes_b,
273            shape_c, modes_c,
274        )
275    }
276
277    fn contract_with_argmax<A: Algebra<Index = u32>>(
278        &self,
279        a: &Self::Storage<A::Scalar>,
280        shape_a: &[usize],
281        strides_a: &[usize],
282        modes_a: &[i32],
283        b: &Self::Storage<A::Scalar>,
284        shape_b: &[usize],
285        strides_b: &[usize],
286        modes_b: &[i32],
287        shape_c: &[usize],
288        modes_c: &[i32],
289    ) -> (Self::Storage<A::Scalar>, Self::Storage<u32>)
290    where
291        A::Scalar: BackendScalar<Self>,
292    {
293        contract::contract_with_argmax::<A>(
294            self, a, shape_a, strides_a, modes_a,
295            b, shape_b, strides_b, modes_b,
296            shape_c, modes_c,
297        )
298    }
299
300    fn copy_strided<T: Scalar>(
301        &self,
302        src: &Vec<T>,
303        shape: &[usize],
304        strides: &[usize],
305        offset: usize,
306    ) -> Vec<T> {
307        let numel: usize = shape.iter().product();
308        let mut dst = vec![T::default(); numel];
309
310        // Iterate over all indices and copy
311        let mut indices = vec![0usize; shape.len()];
312        for dst_elem in dst.iter_mut() {
313            // Compute source offset using strides
314            let src_offset: usize = offset
315                + indices
316                    .iter()
317                    .zip(strides.iter())
318                    .map(|(i, s)| i * s)
319                    .sum::<usize>();
320
321            *dst_elem = src[src_offset];
322
323            // Increment indices (column-major order: first dimension first)
324            for dim in 0..shape.len() {
325                indices[dim] += 1;
326                if indices[dim] < shape[dim] {
327                    break;
328                }
329                indices[dim] = 0;
330            }
331        }
332
333        dst
334    }
335}
336
337/// GEMM using faer for f32 (column-major layout).
338///
339/// Computes C = A @ B where A is m×k, B is k×n, C is m×n.
340fn faer_gemm_f32(a: &[f32], m: usize, k: usize, b: &[f32], n: usize) -> Vec<f32> {
341    use faer::Mat;
342
343    // Create matrices from column-major data
344    // Column-major: element (i, j) is at index j * nrows + i
345    let a_mat = Mat::from_fn(m, k, |i, j| a[j * m + i]);
346    let b_mat = Mat::from_fn(k, n, |i, j| b[j * k + i]);
347
348    // Multiply
349    let c_mat = &a_mat * &b_mat;
350
351    // Convert back to column-major Vec
352    let mut c = vec![0.0f32; m * n];
353    for j in 0..n {
354        for i in 0..m {
355            c[j * m + i] = c_mat[(i, j)];
356        }
357    }
358    c
359}
360
361/// GEMM using faer for f64 (column-major layout).
362fn faer_gemm_f64(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> Vec<f64> {
363    use faer::Mat;
364
365    let a_mat = Mat::from_fn(m, k, |i, j| a[j * m + i]);
366    let b_mat = Mat::from_fn(k, n, |i, j| b[j * k + i]);
367
368    let c_mat = &a_mat * &b_mat;
369
370    let mut c = vec![0.0f64; m * n];
371    for j in 0..n {
372        for i in 0..m {
373            c[j * m + i] = c_mat[(i, j)];
374        }
375    }
376    c
377}
378
379/// Generic GEMM using semiring operations (column-major layout).
380fn generic_gemm<A: Algebra>(
381    a: &[A::Scalar],
382    m: usize,
383    k: usize,
384    b: &[A::Scalar],
385    n: usize,
386) -> Vec<A::Scalar> {
387    let mut c = vec![A::zero().to_scalar(); m * n];
388
389    // Column-major: element (i, j) is at index j * nrows + i
390    for j in 0..n {
391        for i in 0..m {
392            let mut acc = A::zero();
393            for kk in 0..k {
394                let a_val = A::from_scalar(a[kk * m + i]); // A[i, kk] in column-major
395                let b_val = A::from_scalar(b[j * k + kk]); // B[kk, j] in column-major
396                let prod = a_val.mul(b_val);
397                acc = acc.add(prod);
398            }
399            c[j * m + i] = acc.to_scalar();
400        }
401    }
402
403    c
404}
405
406/// Generic GEMM with argmax tracking (column-major layout).
407fn generic_gemm_with_argmax<A: Algebra<Index = u32>>(
408    a: &[A::Scalar],
409    m: usize,
410    k: usize,
411    b: &[A::Scalar],
412    n: usize,
413) -> (Vec<A::Scalar>, Vec<u32>) {
414    let mut c = vec![A::zero().to_scalar(); m * n];
415    let mut argmax = vec![0u32; m * n];
416
417    // Column-major: element (i, j) is at index j * nrows + i
418    for j in 0..n {
419        for i in 0..m {
420            let mut acc = A::zero();
421            let mut best_k = 0u32;
422
423            for kk in 0..k {
424                let a_val = A::from_scalar(a[kk * m + i]); // A[i, kk] in column-major
425                let b_val = A::from_scalar(b[j * k + kk]); // B[kk, j] in column-major
426                let prod = a_val.mul(b_val);
427                let (new_acc, winner) = acc.add_with_argmax(best_k, prod, kk as u32);
428                acc = new_acc;
429                best_k = winner;
430            }
431
432            c[j * m + i] = acc.to_scalar();
433            argmax[j * m + i] = best_k;
434        }
435    }
436
437    (c, argmax)
438}
439
440// Optional: Use tropical-gemm for optimized kernels
441#[cfg(feature = "tropical-kernels")]
442fn try_tropical_gemm<A: Algebra>(
443    a: &[A::Scalar],
444    m: usize,
445    k: usize,
446    b: &[A::Scalar],
447    n: usize,
448) -> Option<Vec<A::Scalar>> {
449    use crate::algebra::{MaxMul, MaxPlus, MinPlus};
450    use std::any::TypeId;
451    use tropical_gemm::{
452        tropical_matmul, TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring,
453    };
454
455    // Dispatch based on algebra type using TypeId
456    // The tropical-gemm types have identical repr(transparent) layout to our types,
457    // and both wrap the scalar directly, so we can safely transmute the output.
458
459    if TypeId::of::<A>() == TypeId::of::<MaxPlus<f32>>() {
460        // SAFETY: A::Scalar is f32, and MaxPlus<f32> has repr(transparent) over f32
461        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
462        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
463
464        let result: Vec<TropicalMaxPlus<f32>> =
465            tropical_matmul::<TropicalMaxPlus<f32>>(a_f32, m, k, b_f32, n);
466
467        // Convert TropicalMaxPlus<f32> -> f32, both are repr(transparent) over f32
468        let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
469
470        // SAFETY: A::Scalar is f32
471        Some(unsafe { std::mem::transmute(scalars) })
472    } else if TypeId::of::<A>() == TypeId::of::<MaxPlus<f64>>() {
473        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
474        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
475
476        let result: Vec<TropicalMaxPlus<f64>> =
477            tropical_matmul::<TropicalMaxPlus<f64>>(a_f64, m, k, b_f64, n);
478        let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
479
480        Some(unsafe { std::mem::transmute(scalars) })
481    } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f32>>() {
482        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
483        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
484
485        let result: Vec<TropicalMinPlus<f32>> =
486            tropical_matmul::<TropicalMinPlus<f32>>(a_f32, m, k, b_f32, n);
487        let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
488
489        Some(unsafe { std::mem::transmute(scalars) })
490    } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f64>>() {
491        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
492        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
493
494        let result: Vec<TropicalMinPlus<f64>> =
495            tropical_matmul::<TropicalMinPlus<f64>>(a_f64, m, k, b_f64, n);
496        let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
497
498        Some(unsafe { std::mem::transmute(scalars) })
499    } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f32>>() {
500        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
501        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
502
503        let result: Vec<TropicalMaxMul<f32>> =
504            tropical_matmul::<TropicalMaxMul<f32>>(a_f32, m, k, b_f32, n);
505        let scalars: Vec<f32> = result.into_iter().map(|x| x.value()).collect();
506
507        Some(unsafe { std::mem::transmute(scalars) })
508    } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f64>>() {
509        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
510        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
511
512        let result: Vec<TropicalMaxMul<f64>> =
513            tropical_matmul::<TropicalMaxMul<f64>>(a_f64, m, k, b_f64, n);
514        let scalars: Vec<f64> = result.into_iter().map(|x| x.value()).collect();
515
516        Some(unsafe { std::mem::transmute(scalars) })
517    } else {
518        // Unsupported type, fall back to generic implementation
519        None
520    }
521}
522
523#[cfg(feature = "tropical-kernels")]
524fn try_tropical_gemm_with_argmax<A: Algebra<Index = u32>>(
525    a: &[A::Scalar],
526    m: usize,
527    k: usize,
528    b: &[A::Scalar],
529    n: usize,
530) -> Option<(Vec<A::Scalar>, Vec<u32>)> {
531    use crate::algebra::{MaxMul, MaxPlus, MinPlus};
532    use std::any::TypeId;
533    use tropical_gemm::{
534        tropical_matmul_with_argmax, TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus,
535        TropicalSemiring,
536    };
537
538    // Dispatch based on algebra type using TypeId
539    if TypeId::of::<A>() == TypeId::of::<MaxPlus<f32>>() {
540        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
541        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
542
543        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f32>>(a_f32, m, k, b_f32, n);
544
545        // Convert to column-major storage
546        // Note: tropical-gemm's accessor functions use (col, row) order internally
547        let mut scalars = Vec::with_capacity(m * n);
548        let mut argmax = Vec::with_capacity(m * n);
549        for j in 0..n {
550            for i in 0..m {
551                scalars.push(result.get(j, i).value());
552                argmax.push(result.get_argmax(j, i));
553            }
554        }
555
556        Some((unsafe { std::mem::transmute(scalars) }, argmax))
557    } else if TypeId::of::<A>() == TypeId::of::<MaxPlus<f64>>() {
558        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
559        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
560
561        let result = tropical_matmul_with_argmax::<TropicalMaxPlus<f64>>(a_f64, m, k, b_f64, n);
562
563        // Convert to column-major storage
564        let mut scalars = Vec::with_capacity(m * n);
565        let mut argmax = Vec::with_capacity(m * n);
566        for j in 0..n {
567            for i in 0..m {
568                scalars.push(result.get(j, i).value());
569                argmax.push(result.get_argmax(j, i));
570            }
571        }
572
573        Some((unsafe { std::mem::transmute(scalars) }, argmax))
574    } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f32>>() {
575        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
576        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
577
578        let result = tropical_matmul_with_argmax::<TropicalMinPlus<f32>>(a_f32, m, k, b_f32, n);
579
580        // Convert to column-major storage
581        let mut scalars = Vec::with_capacity(m * n);
582        let mut argmax = Vec::with_capacity(m * n);
583        for j in 0..n {
584            for i in 0..m {
585                scalars.push(result.get(j, i).value());
586                argmax.push(result.get_argmax(j, i));
587            }
588        }
589
590        Some((unsafe { std::mem::transmute(scalars) }, argmax))
591    } else if TypeId::of::<A>() == TypeId::of::<MinPlus<f64>>() {
592        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
593        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
594
595        let result = tropical_matmul_with_argmax::<TropicalMinPlus<f64>>(a_f64, m, k, b_f64, n);
596
597        // Convert to column-major storage
598        let mut scalars = Vec::with_capacity(m * n);
599        let mut argmax = Vec::with_capacity(m * n);
600        for j in 0..n {
601            for i in 0..m {
602                scalars.push(result.get(j, i).value());
603                argmax.push(result.get_argmax(j, i));
604            }
605        }
606
607        Some((unsafe { std::mem::transmute(scalars) }, argmax))
608    } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f32>>() {
609        let a_f32: &[f32] = unsafe { std::mem::transmute(a) };
610        let b_f32: &[f32] = unsafe { std::mem::transmute(b) };
611
612        let result = tropical_matmul_with_argmax::<TropicalMaxMul<f32>>(a_f32, m, k, b_f32, n);
613
614        // Convert to column-major storage
615        let mut scalars = Vec::with_capacity(m * n);
616        let mut argmax = Vec::with_capacity(m * n);
617        for j in 0..n {
618            for i in 0..m {
619                scalars.push(result.get(j, i).value());
620                argmax.push(result.get_argmax(j, i));
621            }
622        }
623
624        Some((unsafe { std::mem::transmute(scalars) }, argmax))
625    } else if TypeId::of::<A>() == TypeId::of::<MaxMul<f64>>() {
626        let a_f64: &[f64] = unsafe { std::mem::transmute(a) };
627        let b_f64: &[f64] = unsafe { std::mem::transmute(b) };
628
629        let result = tropical_matmul_with_argmax::<TropicalMaxMul<f64>>(a_f64, m, k, b_f64, n);
630
631        // Convert to column-major storage
632        let mut scalars = Vec::with_capacity(m * n);
633        let mut argmax = Vec::with_capacity(m * n);
634        for j in 0..n {
635            for i in 0..m {
636                scalars.push(result.get(j, i).value());
637                argmax.push(result.get_argmax(j, i));
638            }
639        }
640
641        Some((unsafe { std::mem::transmute(scalars) }, argmax))
642    } else {
643        // Unsupported type, fall back to generic implementation
644        None
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use crate::algebra::Standard;
652
653    #[cfg(feature = "tropical")]
654    use crate::algebra::MaxPlus;
655
656    #[test]
657    fn test_cpu_gemm_standard() {
658        let cpu = Cpu;
659        let a = vec![1.0f32, 2.0, 3.0, 4.0]; // 2x2
660        let b = vec![1.0f32, 2.0, 3.0, 4.0]; // 2x2
661
662        let c = cpu.gemm_internal::<Standard<f32>>(&a, 2, 2, &b, 2);
663
664        // [1 2] × [1 2] = [1*1+2*3  1*2+2*4] = [7  10]
665        // [3 4]   [3 4]   [3*1+4*3  3*2+4*4]   [15 22]
666        assert_eq!(c, vec![7.0, 10.0, 15.0, 22.0]);
667    }
668
669    #[cfg(feature = "tropical")]
670    #[test]
671    fn test_cpu_gemm_maxplus() {
672        let cpu = Cpu;
673        let a = vec![1.0f32, 2.0, 3.0, 4.0]; // 2x2
674        let b = vec![1.0f32, 2.0, 3.0, 4.0]; // 2x2
675
676        let c = cpu.gemm_internal::<MaxPlus<f32>>(&a, 2, 2, &b, 2);
677
678        // MaxPlus: C[i,j] = max_k(A[i,k] + B[k,j])
679        // C[0,0] = max(1+1, 2+3) = max(2, 5) = 5
680        // C[0,1] = max(1+2, 2+4) = max(3, 6) = 6
681        // C[1,0] = max(3+1, 4+3) = max(4, 7) = 7
682        // C[1,1] = max(3+2, 4+4) = max(5, 8) = 8
683        assert_eq!(c, vec![5.0, 6.0, 7.0, 8.0]);
684    }
685
686    #[cfg(feature = "tropical")]
687    #[test]
688    fn test_cpu_gemm_with_argmax() {
689        let cpu = Cpu;
690        let a = vec![1.0f32, 2.0, 3.0, 4.0];
691        let b = vec![1.0f32, 2.0, 3.0, 4.0];
692
693        let (c, argmax) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, 2, 2, &b, 2);
694
695        assert_eq!(c, vec![5.0, 6.0, 7.0, 8.0]);
696        // All winners should be k=1 (second column of A, second row of B)
697        assert_eq!(argmax, vec![1, 1, 1, 1]);
698    }
699
700    #[test]
701    fn test_copy_strided() {
702        let cpu = Cpu;
703        // Column-major: data [1,2,3,4,5,6] for shape [2,3] represents:
704        // [[1, 3, 5],
705        //  [2, 4, 6]]
706        let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
707
708        // Transpose: shape [3, 2], strides [2, 1] (original col-major strides permuted)
709        // This reads the original matrix as transposed
710        let dst = cpu.copy_strided(&src, &[3, 2], &[2, 1], 0);
711
712        // Transposed matrix in column-major:
713        // [[1, 2],
714        //  [3, 4],
715        //  [5, 6]] -> column-major data: [1, 3, 5, 2, 4, 6]
716        assert_eq!(dst, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
717    }
718
719    /// Test that optimized tropical-gemm kernels produce same results as generic implementation.
720    #[cfg(feature = "tropical-kernels")]
721    #[test]
722    fn test_tropical_gemm_optimized_maxplus() {
723        use crate::algebra::MaxPlus;
724
725        let cpu = Cpu;
726        let m = 64;
727        let k = 64;
728        let n = 64;
729
730        let a: Vec<f32> = (0..m * k).map(|i| (i % 100) as f32).collect();
731        let b: Vec<f32> = (0..k * n).map(|i| (i % 100) as f32).collect();
732
733        // Test MaxPlus<f32>
734        let c_opt = cpu.gemm_internal::<MaxPlus<f32>>(&a, m, k, &b, n);
735        let c_generic = generic_gemm::<MaxPlus<f32>>(&a, m, k, &b, n);
736
737        for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
738            assert!(
739                (opt - gen).abs() < 1e-6,
740                "MaxPlus mismatch at index {}: opt={}, gen={}",
741                i,
742                opt,
743                gen
744            );
745        }
746    }
747
748    #[cfg(feature = "tropical-kernels")]
749    #[test]
750    fn test_tropical_gemm_optimized_minplus() {
751        use crate::algebra::MinPlus;
752
753        let cpu = Cpu;
754        let m = 32;
755        let k = 32;
756        let n = 32;
757
758        let a: Vec<f32> = (0..m * k).map(|i| (i % 50) as f32).collect();
759        let b: Vec<f32> = (0..k * n).map(|i| (i % 50) as f32).collect();
760
761        // Test MinPlus<f32>
762        let c_opt = cpu.gemm_internal::<MinPlus<f32>>(&a, m, k, &b, n);
763        let c_generic = generic_gemm::<MinPlus<f32>>(&a, m, k, &b, n);
764
765        for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
766            assert!(
767                (opt - gen).abs() < 1e-6,
768                "MinPlus mismatch at index {}: opt={}, gen={}",
769                i,
770                opt,
771                gen
772            );
773        }
774    }
775
776    #[cfg(feature = "tropical-kernels")]
777    #[test]
778    fn test_tropical_gemm_optimized_maxmul() {
779        use crate::algebra::MaxMul;
780
781        let cpu = Cpu;
782        let m = 16;
783        let k = 16;
784        let n = 16;
785
786        // Use small values to avoid overflow in multiplication
787        let a: Vec<f32> = (0..m * k).map(|i| ((i % 10) as f32) * 0.1 + 0.1).collect();
788        let b: Vec<f32> = (0..k * n).map(|i| ((i % 10) as f32) * 0.1 + 0.1).collect();
789
790        // Test MaxMul<f32>
791        let c_opt = cpu.gemm_internal::<MaxMul<f32>>(&a, m, k, &b, n);
792        let c_generic = generic_gemm::<MaxMul<f32>>(&a, m, k, &b, n);
793
794        for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
795            assert!(
796                (opt - gen).abs() < 1e-5,
797                "MaxMul mismatch at index {}: opt={}, gen={}",
798                i,
799                opt,
800                gen
801            );
802        }
803    }
804
805    #[cfg(feature = "tropical-kernels")]
806    #[test]
807    fn test_tropical_gemm_with_argmax_optimized() {
808        use crate::algebra::MaxPlus;
809
810        let cpu = Cpu;
811        let m = 32;
812        let k = 32;
813        let n = 32;
814
815        let a: Vec<f32> = (0..m * k).map(|i| (i % 100) as f32).collect();
816        let b: Vec<f32> = (0..k * n).map(|i| (i % 100) as f32).collect();
817
818        // Test MaxPlus<f32> with argmax
819        let (c_opt, argmax_opt) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, m, k, &b, n);
820        let (c_generic, argmax_generic) = generic_gemm_with_argmax::<MaxPlus<f32>>(&a, m, k, &b, n);
821
822        for (i, (opt, gen)) in c_opt.iter().zip(c_generic.iter()).enumerate() {
823            assert!(
824                (opt - gen).abs() < 1e-6,
825                "MaxPlus with argmax: value mismatch at index {}: opt={}, gen={}",
826                i,
827                opt,
828                gen
829            );
830        }
831
832        for (i, (opt, gen)) in argmax_opt.iter().zip(argmax_generic.iter()).enumerate() {
833            assert_eq!(
834                opt, gen,
835                "MaxPlus with argmax: argmax mismatch at index {}: opt={}, gen={}",
836                i, opt, gen
837            );
838        }
839    }
840
841    #[cfg(feature = "tropical")]
842    #[test]
843    fn test_gemm_backward() {
844        let cpu = Cpu;
845        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
846        let b = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
847
848        let (_c, argmax) = cpu.gemm_with_argmax_internal::<MaxPlus<f32>>(&a, 2, 3, &b, 2);
849
850        let grad_c = vec![1.0f32; 4];
851        let grad_a = cpu.gemm_backward_a_internal::<MaxPlus<f32>>(&grad_c, &argmax, &b, 2, 3, 2);
852        let grad_b = cpu.gemm_backward_b_internal::<MaxPlus<f32>>(&grad_c, &argmax, &a, 2, 3, 2);
853
854        assert_eq!(grad_a.len(), 6);
855        assert_eq!(grad_b.len(), 6);
856
857        // Verify that gradients accumulated correctly (no unsafe transmute issues)
858        // The sum of all gradients should equal the sum of all grad_c elements
859        // since each grad_c element contributes exactly once to grad_a and grad_b
860        let grad_a_sum: f32 = grad_a.iter().sum();
861        let grad_b_sum: f32 = grad_b.iter().sum();
862        let grad_c_sum: f32 = grad_c.iter().sum();
863
864        assert_eq!(grad_a_sum, grad_c_sum, "grad_a sum should equal grad_c sum");
865        assert_eq!(grad_b_sum, grad_c_sum, "grad_b sum should equal grad_c sum");
866    }
867}