Spaces:
Sleeping
Sleeping
| """ | |
| Flask adapter for Hugging Face Spaces using a merged model | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set environment variables for Hugging Face | |
| merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model") | |
| os.environ["MODEL_DIR"] = merged_model_dir | |
| logger.info(f"Setting MODEL_DIR to: {merged_model_dir}") | |
| # Modify these variables in server.py to use merged model | |
| os.environ["USE_MERGED_MODEL"] = "TRUE" # We'll check for this in the patched server | |
| os.environ["BASE_MODEL_PATH"] = merged_model_dir # Use local path instead of HF Hub ID | |
| # Force offline mode to prevent HF Hub access attempts | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| # Import the server module first | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| import server | |
| # Patch the server's load_gemma_model function to use our merged model | |
| def patched_load_gemma_model(): | |
| """Patched function to load merged model""" | |
| try: | |
| logger.info("Loading merged model...") | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| model_path = os.environ["BASE_MODEL_PATH"] | |
| # Check if model directory exists | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model directory not found: {model_path}") | |
| return False | |
| # Load tokenizer | |
| logger.info(f"Loading tokenizer from {model_path}...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| use_fast=True, | |
| local_files_only=True | |
| ) | |
| # Load model | |
| logger.info(f"Loading model from {model_path}...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| local_files_only=True | |
| ) | |
| # Set model to evaluation mode | |
| model.eval() | |
| # Update global variables in server module | |
| server.model = model | |
| server.tokenizer = tokenizer | |
| server.USE_GEMMA_MODEL = True | |
| logger.info("Merged model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading merged model: {str(e)}") | |
| return False | |
| # Replace the original function with our patched version | |
| server.load_gemma_model = patched_load_gemma_model | |
| # Initialize the server | |
| from server import app, initialize_server | |
| initialize_server() | |
| # Configure for Hugging Face Spaces | |
| if __name__ == "__main__": | |
| # Port is set by Hugging Face Spaces | |
| port = int(os.environ.get("PORT", 7860)) | |
| logger.info(f"Starting server on port {port}...") | |
| app.run(host="0.0.0.0", port=port) | |