train_bruteforce_IP.py

Functioning example of the script used to train the BruteForce Interaction Pose (IP).

from Dock2D.Models.TrainerIP import *
import random
from Dock2D.Utility.TorchDataLoader import get_docking_stream
from torch import optim
from Dock2D.Utility.PlotterIP import PlotterIP
from Dock2D.Models.model_sampling import SamplingModel
from Dock2D.Utility.TorchDockingFFT import TorchDockingFFT

if __name__ == '__main__':
    #################################################################################
    # Datasets
    trainset = '../../Datasets/docking_train_400pool.pkl'
    validset = '../../Datasets/docking_valid_400pool.pkl'
    ### testing set
    testset = '../../Datasets/docking_test_400pool.pkl'
    #########################
    #### initialization of random seeds
    random_seed = 42
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    random.seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.set_device(0)
    # torch.autograd.set_detect_anomaly(True)
    ######################
    max_size = 1000
    train_stream = get_docking_stream(trainset, max_size=max_size)
    valid_stream = get_docking_stream(validset,  max_size=max_size)
    test_stream = get_docking_stream(testset, max_size=max_size)
    ######################
    experiment = 'BF_check_code_consolidated_10ep'
    ######################
    train_epochs = 10
    learning_rate = 10 ** -4

    padded_dim = 100
    num_angles = 360
    BFdockingFFT = TorchDockingFFT(padded_dim=padded_dim, num_angles=num_angles)
    model = SamplingModel(BFdockingFFT, IP=True).to(device=0)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    Trainer = TrainerIP(BFdockingFFT, model, optimizer, experiment, BF_eval=True)
    ######################
    ### Train model from beginning, evaluate if valid_stream and/or test_stream passed in
    Trainer.run_trainer(train_epochs=train_epochs, train_stream=train_stream, valid_stream=valid_stream, test_stream=test_stream)

    ### Resume training model at chosen epoch
    # Trainer.run_trainer(train_stream=None, valid_stream=valid_stream, test_stream=test_stream,
    #                     resume_training=True, resume_epoch=13, train_epochs=17)

    # ### Evaluate model on chosen dataset only and plot at chosen epoch and dataset frequency
    # Trainer.run_trainer(train_stream=None, valid_stream=valid_stream, test_stream=test_stream,
    #                     resume_training=True, resume_epoch=15, train_epochs=1)

    ## Plot loss and RMSDs from current experiment
    PlotterIP(experiment).plot_loss(show=True)
    PlotterIP(experiment).plot_rmsd_distribution(plot_epoch=train_epochs, show=True)