train_brutesimplified_IP.py

Functioning example of the script used to train the BruteSimplified model for Interaction Pose (IP).

from Dock2D.Models.TrainerIP import *
import random
from Dock2D.Utility.TorchDataLoader import get_docking_stream
from torch import optim
from Dock2D.Utility.PlotterIP import PlotterIP
from Dock2D.Utility.TorchDockingFFT import TorchDockingFFT
from Dock2D.Models.model_sampling import SamplingModel

if __name__ == '__main__':
    #################################################################################
    # Datasets
    trainset = '../../Datasets/docking_train_400pool.pkl'
    validset = '../../Datasets/docking_valid_400pool.pkl'
    ### testing set
    testset = '../../Datasets/docking_test_400pool.pkl'
    #########################
    #### initialization of random seeds
    random_seed = 42
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    random.seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.set_device(0)
    # torch.autograd.set_detect_anomaly(True)
    ######################
    max_size = 1000
    train_stream = get_docking_stream(trainset, max_size=max_size)
    valid_stream = get_docking_stream(validset, max_size=max_size)
    test_stream = get_docking_stream(testset, max_size=max_size)
    sample_buffer_length = max(len(train_stream), len(valid_stream), len(test_stream))
    ######################
    experiment = 'BS_check_code_consolidated_10ep'
    ######################
    train_epochs = 10
    lr = 10 ** -2
    plotting = False
    show = True
    #####################
    padded_dim = 100
    num_angles = 1
    sampledFFT = TorchDockingFFT(padded_dim=padded_dim, num_angles=num_angles)
    model = SamplingModel(sampledFFT, IP=True).to(device=0)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    Trainer = TrainerIP(sampledFFT, model, optimizer, experiment)
    ######################
    ### Train model from beginning
    # Trainer.run_trainer(train_epochs, train_stream=train_stream)

    ### Resume training model at chosen epoch
    # Trainer.run_trainer(
    #     train_epochs=1, train_stream=train_stream, valid_stream=None, test_stream=None,
    #     resume_training=True, resume_epoch=train_epochs)

    ### Resume training for validation sets
    # Trainer.run_trainer(
    #     train_epochs=1, train_stream=None, valid_stream=valid_stream, #test_stream=valid_stream,
    #     resume_training=True, resume_epoch=train_epochs)

    ## Brute force evaluation and plotting
    start = train_epochs-1
    stop = train_epochs
    eval_angles = 360
    evalFFT = TorchDockingFFT(padded_dim=padded_dim, num_angles=eval_angles)
    eval_model = SamplingModel(evalFFT, IP=True).to(device=0)
    EvalTrainer = TrainerIP(evalFFT, eval_model, optimizer, experiment,
                            BF_eval=True, plotting=plotting, sample_buffer_length=sample_buffer_length)
    for epoch in range(start, stop):
        if stop-1 == epoch:
            plotting = False
            EvalTrainer.run_trainer(train_epochs=1, train_stream=None, valid_stream=valid_stream, test_stream=test_stream,
                                    resume_training=True, resume_epoch=epoch)

    ## Plot loss and RMSDs from current experiment
    PlotterIP(experiment).plot_loss(ylim=None)
    PlotterIP(experiment).plot_rmsd_distribution(plot_epoch=train_epochs, show=show)