PyTorch学习笔记之DataLoaders

A DataLoader wraps a Dataset and provides minibatching, shuffling, multithreading, for you。

 1 import torch
 2 from torch.autograd import Variable
 3 import torch.nn as nn
 4 from torch.utils.data import TensorDataset, DataLoader
 5
 6 # define our whole model as a single Module
 7 class TwoLayerNet(nn.Module):
 8     # Initializer sets up two children (Modules can contain modules)
 9     def _init_(self, D_in, H, D_out):
10         super(TwoLayerNet, self)._init_()
11         self.linear1 = torch.nn.Linear(D_in, H)
12         self.linear2 = torch.nn.Linear(H, D_out)
13
14     # Define forward pass using child modules and autograd ops on Variables
15     # No need to define backward - autograd will handle it
16     def forward(self, x):
17         h_relu = self.linear1(x).clamp(min=0)
18         y_pred = self.linear2(h_relu)
19         return y_pred
20
21 N, D_in, H, D_out = 64, 1000, 100, 10
22 x = Variable(torch.randn(N, D_in))
23 y = Variable(torch.randn(N, D_out))
24
25 # When you need to load custom data, just write your own Dataset class
26 loader = DataLoader(TensorDataset(x, y), batch_size=8)
27
28 model = TwoLayerNet(D_in, H, D_out)
29
30 criterion = torch.nn.MSELoss(size_average=False)
31 optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
32 for epoch in range(10):
33     # Iterate(遍历) over loader to form minibatches
34     for x_batch, y_batch in loader:
35         # Loader gives Tensors so you need to wrap in Variables
36         x_var, y_var = Variable(x), Variable(y)
37         y_pred = model(x_var)
38         loss = criterion(y_pred, y_var)
39
40         optimizer.zero_grad()
41         loss.backward()
42         optimizer.step()
时间: 2024-08-25 21:10:06

PyTorch学习笔记之DataLoaders的相关文章

pytorch 学习笔记之编写 C 扩展,又涨姿势了

pytorch利用CFFI 进行 C 语言扩展.包括两个基本的步骤(docs): 编写 C 代码: python 调用 C 代码,实现相应的 Function 或 Module. 在之前的文章中,我们已经了解了如何自定义 Module.至于 [py]torch 的 C 代码库的结构,我们留待之后讨论: 这里,重点关注,如何在 pytorch C 代码库高层接口的基础上,编写 C 代码,以及如何调用自己编写的 C 代码. 官方示例了如何定义一个加法运算(见 repo).这里我们定义ReLU函数(见

PyTorch学习笔记之nn的简单实例

method 1 1 import torch 2 from torch.autograd import Variable 3 4 N, D_in, H, D_out = 64, 1000, 100, 10 5 x = Variable(torch.randn(N, D_in)) 6 y = Variable(torch.randn(N, D_out), requires_grad=False) 7 8 # define our model as a sequence of layers 9 m

20170721 PyTorch学习笔记之计算图

1. **args, **kwargs的区别 1 def build_vocab(self, *args, **kwargs): 2 counter = Counter() 3 sources = [] 4 for arg in args: 5 if isinstance(arg, Dataset): 6 sources += [getattr(arg, name) for name, field in 7 arg.fields.items() if field is self] 8 else:

PyTorch学习笔记之Tensors 2

Tensors的一些应用 1 ''' 2 Tensors和numpy中的ndarrays较为相似, 因此Tensor也能够使用GPU来加速运算 3 ''' 4 # from _future_ import print_function 5 import torch 6 x = torch.Tensor(5, 3) # 构造一个未初始化的5*3的矩阵 7 8 x2 = torch.rand(5, 3) # 构造一个随机初始化的矩阵 the same as 9 10 # print(x.size()

PyTorch学习笔记之初识word_embedding

1 import torch 2 import torch.nn as nn 3 from torch.autograd import Variable 4 5 word2id = {'hello': 0, 'world': 1} 6 # you have 2 words, and then need 5 dim each word 7 embeds = nn.Embedding(2, 5) 8 # we need variable, because we need use element of

PyTorch学习笔记之Variable

application 1 1 from torch.autograd import Variable 2 import torch 3 b = Variable(torch.FloatTensor([64, 100, 43])) 4 print(b) 5 ''' 6 Variable containing: 7 64 8 100 9 43 10 [torch.FloatTensor of size 3] 11 ''' application 2 1 from torch.autograd im

vector 学习笔记

vector 使用练习: /**************************************** * File Name: vector.cpp * Author: sky0917 * Created Time: 2014年04月27日 11:07:33 ****************************************/ #include <iostream> #include <vector> using namespace std; int main

Caliburn.Micro学习笔记(一)----引导类和命名匹配规则

Caliburn.Micro学习笔记(一)----引导类和命名匹配规则 用了几天时间看了一下开源框架Caliburn.Micro 这是他源码的地址http://caliburnmicro.codeplex.com/ 文档也写的很详细,自己在看它的文档和代码时写了一些demo和笔记,还有它实现的原理记录一下 学习Caliburn.Micro要有MEF和MVVM的基础 先说一下他的命名规则和引导类 以后我会把Caliburn.Micro的 Actions IResult,IHandle ICondu

jQuery学习笔记(一):入门

jQuery学习笔记(一):入门 一.JQuery是什么 JQuery是什么?始终是萦绕在我心中的一个问题: 借鉴网上同学们的总结,可以从以下几个方面观察. 不使用JQuery时获取DOM文本的操作如下: 1 document.getElementById('info').value = 'Hello World!'; 使用JQuery时获取DOM文本操作如下: 1 $('#info').val('Hello World!'); 嗯,可以看出,使用JQuery的优势之一是可以使代码更加简练,使开