JayeshC137's picture
Upload app.py
a44a44f verified
raw
history blame
2.82 kB
"""
Flask adapter for Hugging Face Spaces using a merged model
"""
import os
import sys
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set environment variables for Hugging Face
merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model")
os.environ["MODEL_DIR"] = merged_model_dir
logger.info(f"Setting MODEL_DIR to: {merged_model_dir}")
# Modify these variables in server.py to use merged model
os.environ["USE_MERGED_MODEL"] = "TRUE" # We'll check for this in the patched server
os.environ["BASE_MODEL_PATH"] = merged_model_dir # Use local path instead of HF Hub ID
# Force offline mode to prevent HF Hub access attempts
os.environ["HF_HUB_OFFLINE"] = "1"
# Import the server module first
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import server
# Patch the server's load_gemma_model function to use our merged model
def patched_load_gemma_model():
"""Patched function to load merged model"""
try:
logger.info("Loading merged model...")
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_path = os.environ["BASE_MODEL_PATH"]
# Check if model directory exists
if not os.path.exists(model_path):
logger.error(f"Model directory not found: {model_path}")
return False
# Load tokenizer
logger.info(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=True,
local_files_only=True
)
# Load model
logger.info(f"Loading model from {model_path}...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
local_files_only=True
)
# Set model to evaluation mode
model.eval()
# Update global variables in server module
server.model = model
server.tokenizer = tokenizer
server.USE_GEMMA_MODEL = True
logger.info("Merged model loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading merged model: {str(e)}")
return False
# Replace the original function with our patched version
server.load_gemma_model = patched_load_gemma_model
# Initialize the server
from server import app, initialize_server
initialize_server()
# Configure for Hugging Face Spaces
if __name__ == "__main__":
# Port is set by Hugging Face Spaces
port = int(os.environ.get("PORT", 7860))
logger.info(f"Starting server on port {port}...")
app.run(host="0.0.0.0", port=port)