Get The Shape Of A PyTorch Tensor As A List Of Integers

Get the shape of a PyTorch Tensor as a list of integers by using the PyTorch Shape operation and the Python List constructor

Get the shape of a PyTorch Tensor as a list of integers by using the PyTorch Shape operation and the Python List constructor

Video Transcript


This video will show you how to get the shape of a PyTorch tensor as a list of integers by using the PyTorch shape operation and the Python list constructor.


First, we import PyTorch.

import torch


Then we print the PyTorch version we are using.

print(torch.__version__)

We are using PyTorch 0.4.0.


Let's now manually create a PyTorch tensor.

tensor_one = torch.tensor(
    [
        [
            [1,2,3],
            [4,5,6]
        ]
        ,
        [
            [7,8,9],
            [10,11,12]
        ]
    ]
)

So we use torch.tensor, and we create our data structure here.

We can see that it's 2x2x3.

We assign this to the Python variable tensor_one.


Let's print the tensor_one Python variable to see what we have.

print(tensor_one) 

Visually, we can see that all of our numbers are there, and that it looks like a 2x2x3 PyTorch tensor.


Let's next check the dimensions of this tensor using the PyTorch shape operation.

tensor_one.shape

So tensor_one.shape.

We see that it is of torch size 2x2x3.

Perfect.


So even though we can see that this thing is a torch.Size object, let's check to see the type of the object that the PyTorch shape operation returns.

type(tensor_one.shape)

So we use type and then we pass in our tensor_one.shape, and we see that this thing is of class torch.Size.


Because we want to get the shape of the tensor as a list of integers, let's use the Python list constructor.

tensor_shape_list = list(tensor_one.shape)

So we're going to pass tensor_one.shape, so this torch.Size object, to the Python list constructor.

The resulting list is going to be assigned to the Python variable tensor_shape_list.


Let's now print out the tensor_shape_list Python variable to see what it looks like.

print(tensor_shape_list)

We can see that it looks indeed like a list.

It has a square bracket, then 2, 2, 3, so 2, 2, 3.

It matches that.


Let's check the type of the thing that is returned by passing our torch.Size object to the list constructor.

type(tensor_shape_list)

So we check type of tensor_shape_list and we see that it is of class list.


Finally, let's check the type of elements of the list by looking at the first element.

type(tensor_shape_list[0])

So tensor_shape_list[0].

Remembering that Python is a zero-based index, so when we pass zero here, we're asking for the first element to the list.

And we're going to pass all of this, so we're going to pass the number 2 to the type operation and that gives us the class of int for integer.


Perfect! We were able to get the shape of a PyTorch tensor as a list of integers by using the PyTorch shape operation and the Python list constructor.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.