TrainerIP
- class Dock2D.Models.TrainerIP.TrainerIP(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)
- __init__(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)
- Initialize trainer for IP task models, paths, and class instances. - Parameters
- dockingFFT – dockingFFT initialized to match dimensions of current sampling scheme 
- cur_model – the current docking model initialized outside the trainer 
- cur_optimizer – the optimizer initialized outside the trainer 
- cur_experiment – current experiment name 
- BF_eval – BruteForce evalution of trained model 
- MC_eval – MonteCarlo evalutaion of trained model 
- MC_eval_num_epochs – number of epochs used in MonteCarlo evaluation 
- debug – set to True to check model parameter gradients 
- plotting – create plots or not 
- sample_buffer_length – number of keys in the SampleBuffer, has to be >= to number of training, validation, or testing examples. 
 
 
 - load_checkpoint(checkpoint_fpath)
- Load saved checkpoint state dictionary. - Parameters
- checkpoint_fpath – path to saved model 
- Returns
- self.model, self.optimizer, checkpoint[‘epoch’] 
 
 - resume_training_or_not(resume_training, resume_epoch)
- Resume training the model at specified epoch or not. - Parameters
- resume_training – set to True to resume training, False to start fresh training. 
- resume_epoch – epoch number to resume from 
 
- Returns
- starting epoch number, 1 if resume_training is True, resume_epoch+1 otherwise. 
 
 - run_epoch(data_stream, epoch, training=False, stream_name='train_stream')
- Run the model for an epoch. - Parameters
- data_stream – input data stream 
- epoch – current epoch number 
- training – set to True for training, False for evalutation. 
- stream_name – name of the data stream 
 
 
 - run_model(data, pos_idx, training=True, stream_name='trainset', epoch=0)
- Run a model iteration on the current example. - Parameters
- data – training example 
- pos_idx – current example position index 
- training – set to True for training, False for evalutation. 
- stream_name – data stream name 
- epoch – epoch count used in plotting 
 
- Returns
- loss and rmsd 
 
 - run_trainer(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)
- Helper function to run trainer. - Parameters
- train_epochs – number of epoch to train 
- train_stream – training set data stream 
- valid_stream – valid set data stream 
- test_stream – test set data stream 
- resume_training – resume training from a loaded model state or train fresh model 
- resume_epoch – epoch to load model and resume training 
 
 
 - save_checkpoint(state, filename)
- Save current state of the model to a checkpoint dictionary. - Parameters
- state – checkpoint state dictionary 
- filename – name of saved file 
 
 
 - train_model(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)
- Train model for specified number of epochs and data streams. - Parameters
- train_epochs – number of epoch to train 
- train_stream – training set data stream 
- valid_stream – valid set data stream 
- test_stream – test set data stream 
- resume_training – resume training from a loaded model state or train fresh model 
- resume_epoch – epoch to load model and resume training