查看: 1173|回复: 1

Tensorflow: how to save/restore a model?

[复制链接]

27

主题

37

帖子

116

积分

论坛管理

Rank: 4

积分
116
发表于 2018-10-13 11:59:04 | 显示全部楼层 |阅读模式
本帖最后由 fantomas 于 2018-10-13 16:15 编辑

After you train a model in Tensorflow:
  • How do you save the trained model?
  • How do you later restore this saved model?
Save the model:
  1. <font color="#101094" face="Consolas, Menlo, Monaco, Lucida Console, Liberation Mono, DejaVu Sans Mono, Bitstream Vera Sans Mono, Courier New, monospace, sans-serif"><span style="font-size: 13px; white-space: pre;">import tensorflow as tf

  2. #Prepare to feed input, i.e. feed_dict and placeholders
  3. w1 = tf.placeholder("float", name="w1")
  4. w2 = tf.placeholder("float", name="w2")
  5. b1= tf.Variable(2.0,name="bias")
  6. feed_dict ={w1:4,w2:8}

  7. #Define a test operation that we will restore
  8. w3 = tf.add(w1,w2)
  9. w4 = tf.multiply(w3,b1,name="op_to_restore")
  10. sess = tf.Session()
  11. sess.run(tf.global_variables_initializer())

  12. #Create a saver object which will save all the variables
  13. saver = tf.train.Saver()

  14. #Run the operation by feeding input
  15. print sess.run(w4,feed_dict)
  16. #Prints 24 which is sum of (w1+w2)*b1

  17. #Now, save the graph
  18. saver.save(sess, 'my_test_model',global_step=1000)</span></font>
复制代码

Restore the model:
  1. <div class="blockcode"><blockquote><blockquote>import tensorflow as tf
复制代码
  1. import tensorflow as tf

  2. sess=tf.Session()   
  3. #First let's load meta graph and restore weights
  4. saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  5. saver.restore(sess,tf.train.latest_checkpoint('./'))


  6. # Access saved Variables directly
  7. print(sess.run('bias:0'))
  8. # This will print 2, which is the value of bias that we saved


  9. # Now, let's access and create placeholders variables and
  10. # create feed-dict to feed new data

  11. graph = tf.get_default_graph()
  12. w1 = graph.get_tensor_by_name("w1:0")
  13. w2 = graph.get_tensor_by_name("w2:0")
  14. feed_dict ={w1:13.0,w2:17.0}

  15. #Now, access the op that you want to run.
  16. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

  17. print sess.run(op_to_restore,feed_dict)
  18. #This will print 60 which is calculated
复制代码



This and some more advanced use-cases have been explained very well here.




回复

使用道具 举报

66

主题

151

帖子

1033

积分

xdtech

Rank: 5Rank: 5

积分
1033
发表于 2018-11-13 20:38:58 | 显示全部楼层
有用,帮助很大
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表