TensorFlow 2.0 Custom Callback in Practice:An Utility for better Data Products
Callback Strategies to add incremental benefit and improve Neural Network training
As we know Neural networks are a series of algorithms that mimic the operations of a human brain to recognize relationships between vast amounts of data.During the design of a Neural Network, we have countless choices to play with to make the model optimum fit for given data. During my preparation for google TensorFlow developer certification exam, I learned a few cool techniques to improve deep learning model quality. Apart from a few crucial time painstaking n/w design choice choices like no of nodes, no of layers, tons of variable & bias initialization, activation function etc. there are a few low hanging fruits like controlling learning rates, the number of training epochs to use etc. During training a very deep neural network, one way to balance the training process is to optimally add an early stopping in the training process. While principally serving a similar purpose to the early-stopping method in ML algorithms like XGBoost/LGBM, TensorFlow 2.0 Callbacks has full control over the training. In this blog, I am going to discuss different Callbacks TensorFlow 2.0 Callbacks methods — both standard and customized.
Callbacks and benifits
In layman’s words, callback is one of the controllers by which we can control the neural network(NN) training. In particular callbacks are function blocks by which the following functionalities during training can be achieved:
a) Early Stopping based on performance monitoring
b) Controlling the learning rate
c) Periodically archiving best available model weights
d) Conditional hard stop of the training
Before going into different Callbacks strategies let’s briefly discuss the tangible benefits of callback both in model quality and Engineering aspects.
i) Regularization:As the central idea around callback concept is the early stopping of training, in ML regime it is a simple yet efficient regularization technique. As this technique is a low-hanging fruit to overcome the possible overfitting issue, Geoffrey Hinton called it a “beautiful free lunch” and many times these easy techniques give us significant improvement.
ii) Resource Saving: Complex NN training involving tons of computation heavy back-propagation calculations in AI based applications involving images and videos(e.g. self-driving car, Virtual reality, etc.) the training procedure with 1000s of iterations(epochs) requires massive infrastructure. The following table illustrates typical training run-times for 3 standard CNN architecture on different H/W platforms ranging from a decent CPU to very high-end GPU. This is only a single training process using a standard architecture involving training on static images. Training of a complex AI system(involving videos)of many such networks each with numerous model fit options(Hyperparameter tuning) can easily take up to several days/equivalent parallel operation of GPU computation.
To gather more information related to the engineering aspect of the N/W training process I would encourage you to read this paper. https://www.researchgate.net/publication/328458615_Evaluating_Training_Time_of_Inception-v3_and_Resnet-50101_Models_using_TensorFlow_across_CPU_and_GPU
With the intelligence of Callbacks methods optimal early stop of training would save substantially of GPU hours. Also, TensorFlow 2.0 Callbacks methods provide us the flexibility to periodically saving updated model weights, so in case of possible cloud disconnection issues, we can save the resource from rerunning.
To get approximate cost($) saving estimation for example using GCP I would rather refer to the following linkshttps://cloud.google.com/products/calculator/
Application to Business
Now we will briefly discuss these different Callbacks strategies in light of a very common predictive methodology across all industries- time series forecasting which is relevant for any industry. Here I will first generate typical simulated time series and then train NN in conjunction with different Callbacks methods to influence the training process and finally achieve our goal. The detailed methodology including data preparation code can be found in this github repository.
Step 1: Time Series preparation
Here we create an idea time series with trend, seasonality and noise of 4000 timestamps.Following is the components of the series.
time_period = 4000
baseline = 10
trend = trend(time,0.05)
baseline = 10
amplitude = 35
slope = 0.004
noise_level = 3
We have synthesized train data of 3350 timestamp and test/validation window of 650 timestamps. We choose Mean Squared Error(MSE) as single evaluation metrics for this problem.
Step 2: Setting Benchmark
As the series has a predictable pattern mathematically let’s create a ballpark solution by smoothening the series- a benchmark moving average(MA) solution. While smoothening the signal MSE is 45.5676.
Can we improve the prediction using more sophisticated statistical modeling technique? Let’s try with a standard ARIMA(Autoregressive integrated moving average) model to get more accurate solution.
Now we have a superior benchmark with pattern and validation MSE is 21.882, almost 50% reduction w.r.t. MA.
As we already have almost near perfect forecasting pattern, let’s explore whether NN can beat this benchmark.
Step 3: NN training with callback feature
First we have to prepare the data to be suitable to feed a standard NN. As with any other ML problem, we have to divide our data into features and labels. In this case our feature is effectively a number of values in the series, with our label being the next value. This number of values are here the window size. We take a window of the data to train the model to predict the next value. For example, if we take, 20 timestamps(window size), 20 values will be used as the features and the next value is the label and this pattern will traverse through the series in rolling window basis.
Following would be the setting of the data preparation and input method for the N/W.
window_size = 20
batch_size = 32
shuffle_buffer_size = 1000 #To break Sequence bias
The architecture for now would be a 2 layer Bidirectional LSTM(Long Short term Memory), LSTM. Although for monotoring purpose we will look at MSE, for training optimization Huber loss will be used as loss function, which is a standard function in robust regression and is less sensitive to outliers in data than the squared error loss.
- * The first Lambda layer is to add the flexibility to deal with dimensionality- expansion of array by one dimension.The windowed dataset is a two-dimensional batch of the data.But an LSTM expects three-dimensions;
- batch size, the number of timestamps,and the series dimensionality.
- The second lambda function scales up the outputs by 100.As the default activation function in the LSTM layers is tanh(hyperbolic tangent activation)this outputs values between negative one and one.Since the time series values are in that order10s with gradual ramping up with time propagation multiplying the raw output by a factor of 100 makes the N/W outputs to the same ballpark to the target and hence helps with the learning process.
While applying NN with callbacks for ML based prediction, we can broadly classify these methods into two categories — 1) Standard TF methods, 2) Customized callback subclass
1) Standard TF methods
a) ModelCheckpoint: As a standard TF 2.0 Callbacks method, ModelCheckpoint periodically saves the model as a checkpoint file (in hdf5 format)after certain epochs based on the best value of the defined evaluation metric(eg. accuracy, loss etc.). Although not controlling the training this method is useful in restoring important training information especially in the case of very long training epochs, to avoid losing training update in case of a system failure. For example training on cheaper AWS EC2 spot instances, even if sudden unavailability of m/c the training can be resumed from last saved weights.
checkpoint_filepath = '/callback/model.h5'
model_checkpoint_callback = ModelCheckpoint(
b)EarlyStopping Callback:This is the default TF2 callback method to prevent the overfitting of the model. Early Stopping monitors the performance of the model for every epoch on a held-out validation set during the training, and terminate the training based on following predefined convergence condition.Through this function we can monitor any internal TF metric(val_accuracy,val_mse) or any custom metric function(val_auc,val_r2)
i)patience: the number of epochs with no improvement in evaluation metric.
ii)min_delta: A predefined minimum change in value of the monitored metric to meet the improvement criteria.
Although a decent starter solution there is always a possibility of being stuck in local minima in case of too low patience value and not achieving the maximum performance. Also very high patience value may deteriorate the performance(overfitting) and unnecessary run of extra epoch in case model already achieves global minima.
To be precise, here the training stopped as MSE didn’t improve beyond 25.2782 for 5 epochs, but we don’t know whether this is the best solution under this NN configuration.If we plot the evolution of loss over epochs it seems that both the training and validation curves flatten after approx 75 epochs.
However, when we zoom in that region we find that a decent declining trend for training MSE and a fluctuating trail for validation MSE. It means that there might be scope for further improvement if we wait for a few more epochs to overshoot possible local minima. But we don’t know for how many more epochs should we give the training to improve.
c) ReduceLROnPlateau:Callback function TF2 standard procedure to reduce the learning rate by a defined factor if there is not an improvement or an improvement less than a defined threshold value(min_delta) in the chosen evaluation metric.
While it doesn’t have the early stop feature models often benefit from dynamically reducing the learning rate by a factor of 2–10 once learning stagnates to overcome overfitting issue(new Learning rate(LR) = factor * LR).
There are other TF2 utility callback methods such as BaseLogger, CSVLogger, CallbackList, ProgbarLogger, TensorBoard etc. mainly for more detailed training information(logging).
With standard callback methods there is no conditional stop implementation based on a single value of evaluation metric. The best MSE we got so far is 25.2782,which is not even closer to the benchmark figure of 21.882.Can we improve the model beyond this using the same N/W architechture?We will explore the answer in the following section.
2) Customized callback subclass
These callback methods come under the base TF2 class ‘tf.keras.callbacks’. By subclassing these callbacks, we can bring more flexibility in the training process by adding a conditional stop of training and can perform certain functions when the training/batch/epochs begin or end. Here the training can be stopped conditionally reaching a single defined value of evaluation metric. The names of these functions explain the purpose such as ‘on_epoch_end’ or ‘on_epoch_begin etc. Generally it for a very long epoch like CNN application with a very high-resolution image it’s better to practice to wait till the epoch ends, as there might be significant fluctuation.
Let’s apply this strategy to our case and see whether it’s works or not to beat the benchmark MSE of 21.882.
With wait, till the end of the 839th epoch we see that the training condition converges as we reach validation MSE 21.8755,just below the defined benchmark MSE 21.882. Can we save some computation resources by running fewer epochs to achieve the same target?
LearningRateScheduler: We can if we include the intelligence of adjusting the learning rate by searching a vast range of learning rate(LR). This function allows large weight changes at the beginning of the
learning process and small changes or fine-tuning towards
the end of the learning process. Typically training starts with a relatively large value and decreases it in later training epochs. To adjust the learning rate we apply a custom lambda function that returns the desired learning rate by taking the current epoch as an argument and pass it as the schedule parameter to tf.keras.callbacks class. By running quick 100 epochs with LR exploration ranging from 1e-8 to 1e-4 we pick the optimal stable LR as 1e-5.
lr_schedule = tf.keras.callbacks.LearningRateScheduler(
lambda epoch: 1e-8 * 10**(epoch / 20))
Now together with custom callback subclass and LearningRateScheduler we can reduce the training epoch nos choosing the optimal LR 1e-5 to as low as 115!. Finally by applying ModelCheckpoint with callback subclass we can periodically save the best model to avoid any loss of information.
With this we save a significant computation resource by beating the target MSE of 21.882 by 1.35%(improved validation MSE 21.589).
While TensorFlow has tons of functionalities to design NN, controlling training process and optimizing computation resource this particular feature of TF2 not only helps us build amazing data products but also helps us to keep a closer tab on the model training.While TF has many options to reduce overfitting, the callbacks method with early stopping to training is one of the easiest methods to add a guiding principle to learning,monitoring,logging and resource optimization.