如何在Tensorflow中仅使用Python自定义激活函数?
本帖最后由 fantomas 于 2018-10-13 16:37 编辑https://stackoverflow.com/questi ... ython-in-tensorflow
想要使用的激活函数为例:
构建一个函数
def spiky(x):
r = x % 1
if r <= 0.5:
return r
else:
return 0向量化
import numpy as np
np_spiky = np.vectorize(spiky)激活的梯度:在我们的例子中它很简单,如果x mod 1 <0.5则为1,否则为0。所以:
def d_spiky(x):
r = x % 1
if r <= 0.5:
return 1
else:
return 0
np_d_spiky = np.vectorize(d_spiky)使一个numpy fct成为tensorflow fct:我们首先将np_d_spiky变成tensorflow函数。 tensorflow中有一个函数tf.py_func(func,inp,Tout,stateful = stateful,name = name)将任何numpy函数转换为tensorflow函数,因此我们可以使用它:
import tensorflow as tf
from tensorflow.python.framework import ops
np_d_spiky_32 = lambda x: np_d_spiky(x).astype(np.float32)
def tf_d_spiky(x,name=None):
with tf.name_scope(name, "d_spiky", ) as name:
y = tf.py_func(np_d_spiky_32,
,
,
name=name,
stateful=False)
return ydef py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad)# see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
现在我们差不多完成了,唯一的事情是我们需要传递给上面的py_func函数的grad函数需要采用一种特殊的形式。它需要在操作之前接受操作和先前的梯度,并在操作之后向后传播梯度。def spikygrad(op, grad):
x = op.inputs
n_gr = tf_d_spiky(x)
return grad * n_gr激活函数只有一个输入,这就是x = op.inputs 的原因。如果操作有很多输入,我们需要返回一个元组,每个输入一个梯度。例如,如果操作是a-b相对于a的梯度是+1并且相对于b是-1,那么我们将返回+ 1 * grad,-1 * grad。请注意,我们需要返回输入的tensorflow函数,这就是为什么需要tf_d_spiky,np_d_spiky不能工作,因为它不能作用于tensorflow。或者我们可以使用tensorflow函数编写导数:
def spikygrad2(op, grad):
x = op.inputs
r = tf.mod(x,1)
n_gr = tf.to_float(tf.less_equal(r, 0.5))
return grad * n_gr将它们结合在一起:
np_spiky_32 = lambda x: np_spiky(x).astype(np.float32)
def tf_spiky(x, name=None):
with tf.name_scope(name, "spiky", ) as name:
y = py_func(np_spiky_32,
,
,
name=name,
grad=spikygrad)# <-- here's the call to the gradient
return yTest:with tf.Session() as sess:
x = tf.constant()
y = tf_spiky(x)
tf.initialize_all_variables().run()
print(x.eval(), y.eval(), tf.gradients(y, ).eval())
页:
[1]