Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Flask adapter for Hugging Face Spaces using a merged model
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Set environment variables for Hugging Face
|
13 |
+
merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model")
|
14 |
+
os.environ["MODEL_DIR"] = merged_model_dir
|
15 |
+
logger.info(f"Setting MODEL_DIR to: {merged_model_dir}")
|
16 |
+
|
17 |
+
# Modify these variables in server.py to use merged model
|
18 |
+
os.environ["USE_MERGED_MODEL"] = "TRUE" # We'll check for this in the patched server
|
19 |
+
os.environ["BASE_MODEL_PATH"] = merged_model_dir # Use local path instead of HF Hub ID
|
20 |
+
|
21 |
+
# Force offline mode to prevent HF Hub access attempts
|
22 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
23 |
+
|
24 |
+
# Import the server module first
|
25 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
26 |
+
import server
|
27 |
+
|
28 |
+
# Patch the server's load_gemma_model function to use our merged model
|
29 |
+
def patched_load_gemma_model():
|
30 |
+
"""Patched function to load merged model"""
|
31 |
+
try:
|
32 |
+
logger.info("Loading merged model...")
|
33 |
+
|
34 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
35 |
+
import torch
|
36 |
+
|
37 |
+
model_path = os.environ["BASE_MODEL_PATH"]
|
38 |
+
|
39 |
+
# Check if model directory exists
|
40 |
+
if not os.path.exists(model_path):
|
41 |
+
logger.error(f"Model directory not found: {model_path}")
|
42 |
+
return False
|
43 |
+
|
44 |
+
# Load tokenizer
|
45 |
+
logger.info(f"Loading tokenizer from {model_path}...")
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
47 |
+
model_path,
|
48 |
+
use_fast=True,
|
49 |
+
local_files_only=True
|
50 |
+
)
|
51 |
+
|
52 |
+
# Load model
|
53 |
+
logger.info(f"Loading model from {model_path}...")
|
54 |
+
model = AutoModelForCausalLM.from_pretrained(
|
55 |
+
model_path,
|
56 |
+
torch_dtype=torch.float16,
|
57 |
+
device_map="auto",
|
58 |
+
local_files_only=True
|
59 |
+
)
|
60 |
+
|
61 |
+
# Set model to evaluation mode
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
# Update global variables in server module
|
65 |
+
server.model = model
|
66 |
+
server.tokenizer = tokenizer
|
67 |
+
server.USE_GEMMA_MODEL = True
|
68 |
+
|
69 |
+
logger.info("Merged model loaded successfully!")
|
70 |
+
return True
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Error loading merged model: {str(e)}")
|
73 |
+
return False
|
74 |
+
|
75 |
+
# Replace the original function with our patched version
|
76 |
+
server.load_gemma_model = patched_load_gemma_model
|
77 |
+
|
78 |
+
# Initialize the server
|
79 |
+
from server import app, initialize_server
|
80 |
+
initialize_server()
|
81 |
+
|
82 |
+
# Configure for Hugging Face Spaces
|
83 |
+
if __name__ == "__main__":
|
84 |
+
# Port is set by Hugging Face Spaces
|
85 |
+
port = int(os.environ.get("PORT", 7860))
|
86 |
+
logger.info(f"Starting server on port {port}...")
|
87 |
+
app.run(host="0.0.0.0", port=port)
|