如上篇文章所讲,将我们需用的环境搭建完成以后,我们就可以开始AI之路了,下面就让我们来看看第一个网络框架结构——全连接吧。
import torch.nn as nn#导入所需库 class Net(nn.Module): #初始化网络结构(设计神经网络) def __init__(self): super().__init__() #设计一个多层结构的神经网络 self.layers = nn.Sequential( nn.Linear(28*28,512), #设计一层神经网络,有512个神经元,接受748个 nn.ReLU(), nn.Linear(512,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,10), nn.Softmax(dim=1) ) # 前向计算(使用神经网络),将数据x输入到网络中,返回结果 def forward(self, x): return self.layers(x) ***************************************************************************************************************************************
import torchimport torchvisionimport torch.nn as nnfrom PIL import Imageimport torch.utils.data as datafrom my_net import Netimport numpy as npimport ossave_path = "module/net_ps.pth" train_data = torchvision.datasets.MNIST( root="MNIST_data",#单通道28*28黑白图片(0-9数字) train=True, transform=torchvision.transforms.ToTensor(), download=True)test_data = torchvision.datasets.MNIST( root="MNIST_data", train=False, transform=torchvision.transforms.ToTensor(), download=False) if __name__ == ‘__main__‘: #创建数据加载器,每次从train_data里面取100张数据,打乱 train = data.DataLoader(dataset=train_data,batch_size=100,shuffle=True)#用数据加载器从train中每次加载100张图片并打乱 #实例化网络对象 net = Net() #判断本地是否已经有网络的参数,如果有,那就加载之前的参数 if os.path.isfile(save_path): net = torch.load(save_path) #定义损失函数 loss_fun = nn.MSELoss()#对(h-y)^2求平均 #定义优化器,用这个优化器来优化网络内部的参数 optimizer = torch.optim.Adam(net.parameters()) #取数据,训练网络 for epoch in range(1000000): for i,(x,y) in enumerate(train):#N C H W形状 #将图片变为100,784 x = x.reshape(-1,28*28) #将图片输入到网络,得到结果 out = net(x) #将标签y进行one-hot编码 target = torch.zeros(y.size()[0],10).scatter_(1,y.view(-1,1),1) #将网络的结果和标签拿来做损失 loss = loss_fun(target,out) #优化损失 optimizer.zero_grad()#清空梯度 loss.backward()#根据损失进行反向求导 optimizer.step()#更新梯度 #每训练10次,进行一次测试 if i%10 == 0: out_put = torch.argmax(out,dim=1) # print("target:",y) # print("out:",out_put) print("loss:",loss.item()) #计算准确度 acc = np.mean(np.array(out_put==y,dtype=np.float32)) print("精度:",acc) #保存网络参数 torch.save(net,save_path)
原文地址:https://www.cnblogs.com/wangyueyyy/p/11822340.html
时间: 2024-11-25 08:20:34