import spaces import gradio as gr from gradio_molecule3d import Molecule3D import os import numpy as np import torch from rdkit import Chem import argparse import random from tqdm import tqdm from vina import Vina import esm from utils.relax import openmm_relax, relax_sdf from utils.protein_ligand import PDBProtein, parse_sdf_file from utils.data import torchify_dict from torch_geometric.transforms import Compose from utils.datasets import * from utils.transforms import * from utils.misc import * from utils.data import * from torch.utils.data import DataLoader from models.PD import Pocket_Design_new from functools import partial import pickle import yaml from easydict import EasyDict import uuid from datetime import datetime import tempfile import shutil from Bio import PDB from Bio.PDB import MMCIFParser, PDBIO import logging import zipfile # 配置日志 logger = logging.getLogger(__name__) LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s" logging.basicConfig( format=LOG_FORMAT, level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", filemode="w", ) # 确保目录存在 os.makedirs("./generate/upload", exist_ok=True) os.makedirs("./tmp", exist_ok=True) # 自定义CSS样式 custom_css = """ .title { font-size: 32px; font-weight: bold; color: #4CAF50; display: flex; align-items: center; } .subtitle { font-size: 20px; color: #666; margin-bottom: 20px; } .footer { margin-top: 20px; text-align: center; color: #666; } """ # 3D显示表示设置 - 默认配置 default_reps = [ { "model": 0, "chain": "", "resname": "", "style": "cartoon", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, "visible": True, "opacity": 1.0 }, { "model": 0, "chain": "", "resname": "", "style": "stick", "color": "greenCarbon", "around": 5, # 显示配体周围5Å的残基 "byres": True, "visible": True, "opacity": 0.8 } ] def create_zip_file(directory_path, zip_filename): """将指定目录压缩为zip文件""" try: with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(directory_path): for file in files: file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, directory_path) zipf.write(file_path, arcname) logger.info(f"成功创建压缩文件: {zip_filename}") return zip_filename except Exception as e: logger.error(f"创建压缩文件时出错: {str(e)}") return None def load_config(config_path): """加载配置文件""" with open(config_path, 'r') as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) return EasyDict(config_dict) # 删除了Vina相关的计算函数,因为只需要RMSD结果 def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, residue_dict=None, seq=None, full_seq_idx=None, r10_idx=None): """从蛋白质和配体字典创建数据实例""" instance = {} if protein_dict is not None: for key, item in protein_dict.items(): instance['protein_' + key] = item if ligand_dict is not None: for key, item in ligand_dict.items(): instance['ligand_' + key] = item if residue_dict is not None: for key, item in residue_dict.items(): instance[key] = item if seq is not None: instance['seq'] = seq if full_seq_idx is not None: instance['full_seq_idx'] = full_seq_idx if r10_idx is not None: instance['r10_idx'] = r10_idx return instance def ith_true_index(tensor, i): """找到张量中第i个为真的元素的索引""" true_indices = torch.nonzero(tensor).squeeze() return true_indices[i].item() def name2data(pdb_path, lig_path): """从PDB和SDF文件生成数据""" name = os.path.basename(pdb_path).split('.')[0] dir_name = os.path.dirname(pdb_path) pocket_path = os.path.join(dir_name, f"{name}_pocket.pdb") try: with open(pdb_path, 'r') as f: pdb_block = f.read() protein = PDBProtein(pdb_block) seq = ''.join(protein.to_dict_residue()['seq']) ligand = parse_sdf_file(lig_path, feat=False) if ligand is None: raise ValueError(f"无法从{lig_path}解析配体") r10_idx, r10_residues = protein.query_residues_ligand(ligand, radius=10, selected_residue=None, return_mask=False) full_seq_idx, _ = protein.query_residues_ligand(ligand, radius=3.5, selected_residue=r10_residues, return_mask=False) if not r10_residues: raise ValueError("在配体10Å范围内未找到任何残基") assert len(r10_idx) == len(r10_residues) pdb_block_pocket = protein.residues_to_pdb_block(r10_residues) with open(pocket_path, 'w') as f: f.write(pdb_block_pocket) with open(pocket_path, 'r') as f: pdb_block = f.read() pocket = PDBProtein(pdb_block) pocket_dict = pocket.to_dict_atom() residue_dict = pocket.to_dict_residue() _, residue_dict['protein_edit_residue'] = pocket.query_residues_ligand(ligand) if residue_dict['protein_edit_residue'].sum() == 0: raise ValueError("在口袋内未找到可编辑残基") assert residue_dict['protein_edit_residue'].sum() > 0 and residue_dict['protein_edit_residue'].sum() == len(full_seq_idx) assert len(residue_dict['protein_edit_residue']) == len(r10_idx) full_seq_idx.sort() r10_idx.sort() data = from_protein_ligand_dicts( protein_dict=torchify_dict(pocket_dict), ligand_dict=torchify_dict(ligand), residue_dict=torchify_dict(residue_dict), seq=seq, full_seq_idx=torch.tensor(full_seq_idx), r10_idx=torch.tensor(r10_idx) ) data['protein_filename'] = pocket_path data['ligand_filename'] = lig_path data['whole_protein_name'] = pdb_path return transform(data) except Exception as e: logger.error(f"name2data中出错: {str(e)}") raise def convert_cif_to_pdb(cif_path): """将CIF文件转换为PDB文件并保存为临时文件""" try: parser = MMCIFParser() structure = parser.get_structure("protein", cif_path) with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file: temp_pdb_path = temp_file.name io = PDBIO() io.set_structure(structure) io.save(temp_pdb_path) return temp_pdb_path except Exception as e: logger.error(f"将CIF转换为PDB时出错: {str(e)}") raise def align_pdb_files(pdb_file_1, pdb_file_2): """将两个PDB文件对齐,将第二个结构对齐到第一个结构上""" try: parser = PDB.PPBuilder() io = PDB.PDBIO() structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1) structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2) super_imposer = PDB.Superimposer() model_1 = structure_1[0] model_2 = structure_2[0] atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"] atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"] if not atoms_1 or not atoms_2: logger.warning("未找到用于对齐的CA原子") return min_length = min(len(atoms_1), len(atoms_2)) if min_length == 0: logger.warning("没有可用于对齐的原子") return super_imposer.set_atoms(atoms_1[:min_length], atoms_2[:min_length]) super_imposer.apply(model_2) io.set_structure(structure_2) io.save(pdb_file_2) except Exception as e: logger.error(f"对齐PDB文件时出错: {str(e)}") raise def create_combined_structure(protein_path, ligand_path, output_path): """将蛋白质和配体合并为一个PDB文件以便可视化""" try: # 读取蛋白质PDB文件 with open(protein_path, 'r') as f: protein_content = f.read() # 读取配体SDF文件并转换为PDB格式的字符串 mol = Chem.MolFromMolFile(ligand_path) if mol is None: logger.error(f"无法读取配体文件: {ligand_path}") return protein_path # 将配体转换为PDB格式 ligand_pdb_block = Chem.MolToPDBBlock(mol) # 合并蛋白质和配体 combined_content = protein_content.rstrip() + "\n" + ligand_pdb_block # 保存合并后的文件 with open(output_path, 'w') as f: f.write(combined_content) return output_path except Exception as e: logger.error(f"创建合并结构时出错: {str(e)}") return protein_path # 如果失败,返回原始蛋白质文件 @spaces.GPU(duration=500) def process_files(pdb_file, sdf_file, config_path): """处理上传的PDB和SDF文件""" try: unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" upload_dir = os.path.join("./generate/upload", unique_id) os.makedirs(upload_dir, exist_ok=True) logger.info(f"使用ID处理文件: {unique_id}") config = load_config(config_path) pdb_save_path = os.path.join(upload_dir, "protein.pdb") sdf_save_path = os.path.join(upload_dir, "ligand.sdf") shutil.copy(pdb_file, pdb_save_path) shutil.copy(sdf_file, sdf_save_path) logger.info(f"文件已保存到 {upload_dir}") device = "cuda:0" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {device}") protein_featurizer = FeaturizeProteinAtom() ligand_featurizer = FeaturizeLigandAtom() global transform transform = Compose([ protein_featurizer, ligand_featurizer, ]) logger.info("加载ESM模型...") name = 'esm2_t33_650M_UR50D' pretrained_model, alphabet = esm.pretrained.load_model_and_alphabet_hub(name) batch_converter = alphabet.get_batch_converter() checkpoint_path = config.model.checkpoint logger.info(f"从{checkpoint_path}加载检查点") ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) del pretrained_model logger.info("初始化模型...") model = Pocket_Design_new( config.model, protein_atom_feature_dim=protein_featurizer.feature_dim, ligand_atom_feature_dim=ligand_featurizer.feature_dim, device=device ).to(device) model.load_state_dict(ckpt['model']) logger.info("处理输入数据...") data = name2data(pdb_save_path, sdf_save_path) batch_size = 2 datalist = [data for _ in range(batch_size)] protein_filename = data['protein_filename'] ligand_filename = data['ligand_filename'] whole_protein_name = data['whole_protein_name'] dir_name = os.path.dirname(protein_filename) model.generate_id = 0 model.generate_id1 = 0 test_loader = DataLoader( datalist, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=partial(collate_mols_block, batch_converter=batch_converter) ) logger.info("生成结构...") with torch.no_grad(): model.eval() for batch in tqdm(test_loader, desc='Test'): for key in batch: if torch.is_tensor(batch[key]): batch[key] = batch[key].to(device) aar, rmsd, attend_logits = model.generate(batch, dir_name) logger.info(f'RMSD: {rmsd}') # 创建结果文件 result_path = os.path.join(dir_name, "0_whole.pdb") relaxed_path = os.path.join(dir_name, "0_relaxed.pdb") if os.path.exists(relaxed_path): shutil.copy(relaxed_path, result_path) else: shutil.copy(pdb_save_path, result_path) # 创建包含蛋白质和配体的合并文件用于可视化 combined_path = os.path.join(dir_name, "combined_structure.pdb") visualization_path = create_combined_structure(result_path, sdf_save_path, combined_path) # 创建压缩文件 zip_filename = os.path.join("./generate/upload", f"{unique_id}_results.zip") zip_path = create_zip_file(upload_dir, zip_filename) logger.info(f"结果已保存到 {result_path}") logger.info(f"压缩文件已创建: {zip_path}") summary = f""" 处理完成! 结果摘要: - 均方根偏差 (RMSD): {rmsd} 文件说明: - 所有结果文件已打包为ZIP文件供下载 - 包含原始输入、处理结果等 - 任务ID: {unique_id} """ return visualization_path, zip_path, summary except Exception as e: import traceback error_trace = traceback.format_exc() logger.error(f"处理过程中出错: {error_trace}") return None, None, f"处理过程中出错: {str(e)}" def gradio_interface(pdb_file, sdf_file, config_path): """Gradio接口函数""" if pdb_file is None or sdf_file is None: return None, None, "请上传PDB和SDF文件。" logger.info(f"开始处理{pdb_file}和{sdf_file}") pdb_viewer, zip_path, message = process_files(pdb_file, sdf_file, config_path) if pdb_viewer and os.path.exists(pdb_viewer): return pdb_viewer, zip_path, message else: return None, None, message if message else "处理失败,未知错误。" # 创建Gradio接口 with gr.Blocks(title="蛋白质-配体处理", css=custom_css) as demo: gr.Markdown("# 蛋白质-配体结构处理", elem_classes=["title"]) gr.Markdown("上传PDB和SDF文件进行蛋白质口袋设计和配体对接分析", elem_classes=["subtitle"]) with gr.Row(): with gr.Column(scale=1): pdb_input = gr.File(label="上传PDB文件", file_types=[".pdb"]) sdf_input = gr.File(label="上传SDF文件", file_types=[".sdf"]) config_input = gr.Textbox(label="配置文件路径", value="./configs/train_model_moad.yml") submit_btn = gr.Button("处理文件", variant="primary") with gr.Column(scale=2): # 使用Molecule3D组件,固定为默认样式 view3d = Molecule3D( label="3D结构可视化 (蛋白质卡通 + 配体周围残基棒状)", reps=default_reps ) output_message = gr.Textbox(label="处理状态和结果摘要", lines=8) output_file = gr.File(label="下载完整结果包 (ZIP)") # 处理文件的点击事件 submit_btn.click( fn=gradio_interface, inputs=[pdb_input, sdf_input, config_input], outputs=[view3d, output_file, output_message] ) gr.Markdown(""" ## 使用说明 1. **上传文件**: 上传蛋白质PDB文件和配体SDF文件 2. **配置设置**: 保持默认配置路径或调整为您的配置文件位置 3. **处理文件**: 点击"处理文件"按钮开始处理 4. **结果查看**: - 在3D查看器中交互式查看优化后的蛋白质-配体复合物结构 - 查看详细的处理结果摘要 - 下载包含所有结果文件的ZIP压缩包 ## 3D可视化功能 - **旋转**: 鼠标左键拖拽 - **缩放**: 鼠标滚轮或双指缩放 - **平移**: 鼠标右键拖拽 - **重置视图**: 双击重置到初始视角 可视化样式说明: - 蛋白质以卡通形式显示(白色碳骨架) - 配体周围5Å内的残基以棒状形式显示(绿色碳骨架) ## 下载文件说明 ZIP压缩包包含以下文件: - **protein.pdb**: 原始输入蛋白质文件 - **ligand.sdf**: 原始输入配体文件 - **protein_pocket.pdb**: 提取的蛋白质口袋文件 - **0_whole.pdb**: 优化后的完整蛋白质结构 - **0_relaxed.pdb**: 松弛优化后的蛋白质结构 - **combined_structure.pdb**: 用于可视化的蛋白质-配体复合物 ## 技术说明 该应用程序使用深度学习方法优化蛋白质口袋结构,提高与特定配体的结合能力。主要功能包括: - **蛋白质口袋识别**: 自动识别并提取配体结合口袋 - **结构优化设计**: 使用AI模型优化口袋残基构象 - **分子对接评分**: 使用Vina进行结合能评估 - **交互式3D可视化**: 清晰展示蛋白质-配体相互作用 - **完整结果打包**: 所有中间和最终结果文件统一打包下载 处理可能需要几分钟时间,请耐心等待。 """) gr.Markdown("© 2025 zaixi", elem_classes=["footer"]) if __name__ == "__main__": demo.launch(share=True)