Newer
Older
from PocketGen.models.PD import Pocket_Design_new
from PocketGen.utils.misc import seed_all, load_config
from PocketGen.utils.transforms import FeaturizeProteinAtom, FeaturizeLigandAtom
from PocketGen.utils.data import collate_mols_block
from eval.docking import docking
from eval.prepare import prepare
from eval.window import compute_box
from eval.chemutils import kd
from eval.mutations import mutations
def __init__(self, checkpoint_path:str, args):
"""
The mutant generation model constructor. This method does the setup of
torch and CUDA environment, loads the checkpoint and then returns a PocketGen
instance using the weights from checkpoints and the parameters retrieved.
@param checkpoint_path (str): Path to checkpoint (.pt) file for PocketGen.
@param verbose (int): 0 for quiet, 1 for necessary information and 2 for debug.
self.verbose = args["verbose"]
self.device = args["device"]
self.outputdir = args["output"]
self.config = load_config('./PocketGen/configs/train_model.yml')
print('Flint setup started, please wait.')
if self.verbose == 2:
print('Now initializing pytorch and CUDA environment :')
torch.cuda.empty_cache()
seed_all(2089)
if self.verbose == 2:
print('\tpytorch and CUDA initialized correctly.')
print('Now retrieving alphabet from fair-ESM :')
# sets ESM2 alphabet as the usual alphabet
pretrained_model, self.alphabet = esm.pretrained.load_model_and_alphabet_hub('esm2_t33_650M_UR50D')
del pretrained_model # ESM2 pretrained_model that we don't need here is deleted from memory
print('\tESM alphabet successfully loaded.')
print('Now building PocketGen model :')
self.checkpoint = torch.load(checkpoint_path, map_location=self.device)
if self.verbose == 2:
print('\tcheckpoint successfully created.')
protein_atom_feature_dim=FeaturizeProteinAtom().feature_dim,
ligand_atom_feature_dim=FeaturizeLigandAtom().feature_dim,
device=self.device
)
if self.verbose == 2:
print("\tPocketGen model well instanciated.")
self.model = self.model.to(self.device)
if self.verbose == 2:
print('\tPocketGen model sent to selected device.')
self.model.load_state_dict(self.checkpoint['model'])
if self.verbose == 2:
print('\tcheckpoint loaded into PocketGen.')
print('End of setup, model can now be used.\n\n')
def input(self, receptor_path:str, ligand_path:str) -> "Model":
"""
Loads a protein receptor and a ligand from files and store it in
a data-loader, useable by the model when generating mutants.
@param ligand_path (str): path to the ligand SDF file.
@param receptor_path (str): path to the receptor PDB file.
@return (Model): the instance of Model, for chainability purposes.
"""
if self.verbose == 2:
print('Now parsing data from receptor and ligand :')
# get dense features from receptor-ligand interaction
features = interaction(receptor_path, ligand_path)
if self.verbose == 2:
print('\tsuccessfully parsed interaction features.\n')
print('Now building the pytorch dataloader :')
# initialize the data loader (including batch converter)
[features for _ in range(self.size)],
batch_size=4,
shuffle=False,
num_workers=self.config.train.num_workers,
collate_fn=partial(
collate_mols_block,
batch_converter=self.alphabet.get_batch_converter()
)
)
# stores the source input files to compare
self.sources = [receptor_path, ligand_path]
if self.verbose == 2:
print('\tpytorch dataloader built correctly.')
"""
Generates mutants based on the input protein receptor.
@return (Model): the instance of Model, for chainability purposes.
print("Now generating new mutant protein receptors :")
# place it in eval mode
self.model.eval()
# no need to compute gradients during inference
with torch.no_grad():
for b, batch in enumerate(self.loader):
# move batch to selected device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
# well-predicted AA on total mask redisue
# root mean squared deviation (RMSD)
aa_ratio, rmsd, attend_logits = self.model.generate(
batch, target_path=os.path.join(self.outputdir, f"batch_{b}")
shutil.copyfile(self.sources[0], os.path.join(self.outputdir, f"batch_{b}", f"{b}_orig_whole.pdb"))
if self.verbose > 0:
print(f"\tinference done on a batch.")
def results(self) -> "Model":
"""
write results in a summary file, along with all generated PDBs.
@return (Model): the instance of Model, for chainability purposes.
"""
# initialize the resulting summary TSV
summary = "ID\tdelta_G\tKd\tmutations (AA)\n"
if self.verbose > 0:
print(f"Now writing output files :")
for b in range(self._nbatch()):
for i in range(self.size):
receptor_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}_whole.pdb")
ligand_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}.sdf")
# compute the docking window around ligand
docking_box = compute_box(receptor_path, ligand_path)
try:
energies = docking(
receptor_file=prepare(receptor_path),
ligand_file=prepare(ligand_path),
center=docking_box["center"],
box_size=docking_box["size"]
)
except Exception as e:
print(f"\t\terror simulating docking:{e}")
energies = np.zeros(1)
# calculates the mean Kd and deltaG
mean_kd = np.mean([kd(e) for e in energies])
mean_dg = np.mean(energies)
# find the number of mutations (AA-level)
n_mutations = mutations(
os.path.join(self.outputdir, f"batch_{b}", f"{b}_orig_whole.pdb"),
os.path.join(self.outputdir, f"batch_{b}", f"{i}_whole.pdb")
)
summary += f"batch_{b}/{i}\t{mean_dg}\t{mean_kd}\t{n_mutations}" + "\n"
if self.verbose == 2:
print(f"\twrote one new entry in the summary file.")
if self.verbose > 0:
print(f"You can find the files and summary in your output folder.")
# write summary to a local file
with open(os.path.join(self.outputdir, "summary.tsv"), "w") as file:
file.write(summary)
return self
def _nbatch(self) -> int:
"""
returns the number of batches stored from now in the output directory
@return (int): the number of folders in dir
"""
os.makedirs(self.outputdir, exist_ok=True)
return len([f for f in os.listdir(self.outputdir) if os.path.isdir(os.path.join(self.outputdir, f))])