您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

tf.data.Dataset:如何获取数据集大小(历元中的元素数)?

tf.data.Dataset:如何获取数据集大小(历元中的元素数)?

tf.data.Dataset.list_files创建一个张量称为MatchingFiles:0(如果适用,带有适当的前缀)。

您可以评估

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

获取文件数。

当然,这仅在简单情况下有效,特别是如果每??个图像只有一个样本(或已知数量的样本)。

在更复杂的情况下,例如,当您不知道每个文件中的样本数时,您只能在时期结束时观察样本数。

为此,您可以观看计数的纪元数Datasetrepeat()创建一个名为的成员_count,该成员对纪元数进行计数。通过在迭代过程中观察它,您可以发现它何时发生变化,并从那里计算出数据集的大小。

该计数器可能埋在Datasets的层次结构中,该层次结构是在连续调用成员函数时创建的,因此我们必须像这样挖掘它。

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

请注意,使用此技术时,数据集大小的计算并不准确,因为在此期间epoch_counter递增的批次通常会将来自两个连续时期的样本混合在一起。因此,此计算精确到您的批生产长度。

其他 2022/1/1 18:41:32 有494人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶