import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
train_dataset = datasets.MNIST('./data', train=True, download=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) # 正则化,降低模型复杂度,防止过拟合
test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
batch_size = 256
epochs = 10
learning_rate = 0.01
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
def plot_curve(data, name):
plt.plot(range(len(data)), data, color='red')
plt.legend([name], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
return x
net = Net()
opt = optim.SGD(net.parameters(), lr=learning_rate, enumerate=0.9) # 返回权值 w1,b1,w2,b2,w3,b3
train_loss = []
all_train_loss = []
num = 0
for epoch in range(epochs):
for batch_idx, (x, y) in enumerate(train_dataloader):
x = x.view(x.size(0), 28*28)
out = net(x).to(float)
y_onehot = F.one_hot(y)
y_onehot = y_onehot.to(float)
loss = F.mse_loss(out, y_onehot)
# loss = loss.to(torch.float)
opt.zero_grad() # 梯度清零
loss.backward() # 计算梯度
opt.step() # 更新梯度
all_train_loss.append(loss.item())
num += 1
if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
train_loss.append(loss.item() / num)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plot_curve(train_loss, 'train loss')
plt.subplot(1, 2, 2)
plot_curve(all_train_loss, 'all_train_loss')
plt.show()
# 模型评估
y_all = []
pred_all = []
total_correct = 0
for x, y in test_dataloader:
x = x.view(x.size(0), 28*28)
out = net(x).to(torch.float32)
pred = out.argmax(dim=1)
y_all.extend(y.numpy())
pred_all.extend(pred.numpy())
y_accuracy = accuracy_score(y_all, pred_all)
y_precision = precision_score(y_all, pred_all, average='weighted')
y_recall = recall_score(y_all, pred_all, average='weighted')
y_f1 = f1_score(y_all, pred_all, average='weighted')
print("[test] accuracy:{:.4f} precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(y_accuracy, y_precision, y_recall, y_f1))
print(classification_report(y_all, pred_all))
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')这段代码可以显示图片结果
综上,可以尝试改变网络层数,激活函数,计算loss的方法等设计自己的网络。
完