引用库文件
from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np import pandas as pd import tensorflow as tf from tensorflow import feature_column from tensorflow.keras import layers from sklearn.model_selection import train_test_split
加载数据集,生成数据帧资源句柄
# 将heart.csv数据集下载并加载到数据帧中 path_data = "E:/pre_data/heart.csv" dataframe = pd.read_csv(path_data)
将pandas dataframe 数据格式转变为 tf.data 格式的数据集形式
# 拷贝数据帧,id(dataframe)!=id(dataframe_new) dataframe_new = dataframe.copy() # 从dataframe_new数据中获取target属性 labels = dataframe_new.pop(‘target‘) # 要构建Dataset内存中的数据 ds = tf.data.Dataset.from_tensor_slices((dict(dataframe_new), labels)) # 将数据打乱的混乱程度 ds = ds.shuffle(buffer_size=len(dataframe_new)) # 从数据集中取出数据集的个数 ds = ds.batch(100) # 指定数据集重复的次数 ds = ds.repeat(2)
ds 中有shuffle、batch、repeat三个方法;具体区别如下
shuffle:
tensorflow
中的数据集类Dataset
有一个shuffle
方法,用来打乱数据集中数据顺序,训练时非常常用。其中shuffle
方法有一个参数buffer_size
,非常令人费解,文档的解释如下:
原文地址:https://www.cnblogs.com/gengyi/p/11107492.html
时间: 2024-10-19 15:56:15