Newer
Older
from PocketGen.models.PD import Pocket_Design_new
from PocketGen.utils.misc import seed_all, load_config
from PocketGen.utils.transforms import FeaturizeProteinAtom, FeaturizeLigandAtom
from PocketGen.utils.data import collate_mols_block
from eval.docking import docking
from eval.prepare import prepare
from eval.window import compute_box
from eval.chemutils import kd
from eval.mutations import mutations
def __init__(self, checkpoint_path:str, args):
"""
The mutant generation model constructor. This method does the setup of
torch and CUDA environment, loads the checkpoint and then returns a PocketGen
instance using the weights from checkpoints and the parameters retrieved.
@param checkpoint_path (str): Path to checkpoint (.pt) file for PocketGen.
@param verbose (int): 0 for quiet, 1 for necessary information and 2 for debug.
self.verbose = args["verbose"]
self.device = args["device"]
self.outputdir = args["output"]
self.size = args["number"]
self.sources = {}
self.config = load_config('./PocketGen/configs/train_model.yml')
print('Flint setup started, please wait.')
if self.verbose == 2:
print('Now initializing pytorch and CUDA environment :')
torch.cuda.empty_cache()
seed_all(2089)
if self.verbose == 2:
print('\tpytorch and CUDA initialized correctly.')
print('Now retrieving alphabet from fair-ESM :')
# sets ESM2 alphabet as the usual alphabet
pretrained_model, self.alphabet = esm.pretrained.load_model_and_alphabet_hub('esm2_t33_650M_UR50D')
del pretrained_model # ESM2 pretrained_model that we don't need here is deleted from memory
print('\tESM alphabet successfully loaded.')
print('Now building PocketGen model :')
self.checkpoint = torch.load(checkpoint_path, map_location=self.device)
if self.verbose == 2:
print('\tcheckpoint successfully created.')
protein_atom_feature_dim=FeaturizeProteinAtom().feature_dim,
ligand_atom_feature_dim=FeaturizeLigandAtom().feature_dim,
device=self.device
)
if self.verbose == 2:
print("\tPocketGen model well instanciated.")
self.model = self.model.to(self.device)
if self.verbose == 2:
print('\tPocketGen model sent to selected device.')
self.model.load_state_dict(self.checkpoint['model'])
if self.verbose == 2:
print('\tcheckpoint loaded into PocketGen.')
print('End of setup, model can now be used.\n\n')
def input(self, receptor_path:str, ligand_path:str) -> "Model":
"""
Loads a protein receptor and a ligand from files and store it in
a data-loader, useable by the model when generating mutants.
@param ligand_path (str): path to the ligand SDF file.
@param receptor_path (str): path to the receptor PDB file.
@return (Model): the instance of Model, for chainability purposes.
"""
if self.verbose == 2:
print('Now parsing data from receptor and ligand :')
# get dense features from receptor-ligand interaction
features = interaction(receptor_path, ligand_path)
if self.verbose == 2:
print('\tsuccessfully parsed interaction features.\n')
print('Now building the pytorch dataloader :')
# initialize the data loader (including batch converter)
[features for _ in range(self.size)],
batch_size=4,
shuffle=False,
num_workers=self.config.train.num_workers,
collate_fn=partial(
collate_mols_block,
batch_converter=self.alphabet.get_batch_converter()
)
)
# stores the source input files to compare
self.sources[self._nbatch()] = [receptor_path, ligand_path]
if self.verbose == 2:
print('\tpytorch dataloader built correctly.')
"""
Generates mutants based on the input protein receptor.
@return (Model): the instance of Model, for chainability purposes.
"""
if self.verbose > 0:
print("Now generating new mutant protein receptor :")
# place it in eval mode
self.model.eval()
# no need to compute gradients during inference
with torch.no_grad():
for b, batch in enumerate(self.loader):
# move batch to selected device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
# well-predicted AA on total mask redisue
# root mean squared deviation (RMSD)
aa_ratio, rmsd, attend_logits = self.model.generate(
batch, output_folder=os.path.join(self.outputdir, f"batch_{b + self._nbatch()}")
)
if self.verbose > 0:
print(f"\tinference done on a batch.")
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def results(self) -> "Model":
"""
write results in a summary file, along with all generated PDBs.
@return (Model): the instance of Model, for chainability purposes.
"""
# initialize the resulting summary TSV
summary = "ID\tdelta_G\tKd\tmutations (AA)\n"
if self.verbose > 0:
print(f"Now writing output files :")
for b in range(self._nbatch()):
for i in range(self.size):
receptor_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}.pdb")
ligand_path = os.path.join(self.outputdir, f"batch_{b}", f"{i}.sdf")
# compute the docking window around ligand
docking_box = compute_box(receptor_path, ligand_path)
energies = docking(
receptor_file=prepare(receptor_path),
ligand_file=prepare(ligand_path),
center=docking_box["center"],
box_size=docking_box["size"]
)
# calculates the mean Kd and deltaG
mean_kd = np.mean([kd(e) for e in energies])
mean_dg = np.mean(energies)
# find the number of mutations (AA-level)
n_mutations = mutations(
self.sources[b][0],
os.path.join(self.outputdir, f"batch_{b}", f"{i}_whole.pdb")
)
summary += f"batch_{b}/{i}\t{mean_dg}\t{mean_kd}\t{n_mutations}" + "\n"
if self.verbose == 2:
print(f"\twrote one new entry in the summary file.")
if self.verbose > 0:
print(f"You can find the files and summary in your output folder.")
# write summary to a local file
with open(os.path.join(self.outputdir, "summary.tsv"), "w") as file:
file.write(summary)
return self
def _nbatch(self) -> int:
"""
returns the number of batches stored from now in the output directory
@return (int): the number of folders in dir
"""
os.makedirs(self.outputdir, exist_ok=True)
return len(os.listdir(self.outputdir))