JayeshC137 commited on
Commit
a44a44f
·
verified ·
1 Parent(s): 69d30bd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flask adapter for Hugging Face Spaces using a merged model
3
+ """
4
+ import os
5
+ import sys
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Set environment variables for Hugging Face
13
+ merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model")
14
+ os.environ["MODEL_DIR"] = merged_model_dir
15
+ logger.info(f"Setting MODEL_DIR to: {merged_model_dir}")
16
+
17
+ # Modify these variables in server.py to use merged model
18
+ os.environ["USE_MERGED_MODEL"] = "TRUE" # We'll check for this in the patched server
19
+ os.environ["BASE_MODEL_PATH"] = merged_model_dir # Use local path instead of HF Hub ID
20
+
21
+ # Force offline mode to prevent HF Hub access attempts
22
+ os.environ["HF_HUB_OFFLINE"] = "1"
23
+
24
+ # Import the server module first
25
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
26
+ import server
27
+
28
+ # Patch the server's load_gemma_model function to use our merged model
29
+ def patched_load_gemma_model():
30
+ """Patched function to load merged model"""
31
+ try:
32
+ logger.info("Loading merged model...")
33
+
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer
35
+ import torch
36
+
37
+ model_path = os.environ["BASE_MODEL_PATH"]
38
+
39
+ # Check if model directory exists
40
+ if not os.path.exists(model_path):
41
+ logger.error(f"Model directory not found: {model_path}")
42
+ return False
43
+
44
+ # Load tokenizer
45
+ logger.info(f"Loading tokenizer from {model_path}...")
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ model_path,
48
+ use_fast=True,
49
+ local_files_only=True
50
+ )
51
+
52
+ # Load model
53
+ logger.info(f"Loading model from {model_path}...")
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ model_path,
56
+ torch_dtype=torch.float16,
57
+ device_map="auto",
58
+ local_files_only=True
59
+ )
60
+
61
+ # Set model to evaluation mode
62
+ model.eval()
63
+
64
+ # Update global variables in server module
65
+ server.model = model
66
+ server.tokenizer = tokenizer
67
+ server.USE_GEMMA_MODEL = True
68
+
69
+ logger.info("Merged model loaded successfully!")
70
+ return True
71
+ except Exception as e:
72
+ logger.error(f"Error loading merged model: {str(e)}")
73
+ return False
74
+
75
+ # Replace the original function with our patched version
76
+ server.load_gemma_model = patched_load_gemma_model
77
+
78
+ # Initialize the server
79
+ from server import app, initialize_server
80
+ initialize_server()
81
+
82
+ # Configure for Hugging Face Spaces
83
+ if __name__ == "__main__":
84
+ # Port is set by Hugging Face Spaces
85
+ port = int(os.environ.get("PORT", 7860))
86
+ logger.info(f"Starting server on port {port}...")
87
+ app.run(host="0.0.0.0", port=port)