第一章基本操作-线性回归模型

第一步:构造数据

import numpy as np
import os

x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1)

y_values = [i * 2 + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1)

第二步: 使用class LinearRegressionModel

class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.linear(x)
        return out

第三步: 实例化模型,初始化epochs, 学习率,定义SGD优化函数,以及定义mse优化损失函数

input_dim = 1
output_dim = 1

model = LinearRegressionModel(input_dim, output_dim)

epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
criterion = nn.MSELoss()

第四步: 如果模型存在就使用model.load_state_dict(torch.load("model.pkl")) 加载模型 参数,进行模型的参数优化,每50次,使用torch.save(model.state_dict)保存模型

if os.path.exists("model.pkl"):
    model.load_state_dict(torch.load("model.pkl"))

for epoch in range(epochs):

    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)

    # 梯度每次清零
    optimizer.zero_grad()

    # 前向传播
    outputs = model(inputs)

    # 计算损失值
    loss = criterion(outputs, labels)

    #反向传播
    loss.backward()

    #更新权重参数
    optimizer.step()

    if epoch % 50 == 0:
        print("epoch:{},loss:{}".format(epoch, loss.item()))
        torch.save(model.state_dict(), "model.pkl")

原文地址:https://www.cnblogs.com/my-love-is-python/p/12650327.html

时间: 2024-08-30 03:03:08

第一章基本操作-线性回归模型的相关文章

第一章基本操作-线性拟合(GPU版本)

第一步:构造数据 import numpy as np import os x_values = [i for i in range(11)] x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1) y_values = [i * 2 + 1 for i in x_values] y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1) 第二步: 使用class

第一章 基本操作

创建数据库 mysql> create database example; Query OK, 1 row affected 查看已存在的数据库: mysql> show databases; +--------------------+ | Database | +--------------------+ | information_schema | | MyCloudDB | | example | | mydatabase | | mysql | | performance_schem

第一章 基础编程模型

课后习题1.1.14 编写一个静态方法lg(),接受一个整形参数N,返回不大于log2N的最大整数.不要使用Math库. public class Test{ public static void main(String[] args){ int i; i=lg(17,2); System.out.println(i); } public static int lg(int N,int M){ int a = 0; while(N>=M){ N=N/M; a++; } return a; } }

第一章基本操作-自动求导

使用目标对象的.backward()进行反向梯度求导 import torch x = torch.randn(3, 4, requires_grad=True) print(x) b = torch.randn(3, 4, requires_grad=True) t = x + b y = t.sum() y.backward() print(b.grad) x = torch.rand(1) b = torch.rand(1, requires_grad=True) w = torch.ra

CLR执行模型《CLR via c#》第一章

这是我看<CLR via c#>第四版的一些小笔记和总结,如有不对的地方,欢迎指出. <CLR via c#>第一章CLR的执行模型讲的是如何将源代码生成为一个应用程序,或者生成为一组可重新分发的组件(文件)- 这些组件(文件)包含类型(类和结构等),解释了应用程序如何执行. CLR(common language runtime ,公共语言运行时),顾名思义,它是一个可以支持多种语言的“运行时”. 通常我们c#程序的执行过程是 CLR的JIT(即时编译器)把IL代码编译成机器指令

javascript数据结构和算法 第一章(Javascript编程环境和模型) 一

这一章介绍了我们在这本书中使用的描述各种数据结构和算法的Javascript的编程环境和编程架构. Javascript 环境 Javascript 在很长一段时间都是被作为web浏览器内置脚本编程语言来使用. 然而,在过去几年里,javascript编程环境得到了极大的发展,他们可以使javascript在桌面或者服务端运行. 在我们这本书中,我们使用其中的一个javascript环境:javascript shell:是Mozilla公司的javascript环境,被称为SpiderMonk

2017上半年软考 第一章 重要知识点

第一章 信息化的知识,具体讲了:重要的知识点是: 融合,信息技术和工业制造深度融合.人和机器的融合.信息资源和材料资源的融合 :信息论奠基者:香农: 信息的传输技术是信息技术的核心: 恰当的冗余编码可以在信息收到噪声侵扰时被恢复: 信息系统的基本规律应包括信息的度量.信源特性饿信源编码.信道特性和新到编码.检测理论.估计理论以及密码学: 信息系统特性:目的性.可嵌套行性.稳定性.开放性.脆弱性.健壮性: 信息系统生命周期:立项[规划].开发[分析.设计.实施].运维.消亡: 信息化层次:产品信息

[MOOC笔记]第一章 绪论(数据结构)

1.  计算 学习DSA的目的是实现有效的和高效的计算,同时在资源消耗的方面做到足够的低廉. 计算 = 信息处理:借助某些工具,遵照一定规则,以明确而机械的形式进行. 计算模型 = 计算机 = 信息处理工具 算法:在特定的计算模型下,旨在解决特定问题的指令序列. 算法的要素: 输入 待处理的信息(问题) 输出 经处理的信息(答案) 正确性 的确可以解决指定的问题 确定性 任一算法都可以描述为一个由基本操作组成的序列 可行性 每一基本操作都可实现,且在常数时间内完成 有穷性 对于任何输入,经有穷次

统计学习方法 笔记&lt;第一章&gt;

第一章 统计学习方法概述 1.1 统计学习 统计学习(statistical learning)是关于计算机基于数据概率模型并运用模型进行预测和分析的学科.统计学习也称为统计机器学习,现在人们提及的机器学习一般都是指统计机器学习. 统计学习的对象是数据(data),关于数据的基本假设是同类数据具有一定的统计规律性(前提):比如可以用随机变量描述数据中的特征,用概率分布描述数据的统计规律等. 统计学习的目的:对现有的数据进行分析,构建概率统计模型,分析和预测未知新数据,同时也需要考虑模型的复杂度以