查看: 1437|回复: 1

PyTorch模型

[复制链接]

9

主题

47

帖子

123

积分

注册会员

Rank: 2

积分
123
发表于 2018-10-15 11:46:51 | 显示全部楼层 |阅读模式
有没有办法,可以像model.summary()方法在Keras中所做的那样,在PyTorch中打印一个模型的摘要?
  1. Model Summary:
  2. ____________________________________________________________________________________________________
  3. Layer (type)                     Output Shape          Param #     Connected to                     
  4. ====================================================================================================
  5. input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
  6. ____________________________________________________________________________________________________
  7. convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
  8. ____________________________________________________________________________________________________
  9. maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
  10. ____________________________________________________________________________________________________
  11. flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]            
  12. ____________________________________________________________________________________________________
  13. dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
  14. ====================================================================================================
  15. Total params: 2,385
  16. Trainable params: 2,385
  17. Non-trainable params: 0
复制代码

回复

使用道具 举报

7

主题

28

帖子

79

积分

注册会员

Rank: 2

积分
79
发表于 2018-10-15 11:49:47 | 显示全部楼层
本帖最后由 神龙教 于 2018-10-15 11:56 编辑

在Keras的模型中,不会得到关于模型的详细信息。简单地打印模型可以让对涉及的不同层及其规范有一些了解。比如说:
  1. from torchvision import models
  2. model = models.vgg16()
  3. print(model)
复制代码
输出如下:
  1. VGG (
  2.   (features): Sequential (
  3.     (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  4.     (1): ReLU (inplace)
  5.     (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  6.     (3): ReLU (inplace)
  7.     (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  8.     (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  9.     (6): ReLU (inplace)
  10.     (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  11.     (8): ReLU (inplace)
  12.     (9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  13.     (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  14.     (11): ReLU (inplace)
  15.     (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  16.     (13): ReLU (inplace)
  17.     (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  18.     (15): ReLU (inplace)
  19.     (16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  20.     (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  21.     (18): ReLU (inplace)
  22.     (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  23.     (20): ReLU (inplace)
  24.     (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  25.     (22): ReLU (inplace)
  26.     (23): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  27.     (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  28.     (25): ReLU (inplace)
  29.     (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  30.     (27): ReLU (inplace)
  31.     (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  32.     (29): ReLU (inplace)
  33.     (30): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  34.   )
  35.   (classifier): Sequential (
  36.     (0): Dropout (p = 0.5)
  37.     (1): Linear (25088 -> 4096)
  38.     (2): ReLU (inplace)
  39.     (3): Dropout (p = 0.5)
  40.     (4): Linear (4096 -> 4096)
  41.     (5): ReLU (inplace)
  42.     (6): Linear (4096 -> 1000)
  43.   )
  44. )
复制代码


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表