TrainerIP
- class Dock2D.Models.TrainerIP.TrainerIP(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)
- __init__(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)
Initialize trainer for IP task models, paths, and class instances.
- Parameters
dockingFFT – dockingFFT initialized to match dimensions of current sampling scheme
cur_model – the current docking model initialized outside the trainer
cur_optimizer – the optimizer initialized outside the trainer
cur_experiment – current experiment name
BF_eval – BruteForce evalution of trained model
MC_eval – MonteCarlo evalutaion of trained model
MC_eval_num_epochs – number of epochs used in MonteCarlo evaluation
debug – set to True to check model parameter gradients
plotting – create plots or not
sample_buffer_length – number of keys in the SampleBuffer, has to be >= to number of training, validation, or testing examples.
- load_checkpoint(checkpoint_fpath)
Load saved checkpoint state dictionary.
- Parameters
checkpoint_fpath – path to saved model
- Returns
self.model, self.optimizer, checkpoint[‘epoch’]
- resume_training_or_not(resume_training, resume_epoch)
Resume training the model at specified epoch or not.
- Parameters
resume_training – set to True to resume training, False to start fresh training.
resume_epoch – epoch number to resume from
- Returns
starting epoch number, 1 if resume_training is True, resume_epoch+1 otherwise.
- run_epoch(data_stream, epoch, training=False, stream_name='train_stream')
Run the model for an epoch.
- Parameters
data_stream – input data stream
epoch – current epoch number
training – set to True for training, False for evalutation.
stream_name – name of the data stream
- run_model(data, pos_idx, training=True, stream_name='trainset', epoch=0)
Run a model iteration on the current example.
- Parameters
data – training example
pos_idx – current example position index
training – set to True for training, False for evalutation.
stream_name – data stream name
epoch – epoch count used in plotting
- Returns
loss and rmsd
- run_trainer(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)
Helper function to run trainer.
- Parameters
train_epochs – number of epoch to train
train_stream – training set data stream
valid_stream – valid set data stream
test_stream – test set data stream
resume_training – resume training from a loaded model state or train fresh model
resume_epoch – epoch to load model and resume training
- save_checkpoint(state, filename)
Save current state of the model to a checkpoint dictionary.
- Parameters
state – checkpoint state dictionary
filename – name of saved file
- train_model(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)
Train model for specified number of epochs and data streams.
- Parameters
train_epochs – number of epoch to train
train_stream – training set data stream
valid_stream – valid set data stream
test_stream – test set data stream
resume_training – resume training from a loaded model state or train fresh model
resume_epoch – epoch to load model and resume training