train_bruteforce_FI.py
Functioning example of the script used to train the BruteForce Fact-of-Interaction (FI).
from Dock2D.Models.TrainerFI import *
import random
from Dock2D.Utility.TorchDataLoader import get_interaction_stream
from torch import optim
from Dock2D.Utility.PlotterFI import PlotterFI
from Dock2D.Models.model_interaction import Interaction
from Dock2D.Models.model_sampling import SamplingModel
from Dock2D.Utility.TorchDockingFFT import TorchDockingFFT
if __name__ == '__main__':
#################################################################################
##Datasets
trainset = '../../Datasets/interaction_train_400pool.pkl'
validset = '../../Datasets/interaction_valid_400pool.pkl'
# ### testing set
testset = '../../Datasets/interaction_test_400pool.pkl'
#########################
#### initialization of random seeds
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.cuda.set_device(0)
# torch.autograd.set_detect_anomaly(True)
#########################
## number_of_pairs provides max_size of interactions: max_size = (number_of_pairs**2 + number_of_pairs)/2
number_of_pairs = 100
train_stream = get_interaction_stream(trainset, number_of_pairs=number_of_pairs)
valid_stream = get_interaction_stream(validset, number_of_pairs=number_of_pairs)
test_stream = get_interaction_stream(testset, number_of_pairs=number_of_pairs)
######################
experiment = 'BF_FI_check_consolidated'
##################### Load and freeze/unfreeze params (training, no eval)
### path to pretrained docking model
# path_pretrain = 'Log/RECODE_CHECK_BFDOCKING_30epochsend.th'
path_pretrain = 'Log/FINAL_CHECK_DOCKING30.th'
# training_case = 'A' # CaseA: train with docking model frozen
# training_case = 'B' # CaseB: train with docking model unfrozen
# training_case = 'C' # CaseC: train with docking model SE2 CNN frozen and scoring ("a") coeffs unfrozen
training_case = 'scratch' # Case scratch: train everything from scratch
experiment = training_case + '_' + experiment
#####################
train_epochs = 20
lr_interaction = 10 ** -1
lr_docking = 10 ** -4
sample_steps = 10
sample_buffer_length = max(len(train_stream), len(valid_stream), len(test_stream))
debug = False
plotting = False
show = False
interaction_model = Interaction().to(device=0)
interaction_optimizer = optim.Adam(interaction_model.parameters(), lr=lr_interaction)
padded_dim = 100
num_angles = 360
dockingFFT = TorchDockingFFT(padded_dim=padded_dim, num_angles=num_angles)
docking_model = SamplingModel(dockingFFT, sample_steps=sample_steps, FI_BF=True).to(device=0)
docking_optimizer = optim.Adam(docking_model.parameters(), lr=lr_docking)
Trainer = TrainerFI(docking_model, docking_optimizer, interaction_model, interaction_optimizer, experiment,
training_case, path_pretrain,
FI_MC=False)
######################
### Train model from beginning
Trainer.run_trainer(train_epochs, train_stream=train_stream, valid_stream=None, test_stream=None)
## Resume training model at chosen epoch
# Trainer.run_trainer(resume_training=True, resume_epoch=14, train_epochs=6, train_stream=train_stream, valid_stream=None, test_stream=None)
### Validate model at chosen epoch
Trainer.run_trainer(train_epochs=1, train_stream=None, valid_stream=valid_stream, test_stream=test_stream,
resume_training=True, resume_epoch=train_epochs)
### Plot loss and free energy distributions with learned F_0 decision threshold
PlotterFI(experiment).plot_loss(show=True)
PlotterFI(experiment).plot_deltaF_distribution(plot_epoch=train_epochs, show=True)