How to index intermediate dimension with an index tensor in pytorch?

How can I index a tensor t with n dimensions with an index tensor of m < n dimensions, such that the last dimensions of t are preserved? The index tensor is shaped equal to tensor t for all dimensions before dimension m. Or in other terms, I want to index intermediate dimensions of a tensor, while keeping all the following dimensions of the selected indices preserved.

For example, lets say we have the two tensors:

t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()

with t:

tensor([[[ 15.2165,  -7.9702],
         [  0.6646,   5.2844],
         [-22.0657,  -5.9876],
         [ -9.7319,  11.7384],
         [  4.3985,  -6.7058]],

        [[-15.6854, -11.9362],
         [ 11.3054,   3.3068],
         [ -4.7756,  -7.4524],
         [  5.0977, -17.3831],
         [  3.9152, -11.5047]],

        [[ -5.4265, -22.6456],
         [  1.6639,  10.1483],
         [ 13.2129,   3.7850],
         [  3.8543,  -4.3496],
         [ -8.7577, -12.9722]]])

Then the output I would like to have would have shape (3, 2, 2) and be:

tensor([[[  0.6646,   5.2844],
         [ -9.7319,  11.7384]],
        [[-15.6854, -11.9362],
         [  3.9152, -11.5047]],
        [[  3.8543,  -4.3496],
         [ 13.2129,   3.7850]]])

Another example would be that I have a tensor t of shape (40, 10, 6, 2) and an index tensor of shape (40, 10, 3). This should query dimension 3 of tensor t and the expected output shape would be (40, 10, 3, 2).

How can I achieve this in a generic way, without using loops?

1 answer

  • answered 2020-10-29 12:59 Berriel

    In this case, you can do something like this:

    t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
    

    Full code:

    import torch
    
    t = torch.tensor([[[ 15.2165,  -7.9702],
                       [  0.6646,   5.2844],
                       [-22.0657,  -5.9876],
                       [ -9.7319,  11.7384],
                       [  4.3985,  -6.7058]],
                      [[-15.6854, -11.9362],
                       [ 11.3054,   3.3068],
                       [ -4.7756,  -7.4524],
                       [  5.0977, -17.3831],
                       [  3.9152, -11.5047]],
                      [[ -5.4265, -22.6456],
                       [  1.6639,  10.1483],
                       [ 13.2129,   3.7850],
                       [  3.8543,  -4.3496],
                       [ -8.7577, -12.9722]]])
    
    index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
    
    output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
    
    # tensor([[[  0.6646,   5.2844],
    #          [ -9.7319,  11.7384]],
    # 
    #         [[-15.6854, -11.9362],
    #          [  3.9152, -11.5047]],
    # 
    #         [[  3.8543,  -4.3496],
    #          [ 13.2129,   3.7850]]])