dataset.py
‘‘‘ 准备数据集 ‘‘‘ import torch from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.transforms import ToTensor,Compose,Normalize import torchvision import config def mnist_dataset(train): func = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=(0.1307,), std = (0.3081,) ) ]) #准备Mnist数据集 return MNIST(root="../mnist",train=train,download=False,transform=func) def get_dataloader(train = True): mnist = mnist_dataset(train) batch_size = config.train_batch_size if train else config.test_batch_size return DataLoader(mnist,batch_size=batch_size,shuffle=True) if __name__ == ‘__main__‘: for (images,labels) in get_dataloader(): print(images.size()) print(labels) break
model.py
‘‘‘定义模型‘‘‘ import torch.nn as nn import torch.nn.functional as F class MnistModel(nn.Module): def __init__(self): super(MnistModel,self).__init__() self.fc1 = nn.Linear(28*28,100) self.fc2 = nn.Linear(100,10) def forward(self,image): image_viwed = image.view(-1,28*28) fc1_out = self.fc1(image_viwed) fc1_out_relu = F.relu(fc1_out) out = self.fc2(fc1_out_relu) return F.log_softmax(out,dim=-1)
config.py
‘‘‘ 项目配置 ‘‘‘ import torch train_batch_size = 128 test_batch_size = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train.py
‘‘‘ 进行模型的训练 ‘‘‘ from dataset import get_dataloader from models import MnistModel from torch import optim import torch.nn.functional as F import config from tqdm import tqdm import numpy as np import torch import os from eval import eval #实例化模型、优化器、损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam(model.parameters(),lr=0.001) if os.path.exists("./model/mnist_net.pt"): model.load_state_dict(torch.load("./model/mnist_net.pt")) optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt")) #迭代训练 def train(epoch): train_dataloader = get_dataloader(train=True) bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader)) total_loss = [] for idx,(input,target) in bar: input = input.to(config.device) target = target.to(config.device) #梯度置为0 optimizer.zero_grad() #计算得到预测值 output = model(input) #得到损失 loss = F.nll_loss(output,target) total_loss.append(loss.item()) #反向传播,计算损失 loss.backward() #参数更新 optimizer.step() if idx%10 ==0: bar.set_description("epoch:{} idx:{},loss:{}".format(epoch,idx,np.mean(total_loss))) torch.save(model.state_dict(),"model/mnist_net.pt") torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt") if __name__ == ‘__main__‘: for i in range(10): train(i) eval()
eval.py
‘‘‘ 进行模型的训练 ‘‘‘ from dataset import get_dataloader from models import MnistModel from torch import optim import torch.nn.functional as F import config import numpy as np import torch import os #迭代训练 def eval(): # 实例化模型、优化器、损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam(model.parameters(), lr=0.01) if os.path.exists("./model/mnist_net.pt"): model.load_state_dict(torch.load("./model/mnist_net.pt")) optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt")) test_dataloader = get_dataloader(train=False) total_loss = [] total_acc = [] with torch.no_grad(): for input,target in test_dataloader: input = input.to(config.device) target = target.to(config.device) #计算得到预测值 output = model(input) #计算损失 loss = F.nll_loss(output,target) #反向传播,计算损失 total_loss.append(loss.item()) #计算准确率 pred = output.max(dim=-1)[-1] total_acc.append(pred.eq(target).float().mean().item()) print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc))) if __name__ == ‘__main__‘: eval()
D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day3/手写数字识别/train.py epoch:0 idx:460,loss:0.32289110562095413: 100%|██████████| 469/469 [00:24<00:00, 19.05it/s] test loss:0.17968503131142147,test acc:0.9453125 epoch:1 idx:460,loss:0.15012750004513145: 100%|█████████▉| 468/469 [00:20<00:00, 22.10it/s]epoch:1 idx:460,loss:0.15012750004513145: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s] test loss:0.12370304338916947,test acc:0.9624208860759493 epoch:2 idx:460,loss:0.10398845713577534: 99%|█████████▉| 464/469 [00:21<00:00, 22.78it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|█████████▉| 467/469 [00:21<00:00, 22.71it/s]epoch:2 idx:460,loss:0.10398845713577534: 100%|██████████| 469/469 [00:21<00:00, 21.82it/s] test loss:0.10385569722592077,test acc:0.9697389240506329 epoch:3 idx:460,loss:0.07973297938720653: 100%|█████████▉| 467/469 [00:22<00:00, 23.12it/s]epoch:3 idx:460,loss:0.07973297938720653: 100%|██████████| 469/469 [00:22<00:00, 20.84it/s] test loss:0.08691684670652015,test acc:0.9754746835443038 epoch:4 idx:460,loss:0.0650228117158285: 100%|█████████▉| 468/469 [00:21<00:00, 24.06it/s]epoch:4 idx:460,loss:0.0650228117158285: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s] test loss:0.0803159438309413,test acc:0.9760680379746836 epoch:5 idx:460,loss:0.05270117848966101: 100%|██████████| 469/469 [00:21<00:00, 21.92it/s] test loss:0.08102699166423158,test acc:0.9759691455696202 epoch:6 idx:460,loss:0.04386751471317642: 100%|██████████| 469/469 [00:19<00:00, 24.58it/s] test loss:0.07991968260347089,test acc:0.9769580696202531 epoch:7 idx:460,loss:0.03656852366544161: 100%|██████████| 469/469 [00:15<00:00, 31.20it/s] test loss:0.07767781678917288,test acc:0.9774525316455697 epoch:8 idx:460,loss:0.03112584312896925: 100%|██████████| 469/469 [00:14<00:00, 32.41it/s] test loss:0.07755146227494071,test acc:0.9773536392405063 epoch:9 idx:460,loss:0.025217091969725495: 100%|██████████| 469/469 [00:14<00:00, 31.53it/s] test loss:0.07112929566845863,test acc:0.9802215189873418
原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12314982.html
时间: 2024-09-30 03:31:58