请选择 进入手机版 | 继续访问电脑版
查看: 24|回复: 1

tensorflow2.0中Layer的__init__(),build(), call()函数

[复制链接]

570

主题

1066

帖子

5827

积分

xdtech

Rank: 5Rank: 5

积分
5827
发表于 2020-6-23 12:34:01 | 显示全部楼层 |阅读模式
先看官方手册中使用了Layer中的这三个关键函数的一个简单的实例:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                    shape=[int(input_shape[-1]),
                                           self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
从直观上理解,似乎__init__()和build()函数都在对Layer进行初始化,都初始化了一些成员函数,而call()函数则是在该layer被调用时执行。

显然,这三个函数都是从tf.keras.layers.Layer处继承而来的,那么不妨看一下官方对这几个函数作何解释。
下图为tf.keras.layers.Layer的官方文档

简单翻译,就是说官方推荐凡是tf.keras.layers.Layer的派生类都要实现__init__(),build(), call()这三个方法
__init__():保存成员变量的设置
build():在call()函数第一次执行时会被调用一次,这时候可以知道输入数据的shape。返回去看一看,果然是__init__()函数中只初始化了输出数据的shape,而输入数据的shape需要在build()函数中动态获取,这也解释了为什么在有__init__()函数时还需要使用build()函数
call(): call()函数就很简单了,即当其被调用时会被执行。


回复

使用道具 举报

570

主题

1066

帖子

5827

积分

xdtech

Rank: 5Rank: 5

积分
5827
 楼主| 发表于 2020-6-23 12:34:04 | 显示全部楼层
————————————————

版权声明:本文为CSDN博主「_吟游诗人」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:https://blog.csdn.net/qq_32623363/java/article/details/104128497
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表