Source code for nnodely.support.earlystopping

# Early stopping functions:
# The functions return True if the training should stop

# "Classical" early stopping based on the validation loss:
# Stop if the validation loss has not improved for a certain number of epochs
[docs] def early_stop_patience(train_losses, val_losses, params): """ Determines whether to stop training early based on validation loss and patience. Parameters ---------- train_losses : dict A dictionary where keys are epoch numbers and values are lists of training loss values. val_losses : dict A dictionary where keys are epoch numbers and values are lists of validation loss values. params : dict A dictionary of parameters. Should contain 'patience' which is the number of epochs to wait for improvement. Optionally, it can contain 'error' which specifies the type of loss to be used. Returns ------- bool True if training should be stopped early, False otherwise. """ patience = params['patience'] if 'patience' in params.keys() else 50 if val_losses: losses = val_losses else: # if there is no validation set, use the training losses losses = train_losses if 'error' in params.keys(): # if the type of loss to be used is provided by the user losses_use = losses[params['error']] else: # take the mean of all the losses for all the keys of the dictionary import numpy as np losses_use = [np.mean([losses[key][index] for key in losses.keys()]) for index in range(len(losses[list(losses.keys())[0]]))] if len(losses_use) > patience: # index of the minimum validation loss min_val_loss_index = losses_use.index(min(losses_use)) # check if the patience has been exceeded if min_val_loss_index < len(losses_use) - patience: return True return False
[docs] def select_best_model(train_losses, val_losses, params): """ Selects the best model based on the validation or training losses. Parameters ---------- train_losses : dict A dictionary where keys are epoch numbers and values are lists of training loss values. val_losses : dict A dictionary where keys are epoch numbers and values are lists of validation loss values. params : dict A dictionary of parameters. Returns ------- bool True if the current model is the best model, False otherwise. """ if val_losses: losses = val_losses else: # if there is no validation set, use the training losses losses = train_losses import numpy as np losses_use = [np.mean([losses[key][index] for key in losses.keys()]) for index in range(len(losses[list(losses.keys())[0]]))] if len(losses_use)-1 == losses_use.index(min(losses_use)): return True else: return False
[docs] def mean_stopping(train_losses, val_losses, params): """ Determines whether to stop training early based on the mean difference between training and validation losses. Parameters ---------- train_losses : dict A dictionary where keys are epoch numbers and values are lists of training loss values. val_losses : dict A dictionary where keys are epoch numbers and values are lists of validation loss values. params : dict A dictionary of parameters. Should contain 'tol' which is the tolerance value for early stopping. Returns ------- bool True if training should be stopped early, False otherwise. """ tol = params['tol'] if 'tol' in params.keys() else 0.001 if val_losses: for (train_loss_name, train_loss_value), (val_loss_name, val_loss_value) in zip(train_losses.items(), val_losses.items()): if abs(train_loss_value[-1] - val_loss_value[-1]) < tol: return True else: for loss_name, loss_value in train_losses.items(): if loss_value[-1] < tol: return True return False
[docs] def standard_early_stopping(train_losses, val_losses, params): """ Determines whether to stop training early based on training and validation losses. Parameters ---------- train_losses : dict A dictionary where keys are epoch numbers and values are lists of training loss values. val_losses : dict A dictionary where keys are epoch numbers and values are lists of validation loss values. params : dict A dictionary of parameters. Should contain 'tol' which is the tolerance value for early stopping. Returns ------- bool True if training should be stopped early, False otherwise. """ n = params['tol'] if 'tol' in params.keys() else 10 if val_losses: for (_, train_loss_value), (_, val_loss_value) in zip(train_losses.items(), val_losses.items()): if (len(train_loss_value) <= n) and (len(val_loss_value) <= n): return False tol = 0.0 for train_loss, val_loss in zip(train_loss_value[-n:], val_loss_value[-n:]): if abs(train_loss - val_loss) > tol: tol = abs(train_loss - val_loss) else: return False else: for _, loss_value in train_losses.items(): if (len(loss_value) <= n): return False tol = loss_value[-n] for loss in loss_value[-n+1:]: if loss < tol: return False return True