tropical_gemm/simd/
detect.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum SimdLevel {
5 Scalar,
7 Sse2,
9 Avx,
11 Avx2,
13 Avx512,
15 Neon,
17}
18
19impl SimdLevel {
20 pub fn detect() -> Self {
22 #[cfg(target_arch = "x86_64")]
23 {
24 if is_x86_feature_detected!("avx512f") {
25 return SimdLevel::Avx512;
26 }
27 if is_x86_feature_detected!("avx2") {
28 return SimdLevel::Avx2;
29 }
30 if is_x86_feature_detected!("avx") {
31 return SimdLevel::Avx;
32 }
33 SimdLevel::Sse2
35 }
36
37 #[cfg(target_arch = "aarch64")]
38 {
39 SimdLevel::Neon
41 }
42
43 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
44 {
45 SimdLevel::Scalar
46 }
47 }
48
49 pub fn width_bytes(&self) -> usize {
51 match self {
52 SimdLevel::Scalar => 1,
53 SimdLevel::Sse2 | SimdLevel::Neon => 16,
54 SimdLevel::Avx | SimdLevel::Avx2 => 32,
55 SimdLevel::Avx512 => 64,
56 }
57 }
58
59 pub fn f32_width(&self) -> usize {
61 self.width_bytes() / 4
62 }
63
64 pub fn f64_width(&self) -> usize {
66 self.width_bytes() / 8
67 }
68}
69
70static SIMD_LEVEL: std::sync::OnceLock<SimdLevel> = std::sync::OnceLock::new();
72
73pub fn simd_level() -> SimdLevel {
75 *SIMD_LEVEL.get_or_init(SimdLevel::detect)
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_detect() {
84 let level = SimdLevel::detect();
85 println!("Detected SIMD level: {:?}", level);
86
87 assert!(level >= SimdLevel::Scalar);
89
90 #[cfg(target_arch = "x86_64")]
92 assert!(level >= SimdLevel::Sse2);
93
94 #[cfg(target_arch = "aarch64")]
96 assert_eq!(level, SimdLevel::Neon);
97 }
98
99 #[test]
100 fn test_width() {
101 assert_eq!(SimdLevel::Avx2.f32_width(), 8);
102 assert_eq!(SimdLevel::Avx2.f64_width(), 4);
103 assert_eq!(SimdLevel::Sse2.f32_width(), 4);
104 }
105
106 #[test]
107 fn test_width_bytes() {
108 assert_eq!(SimdLevel::Scalar.width_bytes(), 1);
109 assert_eq!(SimdLevel::Sse2.width_bytes(), 16);
110 assert_eq!(SimdLevel::Neon.width_bytes(), 16);
111 assert_eq!(SimdLevel::Avx.width_bytes(), 32);
112 assert_eq!(SimdLevel::Avx2.width_bytes(), 32);
113 assert_eq!(SimdLevel::Avx512.width_bytes(), 64);
114 }
115
116 #[test]
117 fn test_all_widths() {
118 assert_eq!(SimdLevel::Scalar.f32_width(), 0); assert_eq!(SimdLevel::Sse2.f32_width(), 4); assert_eq!(SimdLevel::Neon.f32_width(), 4); assert_eq!(SimdLevel::Avx.f32_width(), 8); assert_eq!(SimdLevel::Avx512.f32_width(), 16); assert_eq!(SimdLevel::Scalar.f64_width(), 0); assert_eq!(SimdLevel::Sse2.f64_width(), 2); assert_eq!(SimdLevel::Neon.f64_width(), 2); assert_eq!(SimdLevel::Avx.f64_width(), 4); assert_eq!(SimdLevel::Avx512.f64_width(), 8); }
132
133 #[test]
134 fn test_simd_level_cached() {
135 let level1 = simd_level();
137 let level2 = simd_level();
138 assert_eq!(level1, level2);
139 }
140}