tf.data实例

使用tf.data作为数据输入训练模型


welcome

tf.data数据处理

一维数据

1
2
3
4
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
dataset
for ele in dataset:
print(ele)

二维数据

1
2
3
4
dataset = tf.data.Dataset.from_tensor_slices([[1,2],[3,4],[5,6],[7,8]])
dataset
for ele in dataset:
print(ele.numpy())

字典

1
2
3
4
5
6
7
dataset = tf.data.Dataset.from_tensor_slices({
'a' : [1,2,3,4],
'b' : [3,4,1,2]
})
dataset
for ele in dataset:
print(ele.get("a"))

使用numpy生成数组

1
2
3
4
5
6
7
8
9
dataset = tf.data.Dataset.from_tensor_slices(np.array([1,2,3,4,5]))
dataset = dataset.shuffle(5)#打乱数据
dataset = dataset.repeat(5)#重复扩展数据5次
dataset = dataset.batch(5)#将数据分成5个batch
dataset = dataset.map(tf.square)#将每个tensor开方
for ele in dataset:
print(ele.numpy())
for ele in dataset:
print(ele)

训练数据集

使用numpy 处理数据免去使用tf api定义batchSize和训练数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
import numpy as np
# 加载数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images/255.0
test_images = test_images/255.0
ds_train_images = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
ds_train = tf.data.Dataset.zip((ds_train_images, ds_train_labels))# 将两个数据集合并可以用zip()函数
ds_train = ds_train.shuffle(10000).repeat().batch(32)
# 创建test数据集
ds_test_images = tf.data.Dataset.from_tensor_slices(test_images)
ds_test_labels = tf.data.Dataset.from_tensor_slices(test_labels)
ds_test = tf.data.Dataset.zip((ds_test_images, ds_test_labels))
ds_test = ds_test.batch(32)
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
step_pre_epoch = train_images.shape[0]//32
model.fit(ds_train, epochs=5, steps_per_epoch=step_pre_epoch,validation_data=ds_test,validation_steps=test_images.shape[0]//32)
# 绘制accuracy图像
import matplotlib.pyplot as plt
plt.plot(model.history.history['accuracy'])

×

纯属好玩

扫码支持
扫码打赏,你说多少就多少

打开支付宝扫一扫,即可进行扫码打赏哦

文章目录
  1. 1. tf.data数据处理
    1. 1.1. 一维数据
    2. 1.2. 二维数据
    3. 1.3. 字典
    4. 1.4. 使用numpy生成数组
    5. 1.5. 训练数据集