tropical_gemm/core/
tiling.rs

1/// Tiling parameters for BLIS-style GEMM blocking.
2///
3/// These parameters control how matrices are partitioned to fit in
4/// various levels of the cache hierarchy.
5#[derive(Debug, Clone, Copy)]
6pub struct TilingParams {
7    /// Block size for M dimension (L2 cache).
8    pub mc: usize,
9    /// Block size for N dimension (L2 cache).
10    pub nc: usize,
11    /// Block size for K dimension (L1 cache).
12    pub kc: usize,
13    /// Microkernel M dimension (registers).
14    pub mr: usize,
15    /// Microkernel N dimension (registers).
16    pub nr: usize,
17}
18
19impl TilingParams {
20    /// Default parameters for f32 with AVX2.
21    pub const F32_AVX2: Self = Self {
22        mc: 256,
23        nc: 256,
24        kc: 512,
25        mr: 8,
26        nr: 8,
27    };
28
29    /// Default parameters for f64 with AVX2.
30    pub const F64_AVX2: Self = Self {
31        mc: 128,
32        nc: 128,
33        kc: 256,
34        mr: 4,
35        nr: 4,
36    };
37
38    /// Default parameters for portable (non-SIMD) execution.
39    pub const PORTABLE: Self = Self {
40        mc: 64,
41        nc: 64,
42        kc: 256,
43        mr: 4,
44        nr: 4,
45    };
46
47    /// Create custom tiling parameters.
48    pub const fn new(mc: usize, nc: usize, kc: usize, mr: usize, nr: usize) -> Self {
49        Self { mc, nc, kc, mr, nr }
50    }
51
52    /// Validate that tiling parameters are consistent.
53    pub fn validate(&self) -> Result<(), &'static str> {
54        // Check for zero values first (before divisibility checks)
55        if self.mr == 0 || self.nr == 0 {
56            return Err("mr and nr must be non-zero");
57        }
58        if self.mc == 0 || self.nc == 0 || self.kc == 0 {
59            return Err("mc, nc, and kc must be non-zero");
60        }
61        // Now check divisibility
62        if !self.mc.is_multiple_of(self.mr) {
63            return Err("mc must be divisible by mr");
64        }
65        if !self.nc.is_multiple_of(self.nr) {
66            return Err("nc must be divisible by nr");
67        }
68        Ok(())
69    }
70}
71
72impl Default for TilingParams {
73    fn default() -> Self {
74        Self::PORTABLE
75    }
76}
77
78/// Iterator over blocks for the outer loop.
79pub struct BlockIterator {
80    total: usize,
81    block_size: usize,
82    current: usize,
83}
84
85impl BlockIterator {
86    pub fn new(total: usize, block_size: usize) -> Self {
87        Self {
88            total,
89            block_size,
90            current: 0,
91        }
92    }
93}
94
95impl Iterator for BlockIterator {
96    /// (start, length) of each block
97    type Item = (usize, usize);
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if self.current >= self.total {
101            return None;
102        }
103
104        let start = self.current;
105        let len = (self.total - start).min(self.block_size);
106        self.current += len;
107
108        Some((start, len))
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_block_iterator() {
118        let iter = BlockIterator::new(10, 3);
119        let blocks: Vec<_> = iter.collect();
120
121        assert_eq!(blocks, vec![(0, 3), (3, 3), (6, 3), (9, 1)]);
122    }
123
124    #[test]
125    fn test_block_iterator_exact() {
126        // When total is exactly divisible by block_size
127        let iter = BlockIterator::new(9, 3);
128        let blocks: Vec<_> = iter.collect();
129        assert_eq!(blocks, vec![(0, 3), (3, 3), (6, 3)]);
130    }
131
132    #[test]
133    fn test_block_iterator_empty() {
134        let iter = BlockIterator::new(0, 3);
135        let blocks: Vec<_> = iter.collect();
136        assert!(blocks.is_empty());
137    }
138
139    #[test]
140    fn test_validate_params() {
141        assert!(TilingParams::F32_AVX2.validate().is_ok());
142        assert!(TilingParams::F64_AVX2.validate().is_ok());
143        assert!(TilingParams::PORTABLE.validate().is_ok());
144
145        let bad = TilingParams::new(100, 64, 256, 8, 8);
146        assert!(bad.validate().is_err()); // 100 % 8 != 0
147    }
148
149    #[test]
150    fn test_validate_nc_not_divisible() {
151        // nc not divisible by nr
152        let bad = TilingParams::new(64, 100, 256, 8, 8);
153        assert_eq!(bad.validate(), Err("nc must be divisible by nr"));
154    }
155
156    #[test]
157    fn test_validate_mr_zero() {
158        let bad = TilingParams::new(64, 64, 256, 0, 8);
159        assert_eq!(bad.validate(), Err("mr and nr must be non-zero"));
160    }
161
162    #[test]
163    fn test_validate_nr_zero() {
164        let bad = TilingParams::new(64, 64, 256, 8, 0);
165        assert_eq!(bad.validate(), Err("mr and nr must be non-zero"));
166    }
167
168    #[test]
169    fn test_validate_mc_zero() {
170        let bad = TilingParams::new(0, 64, 256, 8, 8);
171        assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
172    }
173
174    #[test]
175    fn test_validate_nc_zero() {
176        let bad = TilingParams::new(64, 0, 256, 8, 8);
177        assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
178    }
179
180    #[test]
181    fn test_validate_kc_zero() {
182        let bad = TilingParams::new(64, 64, 0, 8, 8);
183        assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
184    }
185
186    #[test]
187    fn test_default() {
188        let default = TilingParams::default();
189        assert_eq!(default.mc, TilingParams::PORTABLE.mc);
190        assert_eq!(default.nc, TilingParams::PORTABLE.nc);
191        assert_eq!(default.kc, TilingParams::PORTABLE.kc);
192        assert_eq!(default.mr, TilingParams::PORTABLE.mr);
193        assert_eq!(default.nr, TilingParams::PORTABLE.nr);
194    }
195
196    #[test]
197    fn test_debug() {
198        let params = TilingParams::PORTABLE;
199        let debug_str = format!("{:?}", params);
200        assert!(debug_str.contains("TilingParams"));
201    }
202
203    #[test]
204    fn test_clone() {
205        let params = TilingParams::F32_AVX2;
206        let cloned = params;
207        assert_eq!(params.mc, cloned.mc);
208    }
209}