import os import hashlib import tempfile import glob import re from pathlib import Path from subprocess import run, PIPE import numpy as np import torch from transformers import AutoTokenizer, AutoModel # ----------------------------------------------------------- # Common helpers # ----------------------------------------------------------- _CACHE_ROOT = Path(os.path.expanduser("~/.rna_cache")) _CACHE_ROOT.mkdir(parents=True, exist_ok=True) _BPPM_DIR = _CACHE_ROOT / "bppm" _TORSION_DIR = _CACHE_ROOT / "torsion" _BPPM_DIR.mkdir(exist_ok=True) _TORSION_DIR.mkdir(exist_ok=True) TARGET_ANGLES = [ "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "chi", "eta", "theta" ] def _hash_seq(seq: str) -> str: """Return a deterministic SHA‑1 hash of the sequence (U→T, upper‑case).""" return hashlib.sha1(seq.upper().replace("U", "T").encode()).hexdigest() # ----------------------------------------------------------- # BPPM ​​utilities # ----------------------------------------------------------- _BPPM_PATTERN = re.compile(r"(\d+)\s+(\d+)\s+([\d\.]+)\s+ubox") def _run_rnafold_to_bppm(seq: str) -> np.ndarray: """Run RNAfold -p in a private tmpdir and parse the dot plot.""" L = len(seq) bppm = np.zeros((L, L), dtype=np.float32) with tempfile.TemporaryDirectory(prefix="rnafold_") as tmp: tmp_path = Path(tmp) (tmp_path / "input.fa").write_text(f">seq\n{seq}\n") cmd = "RNAfold -p < input.fa" result = run(cmd, shell=True, cwd=tmp, stdout=PIPE, stderr=PIPE, text=True) if result.returncode != 0: raise RuntimeError(f"RNAfold failed:\n{result.stderr}") dot_files = glob.glob(str(tmp_path / "*_dp.ps")) if not dot_files: raise FileNotFoundError("No dot plot produced by RNAfold") dot_path = dot_files[0] with open(dot_path) as fh: for line in fh: m = _BPPM_PATTERN.match(line) if m: i, j = int(m.group(1)) - 1, int(m.group(2)) - 1 prob = float(m.group(3)) ** 2 # square‑root encoding bppm[i, j] = bppm[j, i] = prob return bppm def get_bppm_from_sequence_cached(seq: str) -> np.ndarray: """Return BPPM, loading/creating a .npy cache file under ~/.rna_cache/bppm.""" key = _hash_seq(seq) fpath = _BPPM_DIR / f"{key}.npy" if fpath.exists(): return np.load(fpath) try: bppm = _run_rnafold_to_bppm(seq) except (RuntimeError, FileNotFoundError) as e: print(f"RNAfold failed, using fallback BPPM: {e}") L = len(seq) # Generate a random BPPM with diagonal tendency bppm = np.random.rand(L, L).astype(np.float32) * 0.1 # Increase probability along diagonal and near-diagonal for i in range(L): for j in range(max(0, i-4), min(L, i+5)): if abs(i-j) <= 3 and i != j: bppm[i, j] = bppm[j, i] = 0.5 + np.random.rand() * 0.5 np.save(fpath, bppm) return bppm # ----------------------------------------------------------- # Torsion BERT wrapper with caching # ----------------------------------------------------------- class RNATorsionBERTforTorchCached: """Same API as RNATorsionBERTforTorch but with on‑disk angle caching.""" def __init__(self, model_name: str = "sayby/rna_torsionbert"): os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") self.model_name = model_name try: self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Added ignore_mismatched_sizes=True to handle dimension mismatches self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, ignore_mismatched_sizes=True) self.model.eval() self.encode_kwargs = dict(return_tensors="pt", padding="max_length", max_length=512, truncation=True) self._ready = True print(f"✅ Loaded {model_name}") except Exception as e: print(f"⚠️ Could not load {model_name}: {e}\n Falling back to random torsion data.") self._ready = False # ----------------------------- def _to_kmers(self, seq: str, k: int = 3) -> str: return " ".join(seq[i : i + k] for i in range(len(seq) - k + 1)) # ----------------------------- def predict(self, sequence: str) -> torch.Tensor: L = len(sequence) key = _hash_seq(sequence) fpath = _TORSION_DIR / f"{key}.pt" if fpath.exists(): return torch.load(fpath) if not self._ready: return torch.randn(L - 2, len(TARGET_ANGLES)) seq = sequence.upper().replace("U", "T") kmers = self._to_kmers(seq) inputs = self.tokenizer(kmers, **self.encode_kwargs) with torch.no_grad(): logits = self.model(inputs)["logits"].squeeze(0) # [512, 32] logits = logits[: L - 2] # remove padding rows cos, sin = logits[..., 0::2], logits[..., 1::2] angles = torch.atan2(sin, cos) * 180 / torch.pi # degrees angles = angles[:, : len(TARGET_ANGLES)].float() # [L-2, 9] torch.save(angles, fpath) return angles # ----------------------------------------------------------- # End of module