""" Flask adapter for Hugging Face Spaces using a merged model """ import os import sys import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set environment variables for Hugging Face merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model") os.environ["MODEL_DIR"] = merged_model_dir logger.info(f"Setting MODEL_DIR to: {merged_model_dir}") # Modify these variables in server.py to use merged model os.environ["USE_MERGED_MODEL"] = "TRUE" # We'll check for this in the patched server os.environ["BASE_MODEL_PATH"] = merged_model_dir # Use local path instead of HF Hub ID # Force offline mode to prevent HF Hub access attempts os.environ["HF_HUB_OFFLINE"] = "1" # Import the server module first sys.path.append(os.path.dirname(os.path.abspath(__file__))) import server # Patch the server's load_gemma_model function to use our merged model def patched_load_gemma_model(): """Patched function to load merged model""" try: logger.info("Loading merged model...") from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_path = os.environ["BASE_MODEL_PATH"] # Check if model directory exists if not os.path.exists(model_path): logger.error(f"Model directory not found: {model_path}") return False # Load tokenizer logger.info(f"Loading tokenizer from {model_path}...") tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=True, local_files_only=True ) # Load model logger.info(f"Loading model from {model_path}...") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", local_files_only=True ) # Set model to evaluation mode model.eval() # Update global variables in server module server.model = model server.tokenizer = tokenizer server.USE_GEMMA_MODEL = True logger.info("Merged model loaded successfully!") return True except Exception as e: logger.error(f"Error loading merged model: {str(e)}") return False # Replace the original function with our patched version server.load_gemma_model = patched_load_gemma_model # Initialize the server from server import app, initialize_server initialize_server() # Configure for Hugging Face Spaces if __name__ == "__main__": # Port is set by Hugging Face Spaces port = int(os.environ.get("PORT", 7860)) logger.info(f"Starting server on port {port}...") app.run(host="0.0.0.0", port=port)