文章:
http://python.jobbole.com/81215/
python的函数库好强大!看完这篇博再也不会用matlab了~~
这篇文章使用【panda】读取csv的数据,使用【sklearn】中的linear_model训练模型并进行线性预测,使用【matplotlib】将拟合的情况用图表示出来。
下面的表格是用于训练模型的表格:
代码如下:
# -*- coding: utf-8 -*- ‘‘‘ Created on 2016/11/26 @author: chensi ‘‘‘ # Required Packages import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn import datasets, linear_model from numpy.ma.core import getdata # Function to get data def get_data(file_name): data = pd.read_excel(file_name) X_parameter = [] Y_parameter = [] for single_square_feet ,single_price_value in zip(data[‘square_feet‘],data[‘price‘]): X_parameter.append([float(single_square_feet)]) Y_parameter.append(float(single_price_value)) return X_parameter,Y_parameter # Function for Fitting our data to Linear model def linear_model_main(X_parameters,Y_parameters,predict_value): # Create linear regression object regr = linear_model.LinearRegression() regr.fit(X_parameters, Y_parameters) predict_outcome = regr.predict(predict_value) predictions = {} predictions[‘intercept‘] = regr.intercept_ predictions[‘coefficient‘] = regr.coef_ predictions[‘predicted_value‘] = predict_outcome return predictions # Function to show the resutls of linear fit model def show_linear_line(X_parameters,Y_parameters): # Create linear regression object regr = linear_model.LinearRegression() regr.fit(X_parameters, Y_parameters) plt.scatter(X_parameters,Y_parameters,color=‘blue‘) plt.plot(X_parameters,regr.predict(X_parameters),color=‘red‘,linewidth=4) plt.xticks(()) plt.yticks(()) plt.show() #---------Test--------------- #---------------------------- x,y = get_data("g:/input_data.csv") show_linear_line(x,y) print(linear_model_main(x,y,150)) #---------------------------- #----------------------------
输出的图:
例子二:
代码:
# -*- coding: utf-8 -*- ‘‘‘ Created on 2016/11/26 @author: chensi ‘‘‘ # Required Packages import csv import sys import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn import datasets, linear_model # Function to get data def get_data(file_name): data = pd.read_excel(file_name) flash_x_parameter = [] flash_y_parameter = [] arrow_x_parameter = [] arrow_y_parameter = [] for x1,y1,x2,y2 in zip(data[‘flash_episode_number‘],data[‘flash_us_viewers‘],data[‘arrow_episode_number‘],data[‘arrow_us_viewers‘]): flash_x_parameter.append([float(x1)]) flash_y_parameter.append(float(y1)) arrow_x_parameter.append([float(x2)]) arrow_y_parameter.append(float(y2)) return flash_x_parameter,flash_y_parameter,arrow_x_parameter,arrow_y_parameter # Function to know which Tv show will have more viewers def more_viewers(x1,y1,x2,y2): regr1 = linear_model.LinearRegression() regr1.fit(x1, y1) predicted_value1 = regr1.predict(9) print(predicted_value1) regr2 = linear_model.LinearRegression() regr2.fit(x2, y2) predicted_value2 = regr2.predict(9) #print predicted_value1 #print predicted_value2 if predicted_value1 > predicted_value2: print ("The Flash Tv Show will have more viewers for next week") else: print ("Arrow Tv Show will have more viewers for next week") x1,y1,x2,y2 = get_data(‘G:/input_data_2.xlsx‘) #print x1,y1,x2,y2 more_viewers(x1,y1,x2,y2)
输出:
时间: 2024-10-29 19:10:14