from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python.keras.models import Sequential,Model from tensorflow.python.keras.layers import Dense,Flatten,Input import tensorflow as tf from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python import keras import os import numpy as np class SingleNN(object): #建立神经网络模型 model = keras.Sequential([ keras.layers.Flatten(input_shape=(28,28)), keras.layers.Dense(128,activation=tf.nn.relu), keras.layers.Dense(10,activation=tf.nn.softmax) ]) def __init__(self): (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data() #归一化 self.x_train = self.x_train/255.0 self.x_test = self.x_test/255.0 def singlenn_compile(self): ‘‘‘ 编译模型优化器、损失、准确率 :return: ‘‘‘ SingleNN.model.compile( optimizer=keras.optimizers.SGD(lr=0.01), loss=keras.losses.sparse_categorical_crossentropy, metrics=[‘accuracy‘] ) def singlenn_fit(self): """ 进行fit训练 :return: """ # modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5", # # monitor="val_acc", #保存损失还是准确率 # # save_best_only=True, # save_weights_only=True, # mode = ‘auto‘, # period = 1 # ) board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True) SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board]) def single_evalute(self): ‘‘‘ 模型评估 :return: ‘‘‘ test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test) print(test_loss,test_acc) def single_predict(self): ‘‘‘ 预测结果 :return: ‘‘‘ # if os.path.exists("./ckpt/checkpoink"): # SingleNN.model.load_weights("./ckpt/SingleNN") if os.path.exists("./ckpt/SingleNN.h5"): SingleNN.model.load_weights("./ckpt/SingleNN.h5") predictions = SingleNN.model.predict(self.x_test) return predictions if __name__ == ‘__main__‘: snn = SingleNN() snn.singlenn_compile() snn.singlenn_fit() snn.single_evalute() # # SingleNN.model.save_weights("./ckpt/SingleNN") # SingleNN.model.save_weights("./ckpt/SingleNN.h5") # predictions = snn.single_predict() # print(predictions) # result = np.argmax(predictions,axis=1) # print(result)
原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12250739.html
时间: 2024-10-09 13:18:10