TrainerFI
- class Dock2D.Models.TrainerFI.TrainerFI(docking_model, docking_optimizer, interaction_model, interaction_optimizer, experiment, training_case='scratch', path_pretrain=None, FI_MC=False, debug=False, plotting=False, sample_buffer_length=1000)
- __init__(docking_model, docking_optimizer, interaction_model, interaction_optimizer, experiment, training_case='scratch', path_pretrain=None, FI_MC=False, debug=False, plotting=False, sample_buffer_length=1000)
- Parameters
docking_model – the current docking model initialized outside the trainer
docking_optimizer – the docking optimizer initialized outside the trainer
interaction_model – the current interaction model initialized outside the trainer
interaction_optimizer – the interaction optimizer initialized outside the trainer
experiment – current experiment name
training_case – current training case
path_pretrain – path to pretrained model
FI_MC – set True to use MonteCarlo (MC) for FI task.
debug – set to True show debug verbose model
plotting – create plots or not
sample_buffer_length – number of keys in the SampleBuffer, has to be >= to number of training, validation, or testing examples.
- check_APR(check_epoch, datastream, stream_name=None, deltaF_logfile=None, experiment=None)
Check accuracy, precision, recall, F1score and MCC
- Parameters
check_epoch – epoch to evaluate
datastream – data stream
stream_name – data stream name
deltaF_logfile – free energy log file name for evaluation set
experiment – current experiment name
- static classify(pred_interact, gt_interact)
Confusion matrix values.
- Parameters
pred_interact – predicted interaction
gt_interact – ground truth interaction
- Returns
confusion matrix values TP, FP, TN, FN,
- freeze_weights()
Freeze model weights depending on the experiment training case. These range from A) frozen pretrained docking model, B) unfrozen pretrained docking model, C) unfrozen pretrained docking model scoring coefficients, but frozen conv net, D) train the docking model from scratch.
- static load_checkpoint(checkpoint_fpath, model, optimizer, FI_MC=False)
Load saved checkpoint state dictionary.
- Parameters
checkpoint_fpath – path to saved model
model – model to load, either docking or interaction models
optimizer – model optimizer
FI_MC – return addtionally saved values from state dict
- Returns
model, optimizer, checkpoint[‘epoch’], and if FI_MC==True, additionally return checkpoint[‘alpha_buffer’], checkpoint[‘free_energy_buffer’]
- resume_training_or_not(resume_training, resume_epoch)
Resume training the model at specified epoch or not.
- Parameters
resume_training – set to True to resume training, False to start fresh training.
resume_epoch – epoch number to resume from
- Returns
starting epoch number, 1 if resume_training is True, resume_epoch+1 otherwise.
- run_epoch(data_stream, epoch, training=False, stream_name='train_stream')
Run the model for an epoch.
- Parameters
data_stream – input data stream
epoch – current epoch number
training – set to True for training, False for evalutation.
- run_model(data, pos_idx, stream_name, training=True, epoch=0)
Run a model iteration on the current example.
- Parameters
data – training example
pos_idx – current example position index
training – set to True for training, False for evalutation.
stream_name – data stream name
epoch – epoch count used in plotting
- Returns
loss, F, F_0, gt_interact and under evalutation, TP, FP, TN, FN, plus previously listed values.
- run_trainer(train_epochs, train_stream=None, valid_stream=None, test_stream=None, resume_training=False, resume_epoch=0)
Helper function to run trainer.
- Parameters
train_epochs – number of epoch to train
train_stream – training set data stream
valid_stream – valid set data stream
test_stream – test set data stream
resume_training – resume training from a loaded model state or train fresh model
resume_epoch – epoch to load model and resume training
- static save_checkpoint(state, filename, model)
Save current state of the model to a checkpoint dictionary.
- Parameters
state – checkpoint state dictionary
filename – name of saved file
model – model to save, either docking or interaction models
- set_docking_model_state()
Initialize the docking model training case. A) frozen pretrained docking model, B) unfrozen pretrained docking model, C) frozen conv net, unfrozen pretrained docking model scoring coefficients D) train the docking model from scratch.
- train_model(train_epochs, train_stream, valid_stream, test_stream, resume_training=False, resume_epoch=0)
Train model for specified number of epochs and data streams.
- Parameters
train_epochs – number of epoch to train
train_stream – training set data stream
valid_stream – valid set data stream
test_stream – test set data stream
resume_training – resume training from a loaded model state or train fresh model
resume_epoch – epoch to load model and resume training