《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测(二)

接上一部分,此篇将用tensorflow建立神经网络,对波士顿房价数据进行简单建模预测。

二、使用tensorflow拟合boston房价datasets

1、数据处理依然利用sklearn来分训练集和测试集。

2、使用一层隐藏层的简单网络,试下来用当前这组超参数收敛较快,准确率也可以。

3、激活函数使用relu来引入非线性因子。

4、原本想使用如下方式来动态更新lr,但是尝试下来效果不明显,就索性不要了。

def learning_rate(epoch):
    if epoch < 200:
        return 0.01
    if epoch < 400:
        return 0.001
    if epoch < 800:
        return 1e-4

好了,废话不多说了,看代码如下:

from sklearn import datasets
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

dataset = datasets.load_boston()
x = dataset.data
target = dataset.target
y = np.reshape(target,(len(target), 1))

x_train, x_verify, y_train, y_verify = train_test_split(x, y, random_state=1)
y_train = y_train.reshape(-1)
train_data = np.insert(x_train, 0, values=y_train, axis=1)

def r_square(y_verify, y_pred):
    var = np.var(y_verify)
    mse = np.sum(np.power((y_verify-y_pred.reshape(-1,1)), 2))/len(y_verify)
    res = 1 - mse/var
    print(‘var:‘, var)
    print(‘MSE-ljj:‘, mse)
    print(‘R2-ljj:‘, res)

EPOCH = 3000
lr = tf.placeholder(tf.float32, [], ‘lr‘)
x = tf.placeholder(tf.float32, shape=[None, 13], name=‘input_feature_x‘)
y = tf.placeholder(tf.float32, shape=[None, 1], name=‘input_feature_y‘)

W = tf.Variable(tf.truncated_normal(shape=[13, 10], stddev=0.1))
b = tf.Variable(tf.constant(0., shape=[10]))

W2 = tf.Variable(tf.truncated_normal(shape=[10, 1], stddev=0.1))
b2 = tf.Variable(tf.constant(0., shape=[1]))

with tf.Session() as sess:
    hidden1 = tf.nn.relu(tf.add(tf.matmul(x, W), b))

    y_predict = tf.add(tf.matmul(hidden1, W2), b2)
    loss = tf.reduce_mean(tf.reduce_sum(tf.pow(y-y_predict,2), reduction_indices=[1]))
    print(loss.shape)
    train = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    W_res = 0
    b_res = 0
    try:
        last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=‘/home/ljj/PycharmProjects/mooc/train_record‘)
        saver.restore(sess, save_path=last_chk_path)
    except:
        print(‘no save file to recover-----------start new train instead--------‘)

        loss_list = []
        over_flag = 0
        for i in range(EPOCH):
            if over_flag ==1:
                    break
            y_t = train_data[:, 0].reshape(-1, 1)
            _, W_res, b_res, loss_train = sess.run([train, W, b, loss],
                                                   feed_dict={x: train_data[:, 1:],
                                                              y: y_t,
                                                              lr: 0.01})

            checkpoint_file = os.path.join(‘/home/ljj/PycharmProjects/mooc/train_record‘, ‘checkpoint‘)
            saver.save(sess, checkpoint_file, global_step=i)
            loss_list.append(loss_train)
            if loss_train < 0.2:
                over_flag = 1
                break
            if i %500 == 0:
                print(‘EPOCH = {:}, train_loss ={:}‘.format(i, loss_train))
            if i % 500 == 0:
                r = loss.eval(session=sess, feed_dict={x: x_verify,
                                                       y: y_verify,
                                                       lr: 0.01})
                print(‘verify_loss = ‘,r)
            np.random.shuffle(train_data)

        plt.plot(range(len(loss_list)-1), loss_list[1:], ‘r‘)
        plt.show()

    print(‘final loss = ‘,loss.eval(session=sess, feed_dict={x: x_verify,
                                           y: y_verify,
                                           lr: 0.01}))

    y_pred = sess.run(y_predict, feed_dict={x: x_verify,
                                           y: y_verify,
                                           lr: 0.01})

    plt.subplot(2,1,1)
    plt.xlim([0,50])
    plt.plot(range(len(y_verify)), y_pred,‘b--‘)
    plt.plot(range(len(y_verify)), y_verify,‘r‘)
    plt.title(‘validation‘)

    y_ss = sess.run(y_predict, feed_dict={x: x_train,
                                           y: y_train.reshape(-1, 1),
                                           lr: 0.01})
    plt.subplot(2,1,2)
    plt.xlim([0,50])
    plt.plot(range(len(y_train)), y_ss,‘r--‘)
    plt.plot(range(len(y_train)), y_train,‘b‘)
    plt.title(‘train‘)

    plt.savefig(‘tf.png‘)
    plt.show()

    r_square(y_verify, y_pred)

训练了大概3000个epoch后,保存模型,之后可以多次训练,但是loss基本收敛了,没有太大变化。

输出结果如下:

final loss =  15.117827
var: 99.0584735569471
MSE-ljj: 15.11782691349897
R2-ljj: 0.8473848185757882

从图像上看,拟合效果也是一般,再拿一个放大版本的validation图,同样取前50个样本,这样方便和之前的线性回归模型对比。

最后我们还是用数据来说明:

tf模型结果中,

R2:0.847   > 0. 779

MSE:15.1  < 21.8

都比sklearn的线性回归结果要好。所以,此tf模型对波士顿房价数据的可解释性更强。

def learning_rate(epoch):    if epoch < 200:        return 0.01if epoch < 400:        return 0.001if epoch < 800:        return 1e-4

原文地址:https://www.cnblogs.com/lingjiajun/p/10015933.html

时间: 2024-10-06 16:30:08

《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测(二)的相关文章

《用python 玩转数据》项目——B站弹幕数据分析

1. 背景 在视频网站上,一边看视频一边发弹幕已经是网友的习惯.在B站上有很多种类的视频,也聚集了各种爱好的网友.本项目,就是对B站弹幕数据进行分析.选取分析的对象是B站上点播量过1.4亿的一部剧<Re:从零开始的异世界生活>. 2.       算法 分两部分:  第一部分: 2.1     在<Re:从零开始的异世界生活>的首页面,找到共25集的所有对应播放链接和剧名的格式,获取每一集的播放链接,并保存. 2.2     从每一集的播放页面中,通过正则re获取它的cid号,获得

用Python玩转数据:python的函数、模块和包

Python函数 函数可以看成类似于数学中的函数,完成一个特定功能的一段代码. -绝对值函数 abs() -类型函数 type() -四舍五入函数 round() Python中有很多内建函数,即不需要另外导入的函数. -cmp(), str() 和 type()适用于所有标准类型.以下是数值型内建函数和实用内建函数. >>> dir(_builtins_) 命令可以看到Python中的内建变量和内建函数. >>> help(abs) 命令用于查看abs函数的帮助信息.

《用Python玩转数据》学习笔记

1.Python中运行程序的方式: 1.shell方式 2.IDE中建立一个.py文件,然后在shell中用解释器执行该方式 一般是代码段比较短的时候优先考虑用shell方式,如果代码段比较长的话,优先选用文件方式. 选用既可以shell有可以文件执行的执行环境,本课程使用Python(x,y). 2.python(x,y),是一个基于Python的科学计算软件包.Python兼顾了编写效率和执行效率,所以成为了非常受欢迎的科学计算语言.有很多类似的软件包.python的主要优点是其中包含了非常

微软数据挖掘算法:Microsoft 线性回归分析算法(11)

前言 此篇为微软系列挖掘算法的最后一篇了,完整该篇之后,微软在商业智能这块提供的一系列挖掘算法我们就算总结完成了,在此系列中涵盖了微软在商业智能(BI)模块系统所能提供的所有挖掘算法,当然此框架完全可以自己扩充,可以自定义挖掘算法,不过目前此系列中还不涉及,只涉及微软提供的算法,当然这些算法已经基本涵盖大部分的商业数据挖掘的应用场景,也就是说熟练了这些算法大部分的应用场景都能游刃有余的解决,每篇算法总结包含:算法原理.算法特点.应用场景以及具体的操作详细步骤.为了方便阅读,我还特定整理一篇目录:

阿里,腾讯内部十二个大数据项目,你都有做过吗?

随着社会的进步,大数据的高需求,高薪资,高待遇,促使很多人都来学习和转行到大数据这个行业.学习大数据是为了什么?成为一名大数据高级工程师.而大数据工程师能得到高薪.高待遇的能力在哪?自然是项目经验.下面给大家大概介绍一下在阿里的"双11"."双12"."双旦"即将到来的"618"与腾讯大数据都用上的十二个大数据项目:阿里,腾讯内部十二个大数据项目,你都有做过吗?一个大数据分析项目关键构成如下: 信息采集组.数据清洗组.数据融合

python之简单线性回归分析

使用sklearn库的linear_model.LinearRegression(),可以非常简单的进行线性回归分析 以下为代码: 1 # 导入sklearn库下的linear_model类 2 from sklearn import linear_model 3 # 导入pandas库,别名为pd 4 import pandas as pd 5 6 filename = r'D:\test.xlsx' 7 # 读取数据文件 8 data = pd.read_excel(filename) 9

Python即时网络爬虫项目启动说明

作为酷爱编程的老程序员,实在按耐不下这个冲动,Python真的是太火了,不断撩拨我的心. 我是对Python存有戒备之心的,想当年我基于Drupal做的系统,使用php语言,当语言升级了,推翻了老版本很多东西,不得不花费很多时间和精力去移植和升级,至今还有一些隐藏在某处的代码埋着雷.我估计Python也避免不了这个问题(其实这种声音已经不少,比如Python 3 正在毁灭 Python). 但是,我还是启动了这个Python即时网络爬虫项目.我用C++.Java和Javascript编写爬虫相关

程序员带你十天快速入门Python,玩转电脑软件开发(二)

关注今日头条-做全栈攻城狮,学代码也要读书,爱全栈,更爱生活.提供程序员技术及生活指导干货. 如果你真想学习,请评论学过的每篇文章,记录学习的痕迹. 请把所有教程文章中所提及的代码,最少敲写三遍,达到熟悉的效果. 声明:本次教程主要适用于已经习得一门编程语言的程序员.想要学习第二门语言.有梦想,立志做全栈攻城狮的你 如果是小白,也可以学习本教程.不过可能有些困难.如有问题在文章下方进行讨论.或者添加QQ群538742639.群马上就满了,名额不多. 上节课主要讲解了以下内容: 为什么学习Pyth

使用Python玩转WMI

最近在网上搜索Python和WMI相关资料时,发现大部分文章都千篇一律,并且基本上只说了很基础的使用,并未深入说明如何使用WMI.本文打算更进一步,让我们使用Python玩转WMI. 1 什么是WMI 具体请看微软官网对WMI的介绍.这里简单说明下,WMI的全称是Windows Management Instrumentation,即Windows管理规范.它是Windows操作系统上管理数据和操作的基础设施.我们可以使用WMI脚本或者应用自动化管理任务等. 从Using WMI可以知道WMI支