How to Subclass The nn.Module Class in PyTorch

Construct A Custom PyTorch Model by creating your own custom PyTorch module by subclassing the PyTorch nn.Module class

Construct A Custom PyTorch Model by creating your own custom PyTorch module by subclassing the PyTorch nn.Module class

Video Transcript


The recommended method of constructing a custom model in PyTorch is to defind your own subclass of the PyTorch module class.

In order to do this, a bit of knowledge of Python classes is necessary.


For this demonstration, we will need to import torch.nn as nn

import torch.nn as nn


and obtain the Ordered Dictionary from the collections library.

from collections import OrderedDict


For our class, first we name it convolutional and ensure that it is a subclass of the nn.Module class.

class Convolutional(nn.Module):


Next, we define the init method.

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):


Here, we initialize our super class,

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super(Convolutional, self).__init__()


define our sequential containers,

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super(Convolutional, self).__init__()
        self.layer1 = nn.Sequential()


and fill those containers with our convolutional and rectified linear unit layers as usual.

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super(Convolutional, self).__init__()
        self.layer1 = nn.Sequential()
        self.layer1.add_module("Conv1", nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=3, padding=1))
        self.layer1.add_module("Relu1", nn.ReLU(inplace=False))

The first sequential container defines the first layer of our convolutional neural network.

In this case, a 2D convolutional layer and a ReLU layer are being considered as one layer.


We can also define our second layer in an equivalent but slightly cleaner way by passing an Ordered Dictionary from the collections library to the sequential container when we initialize.

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super(Convolutional, self).__init__()
        self.layer1 = nn.Sequential()
        self.layer1.add_module("Conv1", nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=3, padding=1))
        self.layer1.add_module("Relu1", nn.ReLU(inplace=False))

The reason we pass an Ordered Dictionary instead of a simple list is it allows us to name our layers which helps a lot if debugging of the network is required in the future.

As a reminder, the input channels argument of the second layer needs to match the output channels of the first layer or an error will be raised.


Next, we need to define the forward method of our subclass.

class Convolutional(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super(Convolutional, self).__init__()
        self.layer1 = nn.Sequential()
        self.layer1.add_module("Conv1", nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=3, padding=1))
        self.layer1.add_module("Relu1", nn.ReLU(inplace=False))
        self.layer2 = nn.Sequential(OrderedDict([
            ('Conv2', nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)),
            ('Reul2', nn.ReLU(inplace=False))
        ]))
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

The forward method takes a single argument, x, which starts as simply the input data.

We then define what happens to x as it passes through our network.

In our case, we simply apply a layer1 and layer2 to x.


This network is still not fully functional as it requires a reshaping step, an output layer, and could do with some max pooling in the convolutional layers.


However, defining our network in this way makes these steps much easier to add.

Receive the Data Science Weekly Newsletter every Thursday

Easy to unsubscribe at any time. Your e-mail address is safe.