查看: 1911|回复: 0

使用Keras调用多GPU时出现无法保存模型的解决方法

[复制链接]

665

主题

1234

帖子

6684

积分

xdtech

Rank: 5Rank: 5

积分
6684
发表于 2019-4-20 15:41:15 | 显示全部楼层 |阅读模式
model = load_model('./model/RESNET50_model.h5')


导致错误:ValueError: axes don't match array


原因:使用多个GPU进行保存模型,才导致出现这样的错误


方法一(不建议使用)
使用低版本的keras比如2.1.5版本,它解决了我的问题。


pip install keras==2.1.5


方法二
在使用keras 的并行多路GPU时出现了模型无法载入的情况,在使用单个GPU时运行完全没有问题。


构建类


from keras.callbacks import Callback,ModelCheckpoint


class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
                self.single_model = model
                super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)
if __name__ == '__main__':
        single_model = cnn_model(kernel_size, nb_filters, channels, nb_classes)  #原始模型
        model = multi_gpu_model(single_model, 2) # 多GPU并行模型
        saveBestModel = ParallelModelCheckpoint(single_model,'./model/RESNET50_model.h5',monitor='val_acc',
                                                 verbose=1, save_best_only=True, mode='auto')  #一定要是第一次读取的原始模型single_model
       


在这里需要解释一下,这个single_model 模型是没有进行并行的model


single_model  = Model(inputs =inputs, outputs =outputs )
model = multi_gpu_model(single_model  , gpus=2)




方法三
original_model = ...
parallel_model = multi_gpu_model(original_model, gpus=n)

class MyCbk(keras.callbacks.Callback):

    def __init__(self, model):
         self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('model_at_epoch_%d.h5' % epoch)

cbk = MyCbk(original_model)
parallel_model.fit(..., callbacks=[cbk])




同理这里也是一样的。其实在上面两种方法中可以发现,基本都是在checkpoint 问题上都是使用了单个model进行运行的。
意思就是直接使用传入方法keras.utils.multi_gpu_model(model, gpus)中的model即可,而不要使用返回的parallel_model。


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表