MisterAI commited on
Commit
c8d3212
·
verified ·
1 Parent(s): a523a29

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #MisterAI/Docker_Ollama
2
+ #app.py_02
3
+ #https://huggingface.co/spaces/MisterAI/Docker_Ollama/
4
+
5
+ import logging
6
+ import requests
7
+ from pydantic import BaseModel
8
+ from langchain_community.llms import Ollama
9
+ from langchain.callbacks.manager import CallbackManager
10
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
+ import gradio as gr
12
+ import threading
13
+ import subprocess
14
+ from bs4 import BeautifulSoup
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Cache pour stocker les modèles déjà chargés
20
+ loaded_models = {}
21
+
22
+ # Variable pour suivre l'état du bouton "Stop"
23
+ stop_flag = False
24
+
25
+ def get_model_list():
26
+ url = "https://ollama.com/search"
27
+ response = requests.get(url)
28
+
29
+ # Vérifier si la requête a réussi
30
+ if response.status_code == 200:
31
+ # Utiliser BeautifulSoup pour analyser le HTML
32
+ soup = BeautifulSoup(response.text, 'html.parser')
33
+ model_list = []
34
+
35
+ # Trouver tous les éléments de modèle
36
+ model_elements = soup.find_all('li', {'x-test-model': True})
37
+
38
+ for model_element in model_elements:
39
+ model_name = model_element.find('span', {'x-test-search-response-title': True}).text.strip()
40
+ size_elements = model_element.find_all('span', {'x-test-size': True})
41
+
42
+ # Filtrer les modèles par taille
43
+ for size_element in size_elements:
44
+ size = size_element.text.strip()
45
+ if size.endswith('m'):
46
+ # Tous les modèles en millions sont acceptés
47
+ model_list.append(f"{model_name}:{size}")
48
+ elif size.endswith('b'):
49
+ # Convertir les modèles en milliards en milliards
50
+ size_value = float(size[:-1])
51
+ if size_value <= 10: # Filtrer les modèles <= 10 milliards de paramètres
52
+ model_list.append(f"{model_name}:{size}")
53
+
54
+ return model_list
55
+ else:
56
+ logger.error(f"Erreur lors de la récupération de la liste des modèles : {response.status_code} - {response.text}")
57
+ return []
58
+
59
+ def get_llm(model_name):
60
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
61
+ return Ollama(model=model_name, callback_manager=callback_manager)
62
+
63
+ class InputData(BaseModel):
64
+ model_name: str
65
+ input: str
66
+ max_tokens: int = 256
67
+ temperature: float = 0.7
68
+
69
+ def pull_model(model_name):
70
+ try:
71
+ # Exécuter la commande pour tirer le modèle
72
+ subprocess.run(["ollama", "pull", model_name], check=True)
73
+ logger.info(f"Model {model_name} pulled successfully.")
74
+ except subprocess.CalledProcessError as e:
75
+ logger.error(f"Failed to pull model {model_name}: {e}")
76
+ raise
77
+
78
+ def check_and_load_model(model_name):
79
+ # Vérifier si le modèle est déjà chargé
80
+ if model_name in loaded_models:
81
+ logger.info(f"Model {model_name} is already loaded.")
82
+ return loaded_models[model_name]
83
+ else:
84
+ logger.info(f"Loading model {model_name}...")
85
+ # Tirer le modèle si nécessaire
86
+ pull_model(model_name)
87
+ llm = get_llm(model_name)
88
+ loaded_models[model_name] = llm
89
+ return llm
90
+
91
+ # Interface Gradio
92
+ def gradio_interface(model_name, input, max_tokens, temperature, stop_button=None):
93
+ global stop_flag
94
+ stop_flag = False
95
+ response = None # Initialisez la variable response ici
96
+
97
+ def worker():
98
+ nonlocal response # Utilisez nonlocal pour accéder à la variable response définie dans la fonction parente
99
+ llm = check_and_load_model(model_name)
100
+ response = llm(input, max_tokens=max_tokens, temperature=temperature)
101
+
102
+ thread = threading.Thread(target=worker)
103
+ thread.start()
104
+ thread.join()
105
+
106
+ if stop_flag:
107
+ return "Processing stopped by the user."
108
+ else:
109
+ return response # Maintenant, response est accessible ici
110
+
111
+ model_list = get_model_list()
112
+
113
+ demo = gr.Interface(
114
+ fn=gradio_interface,
115
+ inputs=[
116
+ gr.Dropdown(model_list, label="Select Model", value="mistral:7b"),
117
+ gr.Textbox(label="Input"),
118
+ gr.Slider(minimum=1, maximum=2048, step=1, label="Max Tokens", value=256),
119
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7),
120
+ gr.Button(value="Stop", variant="stop")
121
+ ],
122
+ outputs=[
123
+ gr.Textbox(label="Output")
124
+ ],
125
+ title="Ollama Demo"
126
+ )
127
+
128
+ def stop_processing():
129
+ global stop_flag
130
+ stop_flag = True
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch(server_name="0.0.0.0", server_port=7860)
134
+