本帖最后由 fantomas 于 2018-10-13 16:15 编辑
After you train a model in Tensorflow: - How do you save the trained model?
- How do you later restore this saved model?
Save the model: - <font color="#101094" face="Consolas, Menlo, Monaco, Lucida Console, Liberation Mono, DejaVu Sans Mono, Bitstream Vera Sans Mono, Courier New, monospace, sans-serif"><span style="font-size: 13px; white-space: pre;">import tensorflow as tf
- #Prepare to feed input, i.e. feed_dict and placeholders
- w1 = tf.placeholder("float", name="w1")
- w2 = tf.placeholder("float", name="w2")
- b1= tf.Variable(2.0,name="bias")
- feed_dict ={w1:4,w2:8}
- #Define a test operation that we will restore
- w3 = tf.add(w1,w2)
- w4 = tf.multiply(w3,b1,name="op_to_restore")
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- #Create a saver object which will save all the variables
- saver = tf.train.Saver()
- #Run the operation by feeding input
- print sess.run(w4,feed_dict)
- #Prints 24 which is sum of (w1+w2)*b1
- #Now, save the graph
- saver.save(sess, 'my_test_model',global_step=1000)</span></font>
复制代码
Restore the model: - <div class="blockcode"><blockquote><blockquote>import tensorflow as tf
复制代码- import tensorflow as tf
- sess=tf.Session()
- #First let's load meta graph and restore weights
- saver = tf.train.import_meta_graph('my_test_model-1000.meta')
- saver.restore(sess,tf.train.latest_checkpoint('./'))
- # Access saved Variables directly
- print(sess.run('bias:0'))
- # This will print 2, which is the value of bias that we saved
- # Now, let's access and create placeholders variables and
- # create feed-dict to feed new data
- graph = tf.get_default_graph()
- w1 = graph.get_tensor_by_name("w1:0")
- w2 = graph.get_tensor_by_name("w2:0")
- feed_dict ={w1:13.0,w2:17.0}
- #Now, access the op that you want to run.
- op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
- print sess.run(op_to_restore,feed_dict)
- #This will print 60 which is calculated
复制代码
This and some more advanced use-cases have been explained very well here.
|