Add A New Dimension To The Middle Of A Tensor In PyTorch

Add a new dimension to the middle of a PyTorch tensor by using None-style indexing

Add a new dimension to the middle of a PyTorch tensor by using None-style indexing

Video Transcript


This video will show you how to add a new dimension to the middle of a PyTorch tensor by using None style indexing.


First, we import PyTorch.

import torch


Then we print the PyTorch version we are using.

print(torch.__version__)

We are using PyTorch 0.4.0.


To start off with, let’s create an empty PyTorch tensor of size 2x4x6x8 using the PyTorch Tensor operation, and we’re going to assign the uninitialized tensor to the Python variable pt_empty_tensor_ex.

pt_empty_tensor_ex = torch.Tensor(2,4,6,8)


Let’s check what dimensions our pt_empty_tensor_ex Python variable has by using the PyTorch size operation.

print(pt_empty_tensor_ex.size())

We see that it is a 2x4x6x8 tensor which is how we defined it.


For this example, we want to add a new dimension to the middle of the PyTorch tensor.

So we want to go from 2x4x6x8 to adding a new dimension between the 4 and the 6.


The way we’ll do this is we will use None style indexing.

pt_extend_middle_tensor_ex = pt_empty_tensor_ex[:,:,None,:]

So we use None.

Here, we have a capital N.

This is going to tell PyTorch that we want a new axis for the tensor assigned to the pt_empty_tensor_ex Python variable.

We also use the Python colon notation.

So for the first index, we use a colon to specify that we want everything in the already existing first dimension.

Then we use a colon as the second index to specify that we want everything in the already existing second axis.

Then we use None to specify we want to insert a new axis, then comma, then a colon as the last index to specify that we want the rest of the tensor.

We assign this new tensor that’s going to be returned to the Python variable pt_extend_middle_tensor_ex.


It’s useful to check the size of the pt_extend_middle_tensor_ex.

print(pt_extend_middle_tensor_ex.size())

So we use the PyTorch size, and we’re going to print it.

What we see is that the torch size is now 2x4x1x6x8, whereas before, it was 2x4x6x8.

So we were able to insert a new dimension in the middle of the PyTorch tensor.


Perfect - So we were able to add a new dimension to the middle of a PyTorch tensor by using None style indexing.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.