Source code for flax.training.early_stopping
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Early stopping."""
import math
from flax import struct
[docs]class EarlyStopping(struct.PyTreeNode):
"""Early stopping to avoid overfitting during training.
The following example stops training early if the difference between losses
recorded in the current epoch and previous epoch is less than 1e-3
consecutively for 2 times::
>>> from flax.training.early_stopping import EarlyStopping
>>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng):
... ...
... loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch]
... return None, {'loss': loss}
>>> early_stop = EarlyStopping(min_delta=1e-3, patience=2)
>>> optimizer = None
>>> for epoch in range(10):
... optimizer, train_metrics = train_epoch(
... optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None)
... early_stop = early_stop.update(train_metrics['loss'])
... if early_stop.should_stop:
... print(f'Met early stopping criteria, breaking at epoch {epoch}')
... break
Met early stopping criteria, breaking at epoch 7
Attributes:
min_delta: Minimum delta between updates to be considered an
improvement.
patience: Number of steps of no improvement before stopping.
best_metric: Current best metric value.
patience_count: Number of steps since last improving update.
should_stop: Whether the training loop should stop to avoid
overfitting.
has_improved: Whether the metric has improved greater or
equal to the min_delta in the last ``.update`` call.
"""
min_delta: float = 0
patience: int = 0
best_metric: float = float('inf')
patience_count: int = 0
should_stop: bool = False
has_improved: bool = False
def reset(self):
return self.replace(
best_metric=float('inf'),
patience_count=0,
should_stop=False,
has_improved=False,
)
[docs] def update(self, metric):
"""Update the state based on metric.
Returns:
The updated EarlyStopping class. The ``.has_improved`` attribute is True
when there was an improvement greater than ``min_delta`` from the previous
``best_metric``.
"""
if (
math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta
):
return self.replace(
best_metric=metric, patience_count=0, has_improved=True
)
else:
should_stop = self.patience_count >= self.patience or self.should_stop
return self.replace(
patience_count=self.patience_count + 1,
should_stop=should_stop,
has_improved=False,
)