tf.data.Dataset.list_files
创建一个张量称为MatchingFiles:0
(如果适用,带有适当的前缀)。
您可以评估
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
当然,这仅在简单情况下有效,特别是如果每??个图像只有一个样本(或已知数量的样本)。
在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能在时期结束时观察样本数。
为此,您可以观看计数的纪元数Dataset
。repeat()
创建一个名为的成员_count
,该成员对纪元数进行计数。通过在迭代过程中观察它,您可以发现它何时发生变化,并从那里计算出数据集的大小。
该计数器可能埋在Dataset
s的层次结构中,该层次结构是在连续调用成员函数时创建的,因此我们必须像这样挖掘它。
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count
请注意,使用此技术时,数据集大小的计算并不准确,因为在此期间epoch_counter
递增的批次通常会将来自两个连续时期的样本混合在一起。因此,此计算精确到您的批生产长度。