omeinsum/tensor/mod.rs
1//! Stride-based tensor type with zero-copy views.
2//!
3//! The [`Tensor`] type supports:
4//! - Zero-copy `permute` and `reshape` operations
5//! - Automatic contiguous copy when needed for GEMM
6//! - Generic over algebra and backend
7
8mod ops;
9mod view;
10
11use std::sync::Arc;
12
13use crate::algebra::{Algebra, Scalar};
14use crate::backend::{Backend, Storage};
15
16pub use view::TensorView;
17
18/// A multi-dimensional tensor with stride-based layout.
19///
20/// Tensors support zero-copy view operations (permute, reshape) and
21/// automatically make data contiguous when needed for operations like GEMM.
22///
23/// # Type Parameters
24///
25/// * `T` - The scalar element type (f32, f64, etc.)
26/// * `B` - The backend type (Cpu, Cuda)
27///
28/// # Example
29///
30/// ```rust
31/// use omeinsum::{Tensor, Cpu};
32///
33/// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
34/// let b = a.permute(&[1, 0]); // Zero-copy transpose
35/// let c = b.contiguous(); // Make contiguous copy
36/// ```
37#[derive(Clone)]
38pub struct Tensor<T: Scalar, B: Backend> {
39 /// Shared storage (reference counted)
40 storage: Arc<B::Storage<T>>,
41
42 /// Shape of this view
43 shape: Vec<usize>,
44
45 /// Strides for each dimension (in elements)
46 strides: Vec<usize>,
47
48 /// Offset into storage
49 offset: usize,
50
51 /// Backend instance
52 backend: B,
53}
54
55impl<T: Scalar, B: Backend> Tensor<T, B> {
56 // ========================================================================
57 // Constructors
58 // ========================================================================
59
60 /// Create a tensor from data with the given shape.
61 ///
62 /// Data is assumed to be in column-major (Fortran) order.
63 pub fn from_data(data: &[T], shape: &[usize]) -> Self
64 where
65 B: Default,
66 {
67 Self::from_data_with_backend(data, shape, B::default())
68 }
69
70 /// Create a tensor from data with explicit backend.
71 pub fn from_data_with_backend(data: &[T], shape: &[usize], backend: B) -> Self {
72 let numel: usize = shape.iter().product();
73 assert_eq!(
74 data.len(),
75 numel,
76 "Data length {} doesn't match shape {:?} (expected {})",
77 data.len(),
78 shape,
79 numel
80 );
81
82 let storage = backend.from_slice(data);
83 let strides = compute_contiguous_strides(shape);
84
85 Self {
86 storage: Arc::new(storage),
87 shape: shape.to_vec(),
88 strides,
89 offset: 0,
90 backend,
91 }
92 }
93
94 /// Create a zero-filled tensor.
95 pub fn zeros(shape: &[usize]) -> Self
96 where
97 B: Default,
98 {
99 Self::zeros_with_backend(shape, B::default())
100 }
101
102 /// Create a zero-filled tensor with explicit backend.
103 pub fn zeros_with_backend(shape: &[usize], backend: B) -> Self {
104 let numel: usize = shape.iter().product();
105 let storage = backend.alloc(numel);
106 let strides = compute_contiguous_strides(shape);
107
108 Self {
109 storage: Arc::new(storage),
110 shape: shape.to_vec(),
111 strides,
112 offset: 0,
113 backend,
114 }
115 }
116
117 /// Create a tensor from storage with given shape.
118 ///
119 /// The storage must be contiguous and have exactly `shape.iter().product()` elements.
120 pub fn from_storage(storage: B::Storage<T>, shape: &[usize], backend: B) -> Self {
121 let numel: usize = shape.iter().product();
122 assert_eq!(
123 storage.len(),
124 numel,
125 "Storage length {} doesn't match shape {:?} (expected {})",
126 storage.len(),
127 shape,
128 numel
129 );
130
131 let strides = compute_contiguous_strides(shape);
132
133 Self {
134 storage: Arc::new(storage),
135 shape: shape.to_vec(),
136 strides,
137 offset: 0,
138 backend,
139 }
140 }
141
142 /// Get a reference to the underlying storage.
143 ///
144 /// Returns `Some(&storage)` only if the tensor is contiguous and has no offset.
145 /// For non-contiguous tensors, call `contiguous()` first.
146 pub fn storage(&self) -> Option<&B::Storage<T>> {
147 if self.is_contiguous() {
148 Some(self.storage.as_ref())
149 } else {
150 None
151 }
152 }
153
154 // ========================================================================
155 // Metadata
156 // ========================================================================
157
158 /// Get the shape of the tensor.
159 #[inline]
160 pub fn shape(&self) -> &[usize] {
161 &self.shape
162 }
163
164 /// Get the strides of the tensor.
165 #[inline]
166 pub fn strides(&self) -> &[usize] {
167 &self.strides
168 }
169
170 /// Get the number of dimensions.
171 #[inline]
172 pub fn ndim(&self) -> usize {
173 self.shape.len()
174 }
175
176 /// Get the total number of elements.
177 #[inline]
178 pub fn numel(&self) -> usize {
179 self.shape.iter().product()
180 }
181
182 /// Get the backend.
183 #[inline]
184 pub fn backend(&self) -> &B {
185 &self.backend
186 }
187
188 /// Check if the tensor is contiguous in memory (row-major).
189 pub fn is_contiguous(&self) -> bool {
190 if self.offset != 0 {
191 return false;
192 }
193 let expected = compute_contiguous_strides(&self.shape);
194 self.strides == expected
195 }
196
197 // ========================================================================
198 // Data Access
199 // ========================================================================
200
201 /// Copy all data to a Vec.
202 pub fn to_vec(&self) -> Vec<T> {
203 if self.is_contiguous() {
204 self.storage.to_vec()
205 } else {
206 self.contiguous().storage.to_vec()
207 }
208 }
209
210 /// Get underlying storage (only if contiguous).
211 pub fn as_slice(&self) -> Option<&[T]>
212 where
213 B::Storage<T>: AsRef<[T]>,
214 {
215 if self.is_contiguous() {
216 Some(self.storage.as_ref().as_ref())
217 } else {
218 None
219 }
220 }
221
222 /// Get element at linear index (column-major).
223 ///
224 /// This is an O(ndim) operation that directly accesses storage without
225 /// allocating memory. The linear index is interpreted in column-major order.
226 ///
227 /// # Arguments
228 ///
229 /// * `index` - Linear index into the flattened tensor (column-major order)
230 ///
231 /// # Panics
232 ///
233 /// Panics if index is out of bounds.
234 ///
235 /// # Example
236 ///
237 /// ```rust
238 /// use omeinsum::{Tensor, Cpu};
239 ///
240 /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
241 /// assert_eq!(t.get(0), 1.0);
242 /// assert_eq!(t.get(3), 4.0);
243 /// ```
244 pub fn get(&self, index: usize) -> T {
245 let numel = self.numel();
246 assert!(
247 index < numel,
248 "Index {} out of bounds for tensor with {} elements (shape {:?})",
249 index,
250 numel,
251 self.shape
252 );
253
254 // Convert linear index to multi-dimensional coordinates (column-major)
255 // Column-major: first dimension varies fastest
256 let mut remaining = index;
257 let mut storage_offset = self.offset;
258
259 for dim in 0..self.ndim() {
260 let coord = remaining % self.shape[dim];
261 remaining /= self.shape[dim];
262 storage_offset += coord * self.strides[dim];
263 }
264
265 self.storage.get(storage_offset)
266 }
267
268 // ========================================================================
269 // View Operations (zero-copy)
270 // ========================================================================
271
272 /// Permute dimensions (zero-copy).
273 ///
274 /// # Example
275 ///
276 /// ```rust
277 /// use omeinsum::{Tensor, Cpu};
278 ///
279 /// let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
280 /// let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 4]);
281 /// let b = a.permute(&[2, 0, 1]); // Shape becomes [4, 2, 3]
282 /// assert_eq!(b.shape(), &[4, 2, 3]);
283 /// ```
284 pub fn permute(&self, axes: &[usize]) -> Self {
285 assert_eq!(
286 axes.len(),
287 self.ndim(),
288 "Permutation axes length {} doesn't match ndim {}",
289 axes.len(),
290 self.ndim()
291 );
292
293 // Check axes are valid and unique
294 let mut seen = vec![false; self.ndim()];
295 for &ax in axes {
296 assert!(
297 ax < self.ndim(),
298 "Axis {} out of range for ndim {}",
299 ax,
300 self.ndim()
301 );
302 assert!(!seen[ax], "Duplicate axis {} in permutation", ax);
303 seen[ax] = true;
304 }
305
306 let new_shape: Vec<usize> = axes.iter().map(|&i| self.shape[i]).collect();
307 let new_strides: Vec<usize> = axes.iter().map(|&i| self.strides[i]).collect();
308
309 Self {
310 storage: Arc::clone(&self.storage),
311 shape: new_shape,
312 strides: new_strides,
313 offset: self.offset,
314 backend: self.backend.clone(),
315 }
316 }
317
318 /// Transpose (2D shorthand for permute).
319 pub fn t(&self) -> Self {
320 assert_eq!(
321 self.ndim(),
322 2,
323 "transpose requires 2D tensor, got {}D",
324 self.ndim()
325 );
326 self.permute(&[1, 0])
327 }
328
329 /// Reshape to a new shape (zero-copy if contiguous).
330 ///
331 /// # Example
332 ///
333 /// ```rust
334 /// use omeinsum::{Tensor, Cpu};
335 ///
336 /// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
337 /// let b = a.reshape(&[6]); // Flatten
338 /// let c = a.reshape(&[3, 2]); // Different shape, same data
339 /// assert_eq!(b.shape(), &[6]);
340 /// assert_eq!(c.shape(), &[3, 2]);
341 /// ```
342 pub fn reshape(&self, new_shape: &[usize]) -> Self {
343 let old_numel: usize = self.shape.iter().product();
344 let new_numel: usize = new_shape.iter().product();
345 assert_eq!(
346 old_numel, new_numel,
347 "Cannot reshape from {:?} ({} elements) to {:?} ({} elements)",
348 self.shape, old_numel, new_shape, new_numel
349 );
350
351 if self.is_contiguous() {
352 // Fast path: just update shape and strides
353 Self {
354 storage: Arc::clone(&self.storage),
355 shape: new_shape.to_vec(),
356 strides: compute_contiguous_strides(new_shape),
357 offset: self.offset,
358 backend: self.backend.clone(),
359 }
360 } else {
361 // Must make contiguous first
362 self.contiguous().reshape(new_shape)
363 }
364 }
365
366 /// Make tensor contiguous in memory.
367 ///
368 /// If already contiguous, returns a clone (shared storage).
369 /// Otherwise, copies data to a new contiguous buffer.
370 pub fn contiguous(&self) -> Self {
371 if self.is_contiguous() {
372 self.clone()
373 } else {
374 let storage =
375 self.backend
376 .copy_strided(&self.storage, &self.shape, &self.strides, self.offset);
377 Self {
378 storage: Arc::new(storage),
379 shape: self.shape.clone(),
380 strides: compute_contiguous_strides(&self.shape),
381 offset: 0,
382 backend: self.backend.clone(),
383 }
384 }
385 }
386
387 // ========================================================================
388 // Reduction Operations
389 // ========================================================================
390
391 /// Sum all elements using the algebra's addition.
392 ///
393 /// # Type Parameters
394 ///
395 /// * `A` - The algebra to use for summation
396 ///
397 /// # Example
398 ///
399 /// ```rust
400 /// use omeinsum::{Tensor, Cpu, Standard};
401 ///
402 /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
403 /// let sum = t.sum::<Standard<f32>>();
404 /// assert_eq!(sum, 10.0);
405 /// ```
406 pub fn sum<A: Algebra<Scalar = T>>(&self) -> T {
407 let data = self.to_vec();
408 let mut acc = A::zero();
409 for val in data {
410 acc = acc.add(A::from_scalar(val));
411 }
412 acc.to_scalar()
413 }
414
415 /// Sum along a specific axis using the algebra's addition.
416 ///
417 /// The result has one fewer dimension than the input.
418 ///
419 /// # Arguments
420 ///
421 /// * `axis` - The axis to sum over
422 ///
423 /// # Panics
424 ///
425 /// Panics if axis is out of bounds.
426 ///
427 /// # Example
428 ///
429 /// ```rust
430 /// use omeinsum::{Tensor, Cpu, Standard};
431 ///
432 /// // Column-major: data [1, 2, 3, 4] with shape [2, 2] represents:
433 /// // [[1, 3],
434 /// // [2, 4]]
435 /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
436 /// // Sum over axis 1 (columns): [1+3, 2+4] = [4, 6]
437 /// let result = t.sum_axis::<Standard<f32>>(1);
438 /// assert_eq!(result.to_vec(), vec![4.0, 6.0]);
439 /// ```
440 pub fn sum_axis<A: Algebra<Scalar = T>>(&self, axis: usize) -> Self
441 where
442 B: Default,
443 {
444 assert!(
445 axis < self.ndim(),
446 "Axis {} out of bounds for {}D tensor",
447 axis,
448 self.ndim()
449 );
450
451 let mut new_shape: Vec<usize> = self.shape.clone();
452 new_shape.remove(axis);
453
454 // Handle reduction to scalar
455 if new_shape.is_empty() {
456 let sum = self.sum::<A>();
457 return Self::from_data(&[sum], &[1]);
458 }
459
460 let data = self.to_vec();
461 let output_strides = compute_contiguous_strides(&new_shape);
462 let _axis_size = self.shape[axis];
463
464 // Compute output size
465 let output_numel: usize = new_shape.iter().product();
466 let mut result = vec![A::zero(); output_numel];
467
468 // Iterate over all elements in the input
469 for (flat_idx, &val) in data.iter().enumerate() {
470 // Convert flat index to multi-dimensional coordinates (column-major)
471 let mut coords: Vec<usize> = vec![0; self.ndim()];
472 let mut remaining = flat_idx;
473 for (dim, coord) in coords.iter_mut().enumerate() {
474 *coord = remaining % self.shape[dim];
475 remaining /= self.shape[dim];
476 }
477
478 // Build output coordinates by removing the summed axis
479 let out_coords: Vec<usize> = coords
480 .iter()
481 .enumerate()
482 .filter(|(i, _)| *i != axis)
483 .map(|(_, &c)| c)
484 .collect();
485
486 // Convert output coordinates to flat index (column-major)
487 let mut out_flat_idx = 0;
488 for (i, &coord) in out_coords.iter().enumerate() {
489 out_flat_idx += coord * output_strides[i];
490 }
491
492 result[out_flat_idx] = result[out_flat_idx].add(A::from_scalar(val));
493 }
494
495 let result_data: Vec<T> = result.into_iter().map(|v| v.to_scalar()).collect();
496 Self::from_data(&result_data, &new_shape)
497 }
498
499 /// Extract diagonal elements from a 2D tensor.
500 ///
501 /// # Panics
502 ///
503 /// Panics if the tensor is not 2D or not square.
504 ///
505 /// # Example
506 ///
507 /// ```rust
508 /// use omeinsum::{Tensor, Cpu};
509 ///
510 /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
511 /// let diag = t.diagonal();
512 /// assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
513 /// ```
514 pub fn diagonal(&self) -> Self
515 where
516 B: Default,
517 {
518 assert_eq!(
519 self.ndim(),
520 2,
521 "diagonal requires 2D tensor, got {}D",
522 self.ndim()
523 );
524 assert_eq!(
525 self.shape[0], self.shape[1],
526 "diagonal requires square tensor, got {:?}",
527 self.shape
528 );
529
530 let n = self.shape[0];
531 let data = self.to_vec();
532 let diag: Vec<T> = (0..n).map(|i| data[i * n + i]).collect();
533
534 Self::from_data(&diag, &[n])
535 }
536}
537
538/// Compute contiguous strides for column-major (Fortran) layout.
539///
540/// For shape [m, n], returns strides [1, m] (first dimension is contiguous).
541pub fn compute_contiguous_strides(shape: &[usize]) -> Vec<usize> {
542 if shape.is_empty() {
543 return vec![];
544 }
545
546 let mut strides = vec![1; shape.len()];
547 for i in 1..shape.len() {
548 strides[i] = strides[i - 1] * shape[i - 1];
549 }
550 strides
551}
552
553impl<T: Scalar, B: Backend> std::fmt::Debug for Tensor<T, B> {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("Tensor")
556 .field("shape", &self.shape)
557 .field("strides", &self.strides)
558 .field("offset", &self.offset)
559 .field("contiguous", &self.is_contiguous())
560 .field("backend", &B::name())
561 .finish()
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::backend::Cpu;
569
570 #[test]
571 fn test_tensor_creation() {
572 // Column-major: data [1,2,3,4,5,6] for shape [2,3] represents:
573 // [[1, 3, 5],
574 // [2, 4, 6]]
575 // Strides for column-major [2, 3] are [1, 2]
576 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
577 assert_eq!(t.shape(), &[2, 3]);
578 assert_eq!(t.strides(), &[1, 2]); // Column-major strides
579 assert!(t.is_contiguous());
580 assert_eq!(t.numel(), 6);
581 }
582
583 #[test]
584 fn test_permute() {
585 // Column-major: data [1,2,3,4,5,6] for shape [2,3] represents:
586 // [[1, 3, 5],
587 // [2, 4, 6]]
588 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
589 let p = t.permute(&[1, 0]);
590
591 assert_eq!(p.shape(), &[3, 2]);
592 assert_eq!(p.strides(), &[2, 1]); // Permuted strides
593 assert!(!p.is_contiguous());
594
595 // After making contiguous, data should be transposed
596 // Transposed matrix in column-major:
597 // [[1, 2],
598 // [3, 4],
599 // [5, 6]] -> column-major data: [1, 3, 5, 2, 4, 6]
600 let c = p.contiguous();
601 assert!(c.is_contiguous());
602 assert_eq!(c.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
603 }
604
605 #[test]
606 fn test_reshape() {
607 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
608 let r = t.reshape(&[3, 2]);
609
610 assert_eq!(r.shape(), &[3, 2]);
611 assert!(r.is_contiguous());
612 assert_eq!(r.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
613 }
614
615 #[test]
616 fn test_permute_then_reshape() {
617 // Column-major: data [1,2,3,4,5,6] for shape [2,3]
618 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
619 let p = t.permute(&[1, 0]); // [3, 2], non-contiguous
620 let r = p.reshape(&[6]); // Must make contiguous first
621
622 assert_eq!(r.shape(), &[6]);
623 assert!(r.is_contiguous());
624 // Transposed and flattened in column-major: [1, 3, 5, 2, 4, 6]
625 assert_eq!(r.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
626 }
627
628 #[test]
629 fn test_sum() {
630 use crate::algebra::Standard;
631
632 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
633 let sum = t.sum::<Standard<f32>>();
634 assert_eq!(sum, 10.0);
635 }
636
637 #[test]
638 fn test_sum_axis() {
639 use crate::algebra::Standard;
640
641 // Column-major: data [1, 2, 3, 4] for shape [2, 2] represents:
642 // [[1, 3],
643 // [2, 4]]
644 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
645
646 // Sum over axis 1 (columns): [1+3, 2+4] = [4, 6]
647 let sum_cols = t.sum_axis::<Standard<f32>>(1);
648 assert_eq!(sum_cols.shape(), &[2]);
649 assert_eq!(sum_cols.to_vec(), vec![4.0, 6.0]);
650
651 // Sum over axis 0 (rows): [1+2, 3+4] = [3, 7]
652 let sum_rows = t.sum_axis::<Standard<f32>>(0);
653 assert_eq!(sum_rows.shape(), &[2]);
654 assert_eq!(sum_rows.to_vec(), vec![3.0, 7.0]);
655 }
656
657 #[test]
658 fn test_diagonal() {
659 // Matrix: [[1, 2], [3, 4]]
660 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
661 let diag = t.diagonal();
662
663 assert_eq!(diag.shape(), &[2]);
664 assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
665 }
666
667 #[test]
668 fn test_diagonal_3x3() {
669 // Matrix: [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
670 let t =
671 Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3]);
672 let diag = t.diagonal();
673
674 assert_eq!(diag.shape(), &[3]);
675 assert_eq!(diag.to_vec(), vec![1.0, 5.0, 9.0]);
676 }
677
678 #[test]
679 fn test_get() {
680 // Column-major: data [1, 2, 3, 4, 5, 6] for shape [2, 3] represents:
681 // [[1, 3, 5],
682 // [2, 4, 6]]
683 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
684
685 // Test accessing each element by linear index
686 assert_eq!(t.get(0), 1.0);
687 assert_eq!(t.get(1), 2.0);
688 assert_eq!(t.get(2), 3.0);
689 assert_eq!(t.get(3), 4.0);
690 assert_eq!(t.get(4), 5.0);
691 assert_eq!(t.get(5), 6.0);
692 }
693
694 #[test]
695 fn test_get_permuted() {
696 // Test that get works correctly on permuted (non-contiguous) tensors
697 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
698 let p = t.permute(&[1, 0]); // Shape becomes [3, 2], non-contiguous
699
700 // After transpose, column-major data should be [1, 3, 5, 2, 4, 6]
701 assert_eq!(p.get(0), 1.0);
702 assert_eq!(p.get(1), 3.0);
703 assert_eq!(p.get(2), 5.0);
704 assert_eq!(p.get(3), 2.0);
705 assert_eq!(p.get(4), 4.0);
706 assert_eq!(p.get(5), 6.0);
707 }
708
709 #[test]
710 #[should_panic(expected = "out of bounds")]
711 fn test_get_out_of_bounds() {
712 let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
713 let _ = t.get(4); // Index 4 is out of bounds for 4-element tensor
714 }
715
716 #[test]
717 fn test_get_3d_tensor() {
718 // Test get on a 3D tensor to ensure multi-dimensional indexing works
719 // Shape [2, 3, 2], 12 elements
720 let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
721 let t = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 2]);
722
723 // In column-major, elements are ordered by first dim varying fastest
724 for i in 0..12 {
725 assert_eq!(t.get(i), data[i]);
726 }
727 }
728}