回调函数
回调函数
大部分时候,
其中
内置回调函数
- BaseLogger:收集每个
epoch 上metrics 在各个batch 上的平均值,对stateful_metrics 参数中的带中间状态的指标直接拿最终值无需对各个batch 平均,指标均值结果将添加到logs 变量中。该回调函数被所有模型默认添加,且是第一个被添加的。 - History:将
BaseLogger 计算的各个epoch 的metrics 结果记录到history 这个dict 变量中,并作为model.fit 的返回值。该回调函数被所有模型默认添加,在BaseLogger 之后被添加。 - EarlyStopping:当被监控指标在设定的若干个
epoch 后没有提升,则提前终止训练。 - TensorBoard:为
Tensorboard 可视化保存日志信息。支持评估指标,计算图,模型参数等的可视化。 - ModelCheckpoint:在每个
epoch 后保存模型。 - ReduceLROnPlateau:如果监控指标在设定的若干个
epoch 后没有提升,则以一定的因子减少学习率。 - TerminateOnNaN:如果遇到
loss 为NaN ,提前终止训练。 - LearningRateScheduler:学习率控制器。给定学习率
lr 和epoch 的函数关系,根据该函数关系在每个epoch 前调整学习率。 - CSVLogger:将每个
epoch 后的logs 结果记录到CSV 文件中。 - ProgbarLogger:将每个
epoch 后的logs 结果打印到标准输出流中。
自定义回调函数
可以使用
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K
# 示范使用LambdaCallback编写较为简单的回调函数
import json
json_log = open('./data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: json_log.write(
json.dumps(dict(epoch = epoch,**logs)) + '\n'),
on_train_end=lambda logs: json_log.close()
)
# 示范通过Callback子类化编写回调函数(LearningRateScheduler的源代码)
class LearningRateScheduler(callbacks.Callback):
def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
try:
lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
raise ValueError('The dtype of Tensor should be float')
K.set_value(self.model.optimizer.lr, K.get_value(lr))
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)