omeinsum/tensor/
mod.rs

1//! Stride-based tensor type with zero-copy views.
2//!
3//! The [`Tensor`] type supports:
4//! - Zero-copy `permute` and `reshape` operations
5//! - Automatic contiguous copy when needed for GEMM
6//! - Generic over algebra and backend
7
8mod ops;
9mod view;
10
11use std::sync::Arc;
12
13use crate::algebra::{Algebra, Scalar};
14use crate::backend::{Backend, Storage};
15
16pub use view::TensorView;
17
18/// A multi-dimensional tensor with stride-based layout.
19///
20/// Tensors support zero-copy view operations (permute, reshape) and
21/// automatically make data contiguous when needed for operations like GEMM.
22///
23/// # Type Parameters
24///
25/// * `T` - The scalar element type (f32, f64, etc.)
26/// * `B` - The backend type (Cpu, Cuda)
27///
28/// # Example
29///
30/// ```rust
31/// use omeinsum::{Tensor, Cpu};
32///
33/// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
34/// let b = a.permute(&[1, 0]);  // Zero-copy transpose
35/// let c = b.contiguous();      // Make contiguous copy
36/// ```
37#[derive(Clone)]
38pub struct Tensor<T: Scalar, B: Backend> {
39    /// Shared storage (reference counted)
40    storage: Arc<B::Storage<T>>,
41
42    /// Shape of this view
43    shape: Vec<usize>,
44
45    /// Strides for each dimension (in elements)
46    strides: Vec<usize>,
47
48    /// Offset into storage
49    offset: usize,
50
51    /// Backend instance
52    backend: B,
53}
54
55impl<T: Scalar, B: Backend> Tensor<T, B> {
56    // ========================================================================
57    // Constructors
58    // ========================================================================
59
60    /// Create a tensor from data with the given shape.
61    ///
62    /// Data is assumed to be in column-major (Fortran) order.
63    pub fn from_data(data: &[T], shape: &[usize]) -> Self
64    where
65        B: Default,
66    {
67        Self::from_data_with_backend(data, shape, B::default())
68    }
69
70    /// Create a tensor from data with explicit backend.
71    pub fn from_data_with_backend(data: &[T], shape: &[usize], backend: B) -> Self {
72        let numel: usize = shape.iter().product();
73        assert_eq!(
74            data.len(),
75            numel,
76            "Data length {} doesn't match shape {:?} (expected {})",
77            data.len(),
78            shape,
79            numel
80        );
81
82        let storage = backend.from_slice(data);
83        let strides = compute_contiguous_strides(shape);
84
85        Self {
86            storage: Arc::new(storage),
87            shape: shape.to_vec(),
88            strides,
89            offset: 0,
90            backend,
91        }
92    }
93
94    /// Create a zero-filled tensor.
95    pub fn zeros(shape: &[usize]) -> Self
96    where
97        B: Default,
98    {
99        Self::zeros_with_backend(shape, B::default())
100    }
101
102    /// Create a zero-filled tensor with explicit backend.
103    pub fn zeros_with_backend(shape: &[usize], backend: B) -> Self {
104        let numel: usize = shape.iter().product();
105        let storage = backend.alloc(numel);
106        let strides = compute_contiguous_strides(shape);
107
108        Self {
109            storage: Arc::new(storage),
110            shape: shape.to_vec(),
111            strides,
112            offset: 0,
113            backend,
114        }
115    }
116
117    /// Create a tensor from storage with given shape.
118    ///
119    /// The storage must be contiguous and have exactly `shape.iter().product()` elements.
120    pub fn from_storage(storage: B::Storage<T>, shape: &[usize], backend: B) -> Self {
121        let numel: usize = shape.iter().product();
122        assert_eq!(
123            storage.len(),
124            numel,
125            "Storage length {} doesn't match shape {:?} (expected {})",
126            storage.len(),
127            shape,
128            numel
129        );
130
131        let strides = compute_contiguous_strides(shape);
132
133        Self {
134            storage: Arc::new(storage),
135            shape: shape.to_vec(),
136            strides,
137            offset: 0,
138            backend,
139        }
140    }
141
142    /// Get a reference to the underlying storage.
143    ///
144    /// Returns `Some(&storage)` only if the tensor is contiguous and has no offset.
145    /// For non-contiguous tensors, call `contiguous()` first.
146    pub fn storage(&self) -> Option<&B::Storage<T>> {
147        if self.is_contiguous() {
148            Some(self.storage.as_ref())
149        } else {
150            None
151        }
152    }
153
154    // ========================================================================
155    // Metadata
156    // ========================================================================
157
158    /// Get the shape of the tensor.
159    #[inline]
160    pub fn shape(&self) -> &[usize] {
161        &self.shape
162    }
163
164    /// Get the strides of the tensor.
165    #[inline]
166    pub fn strides(&self) -> &[usize] {
167        &self.strides
168    }
169
170    /// Get the number of dimensions.
171    #[inline]
172    pub fn ndim(&self) -> usize {
173        self.shape.len()
174    }
175
176    /// Get the total number of elements.
177    #[inline]
178    pub fn numel(&self) -> usize {
179        self.shape.iter().product()
180    }
181
182    /// Get the backend.
183    #[inline]
184    pub fn backend(&self) -> &B {
185        &self.backend
186    }
187
188    /// Check if the tensor is contiguous in memory (row-major).
189    pub fn is_contiguous(&self) -> bool {
190        if self.offset != 0 {
191            return false;
192        }
193        let expected = compute_contiguous_strides(&self.shape);
194        self.strides == expected
195    }
196
197    // ========================================================================
198    // Data Access
199    // ========================================================================
200
201    /// Copy all data to a Vec.
202    pub fn to_vec(&self) -> Vec<T> {
203        if self.is_contiguous() {
204            self.storage.to_vec()
205        } else {
206            self.contiguous().storage.to_vec()
207        }
208    }
209
210    /// Get underlying storage (only if contiguous).
211    pub fn as_slice(&self) -> Option<&[T]>
212    where
213        B::Storage<T>: AsRef<[T]>,
214    {
215        if self.is_contiguous() {
216            Some(self.storage.as_ref().as_ref())
217        } else {
218            None
219        }
220    }
221
222    /// Get element at linear index (column-major).
223    ///
224    /// This is an O(ndim) operation that directly accesses storage without
225    /// allocating memory. The linear index is interpreted in column-major order.
226    ///
227    /// # Arguments
228    ///
229    /// * `index` - Linear index into the flattened tensor (column-major order)
230    ///
231    /// # Panics
232    ///
233    /// Panics if index is out of bounds.
234    ///
235    /// # Example
236    ///
237    /// ```rust
238    /// use omeinsum::{Tensor, Cpu};
239    ///
240    /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
241    /// assert_eq!(t.get(0), 1.0);
242    /// assert_eq!(t.get(3), 4.0);
243    /// ```
244    pub fn get(&self, index: usize) -> T {
245        let numel = self.numel();
246        assert!(
247            index < numel,
248            "Index {} out of bounds for tensor with {} elements (shape {:?})",
249            index,
250            numel,
251            self.shape
252        );
253
254        // Convert linear index to multi-dimensional coordinates (column-major)
255        // Column-major: first dimension varies fastest
256        let mut remaining = index;
257        let mut storage_offset = self.offset;
258
259        for dim in 0..self.ndim() {
260            let coord = remaining % self.shape[dim];
261            remaining /= self.shape[dim];
262            storage_offset += coord * self.strides[dim];
263        }
264
265        self.storage.get(storage_offset)
266    }
267
268    // ========================================================================
269    // View Operations (zero-copy)
270    // ========================================================================
271
272    /// Permute dimensions (zero-copy).
273    ///
274    /// # Example
275    ///
276    /// ```rust
277    /// use omeinsum::{Tensor, Cpu};
278    ///
279    /// let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
280    /// let a = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 4]);
281    /// let b = a.permute(&[2, 0, 1]);  // Shape becomes [4, 2, 3]
282    /// assert_eq!(b.shape(), &[4, 2, 3]);
283    /// ```
284    pub fn permute(&self, axes: &[usize]) -> Self {
285        assert_eq!(
286            axes.len(),
287            self.ndim(),
288            "Permutation axes length {} doesn't match ndim {}",
289            axes.len(),
290            self.ndim()
291        );
292
293        // Check axes are valid and unique
294        let mut seen = vec![false; self.ndim()];
295        for &ax in axes {
296            assert!(
297                ax < self.ndim(),
298                "Axis {} out of range for ndim {}",
299                ax,
300                self.ndim()
301            );
302            assert!(!seen[ax], "Duplicate axis {} in permutation", ax);
303            seen[ax] = true;
304        }
305
306        let new_shape: Vec<usize> = axes.iter().map(|&i| self.shape[i]).collect();
307        let new_strides: Vec<usize> = axes.iter().map(|&i| self.strides[i]).collect();
308
309        Self {
310            storage: Arc::clone(&self.storage),
311            shape: new_shape,
312            strides: new_strides,
313            offset: self.offset,
314            backend: self.backend.clone(),
315        }
316    }
317
318    /// Transpose (2D shorthand for permute).
319    pub fn t(&self) -> Self {
320        assert_eq!(
321            self.ndim(),
322            2,
323            "transpose requires 2D tensor, got {}D",
324            self.ndim()
325        );
326        self.permute(&[1, 0])
327    }
328
329    /// Reshape to a new shape (zero-copy if contiguous).
330    ///
331    /// # Example
332    ///
333    /// ```rust
334    /// use omeinsum::{Tensor, Cpu};
335    ///
336    /// let a = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
337    /// let b = a.reshape(&[6]);      // Flatten
338    /// let c = a.reshape(&[3, 2]);   // Different shape, same data
339    /// assert_eq!(b.shape(), &[6]);
340    /// assert_eq!(c.shape(), &[3, 2]);
341    /// ```
342    pub fn reshape(&self, new_shape: &[usize]) -> Self {
343        let old_numel: usize = self.shape.iter().product();
344        let new_numel: usize = new_shape.iter().product();
345        assert_eq!(
346            old_numel, new_numel,
347            "Cannot reshape from {:?} ({} elements) to {:?} ({} elements)",
348            self.shape, old_numel, new_shape, new_numel
349        );
350
351        if self.is_contiguous() {
352            // Fast path: just update shape and strides
353            Self {
354                storage: Arc::clone(&self.storage),
355                shape: new_shape.to_vec(),
356                strides: compute_contiguous_strides(new_shape),
357                offset: self.offset,
358                backend: self.backend.clone(),
359            }
360        } else {
361            // Must make contiguous first
362            self.contiguous().reshape(new_shape)
363        }
364    }
365
366    /// Make tensor contiguous in memory.
367    ///
368    /// If already contiguous, returns a clone (shared storage).
369    /// Otherwise, copies data to a new contiguous buffer.
370    pub fn contiguous(&self) -> Self {
371        if self.is_contiguous() {
372            self.clone()
373        } else {
374            let storage =
375                self.backend
376                    .copy_strided(&self.storage, &self.shape, &self.strides, self.offset);
377            Self {
378                storage: Arc::new(storage),
379                shape: self.shape.clone(),
380                strides: compute_contiguous_strides(&self.shape),
381                offset: 0,
382                backend: self.backend.clone(),
383            }
384        }
385    }
386
387    // ========================================================================
388    // Reduction Operations
389    // ========================================================================
390
391    /// Sum all elements using the algebra's addition.
392    ///
393    /// # Type Parameters
394    ///
395    /// * `A` - The algebra to use for summation
396    ///
397    /// # Example
398    ///
399    /// ```rust
400    /// use omeinsum::{Tensor, Cpu, Standard};
401    ///
402    /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
403    /// let sum = t.sum::<Standard<f32>>();
404    /// assert_eq!(sum, 10.0);
405    /// ```
406    pub fn sum<A: Algebra<Scalar = T>>(&self) -> T {
407        let data = self.to_vec();
408        let mut acc = A::zero();
409        for val in data {
410            acc = acc.add(A::from_scalar(val));
411        }
412        acc.to_scalar()
413    }
414
415    /// Sum along a specific axis using the algebra's addition.
416    ///
417    /// The result has one fewer dimension than the input.
418    ///
419    /// # Arguments
420    ///
421    /// * `axis` - The axis to sum over
422    ///
423    /// # Panics
424    ///
425    /// Panics if axis is out of bounds.
426    ///
427    /// # Example
428    ///
429    /// ```rust
430    /// use omeinsum::{Tensor, Cpu, Standard};
431    ///
432    /// // Column-major: data [1, 2, 3, 4] with shape [2, 2] represents:
433    /// // [[1, 3],
434    /// //  [2, 4]]
435    /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
436    /// // Sum over axis 1 (columns): [1+3, 2+4] = [4, 6]
437    /// let result = t.sum_axis::<Standard<f32>>(1);
438    /// assert_eq!(result.to_vec(), vec![4.0, 6.0]);
439    /// ```
440    pub fn sum_axis<A: Algebra<Scalar = T>>(&self, axis: usize) -> Self
441    where
442        B: Default,
443    {
444        assert!(
445            axis < self.ndim(),
446            "Axis {} out of bounds for {}D tensor",
447            axis,
448            self.ndim()
449        );
450
451        let mut new_shape: Vec<usize> = self.shape.clone();
452        new_shape.remove(axis);
453
454        // Handle reduction to scalar
455        if new_shape.is_empty() {
456            let sum = self.sum::<A>();
457            return Self::from_data(&[sum], &[1]);
458        }
459
460        let data = self.to_vec();
461        let output_strides = compute_contiguous_strides(&new_shape);
462        let _axis_size = self.shape[axis];
463
464        // Compute output size
465        let output_numel: usize = new_shape.iter().product();
466        let mut result = vec![A::zero(); output_numel];
467
468        // Iterate over all elements in the input
469        for (flat_idx, &val) in data.iter().enumerate() {
470            // Convert flat index to multi-dimensional coordinates (column-major)
471            let mut coords: Vec<usize> = vec![0; self.ndim()];
472            let mut remaining = flat_idx;
473            for (dim, coord) in coords.iter_mut().enumerate() {
474                *coord = remaining % self.shape[dim];
475                remaining /= self.shape[dim];
476            }
477
478            // Build output coordinates by removing the summed axis
479            let out_coords: Vec<usize> = coords
480                .iter()
481                .enumerate()
482                .filter(|(i, _)| *i != axis)
483                .map(|(_, &c)| c)
484                .collect();
485
486            // Convert output coordinates to flat index (column-major)
487            let mut out_flat_idx = 0;
488            for (i, &coord) in out_coords.iter().enumerate() {
489                out_flat_idx += coord * output_strides[i];
490            }
491
492            result[out_flat_idx] = result[out_flat_idx].add(A::from_scalar(val));
493        }
494
495        let result_data: Vec<T> = result.into_iter().map(|v| v.to_scalar()).collect();
496        Self::from_data(&result_data, &new_shape)
497    }
498
499    /// Extract diagonal elements from a 2D tensor.
500    ///
501    /// # Panics
502    ///
503    /// Panics if the tensor is not 2D or not square.
504    ///
505    /// # Example
506    ///
507    /// ```rust
508    /// use omeinsum::{Tensor, Cpu};
509    ///
510    /// let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
511    /// let diag = t.diagonal();
512    /// assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
513    /// ```
514    pub fn diagonal(&self) -> Self
515    where
516        B: Default,
517    {
518        assert_eq!(
519            self.ndim(),
520            2,
521            "diagonal requires 2D tensor, got {}D",
522            self.ndim()
523        );
524        assert_eq!(
525            self.shape[0], self.shape[1],
526            "diagonal requires square tensor, got {:?}",
527            self.shape
528        );
529
530        let n = self.shape[0];
531        let data = self.to_vec();
532        let diag: Vec<T> = (0..n).map(|i| data[i * n + i]).collect();
533
534        Self::from_data(&diag, &[n])
535    }
536}
537
538/// Compute contiguous strides for column-major (Fortran) layout.
539///
540/// For shape [m, n], returns strides [1, m] (first dimension is contiguous).
541pub fn compute_contiguous_strides(shape: &[usize]) -> Vec<usize> {
542    if shape.is_empty() {
543        return vec![];
544    }
545
546    let mut strides = vec![1; shape.len()];
547    for i in 1..shape.len() {
548        strides[i] = strides[i - 1] * shape[i - 1];
549    }
550    strides
551}
552
553impl<T: Scalar, B: Backend> std::fmt::Debug for Tensor<T, B> {
554    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        f.debug_struct("Tensor")
556            .field("shape", &self.shape)
557            .field("strides", &self.strides)
558            .field("offset", &self.offset)
559            .field("contiguous", &self.is_contiguous())
560            .field("backend", &B::name())
561            .finish()
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use crate::backend::Cpu;
569
570    #[test]
571    fn test_tensor_creation() {
572        // Column-major: data [1,2,3,4,5,6] for shape [2,3] represents:
573        // [[1, 3, 5],
574        //  [2, 4, 6]]
575        // Strides for column-major [2, 3] are [1, 2]
576        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
577        assert_eq!(t.shape(), &[2, 3]);
578        assert_eq!(t.strides(), &[1, 2]); // Column-major strides
579        assert!(t.is_contiguous());
580        assert_eq!(t.numel(), 6);
581    }
582
583    #[test]
584    fn test_permute() {
585        // Column-major: data [1,2,3,4,5,6] for shape [2,3] represents:
586        // [[1, 3, 5],
587        //  [2, 4, 6]]
588        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
589        let p = t.permute(&[1, 0]);
590
591        assert_eq!(p.shape(), &[3, 2]);
592        assert_eq!(p.strides(), &[2, 1]); // Permuted strides
593        assert!(!p.is_contiguous());
594
595        // After making contiguous, data should be transposed
596        // Transposed matrix in column-major:
597        // [[1, 2],
598        //  [3, 4],
599        //  [5, 6]] -> column-major data: [1, 3, 5, 2, 4, 6]
600        let c = p.contiguous();
601        assert!(c.is_contiguous());
602        assert_eq!(c.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
603    }
604
605    #[test]
606    fn test_reshape() {
607        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
608        let r = t.reshape(&[3, 2]);
609
610        assert_eq!(r.shape(), &[3, 2]);
611        assert!(r.is_contiguous());
612        assert_eq!(r.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
613    }
614
615    #[test]
616    fn test_permute_then_reshape() {
617        // Column-major: data [1,2,3,4,5,6] for shape [2,3]
618        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
619        let p = t.permute(&[1, 0]); // [3, 2], non-contiguous
620        let r = p.reshape(&[6]); // Must make contiguous first
621
622        assert_eq!(r.shape(), &[6]);
623        assert!(r.is_contiguous());
624        // Transposed and flattened in column-major: [1, 3, 5, 2, 4, 6]
625        assert_eq!(r.to_vec(), vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
626    }
627
628    #[test]
629    fn test_sum() {
630        use crate::algebra::Standard;
631
632        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
633        let sum = t.sum::<Standard<f32>>();
634        assert_eq!(sum, 10.0);
635    }
636
637    #[test]
638    fn test_sum_axis() {
639        use crate::algebra::Standard;
640
641        // Column-major: data [1, 2, 3, 4] for shape [2, 2] represents:
642        // [[1, 3],
643        //  [2, 4]]
644        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
645
646        // Sum over axis 1 (columns): [1+3, 2+4] = [4, 6]
647        let sum_cols = t.sum_axis::<Standard<f32>>(1);
648        assert_eq!(sum_cols.shape(), &[2]);
649        assert_eq!(sum_cols.to_vec(), vec![4.0, 6.0]);
650
651        // Sum over axis 0 (rows): [1+2, 3+4] = [3, 7]
652        let sum_rows = t.sum_axis::<Standard<f32>>(0);
653        assert_eq!(sum_rows.shape(), &[2]);
654        assert_eq!(sum_rows.to_vec(), vec![3.0, 7.0]);
655    }
656
657    #[test]
658    fn test_diagonal() {
659        // Matrix: [[1, 2], [3, 4]]
660        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
661        let diag = t.diagonal();
662
663        assert_eq!(diag.shape(), &[2]);
664        assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
665    }
666
667    #[test]
668    fn test_diagonal_3x3() {
669        // Matrix: [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
670        let t =
671            Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3]);
672        let diag = t.diagonal();
673
674        assert_eq!(diag.shape(), &[3]);
675        assert_eq!(diag.to_vec(), vec![1.0, 5.0, 9.0]);
676    }
677
678    #[test]
679    fn test_get() {
680        // Column-major: data [1, 2, 3, 4, 5, 6] for shape [2, 3] represents:
681        // [[1, 3, 5],
682        //  [2, 4, 6]]
683        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
684
685        // Test accessing each element by linear index
686        assert_eq!(t.get(0), 1.0);
687        assert_eq!(t.get(1), 2.0);
688        assert_eq!(t.get(2), 3.0);
689        assert_eq!(t.get(3), 4.0);
690        assert_eq!(t.get(4), 5.0);
691        assert_eq!(t.get(5), 6.0);
692    }
693
694    #[test]
695    fn test_get_permuted() {
696        // Test that get works correctly on permuted (non-contiguous) tensors
697        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
698        let p = t.permute(&[1, 0]); // Shape becomes [3, 2], non-contiguous
699
700        // After transpose, column-major data should be [1, 3, 5, 2, 4, 6]
701        assert_eq!(p.get(0), 1.0);
702        assert_eq!(p.get(1), 3.0);
703        assert_eq!(p.get(2), 5.0);
704        assert_eq!(p.get(3), 2.0);
705        assert_eq!(p.get(4), 4.0);
706        assert_eq!(p.get(5), 6.0);
707    }
708
709    #[test]
710    #[should_panic(expected = "out of bounds")]
711    fn test_get_out_of_bounds() {
712        let t = Tensor::<f32, Cpu>::from_data(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
713        let _ = t.get(4); // Index 4 is out of bounds for 4-element tensor
714    }
715
716    #[test]
717    fn test_get_3d_tensor() {
718        // Test get on a 3D tensor to ensure multi-dimensional indexing works
719        // Shape [2, 3, 2], 12 elements
720        let data: Vec<f32> = (1..=12).map(|x| x as f32).collect();
721        let t = Tensor::<f32, Cpu>::from_data(&data, &[2, 3, 2]);
722
723        // In column-major, elements are ordered by first dim varying fastest
724        for i in 0..12 {
725            assert_eq!(t.get(i), data[i]);
726        }
727    }
728}