File size: 2,822 Bytes
a44a44f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Flask adapter for Hugging Face Spaces using a merged model
"""
import os
import sys
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set environment variables for Hugging Face
merged_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "merged_model")
os.environ["MODEL_DIR"] = merged_model_dir
logger.info(f"Setting MODEL_DIR to: {merged_model_dir}")

# Modify these variables in server.py to use merged model
os.environ["USE_MERGED_MODEL"] = "TRUE"  # We'll check for this in the patched server
os.environ["BASE_MODEL_PATH"] = merged_model_dir  # Use local path instead of HF Hub ID

# Force offline mode to prevent HF Hub access attempts
os.environ["HF_HUB_OFFLINE"] = "1"

# Import the server module first
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import server

# Patch the server's load_gemma_model function to use our merged model
def patched_load_gemma_model():
    """Patched function to load merged model"""
    try:
        logger.info("Loading merged model...")
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch
        
        model_path = os.environ["BASE_MODEL_PATH"]
        
        # Check if model directory exists
        if not os.path.exists(model_path):
            logger.error(f"Model directory not found: {model_path}")
            return False
            
        # Load tokenizer
        logger.info(f"Loading tokenizer from {model_path}...")
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            use_fast=True,
            local_files_only=True
        )
        
        # Load model
        logger.info(f"Loading model from {model_path}...")
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            local_files_only=True
        )
        
        # Set model to evaluation mode
        model.eval()
        
        # Update global variables in server module
        server.model = model
        server.tokenizer = tokenizer
        server.USE_GEMMA_MODEL = True
        
        logger.info("Merged model loaded successfully!")
        return True
    except Exception as e:
        logger.error(f"Error loading merged model: {str(e)}")
        return False

# Replace the original function with our patched version
server.load_gemma_model = patched_load_gemma_model

# Initialize the server
from server import app, initialize_server
initialize_server()

# Configure for Hugging Face Spaces
if __name__ == "__main__":
    # Port is set by Hugging Face Spaces
    port = int(os.environ.get("PORT", 7860))
    logger.info(f"Starting server on port {port}...")
    app.run(host="0.0.0.0", port=port)