pgmuvi.trainers¶
- class pgmuvi.trainers.Trainer¶
Bases:
object
- pgmuvi.trainers.train(lightcurve=None, model=None, likelihood=None, train_x=None, train_y=None, maxiter=100, miniter=10, stop=None, lr=0.0001, lossfn='mll', optim='SGD', eps=1e-08, stopavg=9, **kwargs)¶
Given a GP model, a likelihood, and some training data, optimise a loss function to fit the training data.
- Parameters:
model (an instance of gpytorch.models.gp.GP or a subcluss thereof) – The GP model whose (hyper-)parameters will be optimised.
likelihood (an instance of gpytorch.likelihoods.likelihood.Likelihood) – The likelihood function for the Gaussian Process.
train_x (torch.Tensor or array-like) – The values of the independent variables for training.
train_y (torch.Tensor or array-like) – The values of the dependent variables for training.
maxiter (int, default 100) – The maximum number of training iterations to use. If stop is not a positive number, this will be the number of iterations used to train.
miniter (int, default 10) – The minimum number of training iterations to use. This parameter is only used if stop is a positive real number, in which case it is used to ensure that a sufficient number of iterations have been performed before terminating training.
stop (float, default None) – The fractional change in the loss function below which training will be terminated. If set to None, a negative value, not a number of a non-numerical type, training will continue until maxiter is reached.
lr (float, default 1e-4) – The learning rate for the optimiser. Increasing this number will result in larger steps in the parameters each iteration. This will make it easier to escape local minima, but may also result in instability.
lossfn (string or instance of) –
- gpytorch.mlls.marginal_log_likelihood.MarginalLogLikelihood,
default ‘mll’
The loss function that will be used to evaluate the training. If a string, it must take one of the values ‘mll’ or ‘elbo’.
optim (string or instance of torch.optim.optimizer.Optimizer,) – default ‘SGD’ The optimizer that will be used to train the model. If a string, it must take one of the values ‘SGD’, ‘Adam’, ‘AdamW’, ‘NUTS’. Otherwise, it may be any torch or pyro optimiser. If passing a torch or pyro optimiser, it should already have been initialised with all arguments set
eps (float, default 1e-8.) – term added to the denominator to improve numerical stability in some optimisers (e.g. AdamW)
Examples
- pgmuvi.trainers.train_mll()¶
- pgmuvi.trainers.train_variational()¶
- pgmuvi.trainers.train_variational_uncertain()¶