import os import sys import tempfile import json import math import timm import einops import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import gradio as gr from huggingface_hub import snapshot_download from typing import List, Union, Dict import torchvision.transforms as transforms # Vision Model class TimmCNNModel(nn.Module): def __init__(self, num_classes: int = 8, model_name: str = "efficientnet_b0"): super().__init__() self.backbone = timm.create_model( 'efficientnet_b0', pretrained=True, num_classes=0, ) self.feature_dim = self.backbone.num_features self.classifier = nn.Sequential( nn.Dropout(0.1), nn.Linear(self.feature_dim, 512), nn.ReLU(inplace=True), nn.BatchNorm1d(512), nn.Dropout(0.1), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256, num_classes) ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.forward_features(x) logits = self.classifier(features) return logits # Projector Model class Projector_4to3d(nn.Module): def __init__(self, cnn_dim: int = 1280, llm_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1): super().__init__() self.cnn_dim = cnn_dim self.llm_dim = llm_dim # Spatial positional embeddings for 8x8 grid self.spatial_pos_embed = nn.Parameter(torch.randn(64, cnn_dim)) # Multi-scale feature processing self.spatial_conv = nn.Conv2d(cnn_dim, cnn_dim // 2, 1) self.global_pool = nn.AdaptiveAvgPool2d(1) # Enhanced projection layers self.input_proj = nn.Sequential( nn.Linear(cnn_dim, llm_dim), nn.LayerNorm(llm_dim), nn.ReLU(), nn.Dropout(dropout) ) # Multi-head self-attention for spatial reasoning self.spatial_attention = nn.MultiheadAttention( embed_dim=llm_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Cross-attention for text-image alignment self.cross_attention = nn.MultiheadAttention( embed_dim=llm_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) self.norm1 = nn.LayerNorm(llm_dim) self.norm2 = nn.LayerNorm(llm_dim) # Enhanced FFN self.ffn = nn.Sequential( nn.Linear(llm_dim, llm_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(llm_dim * 4, llm_dim), nn.Dropout(dropout) ) self.norm3 = nn.LayerNorm(llm_dim) # Token compression layer self.compress_tokens = nn.Parameter(torch.randn(32, llm_dim)) self.token_compression = nn.MultiheadAttention( embed_dim=llm_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) def forward(self, cnn_features: torch.Tensor, text_embeddings: torch.Tensor = None) -> torch.Tensor: batch_size = cnn_features.shape[0] # Multi-scale processing spatial_features = self.spatial_conv(cnn_features) global_context = self.global_pool(cnn_features).flatten(1) # Flatten spatial features and add positional encoding x = einops.rearrange(cnn_features, "b c h w -> b (h w) c") pos_embeddings = self.spatial_pos_embed.unsqueeze(0).expand(batch_size, -1, -1) x = x + pos_embeddings # Project to LLM dimension x = self.input_proj(x) # Self-attention for spatial reasoning attended_x, spatial_attn_weights = self.spatial_attention(x, x, x) x = self.norm1(x + attended_x) # Cross-attention with text (if available) if text_embeddings is not None: text_embeddings_float = text_embeddings.float() cross_attended, cross_attn_weights = self.cross_attention(x, text_embeddings_float, text_embeddings_float) x = self.norm2(x + cross_attended) # FFN ffn_out = self.ffn(x) x = self.norm3(x + ffn_out) # Optional token compression compress_queries = self.compress_tokens.unsqueeze(0).expand(batch_size, -1, -1) compressed_x, _ = self.token_compression(compress_queries, x, x) return compressed_x # Main VLM Model class Model(nn.Module): def __init__(self, image_model, language_model, projector, tokenizer, prompt="Describe this image:"): super().__init__() self.image_model = image_model self.language_model = language_model self.projector = projector self.tokenizer = tokenizer self.eos_token = tokenizer.eos_token self.prompt = prompt device = next(self.language_model.parameters()).device self.image_model.to(device) self.projector.to(device) # Create prompt embeddings prompt_tokens = tokenizer(text=prompt, return_tensors="pt").input_ids.to(device) prompt_embeddings = language_model.get_input_embeddings()(prompt_tokens).detach() self.register_buffer('prompt_embeddings', prompt_embeddings) @property def device(self): return next(self.parameters()).device def generate(self, patches: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]): device = self.device patches = patches.to(device) image_features = self.image_model.backbone.forward_features(patches) patch_embeddings = self.projector(image_features) patch_embeddings = patch_embeddings.to(torch.bfloat16) embeddings = torch.cat([ self.prompt_embeddings.expand(patches.size(0), -1, -1), patch_embeddings, ], dim=1) prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device) patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device) attention_mask = torch.cat([prompt_mask, patch_mask], dim=1) return self.language_model.generate( inputs_embeds=embeddings, attention_mask=attention_mask, **generator_kwargs ) vlm_model = None tokenizer = None transform = None def download_and_load_models(): global vlm_model, tokenizer, transform print("Starting model download and initialization...") if torch.cuda.is_available(): device = torch.device("cuda:0") print("CUDA available - using GPU") else: device = torch.device("cpu") print("CUDA not available - using CPU") repo_id = "aneeshm44/regfinal" print(f"Downloading from repo: {repo_id}") local_dir = tempfile.mkdtemp(prefix="regfinal_") print(f"Local directory: {local_dir}") try: snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=local_dir, allow_patterns=[ "llmweights/*", "imagemodelweights/finalcheckpoint.pth", "projectorweights/projector.pth" ], local_dir_use_symlinks=False, ) print("Download completed successfully") except Exception as e: print(f"Download failed: {e}") raise e llm_path = os.path.join(local_dir, "llmweights") image_weights_path = os.path.join(local_dir, "imagemodelweights", "finalcheckpoint.pth") projector_weights_path = os.path.join(local_dir, "projectorweights", "projector.pth") print("Loading language model...") try: language_model = AutoModelForCausalLM.from_pretrained( llm_path, trust_remote_code=True, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) language_model.eval() language_model.to(device) tokenizer = AutoTokenizer.from_pretrained(llm_path) print("Language model loaded successfully") except Exception as e: print(f"Language model loading failed: {e}") raise e print("Loading vision model...") try: image_model = TimmCNNModel(num_classes=8) weights = torch.load(image_weights_path, map_location=device) image_model.load_state_dict(weights['model_state_dict']) for param in image_model.parameters(): param.requires_grad = False image_model.eval() image_model.to(device) print("Vision model loaded successfully") except Exception as e: print(f"Vision model loading failed: {e}") raise e print("Loading projector...") try: projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8) weights = torch.load(projector_weights_path, map_location=device) projector.load_state_dict(weights) for param in projector.parameters(): param.requires_grad = False projector.eval() projector.to(device) print("Projector loaded successfully") except Exception as e: print(f"Projector loading failed: {e}") raise e print("Creating VLM model...") try: vlm_model = Model(image_model, language_model, projector, tokenizer, prompt="Describe this image:") vlm_model = vlm_model.to(device) print("VLM model created successfully") except Exception as e: print(f"VLM model creation failed: {e}") raise e transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) print("All models loaded successfully!") def tensor_to_pil_image(tensor): img_tensor = tensor.squeeze(0) img_tensor = torch.clamp(img_tensor, 0, 1) img_array = img_tensor.permute(1, 2, 0).numpy() img_array = (img_array * 255).astype(np.uint8) return Image.fromarray(img_array) def on_image_upload(image): if image is not None: return "Image processed, click 'Generate Report' to produce report." else: return "Models are loaded, upload the Image to get started." def describe_image(image, temperature, top_p, max_tokens, progress=gr.Progress()): global vlm_model, tokenizer, transform if vlm_model is None: return "Models not loaded yet. Please wait for initialization to complete.", None if image is None: return "Please upload an image.", None try: progress(0.1, desc="Starting image processing...") # Preprocess image if isinstance(image, str): image = Image.open(image).convert('RGB') elif hasattr(image, 'convert'): image = image.convert('RGB') progress(0.3, desc="Applying image transformations...") image_tensor = transform(image).unsqueeze(0) # Add batch dimension processed_image = tensor_to_pil_image(image_tensor) progress(0.5, desc="Setting up generation parameters...") # Generation parameters generator_kwargs = { "max_new_tokens": int(max_tokens), "do_sample": True, "temperature": float(temperature), "top_p": float(top_p), "pad_token_id": tokenizer.eos_token_id } progress(0.7, desc="Generating pathology report...") # Generate description with torch.no_grad(): output_ids = vlm_model.generate(image_tensor, generator_kwargs) text = tokenizer.decode(output_ids[0], skip_special_tokens=True) progress(0.9, desc="Finalizing report...") if "Describe this image:" in text: description = text.split("Describe this image:")[-1].strip() else: description = text.strip() result_text = description if description else "Unable to generate description." progress(1.0, desc="Complete!") return result_text, processed_image except Exception as e: return f"Error processing image: {str(e)}", None def reset_interface(): return None, "Models are loaded, upload the WSI file to get started.", None try: download_and_load_models() initial_status = "Models are loaded, upload the WSI file to get started." except Exception as e: initial_status = f"Failed to load models: {str(e)}" def create_interface(): with gr.Blocks(title="WSI Pathology Report using Gemma3n") as demo: gr.Markdown("# WSI Pathology Report using Gemma3n") gr.Markdown("Upload a pathology WSI to get concise a report") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload WSI file") # Generation parameters with gr.Row(): temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Temperature", info="Lower values give consistent results and Higher values produce creative results" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p", info="Lower values use a more focused vocabulary for sampling compared to a more diverse vocabulary in Higher values" ) max_tokens_slider = gr.Slider( minimum=10, maximum=200, value=100, step=10, label="Max Tokens for generation" ) with gr.Row(): submit_btn = gr.Button("Generate Report", variant="primary") reset_btn = gr.Button("Reset", variant="secondary") with gr.Column(): output_text = gr.Textbox( label="Pathology Report", lines=8, value=initial_status, show_copy_button=True ) processed_image = gr.Image( label="Processed WSI", show_download_button=True ) image_input.change( fn=on_image_upload, inputs=[image_input], outputs=[output_text] ) submit_btn.click( fn=describe_image, inputs=[image_input, temperature_slider, top_p_slider, max_tokens_slider], outputs=[output_text, processed_image], show_progress=True ) reset_btn.click( fn=reset_interface, inputs=[], outputs=[image_input, output_text, processed_image] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )