TrainerIP

class Dock2D.Models.TrainerIP.TrainerIP(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)
__init__(dockingFFT, cur_model, cur_optimizer, cur_experiment, BF_eval=False, MC_eval=False, MC_eval_num_epochs=10, debug=False, plotting=False, sample_buffer_length=1000, tiling=False)

Initialize trainer for IP task models, paths, and class instances.

Parameters
  • dockingFFT – dockingFFT initialized to match dimensions of current sampling scheme

  • cur_model – the current docking model initialized outside the trainer

  • cur_optimizer – the optimizer initialized outside the trainer

  • cur_experiment – current experiment name

  • BF_eval – BruteForce evalution of trained model

  • MC_eval – MonteCarlo evalutaion of trained model

  • MC_eval_num_epochs – number of epochs used in MonteCarlo evaluation

  • debug – set to True to check model parameter gradients

  • plotting – create plots or not

  • sample_buffer_length – number of keys in the SampleBuffer, has to be >= to number of training, validation, or testing examples.

load_checkpoint(checkpoint_fpath)

Load saved checkpoint state dictionary.

Parameters

checkpoint_fpath – path to saved model

Returns

self.model, self.optimizer, checkpoint[‘epoch’]

resume_training_or_not(resume_training, resume_epoch)

Resume training the model at specified epoch or not.

Parameters
  • resume_training – set to True to resume training, False to start fresh training.

  • resume_epoch – epoch number to resume from

Returns

starting epoch number, 1 if resume_training is True, resume_epoch+1 otherwise.

run_epoch(data_stream, epoch, training=False, stream_name='train_stream')

Run the model for an epoch.

Parameters
  • data_stream – input data stream

  • epoch – current epoch number

  • training – set to True for training, False for evalutation.

  • stream_name – name of the data stream

run_model(data, pos_idx, training=True, stream_name='trainset', epoch=0)

Run a model iteration on the current example.

Parameters
  • data – training example

  • pos_idx – current example position index

  • training – set to True for training, False for evalutation.

  • stream_name – data stream name

  • epoch – epoch count used in plotting

Returns

loss and rmsd

run_trainer(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)

Helper function to run trainer.

Parameters
  • train_epochs – number of epoch to train

  • train_stream – training set data stream

  • valid_stream – valid set data stream

  • test_stream – test set data stream

  • resume_training – resume training from a loaded model state or train fresh model

  • resume_epoch – epoch to load model and resume training

save_checkpoint(state, filename)

Save current state of the model to a checkpoint dictionary.

Parameters
  • state – checkpoint state dictionary

  • filename – name of saved file

train_model(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)

Train model for specified number of epochs and data streams.

Parameters
  • train_epochs – number of epoch to train

  • train_stream – training set data stream

  • valid_stream – valid set data stream

  • test_stream – test set data stream

  • resume_training – resume training from a loaded model state or train fresh model

  • resume_epoch – epoch to load model and resume training