【tf.keras】AdamW: Adam with Weight decay

论文 Decoupled Weight Decay Regularization 中提到,Adam 在使用时,L2 regularization 与 weight decay 并不等价,并提出了 AdamW,在神经网络需要正则项时,用 AdamW 替换 Adam+L2 会得到更好的性能。

TensorFlow 2.0 在 tensorflow_addons 库里面实现了 AdamW,目前在 Mac 和 Linux 上可以直接pip install tensorflow_addons进行安装,在 windows 上还不支持,但也可以直接把这个仓库下载下来使用。

下面是一个利用 AdamW 的示例程序(TF 2.0, tf.keras),在使用 AdamW 的同时,使用 learning rate decay:(以下程序中,AdamW 的结果不如 Adam,这是因为模型比较简单,加入 regularization 反而影响性能)

import tensorflow as tf
import os
from tensorflow_addons.optimizers import AdamW

import numpy as np

from tensorflow.python.keras import backend as K
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras.callbacks import Callback

def lr_schedule(epoch):
    """Learning Rate Schedule
    Learning rate is scheduled to be reduced after 20, 30 epochs.
    Called automatically every epoch as part of callbacks during training.
    # Arguments
        epoch (int): The number of epochs
    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3

    if epoch >= 30:
        lr *= 1e-2
    elif epoch >= 20:
        lr *= 1e-1
    print('Learning rate: ', lr)
    return lr

def wd_schedule(epoch):
    """Weight Decay Schedule
    Weight decay is scheduled to be reduced after 20, 30 epochs.
    Called automatically every epoch as part of callbacks during training.
    # Arguments
        epoch (int): The number of epochs
    # Returns
        wd (float32): weight decay
    """
    wd = 1e-4

    if epoch >= 30:
        wd *= 1e-2
    elif epoch >= 20:
        wd *= 1e-1
    print('Weight decay: ', wd)
    return wd

# just copy the implement of LearningRateScheduler, and then change the lr with weight_decay
@keras_export('keras.callbacks.WeightDecayScheduler')
class WeightDecayScheduler(Callback):
    """Weight Decay Scheduler.

    Arguments:
        schedule: a function that takes an epoch index as input
            (integer, indexed from 0) and returns a new
            weight decay as output (float).
        verbose: int. 0: quiet, 1: update messages.

    ```python
    # This function keeps the weight decay at 0.001 for the first ten epochs
    # and decreases it exponentially after that.
    def scheduler(epoch):
      if epoch < 10:
        return 0.001
      else:
        return 0.001 * tf.math.exp(0.1 * (10 - epoch))

    callback = WeightDecayScheduler(scheduler)
    model.fit(data, labels, epochs=100, callbacks=[callback],
              validation_data=(val_data, val_labels))
    ```
    """

    def __init__(self, schedule, verbose=0):
        super(WeightDecayScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'weight_decay'):
            raise ValueError('Optimizer must have a "weight_decay" attribute.')
        try:  # new API
            weight_decay = float(K.get_value(self.model.optimizer.weight_decay))
            weight_decay = self.schedule(epoch, weight_decay)
        except TypeError:  # Support for old API for backward compatibility
            weight_decay = self.schedule(epoch)
        if not isinstance(weight_decay, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.weight_decay, weight_decay)
        if self.verbose > 0:
            print('\nEpoch %05d: WeightDecayScheduler reducing weight '
                  'decay to %s.' % (epoch + 1, weight_decay))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['weight_decay'] = K.get_value(self.model.optimizer.weight_decay)

if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'

    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, enable=True)
    print(gpus)
    cifar10 = tf.keras.datasets.cifar10

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    optimizer = AdamW(learning_rate=lr_schedule(0), weight_decay=wd_schedule(0))
    # optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

    tb_callback = tf.keras.callbacks.TensorBoard(os.path.join('logs', 'adamw'),
                                                 profile_batch=0)
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
    wd_callback = WeightDecayScheduler(wd_schedule)

    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train, epochs=40, validation_split=0.1,
              callbacks=[tb_callback, lr_callback, wd_callback])

    model.evaluate(x_test, y_test, verbose=2)

以上代码实现了在 learning rate decay 时使用 AdamW,虽然只能是在 epoch 层面进行学习率衰减。

在使用 AdamW 时,如果要使用 learning rate decay,那么对 weight_decay 的值要进行同样的学习率衰减,不然训练会崩掉。

References

How to use AdamW correctly? -- wuliytTaotao
Loshchilov, I., & Hutter, F. Decoupled Weight Decay Regularization. ICLR 2019. Retrieved from http://arxiv.org/abs/1711.05101

原文地址:https://www.cnblogs.com/wuliytTaotao/p/12178778.html

时间: 2024-08-03 23:19:35

【tf.keras】AdamW: Adam with Weight decay的相关文章

【tf.keras】tf.keras使用tensorflow中定义的optimizer

我的 tensorflow+keras 版本: print(tf.VERSION) # '1.10.0' print(tf.keras.__version__) # '2.1.6-tf' tf.keras 没有实现 AdamW,即 Adam with Weight decay.论文<DECOUPLED WEIGHT DECAY REGULARIZATION>提出,在使用 Adam 时,weight decay 不等于 L2 regularization.具体可以参见 当前训练神经网络最快的方式

【tf.keras】在 cifar 上训练 AlexNet,数据集过大导致 OOM

cifar-10 每张图片的大小为 32×32,而 AlexNet 要求图片的输入是 224×224(也有说 227×227 的,这是 224×224 的图片进行大小为 2 的 zero padding 的结果),所以一种做法是将 cifar-10 数据集的图片 resize 到 224×224. 此时遇到的问题是,cifar-10 resize 到 224×224 时,32G 内存都将无法完全加载所有数据,在归一化那一步(即每个像素点除以 255)就将发生 OOM(out of memory)

【tf.keras】tf.keras模型复现

keras 构建模型很简单,上手很方便,同时又是 tensorflow 的高级 API,所以学学也挺好. 模型复现在我们的实验中也挺重要的,跑出了一个模型,虽然我们可以将模型的 checkpoint 保存,但再跑一遍,怎么都得不到相同的结果,对我而言这是不能接受的. 用 keras 实现模型,想要能够复现,需要将设置各个可能的随机过程的 seed:而且,代码不要在 GPU 上跑,而是在 CPU 上跑.(也就是说,GPU 上得到的 keras 模型没办法再复现.) 我的 tensorflow+ke

【tensorflow2.0】高阶api--主要为tf.keras.models提供的模型的类接口

下面的范例使用TensorFlow的高阶API实现线性回归模型. TensorFlow的高阶API主要为tf.keras.models提供的模型的类接口. 使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型. 此处分别演示使用Sequential按层顺序构建模型以及继承Model基类构建自定义模型. 一,使用Sequential按层顺序构建模型[面向新手] import tensorflow as

【tensorflow2.0】处理结构化数据-titanic生存预测

1.准备数据 import numpy as np import pandas as pd import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import models,layers dftrain_raw = pd.read_csv('./data/titanic/train.csv') dftest_raw = pd.read_csv('./data/titanic/test.csv')

zz【清华NLP】图神经网络GNN论文分门别类,16大应用200+篇论文最新推荐

[清华NLP]图神经网络GNN论文分门别类,16大应用200+篇论文最新推荐 图神经网络研究成为当前深度学习领域的热点.最近,清华大学NLP课题组Jie Zhou, Ganqu Cui, Zhengyan Zhang and Yushi Bai同学对 GNN 相关的综述论文.模型与应用进行了综述,并发布在 GitHub 上.16大应用包含物理.知识图谱等最新论文整理推荐. GitHub 链接: https://github.com/thunlp/GNNPapers 目录            

【MVC 4】6.SportsSore:导航

 作者:[美]Adam Freeman      来源:<精通ASP.NET MVC 4> 前面的文章[MVC 4]5.SportsSore —— 一个真实的应用程序 建立了 SportsStore 应用程序的核心基础框架.本文将利用这一基础框架,将一些关键特性添加到该应用程序上. 1.添加导航控件 如果让客户通过产品分类(Category)对象产品进行导航,SportsStore 应用程序会更加适用.这需要从三个方面着手. * 增强 ProductController 类中的 List 动作

【MVC 4】5.SportsSore —— 一个真实的应用程序

 作者:[美]Adam Freeman      来源:<精通ASP.NET MVC 4> 前面建立的都是简单的MVC程序,现在到了吧所有事情综合在一起,以建立一个简单但真实的电子商务应用程序的时候了. 在此打算建立的应用程序 — SportsStore (体育用品商店),将遵循随处可见的在线商店所采取的经典方式.将创建一个客户可以通过分类和页面进行浏览的在线产品分类,一个客户可以添加和删除商品的购物车,和一个客户能够输入其右击地址细节的结算页面.另外,还将创建一个包含创建.读取.更新和删除功

Python全栈开发【第一篇】:初识Python

Python全栈开发[第一篇] 本节内容: Python 的种类 Python 的环境 Python 入门(解释器.编码.变量.input输入.if流程控制与缩进.while循环) if流程控制与while循环练习题 基本数据类型前引 Python 的种类 Cpython Python的官方版本,使用C语言实现,使用最为广泛,CPython实现会将源文件(py文件)转换成字节码文件(pyc文件),然后运行在Python虚拟机上. Jyhton Python的Java实现,Jython会将Pyth