PyTorch Change Tensor Type: Cast A PyTorch Tensor To Another Type

PyTorch change Tensor type - convert and change a PyTorch tensor to another type

PyTorch change Tensor type - convert and change a PyTorch tensor to another type

Video Transcript


We import PyTorch.

import torch


We check what PyTorch version we are using.

print(torch.__version__)

We are using 0.2.0_4.


We start by generating a PyTorch Tensor that’s 3x3x3 using the PyTorch random function.

x = torch.rand(3, 3, 3)


We can check the type of this variable by using the type functionality.

type(x)

We see that it is a FloatTensor.


To convert this FloatTensor to a double, define the variable double_x = x.double().

double_x = x.double()


We can check the type and we see that whereas before this PyTorch tensor was a FloatTensor, we now have a PyTorch tensor that is a DoubleTensor.

type(double_x)


We can convert it back.

We define a variable float_x and say double_x.float().

float_x = double_x.float()

And So we’re casting this DoubleTensor back to a floating tensor.


This time, we’ll print the floating PyTorch tensor.

print(float_x)


Next, we define a float_ten_x variable which is equal to float_x * 10.

float_ten_x = float_x * 10


We print this new variable.

print(float_ten_x)

If we scroll back up, we can see the first number was 0.6096 and now the first number is 6.0964.

So everything has been multiplied by 10 and we can see that it is a FloatTensor.

So by multiplying it by the integer 10, it didn’t change the fact that it was still a PyTorch FloatTensor.


Next, we’re going to define a variable int_ten_x and we’re going to cast our FloatTensor of float_ten_x into integers.

int_ten_x = float_ten_x.int()


We print this new variable and we see that it is indeed integers.

print(int_ten_x)

6, 2, 8 is the first row and here we can see that it was 6.09, 2.04, 8.3.

So 6, 2, and 8. So now, we have a PyTorch IntTensor.


The last thing we do is we cast this IntTensor back to a float.

So we define a variable float_after_int_ten_x = int_ten_x.float().

float_after_int_ten_x = int_ten_x.float()


And we print this new variable.

print(float_after_int_ten_x)

And we see that it is now a PyTorch FloatTensor.


The one thing to notice, however, is 6, 2, 8, when we cast or converted the IntTensor back to a FloatTensor, it had not saved anywhere what numbers were past the decimal points.

So when you’re casting or converting between PyTorch tensor types, it’s always important to remember what kind of precision you are losing when you are doing this type of conversion.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.