import esm import torch import os from torch.utils.data import DataLoader from functools import partial import numpy as np from PocketGen.models.PD import Pocket_Design_new from PocketGen.utils.misc import seed_all, load_config from PocketGen.utils.transforms import FeaturizeProteinAtom, FeaturizeLigandAtom from PocketGen.utils.data import collate_mols_block from .sampler import interaction from eval.docking import docking from eval.prepare import prepare from eval.window import compute_box from eval.chemutils import kd from eval.mutations import mutations class Model: def __init__(self, checkpoint_path:str, args): """ The mutant generation model constructor. This method does the setup of torch and CUDA environment, loads the checkpoint and then returns a PocketGen instance using the weights from checkpoints and the parameters retrieved. @param checkpoint_path (str): Path to checkpoint (.pt) file for PocketGen. @param verbose (int): 0 for quiet, 1 for necessary information and 2 for debug. """ # setup global class variables self.verbose = args["verbose"] self.device = args["device"] self.outputdir = args["output"] self.size = args["number"] self.sources = {} self.config = load_config('./PocketGen/configs/train_model.yml') if self.verbose > 0: print('Flint setup started, please wait.') if self.verbose == 2: print('Now initializing pytorch and CUDA environment :') # cleans cache and sets the libs seeds torch.cuda.empty_cache() seed_all(2089) if self.verbose == 2: print('\tpytorch and CUDA initialized correctly.') print('Now retrieving alphabet from fair-ESM :') # sets ESM2 alphabet as the usual alphabet pretrained_model, self.alphabet = esm.pretrained.load_model_and_alphabet_hub('esm2_t33_650M_UR50D') del pretrained_model # ESM2 pretrained_model that we don't need here is deleted from memory if self.verbose == 2: print('\tESM alphabet successfully loaded.') print('Now building PocketGen model :') # get the model checkpoint from .pt file self.checkpoint = torch.load(checkpoint_path, map_location=self.device) if self.verbose == 2: print('\tcheckpoint successfully created.') # instanciate PocketGen model for pocket design self.model = Pocket_Design_new( self.config.model, protein_atom_feature_dim=FeaturizeProteinAtom().feature_dim, ligand_atom_feature_dim=FeaturizeLigandAtom().feature_dim, device=self.device ) if self.verbose == 2: print("\tPocketGen model well instanciated.") # send model to selected device self.model = self.model.to(self.device) if self.verbose == 2: print('\tPocketGen model sent to selected device.') # load current saved checkpoint into model self.model.load_state_dict(self.checkpoint['model']) if self.verbose == 2: print('\tcheckpoint loaded into PocketGen.') print('End of setup, model can now be used.\n\n') def input(self, receptor_path:str, ligand_path:str) -> "Model": """ Loads a protein receptor and a ligand from files and store it in a data-loader, useable by the model when generating mutants. @param ligand_path (str): path to the ligand SDF file. @param receptor_path (str): path to the receptor PDB file. @return (Model): the instance of Model, for chainability purposes. """ if self.verbose == 2: print('Now parsing data from receptor and ligand :') # get dense features from receptor-ligand interaction features = interaction(receptor_path, ligand_path) if self.verbose == 2: print('\tsuccessfully parsed interaction features.\n') print('Now building the pytorch dataloader :') # initialize the data loader (including batch converter) self.loader = DataLoader( [features for _ in range(self.size)], batch_size=4, shuffle=False, num_workers=self.config.train.num_workers, collate_fn=partial( collate_mols_block, batch_converter=self.alphabet.get_batch_converter() ) ) # stores the source input files to compare self.sources[self._nbatch()] = [receptor_path, ligand_path] if self.verbose == 2: print('\tpytorch dataloader built correctly.') return self def generate(self) -> "Model": """ Generates mutants based on the input protein receptor. @return (Model): the instance of Model, for chainability purposes. """ if self.verbose > 0: print("Now generating new mutant protein receptor :") # place it in eval mode self.model.eval() # no need to compute gradients during inference with torch.no_grad(): for b, batch in enumerate(self.loader): # move batch to selected device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # well-predicted AA on total mask redisue # root mean squared deviation (RMSD) aa_ratio, rmsd, attend_logits = self.model.generate( batch, output_folder=os.path.join(self.outputdir, f"batch_{b + self._nbatch()}") ) if self.verbose > 0: print(f"\tinference done on a batch.") return self def results(self) -> "Model": """ write results in a summary file, along with all generated PDBs. @return (Model): the instance of Model, for chainability purposes. """ # initialize the resulting summary TSV summary = "ID\tdelta_G\tKd\tmutations (AA)\n" if self.verbose > 0: print(f"Now writing output files :") for b in range(self._nbatch()): for i in range(self.size): receptor_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}.pdb") ligand_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}.sdf") # compute the docking window around ligand docking_box = compute_box(receptor_path, ligand_path) energies = docking( receptor_file=prepare(receptor_path), ligand_file=prepare(ligand_path), center=docking_box["center"], box_size=docking_box["size"] ) # calculates the mean Kd and deltaG mean_kd = np.mean([kd(e) for e in energies]) mean_dg = np.mean(energies) # find the number of mutations (AA-level) n_mutations = mutations( self.sources[b][0], os.path.join(self.outputdir, f"batch_{b}", f"{i}_whole.pdb") ) summary += f"batch_{b}/{i}\t{mean_dg}\t{mean_kd}\t{n_mutations}" + "\n" if self.verbose == 2: print(f"\twrote one new entry in the summary file.") if self.verbose > 0: print(f"You can find the files and summary in your output folder.") # write summary to a local file with open(os.path.join(self.outputdir, "summary.tsv"), "w") as file: file.write(summary) return self def _nbatch(self) -> int: """ returns the number of batches stored from now in the output directory @return (int): the number of folders in dir """ os.makedirs(self.outputdir, exist_ok=True) return len(os.listdir(self.outputdir))