# How to index intermediate dimension with an index tensor in pytorch?

How can I index a tensor `t` with n dimensions with an `index` tensor of m < n dimensions, such that the last dimensions of `t` are preserved? The `index` tensor is shaped equal to tensor `t` for all dimensions before dimension m. Or in other terms, I want to index intermediate dimensions of a tensor, while keeping all the following dimensions of the selected indices preserved.

For example, lets say we have the two tensors:

``````t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
``````

with t:

``````tensor([[[ 15.2165,  -7.9702],
[  0.6646,   5.2844],
[-22.0657,  -5.9876],
[ -9.7319,  11.7384],
[  4.3985,  -6.7058]],

[[-15.6854, -11.9362],
[ 11.3054,   3.3068],
[ -4.7756,  -7.4524],
[  5.0977, -17.3831],
[  3.9152, -11.5047]],

[[ -5.4265, -22.6456],
[  1.6639,  10.1483],
[ 13.2129,   3.7850],
[  3.8543,  -4.3496],
[ -8.7577, -12.9722]]])
``````

Then the output I would like to have would have shape `(3, 2, 2)` and be:

``````tensor([[[  0.6646,   5.2844],
[ -9.7319,  11.7384]],
[[-15.6854, -11.9362],
[  3.9152, -11.5047]],
[[  3.8543,  -4.3496],
[ 13.2129,   3.7850]]])
``````

Another example would be that I have a tensor `t` of shape `(40, 10, 6, 2)` and an index tensor of shape `(40, 10, 3)`. This should query dimension 3 of tensor `t` and the expected output shape would be `(40, 10, 3, 2)`.

How can I achieve this in a generic way, without using loops?

In this case, you can do something like this:

``````t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
``````

Full code:

``````import torch

t = torch.tensor([[[ 15.2165,  -7.9702],
[  0.6646,   5.2844],
[-22.0657,  -5.9876],
[ -9.7319,  11.7384],
[  4.3985,  -6.7058]],
[[-15.6854, -11.9362],
[ 11.3054,   3.3068],
[ -4.7756,  -7.4524],
[  5.0977, -17.3831],
[  3.9152, -11.5047]],
[[ -5.4265, -22.6456],
[  1.6639,  10.1483],
[ 13.2129,   3.7850],
[  3.8543,  -4.3496],
[ -8.7577, -12.9722]]])

index = torch.tensor([[1, 3],[0,4],[3,2]]).long()

output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]

# tensor([[[  0.6646,   5.2844],
#          [ -9.7319,  11.7384]],
#
#         [[-15.6854, -11.9362],
#          [  3.9152, -11.5047]],
#
#         [[  3.8543,  -4.3496],
#          [ 13.2129,   3.7850]]])
``````