Flatten A PyTorch Tensor

Flatten A PyTorch Tensor by using the PyTorch view operation

Flatten A PyTorch Tensor by using the PyTorch view operation

Video Transcript


This video will show you how to flatten a PyTorch tensor by using the PyTorch view operation.


First, we start by importing PyTorch.

import torch


Then we print the PyTorch version we are using.

print(torch.__version__)

We are using PyTorch 0.3.1.post2.


Let's now create an initial PyTorch tensor for our example.

pt_initial_tensor_ex = torch.Tensor(
[
    [
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8]
    ]
    ,
    [
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]
    ]
    ,
    [
        [17, 18, 19, 20],
        [21, 22, 23, 24]
    ]
])

We use torch.Tensor, and we pass in our data structure.

We can see that it has one, two, three matrices, and then each matrix has two rows and four columns.

The numbers go from 1 to 24.

We assign that result to the Python variable pt_initial_tensor_ex.


Let's print the pt_initial_tensor_ex Python variable to see what we have.

print(pt_initial_tensor_ex)

We see that it's a PyTorch FloatTensor of size 3x2x4, we see the three matrices and each matrix has two rows and four columns, and all the values are between 1 and 24, inclusive.


When we flatten this PyTorch tensor, we'd like to end up with a list of 24 elements that goes from 1 to 24.


To flatten our tensor, we're going to use the PyTorch view operation and the special case of negative number one.

pt_flattened_tensor_ex = pt_initial_tensor_ex.view(-1)

So when we say whatever our tensor is, .view(-1), that means we want to flatten it completely.

So when we pass in our Python variable pt_initial_tensor_ex then we say .view(-1), we're going to have a flattened tensor and we're going to assign that to the Python variable pt_flattened_tensor_ex.


Let's print out the pt_flattened_tensor_ex Python variable to see what we have.

print(pt_flattened_tensor_ex)

We see that it's a PyTorch FloatTensor of size 24, and we see that it has all our numbers, 1 all the way to 24.

So before, it was 3x2x4, now it's just size 24.


Just to double check that our original tensor didn't change, we're going to print our original tensor to make sure that the .view(-1) didn't do an in-place reshaping of the original tensor.

print(pt_initial_tensor_ex)

When we print it, we see that pt_initial_tensor_ex is still a 3x2x4 PyTorch FloatTensor that has internal matrices where each one has two rows and four columns and we see our original 24 numbers.

So it's still the same after the dot view operation.


Perfect! We were able to flatten a PyTorch tensor by using the PyTorch view operation and the negative one.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.