第一章基本操作-线性拟合(GPU版本)

第一步:构造数据

import numpy as np
import os

x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1)

y_values = [i * 2 + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1)

第二步: 使用class LinearRegressionModel

class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.linear(x)
        return out

第三步: 实例化模型,初始化epochs, 学习率,定义SGD优化函数,以及定义mse优化损失函数,使用model.to(device) 将模型的参数更新放在GPU上

input_dim = 1
output_dim = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LinearRegressionModel(input_dim, output_dim)model.to(device)

epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
criterion = nn.MSELoss()

第四步: 如果模型存在就使用model.load_state_dict(torch.load("model.pkl")) 加载模型 参数,进行模型的参数优化,每50次,使用torch.save(model.state_dict)保存模型 ,使用to(device) 将训练样本和测试样本放在GPU上

if os.path.exists("model.pkl"):
    model.load_state_dict(torch.load("model.pkl"))

for epoch in range(epochs):

    inputs = torch.from_numpy(x_train).to(device)
    labels = torch.from_numpy(y_train).to(device)

    # 梯度每次清零
    optimizer.zero_grad()

    # 前向传播
    outputs = model(inputs)

    # 计算损失值
    loss = criterion(outputs, labels)

    #反向传播
    loss.backward()

    #更新权重参数
    optimizer.step()

    if epoch % 50 == 0:
        print("epoch:{},loss:{}".format(epoch, loss.item()))
        torch.save(model.state_dict(), "model.pkl")

原文地址:https://www.cnblogs.com/my-love-is-python/p/12650342.html

时间: 2024-08-30 03:03:08

第一章基本操作-线性拟合(GPU版本)的相关文章

第一章基本操作-线性回归模型

第一步:构造数据 import numpy as np import os x_values = [i for i in range(11)] x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1) y_values = [i * 2 + 1 for i in x_values] y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1) 第二步: 使用class

第一章 基本操作

创建数据库 mysql> create database example; Query OK, 1 row affected 查看已存在的数据库: mysql> show databases; +--------------------+ | Database | +--------------------+ | information_schema | | MyCloudDB | | example | | mydatabase | | mysql | | performance_schem

第一章基本操作-自动求导

使用目标对象的.backward()进行反向梯度求导 import torch x = torch.randn(3, 4, requires_grad=True) print(x) b = torch.randn(3, 4, requires_grad=True) t = x + b y = t.sum() y.backward() print(b.grad) x = torch.rand(1) b = torch.rand(1, requires_grad=True) w = torch.ra

知识图谱文献综述(第一章)

既然决定了以知识图谱作为研究方向,文献综述是必不可少的. 本文主要总结<知识图谱发展报告(2018)-中国中文信息学会> 1. 知识图谱的研究目标与意义 (略) 2. 知识工程的发展历程 3. 知识图谱技术 人们通过概念掌握对客观世界的理解,概念是对客观世界事物的抽象,是将 人们对世界认知联系在一起的纽带.知识图谱以结构化的形式描述客观世界中概 念.实体及其关系.实体是客观世界中的事物,概念是对具有相同属性的事物的 概括和抽象.本体是知识图谱的知识表示基础,可以形式化表示为,O={C,H, P

《Machine Learning》(第一章)序章

关键词:机器学习,基本术语,假设空间,归纳偏好,机器学习用途 一.机器学习概述 机器学习是一门从数据中,经过计算得到模型(Model)的一种过程,得到的模型不仅能反应出训练数据集中所蕴含的规律,并且能够运用在训练集之外的数据上.而机器学习研究的方向,就是解决:“我们为了得到这种模型,应该采用何种算法” 的问题. 如果说,训练集是我们的生活中的 “经验”,那么模型就是我们的 “经验性解决方法” ,训练集外的数据就是生活中的 “新问题” . 二.基本术语 在解释基本术语的同时,我们用生活中的例子 “

《Deep Learning》译文 第一章 前言(中) 神经网络的变迁与称谓的更迭

转载请注明出处! 第一章 前言(中) 1.1 本书适合哪些人阅读? 可以说本书的受众目标比较广泛,但是本书可能更适合于如下的两类人群,一类是学习过与机器学习相关课程的大学生们(本科生或者研究生),这包括了那些刚刚开始深度学习和AI研究的同学们:另一类是有机器学习或统计学背景的,想快速将深度学习应用在其产品或平台中的软件开发者们.深度学习早已被证实可以在许多软件应用中发挥光和热,比如:计算机视觉.语音与视频处理.自然语言理解.机器人学.生物学与化学.电视游戏.搜索引擎.在线广告与金融学等等. 为了

《机器学习》读书笔记-第一章 引言

<Machine Learning>,作者Tom Mitchell,卡内基梅隆大学. 第一章 引言 1.1 学习问题的标准描述: 机器学习的定义: 如果一个计算机程序针对某类任务T的用P衡量的性能根据经验E来自我完善, 那么我们称这个计算机程序在从E中学习,针对某类任务T,它的性能用P来衡量. 例子: 对于学习下西洋跳棋的计算机程序,它可以通过和自己下棋获取经验: 它的任务是参与西洋跳棋对弈: 它的性能用它赢棋的能力来衡量. 学习问题的三个特征: 任务的种类, 衡量性能提高的标准, 经验的来源

《软件工程概论》第一章核心内容

第一章  软件定义:是计算机系统中与硬件相互依存的另一部分,包括程序.数据和相关文档的完整集合. 软件特性:形态特性.智能特性.开发特性.质量特性.生产特性.管理特性.环境特性.维护特性.废弃特性.应用特性.  软件分类.  (1) 系统软件 (2) 应用软件 (3) 支撑软件 (4) 可复用软件   软件危机的原因:1)缺乏软件开发的经验和有关软件开发数据的积累,使得开发工作的计划很难制定.2)软件人员与用户的交流存在障碍,除了知识背景的差异,缺少合适的交流方法和需求描述工具也是重要的一个原因

数据结构期末复习第一章绪论

前言: 最近快期末了,复习下数据结构,下列习题和答案解析,大部分来源于网络,如有不对之处还请指出. 在这里,星云祝各位考生期末考试顺利,新年快乐! 第一章绪论 1. 数据结构是一门研究非数值计算的程序设计问题中计算机的(操作对象)以及它们之间的(关系)和(操作)的学科. 2. 下列关于数据结构的基本概念中,叙述正确的是( C ). A. 数据元素是数据的最小单位. B. 数据的逻辑结构是指数据的各数据项之间的逻辑关系. C. 任何一个算法的设计取决于选定逻辑结构,而算法的实现依赖于采用的存储结构