Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import pandas as pd | |
import tempfile | |
from pathlib import Path | |
from inference_wrapper import RNAFoldingPredictor | |
from huggingface_hub import hf_hub_download | |
# π½ Download the model weights from Hugging Face Hub | |
ckpt_file = hf_hub_download( | |
repo_id="ptope/rna-model-weights", | |
filename="best_val_model.pt" | |
) | |
# β Set up model and static directory | |
MODEL_PATH = ckpt_file | |
STATIC_DIR = Path("static") | |
os.makedirs(STATIC_DIR, exist_ok=True) | |
predictor = RNAFoldingPredictor(MODEL_PATH) | |
# β Load test sequences and map sequences to target IDs | |
test_sequences = pd.read_csv("test_sequences.csv") | |
test_seq_list = test_sequences["sequence"].dropna().unique().tolist() | |
sequence_to_target_id = { | |
row["sequence"].strip().upper(): row["target_id"] | |
for _, row in test_sequences.iterrows() | |
} | |
# β Format predicted structure into standard PDB format | |
def create_pdb_from_prediction(prediction_df): | |
pdb_lines = ["HEADER RNA STRUCTURE PREDICTION"] | |
atom_index = 1 | |
prev_c1_index = None | |
for _, row in prediction_df.iterrows(): | |
resname = row['resname'] | |
resid = row['resid'] | |
x, y, z = row['x_1'], row['y_1'], row['z_1'] | |
c1_index = atom_index | |
pdb_lines.append(f"ATOM {atom_index:5d} C1' {resname} A{resid:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C") | |
atom_index += 1 | |
base_offset = {'A': (1.5, 0.0, 0.5), 'C': (1.2, 0.3, 0.3), 'G': (1.7, -0.2, 0.7), 'U': (1.0, 0.5, 0.0)}.get(resname, (1.5, 0.0, 0.0)) | |
base_x, base_y, base_z = x + base_offset[0], y + base_offset[1], z + base_offset[2] | |
base_index = atom_index | |
pdb_lines.append(f"ATOM {atom_index:5d} N9 {resname} A{resid:4d} {base_x:8.3f}{base_y:8.3f}{base_z:8.3f} 1.00 0.00 N") | |
atom_index += 1 | |
pdb_lines.append(f"CONECT{c1_index:5d}{base_index:5d}") | |
if prev_c1_index is not None: | |
pdb_lines.append(f"CONECT{prev_c1_index:5d}{c1_index:5d}") | |
prev_c1_index = c1_index | |
pdb_lines.append("END") | |
return "\n".join(pdb_lines) | |
# β This function handles GPT model prediction and downloadable .pdb | |
def generate_and_return_file(sequence, description=""): | |
sequence = sequence.strip().upper() | |
if not sequence or not all(b in "ACGU" for b in sequence): | |
return "Invalid input", None | |
df = predictor.predict(sequence, description) | |
pdb_text = create_pdb_from_prediction(df) | |
filename = f"{next(tempfile._get_candidate_names())}.pdb" | |
path = STATIC_DIR / filename | |
with open(path, "w") as f: | |
f.write(pdb_text) | |
return "Download ready", str(path) | |
# β Load precomputed .pdb from /pdbs folder for Mol* viewer | |
def serve_precomputed_pdb(sequence, description=""): | |
sequence = sequence.strip().upper() | |
if sequence not in sequence_to_target_id: | |
return "" | |
target_id = sequence_to_target_id[sequence] | |
filename = f"{target_id}.pdb" | |
# β Here's the FIX: point viewer to /pdbs folder | |
pdb_url = f"https://ptope--rna-folding-demo.hf.space/pdbs/{filename}" | |
iframe_url = f"https://molstar.org/viewer/?loadFromUrl={pdb_url}" | |
# Check | |
print(f"[DEBUG] Final URL to check: {pdb_url}") | |
print(f"[DEBUG] Final URL to check: {iframe_url}") | |
iframe_html = f'<iframe src="{iframe_url}" width="100%" height="600px" style="border:1px solid #ccc;"></iframe>' | |
return iframe_html | |
# β Gradio app layout | |
with gr.Blocks() as demo: | |
gr.Markdown("## 𧬠RNA 3D Viewer: Precomputed Visualization + Predictive Download") | |
with gr.Row(): | |
with gr.Column(): | |
dropdown = gr.Dropdown(choices=test_seq_list, label="Select a Test Sequence") | |
seq_input = gr.Textbox(label="RNA Sequence", lines=4) | |
desc_input = gr.Textbox(label="Description (optional)", lines=1) | |
dropdown.change(fn=lambda s: s, inputs=dropdown, outputs=seq_input) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
status_output = gr.Textbox(label="Status") | |
file_output = gr.File(label="Download .pdb") | |
viewer_html = gr.HTML(label="Mol* Viewer") | |
# β Submit triggers BOTH file generation and structure visualization | |
submit_btn.click( | |
fn=generate_and_return_file, | |
inputs=[seq_input, desc_input], | |
outputs=[status_output, file_output] | |
) | |
submit_btn.click( | |
fn=serve_precomputed_pdb, | |
inputs=[seq_input, desc_input], | |
outputs=[viewer_html] | |
) | |
demo.launch() |