import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report


train_dataset = datasets.MNIST('./data', train=True, download=False, transform=transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))  # 正则化,降低模型复杂度,防止过拟合
test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

batch_size = 256
epochs = 10
learning_rate = 0.01

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


def plot_curve(data, name):
    plt.plot(range(len(data)), data, color='red')
    plt.legend([name], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)

        return x


net = Net()
opt = optim.SGD(net.parameters(), lr=learning_rate, enumerate=0.9)  # 返回权值 w1,b1,w2,b2,w3,b3
train_loss = []
all_train_loss = []
num = 0
for epoch in range(epochs):
    for batch_idx, (x, y) in enumerate(train_dataloader):
        x = x.view(x.size(0), 28*28)
        out = net(x).to(float)
        y_onehot = F.one_hot(y)
        y_onehot = y_onehot.to(float)
        loss = F.mse_loss(out, y_onehot)

        # loss = loss.to(torch.float)
        opt.zero_grad()  # 梯度清零
        loss.backward()  # 计算梯度
        opt.step()  # 更新梯度

        all_train_loss.append(loss.item())
        num += 1
        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

    train_loss.append(loss.item() / num)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plot_curve(train_loss, 'train loss')
plt.subplot(1, 2, 2)
plot_curve(all_train_loss, 'all_train_loss')
plt.show()

# 模型评估
y_all = []
pred_all = []
total_correct = 0
for x, y in test_dataloader:
    x = x.view(x.size(0), 28*28)
    out = net(x).to(torch.float32)
    pred = out.argmax(dim=1)
    y_all.extend(y.numpy())
    pred_all.extend(pred.numpy())

    y_accuracy = accuracy_score(y_all, pred_all)
    y_precision = precision_score(y_all, pred_all, average='weighted')
    y_recall = recall_score(y_all, pred_all, average='weighted')
    y_f1 = f1_score(y_all, pred_all, average='weighted')

print("[test] accuracy:{:.4f} precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(y_accuracy, y_precision, y_recall, y_f1))
print(classification_report(y_all, pred_all))

def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

这段代码可以显示图片结果

综上,可以尝试改变网络层数,激活函数,计算loss的方法等设计自己的网络。