您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

如何将Keras .h5导出到tensorflow .pb?

如何将Keras .h5导出到tensorflow .pb?

Keras本身不包括将TensorFlow图导出为协议缓冲区文件的任何方法,但是您可以使用常规TensorFlow实用程序来实现。是一篇博客文章,解释了如何使用freeze_graph.pyTensorFlow中包含的实用程序脚本执行此操作,这是完成操作的“典型”方式。

但是,我个人觉得必须创建一个检查点,然后运行一个外部脚本来获取模型,但我更喜欢从我自己的Python代码中执行此操作,因此我使用了这样的函数

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph deFinition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph

这是实施的启发freeze_graph.py。参数也类似于脚本。session是TensorFlow会话对象。keep_var_names仅在您希望不冻结某些变量时才需要(例如,对于有状态模型),通常不需要。output_names是包含产生所需输出的操作名称的列表。clear_devices只需删除任何设备指令即可使图形更具可移植性。因此,对于model具有一个输出的典型Keras ,您将执行以下操作:

from keras import backend as K

# Create, compile and train model...

frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])

然后,您可以像往常一样使用tf.train.write_graph以下命令将图形写入文件

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)
其他 2022/1/1 18:43:01 有490人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶