import tensorflow as tf from tensorflow import keras # train: 60k | test: 10k (x, y), (x_test, y_test) = keras.datasets.mnist.load_data() x.shape y.shape
# 0纯黑、255纯白 x.min(), x.max(), x.mean()
x_test.shape, y_test.shape
# 0-9有10种分类结果 y_onehot = tf.one_hot(y, depth=10) y_onehot[:2]
# train: 50k | test: 10k (x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()
x.shape, y.shape, x_test.shape, y_test.shape
x.min(), x.max()
db = tf.data.Dataset.from_tensor_slices(x_test) next(iter(db)).shape
db = tf.data.Dataset.from_tensor_slices((x_test, y_test)) next(iter(db))[0].shape
打乱数据 db = tf.data.Dataset.from_tensor_slices((x_test, y_test)) db = db.shuffle(10000)
数据预处理 def preprocess(x, y): x = tf.cast(x, dtype=tf.float32) / 255. y = tf.cast(y, dtype=tf.int32) y = tf.one_hot(y, depth=10) return x, y db2 = db.map(preprocess) res = next(iter(db2)) res[0].shape, res[1].shape
一次性得到多张照片 db3 = db2.batch(32) res = next(iter(db3)) res[0].shape, res[1].shape
db_iter = iter(db3) while True: next(db_iter)
repeat() # 迭代不退出 db4 = db3.repeat() # 迭代两次退出 db3 = db3.repeat(2)
def prepare_mnist_features_and_labels(x, y): x = tf.cast(x, tf.float32) / 255. y = tf.cast(y, tf.int64) return x, y def mnist_dataset(): (x, y), (x_val, y_val) = datasets.fashion_mnist.load_data() y = tf.one_hot(y, depth=10) y_val = tf.one_hot(y_val, depth=10) ds = tf.data.Dataset.from_tensor_slices((x, y)) ds = ds.map(prepare_mnist_features_and_labels) ds = ds.shffle(60000).batch(100) ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)) ds_val = ds_val.map(prepare_mnist_features_and_labels) ds_val = ds_val.shuffle(10000).batch(100) return ds, ds_val
原文地址:https://www.cnblogs.com/tszr/p/12141969.html
时间: 2024-10-09 19:33:20