tropical_gemm/simd/
detect.rs

1/// CPU feature detection for runtime SIMD dispatch.
2/// Available SIMD instruction sets.
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum SimdLevel {
5    /// No SIMD, use scalar code.
6    Scalar,
7    /// SSE2 (128-bit, available on all x86-64).
8    Sse2,
9    /// AVX (256-bit float).
10    Avx,
11    /// AVX2 (256-bit integer + FMA).
12    Avx2,
13    /// AVX-512 (512-bit).
14    Avx512,
15    /// ARM NEON (128-bit).
16    Neon,
17}
18
19impl SimdLevel {
20    /// Detect the best available SIMD level at runtime.
21    pub fn detect() -> Self {
22        #[cfg(target_arch = "x86_64")]
23        {
24            if is_x86_feature_detected!("avx512f") {
25                return SimdLevel::Avx512;
26            }
27            if is_x86_feature_detected!("avx2") {
28                return SimdLevel::Avx2;
29            }
30            if is_x86_feature_detected!("avx") {
31                return SimdLevel::Avx;
32            }
33            // SSE2 is always available on x86-64
34            SimdLevel::Sse2
35        }
36
37        #[cfg(target_arch = "aarch64")]
38        {
39            // NEON is always available on AArch64
40            SimdLevel::Neon
41        }
42
43        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
44        {
45            SimdLevel::Scalar
46        }
47    }
48
49    /// Get the SIMD width in bytes.
50    pub fn width_bytes(&self) -> usize {
51        match self {
52            SimdLevel::Scalar => 1,
53            SimdLevel::Sse2 | SimdLevel::Neon => 16,
54            SimdLevel::Avx | SimdLevel::Avx2 => 32,
55            SimdLevel::Avx512 => 64,
56        }
57    }
58
59    /// Get the number of f32 elements that fit in one SIMD register.
60    pub fn f32_width(&self) -> usize {
61        self.width_bytes() / 4
62    }
63
64    /// Get the number of f64 elements that fit in one SIMD register.
65    pub fn f64_width(&self) -> usize {
66        self.width_bytes() / 8
67    }
68}
69
70/// Global cached SIMD level.
71static SIMD_LEVEL: std::sync::OnceLock<SimdLevel> = std::sync::OnceLock::new();
72
73/// Get the detected SIMD level (cached).
74pub fn simd_level() -> SimdLevel {
75    *SIMD_LEVEL.get_or_init(SimdLevel::detect)
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_detect() {
84        let level = SimdLevel::detect();
85        println!("Detected SIMD level: {:?}", level);
86
87        // Should detect at least Scalar
88        assert!(level >= SimdLevel::Scalar);
89
90        // On x86-64, should detect at least SSE2
91        #[cfg(target_arch = "x86_64")]
92        assert!(level >= SimdLevel::Sse2);
93
94        // On AArch64, should detect NEON
95        #[cfg(target_arch = "aarch64")]
96        assert_eq!(level, SimdLevel::Neon);
97    }
98
99    #[test]
100    fn test_width() {
101        assert_eq!(SimdLevel::Avx2.f32_width(), 8);
102        assert_eq!(SimdLevel::Avx2.f64_width(), 4);
103        assert_eq!(SimdLevel::Sse2.f32_width(), 4);
104    }
105
106    #[test]
107    fn test_width_bytes() {
108        assert_eq!(SimdLevel::Scalar.width_bytes(), 1);
109        assert_eq!(SimdLevel::Sse2.width_bytes(), 16);
110        assert_eq!(SimdLevel::Neon.width_bytes(), 16);
111        assert_eq!(SimdLevel::Avx.width_bytes(), 32);
112        assert_eq!(SimdLevel::Avx2.width_bytes(), 32);
113        assert_eq!(SimdLevel::Avx512.width_bytes(), 64);
114    }
115
116    #[test]
117    fn test_all_widths() {
118        // f32 widths
119        assert_eq!(SimdLevel::Scalar.f32_width(), 0); // 1/4 = 0
120        assert_eq!(SimdLevel::Sse2.f32_width(), 4); // 16/4
121        assert_eq!(SimdLevel::Neon.f32_width(), 4); // 16/4
122        assert_eq!(SimdLevel::Avx.f32_width(), 8); // 32/4
123        assert_eq!(SimdLevel::Avx512.f32_width(), 16); // 64/4
124
125        // f64 widths
126        assert_eq!(SimdLevel::Scalar.f64_width(), 0); // 1/8 = 0
127        assert_eq!(SimdLevel::Sse2.f64_width(), 2); // 16/8
128        assert_eq!(SimdLevel::Neon.f64_width(), 2); // 16/8
129        assert_eq!(SimdLevel::Avx.f64_width(), 4); // 32/8
130        assert_eq!(SimdLevel::Avx512.f64_width(), 8); // 64/8
131    }
132
133    #[test]
134    fn test_simd_level_cached() {
135        // Calling simd_level() multiple times should return same value
136        let level1 = simd_level();
137        let level2 = simd_level();
138        assert_eq!(level1, level2);
139    }
140}