
通过PyTorch构建LeNet网络,并使用MNIST数据集进行训练和测试。
5星
- 浏览量: 0
- 大小:None
- 文件类型:None
简介:
近期我正专注于pytorch的学习,并着手手动重构LeNet网络模型,同时为了方便理解和进一步研究,我提供了该网络的源代码,具体内容如下:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2) # 卷积层1
self.relu1 = nn.ReLU() # ReLU激活函数1
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 池化层1
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=2) # 卷积层2
self.relu2 = nn.ReLU() # ReLU激活函数2
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层1
self.relu3 = nn.ReLU() # ReLU激活函数3
self.fc2 = nn.Linear(120, 84) # 全连接层2
self.relu4 = nn.ReLU() # ReLU激活函数4
self.fc3 = nn.Linear(84, 10) # 全连接层3 (输出层)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x))) # 池化和卷积操作
x = x.view(-1, 16 * 5 * 5) # 将数据展平为适合全连接层的格式
x = self.relu3(self.fc1(x)) # 全连接层操作
x = self.relu4(self.fc2(x)) # 全连接层操作
x = self.fc3(x) # 输出层操作
return x
```
全部评论 (0)


