您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

如何在Tensorflow中使用自定义python函数预取数据

如何在Tensorflow中使用自定义python函数预取数据

这是一个常见的用例,大多数实现都使用TensorFlow的 队列 将预处理代码与训练代码分离。有一个有关如何使用队列的教程,但是主要步骤如下:

定义一个队列,q它将缓冲预处理的数据。TensorFlow支持tf.FIFOQueue排队顺序生成元素的简单方法,以及tf.RandomShuffleQueue随机顺序生成元素的高级方法。队列元素是一个或多个张量(可以具有不同的类型和形状)的元组。所有队列都支持单元素(enqueuedequeue)和批处理(enqueue_manydequeue_many)操作,但是要使用批处理操作,必须在构造队列时在队列元素中指定每个张量的形状。

构建一个子图,将预处理的元素排入队列。一种方法tf.placeholder()为张量定义一些与单个输入示例对应的操作,然后将它们传递给q.enqueue()。(如果预处理一次生成一批,则应q.enqueue_many()改用。)您也可以在此子图中包括TensorFlow op。

建立执行训练的子图。这看起来像一个普通的TensorFlow图,但是会通过调用获取其输入q.dequeue_many(BATCH_SIZE)

开始会话。

创建一个或多个执行预处理逻辑的线程,然后执行入队操作,并输入预处理后的数据。您可能会发现tf.train.Coordinatortf.train.QueueRunner实用程序类对此有用。

正常运行训练图(优化器等)。

这是一个简单的load_and_enqueue()功能代码片段,可以帮助您入门:

    # Features are length-100 vectors of floats
    feature_input = tf.placeholder(tf.float32, shape=[100])
    # Labels are scalar integers.
    label_input = tf.placeholder(tf.int32, shape=[])

    # Alternatively, Could do:
    # feature_batch_input = tf.placeholder(tf.float32, shape=[None, 100])
    # label_batch_input = tf.placeholder(tf.int32, shape=[None])

    q = tf.FIFOQueue(100, [tf.float32, tf.int32], shapes=[[100], []])
    enqueue_op = q.enqueue([feature_input, label_input])

    # For batch input, do:
    # enqueue_op = q.enqueue_many([feature_batch_input, label_batch_input])

    feature_batch, label_batch = q.dequeue_many(BATCH_SIZE)
    # Build rest of model taking label_batch, feature_batch as input.
    # [...]
    train_op = ...

    sess = tf.Session()

    def load_and_enqueue():
      with open(...) as feature_file, open(...) as label_file:
        while True:
          feature_array = numpy.fromfile(feature_file, numpy.float32, 100)
          if not feature_array:
            return
          label_value = numpy.fromfile(feature_file, numpy.int32, 1)[0]

          sess.run(enqueue_op, Feed_dict={feature_input: feature_array,
                                          label_input: label_value})

    # Start a thread to enqueue data asynchronously, and hide I/O latency.
    t = threading.Thread(target=load_and_enqueue)
    t.start()

    for _ in range(TRAINING_EPOCHS):
      sess.run(train_op)
python 2022/1/1 18:42:31 有509人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶