TakiTakiTa commited on
Commit
01ccf7c
·
verified ·
1 Parent(s): 86a6ee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -41
app.py CHANGED
@@ -2,41 +2,43 @@ import gradio as gr
2
  import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
- # Global dictionary to store loaded models, keyed by model name.
7
- loaded_models = {}
8
- # Global variable to track the currently loaded model's name.
9
- current_model_name = ""
 
 
 
 
 
 
 
10
 
11
  @spaces.GPU
12
  def load_model(model_name: str):
13
- global loaded_models, current_model_name
14
  try:
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- torch_dtype=torch.bfloat16,
18
- device_map="auto"
19
- )
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- loaded_models[model_name] = (model, tokenizer)
22
- current_model_name = model_name # update global state
23
- return f"Model '{model_name}' loaded successfully."
24
  except Exception as e:
25
- return f"Failed to load model '{model_name}': {str(e)}"
26
 
27
  @spaces.GPU
28
- def generate(prompt, history):
29
- global loaded_models, current_model_name
30
- print("loaded models: ", loaded_models)
31
- print("current model: ", current_model_name)
32
- if current_model_name == "" or current_model_name not in loaded_models:
33
- return "Please load a model first by entering a model name and clicking the Load Model button."
34
-
35
- model, tokenizer = loaded_models[current_model_name]
36
-
37
- # Prepare the messages (with a system prompt and the user's prompt)
38
  messages = [
39
- {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
40
  {"role": "user", "content": prompt}
41
  ]
42
  text = tokenizer.apply_chat_template(
@@ -44,38 +46,47 @@ def generate(prompt, history):
44
  tokenize=False,
45
  add_generation_prompt=True
46
  )
47
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
48
 
49
  generated_ids = model.generate(
50
- **model_inputs,
51
  max_new_tokens=512
52
  )
53
- # Remove the input tokens from the generated tokens.
54
  generated_ids = [
55
- output_ids[len(input_ids):]
56
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
57
  ]
58
-
59
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
60
- return response
 
61
 
62
- # Build the Gradio UI using Blocks.
63
  with gr.Blocks() as demo:
64
  gr.Markdown("## Model Loader")
65
  with gr.Row():
66
- model_name_input = gr.Textbox(
67
  label="Model Name",
68
  value="agentica-org/DeepScaleR-1.5B-Preview",
69
  placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
70
  )
71
  load_button = gr.Button("Load Model")
72
- load_status = gr.Textbox(label="Status", interactive=False)
73
-
74
- # When the Load Model button is clicked, load_model is called.
75
- load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
 
 
76
 
77
  gr.Markdown("## Chat Interface")
78
- # Create the chat interface without extra_inputs.
79
- chat_interface = gr.ChatInterface(fn=generate)
 
 
 
 
 
 
 
 
80
 
81
  demo.launch(share=True)
 
2
  import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ from functools import lru_cache
6
 
7
+ # Cache the loaded model and tokenizer based on the model name.
8
+ @lru_cache(maxsize=1)
9
+ def get_model(model_name: str):
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto"
14
+ )
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ print("Cached model loaded for:", model_name)
17
+ return model, tokenizer
18
 
19
  @spaces.GPU
20
  def load_model(model_name: str):
 
21
  try:
22
+ # Call the caching function. (This will load the model if not already cached.)
23
+ model, tokenizer = get_model(model_name)
24
+ # Print to verify caching (will show up in the logs).
25
+ print("Loaded model:", model_name)
26
+ return f"Model '{model_name}' loaded successfully.", model_name
 
 
 
 
27
  except Exception as e:
28
+ return f"Failed to load model '{model_name}': {str(e)}", ""
29
 
30
  @spaces.GPU
31
+ def generate_response(prompt, chat_history, current_model_name):
32
+ if current_model_name == "":
33
+ return "Please load a model first by entering a model name and clicking the Load Model button.", current_model_name, chat_history
34
+ try:
35
+ model, tokenizer = get_model(current_model_name)
36
+ except Exception as e:
37
+ return f"Error loading model: {str(e)}", current_model_name, chat_history
38
+
39
+ # Prepare conversation messages.
 
40
  messages = [
41
+ {"role": "system", "content": "You are a friendly, helpful assistant."},
42
  {"role": "user", "content": prompt}
43
  ]
44
  text = tokenizer.apply_chat_template(
 
46
  tokenize=False,
47
  add_generation_prompt=True
48
  )
49
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
 
51
  generated_ids = model.generate(
52
+ **inputs,
53
  max_new_tokens=512
54
  )
55
+ # Strip out the prompt tokens.
56
  generated_ids = [
57
+ output_ids[len(input_ids):]
58
+ for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
59
  ]
 
60
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
61
+ chat_history.append([prompt, response])
62
+ return "", current_model_name, chat_history
63
 
 
64
  with gr.Blocks() as demo:
65
  gr.Markdown("## Model Loader")
66
  with gr.Row():
67
+ model_input = gr.Textbox(
68
  label="Model Name",
69
  value="agentica-org/DeepScaleR-1.5B-Preview",
70
  placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
71
  )
72
  load_button = gr.Button("Load Model")
73
+ status_output = gr.Textbox(label="Status", interactive=False)
74
+ # Hidden state for the model name.
75
+ model_state = gr.State("")
76
+
77
+ # When the load button is clicked, update status and state.
78
+ load_button.click(fn=load_model, inputs=model_input, outputs=[status_output, model_state])
79
 
80
  gr.Markdown("## Chat Interface")
81
+ chatbot = gr.Chatbot()
82
+ prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
83
+
84
+ def chat_submit(prompt, history, current_model_name):
85
+ output, updated_state, history = generate_response(prompt, history, current_model_name)
86
+ return "", updated_state, history
87
+
88
+ # When a prompt is submitted, clear the prompt textbox and update chat history and model state.
89
+ prompt_box.submit(fn=chat_submit, inputs=[prompt_box, chatbot, model_state],
90
+ outputs=[prompt_box, model_state, chatbot])
91
 
92
  demo.launch(share=True)