How to index intermediate dimension with an index tensor in pytorch?
How can I index a tensor t
with n dimensions with an index
tensor of m < n dimensions, such that the last dimensions of t
are preserved? The index
tensor is shaped equal to tensor t
for all dimensions before dimension m. Or in other terms, I want to index intermediate dimensions of a tensor, while keeping all the following dimensions of the selected indices preserved.
For example, lets say we have the two tensors:
t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
with t:
tensor([[[ 15.2165, 7.9702],
[ 0.6646, 5.2844],
[22.0657, 5.9876],
[ 9.7319, 11.7384],
[ 4.3985, 6.7058]],
[[15.6854, 11.9362],
[ 11.3054, 3.3068],
[ 4.7756, 7.4524],
[ 5.0977, 17.3831],
[ 3.9152, 11.5047]],
[[ 5.4265, 22.6456],
[ 1.6639, 10.1483],
[ 13.2129, 3.7850],
[ 3.8543, 4.3496],
[ 8.7577, 12.9722]]])
Then the output I would like to have would have shape (3, 2, 2)
and be:
tensor([[[ 0.6646, 5.2844],
[ 9.7319, 11.7384]],
[[15.6854, 11.9362],
[ 3.9152, 11.5047]],
[[ 3.8543, 4.3496],
[ 13.2129, 3.7850]]])
Another example would be that I have a tensor t
of shape (40, 10, 6, 2)
and an index tensor of shape (40, 10, 3)
. This should query dimension 3 of tensor t
and the expected output shape would be (40, 10, 3, 2)
.
How can I achieve this in a generic way, without using loops?
1 answer

In this case, you can do something like this:
t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
Full code:
import torch t = torch.tensor([[[ 15.2165, 7.9702], [ 0.6646, 5.2844], [22.0657, 5.9876], [ 9.7319, 11.7384], [ 4.3985, 6.7058]], [[15.6854, 11.9362], [ 11.3054, 3.3068], [ 4.7756, 7.4524], [ 5.0977, 17.3831], [ 3.9152, 11.5047]], [[ 5.4265, 22.6456], [ 1.6639, 10.1483], [ 13.2129, 3.7850], [ 3.8543, 4.3496], [ 8.7577, 12.9722]]]) index = torch.tensor([[1, 3],[0,4],[3,2]]).long() output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...] # tensor([[[ 0.6646, 5.2844], # [ 9.7319, 11.7384]], # # [[15.6854, 11.9362], # [ 3.9152, 11.5047]], # # [[ 3.8543, 4.3496], # [ 13.2129, 3.7850]]])