File size: 1,557 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import random
import numpy as np

def shuffle_tensor_along_axis(tensor, axis=0, seed=None):
    """
    Shuffle a tensor along a specified axis without affecting the global random state.
    
    Args:
        tensor (torch.Tensor): The input tensor to shuffle
        axis (int, optional): The axis along which to shuffle. Defaults to 0.
        seed (int, optional): Random seed for reproducibility. Defaults to None.
        
    Returns:
        torch.Tensor: The shuffled tensor
    """
    # Clone the tensor to avoid in-place modifications
    shuffled_tensor = tensor.clone()
    
    # Store original random states
    torch_state = torch.get_rng_state()
    np_state = np.random.get_state()
    py_state = random.getstate()
    
    try:
        # Set seed if provided
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
        
        # Get the size of the dimension to shuffle
        dim_size = tensor.shape[axis]
        
        # Generate random indices for shuffling
        indices = torch.randperm(dim_size)
        
        # Create a slice object to shuffle along the specified axis
        slices = [slice(None)] * tensor.dim()
        slices[axis] = indices
        
        # Apply the shuffle
        shuffled_tensor = tensor[slices]
    
    finally:
        # Restore original random states
        torch.set_rng_state(torch_state)
        np.random.set_state(np_state)
        random.setstate(py_state)
        
    return shuffled_tensor