TrainerFI

class Dock2D.Models.TrainerFI.TrainerFI(docking_model, docking_optimizer, interaction_model, interaction_optimizer, experiment, training_case='scratch', path_pretrain=None, FI_MC=False, debug=False, plotting=False, sample_buffer_length=1000)
__init__(docking_model, docking_optimizer, interaction_model, interaction_optimizer, experiment, training_case='scratch', path_pretrain=None, FI_MC=False, debug=False, plotting=False, sample_buffer_length=1000)
Parameters
  • docking_model – the current docking model initialized outside the trainer

  • docking_optimizer – the docking optimizer initialized outside the trainer

  • interaction_model – the current interaction model initialized outside the trainer

  • interaction_optimizer – the interaction optimizer initialized outside the trainer

  • experiment – current experiment name

  • training_case – current training case

  • path_pretrain – path to pretrained model

  • FI_MC – set True to use MonteCarlo (MC) for FI task.

  • debug – set to True show debug verbose model

  • plotting – create plots or not

  • sample_buffer_length – number of keys in the SampleBuffer, has to be >= to number of training, validation, or testing examples.

check_APR(check_epoch, datastream, stream_name=None, deltaF_logfile=None, experiment=None)

Check accuracy, precision, recall, F1score and MCC

Parameters
  • check_epoch – epoch to evaluate

  • datastream – data stream

  • stream_name – data stream name

  • deltaF_logfile – free energy log file name for evaluation set

  • experiment – current experiment name

static classify(pred_interact, gt_interact)

Confusion matrix values.

Parameters
  • pred_interact – predicted interaction

  • gt_interact – ground truth interaction

Returns

confusion matrix values TP, FP, TN, FN,

freeze_weights()

Freeze model weights depending on the experiment training case. These range from A) frozen pretrained docking model, B) unfrozen pretrained docking model, C) unfrozen pretrained docking model scoring coefficients, but frozen conv net, D) train the docking model from scratch.

static load_checkpoint(checkpoint_fpath, model, optimizer, FI_MC=False)

Load saved checkpoint state dictionary.

Parameters
  • checkpoint_fpath – path to saved model

  • model – model to load, either docking or interaction models

  • optimizer – model optimizer

  • FI_MC – return addtionally saved values from state dict

Returns

model, optimizer, checkpoint[‘epoch’], and if FI_MC==True, additionally return checkpoint[‘alpha_buffer’], checkpoint[‘free_energy_buffer’]

resume_training_or_not(resume_training, resume_epoch)

Resume training the model at specified epoch or not.

Parameters
  • resume_training – set to True to resume training, False to start fresh training.

  • resume_epoch – epoch number to resume from

Returns

starting epoch number, 1 if resume_training is True, resume_epoch+1 otherwise.

run_epoch(data_stream, epoch, training=False, stream_name='train_stream')

Run the model for an epoch.

Parameters
  • data_stream – input data stream

  • epoch – current epoch number

  • training – set to True for training, False for evalutation.

run_model(data, pos_idx, stream_name, training=True, epoch=0)

Run a model iteration on the current example.

Parameters
  • data – training example

  • pos_idx – current example position index

  • training – set to True for training, False for evalutation.

  • stream_name – data stream name

  • epoch – epoch count used in plotting

Returns

loss, F, F_0, gt_interact and under evalutation, TP, FP, TN, FN, plus previously listed values.

run_trainer(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)

Helper function to run trainer.

Parameters
  • train_epochs – number of epoch to train

  • train_stream – training set data stream

  • valid_stream – valid set data stream

  • test_stream – test set data stream

  • resume_training – resume training from a loaded model state or train fresh model

  • resume_epoch – epoch to load model and resume training

static save_checkpoint(state, filename, model)

Save current state of the model to a checkpoint dictionary.

Parameters
  • state – checkpoint state dictionary

  • filename – name of saved file

  • model – model to save, either docking or interaction models

set_docking_model_state()

Initialize the docking model training case. A) frozen pretrained docking model, B) unfrozen pretrained docking model, C) frozen conv net, unfrozen pretrained docking model scoring coefficients D) train the docking model from scratch.

train_model(train_epochs, train_stream, valid_stream, test_stream, resume_training=False, resume_epoch=0)

Train model for specified number of epochs and data streams.

Parameters
  • train_epochs – number of epoch to train

  • train_stream – training set data stream

  • valid_stream – valid set data stream

  • test_stream – test set data stream

  • resume_training – resume training from a loaded model state or train fresh model

  • resume_epoch – epoch to load model and resume training