1#[derive(Debug, Clone, Copy)]
6pub struct TilingParams {
7 pub mc: usize,
9 pub nc: usize,
11 pub kc: usize,
13 pub mr: usize,
15 pub nr: usize,
17}
18
19impl TilingParams {
20 pub const F32_AVX2: Self = Self {
22 mc: 256,
23 nc: 256,
24 kc: 512,
25 mr: 8,
26 nr: 8,
27 };
28
29 pub const F64_AVX2: Self = Self {
31 mc: 128,
32 nc: 128,
33 kc: 256,
34 mr: 4,
35 nr: 4,
36 };
37
38 pub const PORTABLE: Self = Self {
40 mc: 64,
41 nc: 64,
42 kc: 256,
43 mr: 4,
44 nr: 4,
45 };
46
47 pub const fn new(mc: usize, nc: usize, kc: usize, mr: usize, nr: usize) -> Self {
49 Self { mc, nc, kc, mr, nr }
50 }
51
52 pub fn validate(&self) -> Result<(), &'static str> {
54 if self.mr == 0 || self.nr == 0 {
56 return Err("mr and nr must be non-zero");
57 }
58 if self.mc == 0 || self.nc == 0 || self.kc == 0 {
59 return Err("mc, nc, and kc must be non-zero");
60 }
61 if !self.mc.is_multiple_of(self.mr) {
63 return Err("mc must be divisible by mr");
64 }
65 if !self.nc.is_multiple_of(self.nr) {
66 return Err("nc must be divisible by nr");
67 }
68 Ok(())
69 }
70}
71
72impl Default for TilingParams {
73 fn default() -> Self {
74 Self::PORTABLE
75 }
76}
77
78pub struct BlockIterator {
80 total: usize,
81 block_size: usize,
82 current: usize,
83}
84
85impl BlockIterator {
86 pub fn new(total: usize, block_size: usize) -> Self {
87 Self {
88 total,
89 block_size,
90 current: 0,
91 }
92 }
93}
94
95impl Iterator for BlockIterator {
96 type Item = (usize, usize);
98
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.current >= self.total {
101 return None;
102 }
103
104 let start = self.current;
105 let len = (self.total - start).min(self.block_size);
106 self.current += len;
107
108 Some((start, len))
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_block_iterator() {
118 let iter = BlockIterator::new(10, 3);
119 let blocks: Vec<_> = iter.collect();
120
121 assert_eq!(blocks, vec![(0, 3), (3, 3), (6, 3), (9, 1)]);
122 }
123
124 #[test]
125 fn test_block_iterator_exact() {
126 let iter = BlockIterator::new(9, 3);
128 let blocks: Vec<_> = iter.collect();
129 assert_eq!(blocks, vec![(0, 3), (3, 3), (6, 3)]);
130 }
131
132 #[test]
133 fn test_block_iterator_empty() {
134 let iter = BlockIterator::new(0, 3);
135 let blocks: Vec<_> = iter.collect();
136 assert!(blocks.is_empty());
137 }
138
139 #[test]
140 fn test_validate_params() {
141 assert!(TilingParams::F32_AVX2.validate().is_ok());
142 assert!(TilingParams::F64_AVX2.validate().is_ok());
143 assert!(TilingParams::PORTABLE.validate().is_ok());
144
145 let bad = TilingParams::new(100, 64, 256, 8, 8);
146 assert!(bad.validate().is_err()); }
148
149 #[test]
150 fn test_validate_nc_not_divisible() {
151 let bad = TilingParams::new(64, 100, 256, 8, 8);
153 assert_eq!(bad.validate(), Err("nc must be divisible by nr"));
154 }
155
156 #[test]
157 fn test_validate_mr_zero() {
158 let bad = TilingParams::new(64, 64, 256, 0, 8);
159 assert_eq!(bad.validate(), Err("mr and nr must be non-zero"));
160 }
161
162 #[test]
163 fn test_validate_nr_zero() {
164 let bad = TilingParams::new(64, 64, 256, 8, 0);
165 assert_eq!(bad.validate(), Err("mr and nr must be non-zero"));
166 }
167
168 #[test]
169 fn test_validate_mc_zero() {
170 let bad = TilingParams::new(0, 64, 256, 8, 8);
171 assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
172 }
173
174 #[test]
175 fn test_validate_nc_zero() {
176 let bad = TilingParams::new(64, 0, 256, 8, 8);
177 assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
178 }
179
180 #[test]
181 fn test_validate_kc_zero() {
182 let bad = TilingParams::new(64, 64, 0, 8, 8);
183 assert_eq!(bad.validate(), Err("mc, nc, and kc must be non-zero"));
184 }
185
186 #[test]
187 fn test_default() {
188 let default = TilingParams::default();
189 assert_eq!(default.mc, TilingParams::PORTABLE.mc);
190 assert_eq!(default.nc, TilingParams::PORTABLE.nc);
191 assert_eq!(default.kc, TilingParams::PORTABLE.kc);
192 assert_eq!(default.mr, TilingParams::PORTABLE.mr);
193 assert_eq!(default.nr, TilingParams::PORTABLE.nr);
194 }
195
196 #[test]
197 fn test_debug() {
198 let params = TilingParams::PORTABLE;
199 let debug_str = format!("{:?}", params);
200 assert!(debug_str.contains("TilingParams"));
201 }
202
203 #[test]
204 fn test_clone() {
205 let params = TilingParams::F32_AVX2;
206 let cloned = params;
207 assert_eq!(params.mc, cloned.mc);
208 }
209}