深度阅读

How to define a simple PyTorch model?

作者
作者
2023年08月22日
更新时间
14.67 分钟
阅读时间
0
阅读量

To define a simple PyTorch model, you can create a new class that inherits from torch.nn.Module. In this class, you define the layers of your model in the __init__ method and specify the forward computation of the model in the forward method. Here is an example of a simple network with two linear layers and a ReLU activation function:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In this code, we define a new class called MyModel that has two linear layers and a ReLU activation function. In the __init__ method, we define the layers of the model and specify their input and output sizes using the nn.Linear class. In the forward method, we specify how the input should be passed through the layers, using the nn.ReLU activation function after the first linear layer.

Once you have defined your model, you can create an instance of it and use it to make predictions on your data.

import torch

model = MyModel()
input_data = torch.randn(1, 10)
output = model(input_data)

This code creates an instance of MyModel, generates some random input data with a size of 1x10, and passes it through the model to get the model’s prediction.

相关标签

博客作者

热爱技术,乐于分享,持续学习。专注于Web开发、系统架构设计和人工智能领域。