简单介绍TensorFlow中关于tf.app.flags命令行参数解析模块

Crq
Crq
管理员
1422
文章
0
粉丝
Linux教程评论3字数 622阅读2分4秒阅读模式
摘要这篇文章主要介绍了TensorFlow中关于tf.app.flags命令行参数解析模块,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
tf.app.flags命令行参数解析模块

说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考我之前的一篇文章:python argparse 模块命令行参数用法及说明。

在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。

直接上代码实例,新建一个名为 test_flags.py 的文件,内容如下:

#coding:utf-8
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("train_batch_size", 128, "batch size of train data")
tf.app.flags.DEFINE_integer("test_batch_size", 64, "batch size of test data")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
def main(unused_argv):
    train_data_path = FLAGS.train_data_path
    print("train_data_path", train_data_path)
    train_batch_size = FLAGS.train_batch_size
    print("train_batch_size", train_batch_size)
    test_batch_size = FLAGS.test_batch_size
    print("test_batch_size", test_batch_size)
    size_sum = tf.add(train_batch_size, test_batch_size)
    with tf.Session() as sess:
        sum_result = sess.run(size_sum)
        print("sum_result", sum_result)
# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
    tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)

上述代码已给出较为详细的注释,在此不再赘述。

该文件的调用示例以及运行结果如下所示

简单介绍TensorFlow中关于tf.app.flags命令行参数解析模块

如果需要修改默认参数的值,则在命令行传入自定义参数值即可,若全部使用默认参数值,则可直接在命令行运行该 python 文件。

读者可能会对 tf.app.run() 有些疑问,在上述注释中也有所解释,但要真正弄清楚其运行原理

还需查阅其源代码
def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS
  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None
  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access
  main = main or sys.modules['__main__'].main
  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

flags_passthrough=f._parse_flags(args=args)这里的_parse_flags就是我们tf.app.flags源码中用来解析命令行参数的函数。

所以这一行就是解析参数的功能;

下面两行代码也就是 tf.app.run 的核心意思:执行程序中 main 函数,并解析命令行参数!

weinxin
我的微信
微信号已复制
我的微信
这是我的微信扫一扫
 
Crq
  • 本文由 Crq 发表于2025年1月16日 18:08:15
  • 转载请注明:https://www.cncrq.com/12633.html
linux下生成高强度密码的四大神器 Linux教程

linux下生成高强度密码的四大神器

安全是一个大的话题,给服务器设置一个高强度的密码是非常重要的。你可能会疑惑一个高强度的密码究竟是什么样的呢?怎么才能生成一个那样的密码呢?不用担心下面我们将介绍 4 种简单方法让你...
mysql中null与“空值”的坑 Linux教程

mysql中null与“空值”的坑

数据库在企业环境中是非常常用的,不仅仅是DBA,运维人员和开发人员都要熟悉数据库的使用,增删改查等操作。而对于使用数据库的人员来说,对于字段、属性的熟悉是相当重要的。今天就给大家分...
匿名

发表评论

匿名网友
:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:
确定

拖动滑块以完成验证