How to define a complex PyTorch model?
Published on Aug. 22, 2023, 12:18 p.m.
Defining a complex PyTorch model typically involves defining multiple layers and composing them using a Sequential or a custom module. Here is an example of how to define a complex PyTorch model:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MyModel, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5)
)
self.layer2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5)
)
self.layer3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
In this example, MyModel is a custom PyTorch module that takes as input a tensor of shape (batch_size, input_dim) and outputs a tensor of shape (batch_size, output_dim). The model contains three layers: layer1 is a fully connected layer followed by batch normalization, ReLU activation, and dropout; layer2 is a second fully connected layer followed by batch normalization, ReLU activation, and dropout; and layer3 is a final fully connected layer.
The forward method specifies how the input tensor x is processed by each layer in sequence. Note that since we used nn.Sequential to define the layers, we can simply pass the input tensor x to each layer in succession using the dot . notation.
This is just one example of a complex PyTorch model - the structure, number of layers, and activation functions used will depend on the specific task and the nature of the data being used.