Infer Dimensions While Reshaping A PyTorch Tensor

Infer dimensions while reshaping a PyTorch tensor by using the PyTorch view operation

Infer dimensions while reshaping a PyTorch tensor by using the PyTorch view operation

Video Transcript


This video will show you how to infer dimensions while reshaping a PyTorch tensor by using the PyTorch view operation.


First, we import PyTorch.

import torch


Then we print the PyTorch version we are using.

print(torch.__version__)

We are using PyTorch 0.3.1.post2.


Let's now create a PyTorch tensor for our example.

pt_initial_tensor_ex = torch.Tensor(
[
    [
        [ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18]
    ]
    ,
    [
        [19, 20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29, 30],
        [31, 32, 33, 34, 35, 36]
    ]
])

We use torch.Tensor, we pass in our data structure which is 2x3x6, and we're going to assign this PyTorch tensor to the Python variable pt_initial_tensor_ex.


Let's print the pt_initial_tensor_ex Python variable to see what we have.

print(pt_initial_tensor_ex)

We see that it's a PyTorch FloatTensor of size 2x3x6, we have all our numbers from 1 to 36, inclusive, and we're going to use this tensor now to reshape it in a variety of ways and infer the shape.


For the first PyTorch tensor reshape with inferred dimension example, let's retain the rank of the tensor which is 3 but we're going to change it from a 2x3x6 to a 2x9 to an unknown.

pt_reshaped_2_by_9_by_x_tensor_ex = pt_initial_tensor_ex.view(2, 9, -1)

The way we let PyTorch know that we don't know what the last number should be is we pass in a -1.

Then when the tensor is reshaped, we are going to assign it to the Python variable pt_reshaped_2_by_9_by_x_tensor_ex.


Let's print the pt_reshaped_2_by_9_by_x_tensor_ex Python variable to see what we have.

print(pt_reshaped_2_by_9_by_x_tensor_ex)

We see that it's a PyTorch FloatTensor of size 2x9x2.

So we have two matrices which has nine rows and two columns, and we see that all our numbers are still there.

You can see that this last dimension was inferred because we passed in 2x9.

So 2 times 9 is 18.

Because we are reshaping and we have the same number of elements, 36 divided 18 is 2.

So PyTorch inferred that the last dimension was 2.


Then to double check the PyTorch view operation isn't doing an in-place reshaping of the original tensor, let's print our original tensor to see what we have.

print(pt_initial_tensor_ex)

We see that it's 2x3x6 and it's a PyTorch FloatTensor.

So our original tensor is still the same.


For the second PyTorch tensor reshape with inferred dimension example, let's decrease the rank of the tensor so that we go from 2x3x6 to a 2 by unknown number dimension.

pt_reshaped_2_by_x_tensor_ex = pt_initial_tensor_ex.view(2, -1)

This result is going to then be assigned to the Python variable pt_reshaped_2_by_x_tensor_ex.


Let's print the pt_reshaped_2_by_x_tensor_ex Python variable to see what we have.

print(pt_reshaped_2_by_x_tensor_ex)

We see that we have a rank two tensor.

So we went from a rank three tensor to a rank two tensor, and we see that PyTorch inferred the last dimension to be 18, which makes sense.

2 times 3 is 6 times 6 is 36.

36 divided 2 is 18.

So we have two rows and 18 columns and all our numbers are still there.


For the third and last PyTorch tensor reshape example, let's increase the rank of the tensor so that we go from 2x3x6 to 2x2x3 by an unknown number.

pt_reshaped_2_by_2_by_3_by_x_tensor_ex = pt_initial_tensor_ex.view(2, 2, 3, -1)

So we pass in our original tensor, we use the PyTorch view operation, we pass in the new dimensions, and then we use a -1 to signify that it's an unknown dimension and we want PyTorch to infer what that dimension should be.

When this expression is evaluated, we're going to assign the result to the Python variable pt_reshaped_2_by_2_by_3_by_x_tensor_ex.


Let's print the result to see what we have.

print(pt_reshaped_2_by_2_by_3_by_x_tensor_ex)

We see that it's a PyTorch FloatTensor of size 2x2x3x3, so now it's a rank four tensor.

We see that the numbers between 1 and 36, inclusive, are still there, and we see that PyTorch inferred the last dimension to be 3.

So 2 times 2 is 4.

4 times 3 is 12.

We know that we have 36 elements, so 36 divided 12 is 3.


Perfect! We were able to infer dimensions while reshaping a PyTorch tensor by using the PyTorch view operation.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.