Gggggg / app.py
Uhhy's picture
Update app.py
3c88fa1 verified
from dotenv import load_dotenv
import os
import json
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
AutoModelForTextToWaveform
)
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
import multiprocessing
import uuid
import numpy as np
from diffusers import FluxPipeline
from tqdm import tqdm
from google.cloud import storage
import io
import spaces
spaces.GPU(duration=0)
load_dotenv()
app = FastAPI()
default_language = "es"
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
if GCS_BUCKET_NAME is None:
raise ValueError("La variable de entorno GCS_BUCKET_NAME no está definida.")
GCS_CREDENTIALS = os.getenv("GCS_CREDENTIALS")
if GCS_CREDENTIALS is None:
raise ValueError("La variable de entorno GCS_CREDENTIALS no está definida.")
gcs_credentials_dict = json.loads(GCS_CREDENTIALS)
with open('gcs_credentials.json', 'w') as f:
json.dump(gcs_credentials_dict, f)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gcs_credentials.json"
storage_client = storage.Client()
bucket = storage_client.bucket(GCS_BUCKET_NAME)
AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
AutoModelForCausalLM.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
AutoTokenizer.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
class ChatbotService:
def __init__(self):
self.model_name = "response_model"
self.tokenizer_name = "response_tokenizer"
self.model = self.load_model()
self.tokenizer = self.load_tokenizer()
def get_response(self, user_id, message, language=default_language):
if self.model is None or self.tokenizer is None:
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
input_text = f"Usuario: {message} Asistente:"
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2,
early_stopping=True)
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
response = response.replace(input_text, "").strip()
return response
def load_model(self):
model_path = f"gs://{GCS_BUCKET_NAME}/model_{self.model_name}"
if bucket.blob(f"model_{self.model_name}").exists():
blob = bucket.blob(f"model_{self.model_name}")
model_bytes = blob.download_as_bytes()
model_buffer = io.BytesIO(model_bytes)
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda")))
return model
return None
def load_tokenizer(self):
tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{self.tokenizer_name}.json"
if bucket.blob(f"tokenizer_{self.tokenizer_name}.json").exists():
blob = bucket.blob(f"tokenizer_{self.tokenizer_name}.json")
tokenizer_bytes = blob.download_as_bytes()
tokenizer_data = json.loads(tokenizer_bytes)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
existing_tokens = tokenizer.get_vocab()
new_tokens = tokenizer_data
for token, id in new_tokens.items():
if token not in existing_tokens:
tokenizer.add_tokens([token])
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return None
chatbot_service = ChatbotService()
class UnifiedModel(AutoModelForSequenceClassification):
def __init__(self, config):
super().__init__(config)
@staticmethod
def load_model():
model_name = "unified_model"
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
if bucket.blob(f"model_{model_name}").exists():
blob = bucket.blob(f"model_{model_name}")
model_bytes = blob.download_as_bytes()
model_buffer = io.BytesIO(model_bytes)
model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda")))
return model
else:
model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
model_buffer = io.BytesIO()
torch.save(model.state_dict(), model_buffer)
model_buffer.seek(0)
blob = bucket.blob(f"model_{model_name}")
blob.upload_from_file(model_buffer, content_type="application/octet-stream")
return model
class SyntheticDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, data):
self.tokenizer = tokenizer
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(),
"labels": label}
conversation_history = {}
tokenizer_name = "unified_tokenizer"
tokenizer = None
unified_model = None
image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16,
cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
image_pipeline.enable_model_cpu_offload()
musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small")
musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small")
@app.on_event("startup")
async def startup_event():
global tokenizer, unified_model
tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{tokenizer_name}.json"
if bucket.blob(f"tokenizer_{tokenizer_name}.json").exists():
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
tokenizer_bytes = blob.download_as_bytes()
tokenizer_data = json.loads(tokenizer_bytes)
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
existing_tokens = tokenizer.get_vocab()
new_tokens = tokenizer_data
for token, id in new_tokens.items():
if token not in existing_tokens:
tokenizer.add_tokens([token])
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
tokenizer.pad_token = tokenizer.eos_token
unified_model = UnifiedModel.load_model()
unified_model.to(torch.device("cuda"))
@app.post("/process")
async def process(request: Request):
global tokenizer, unified_model
data = await request.json()
if data.get("train"):
user_data = data.get("user_data", [])
if not user_data:
user_data = [
{"text": "Hola", "label": 1},
{"text": "Necesito ayuda", "label": 2},
{"text": "No entiendo", "label": 0}
]
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
if bucket.blob("training_queue.json").exists():
blob = bucket.blob("training_queue.json")
training_queue_bytes = blob.download_as_bytes()
existing_data = json.loads(training_queue_bytes)
else:
existing_data = []
new_data = existing_data + [{
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": user_data
}]
new_data_bytes = json.dumps(new_data).encode("utf-8")
blob = bucket.blob("training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
return {"message": "Training data received. Model will be updated asynchronously."}
elif data.get("message"):
user_id = data.get("user_id")
text = data['message']
language = data.get("language", default_language)
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(text)
contextualized_text = " ".join(conversation_history[user_id][-3:])
tokenized_input = tokenizer(contextualized_text, return_tensors="pt")
with torch.no_grad():
logits = unified_model(**tokenized_input).logits
predicted_class = torch.argmax(logits, dim=-1).item()
response = chatbot_service.get_response(user_id, contextualized_text, language)
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
if bucket.blob("training_queue.json").exists():
blob = bucket.blob("training_queue.json")
training_queue_bytes = blob.download_as_bytes()
existing_data = json.loads(training_queue_bytes)
else:
existing_data = []
new_data = existing_data + [{
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": [{"text": contextualized_text, "label": predicted_class}]
}]
new_data_bytes = json.dumps(new_data).encode("utf-8")
blob = bucket.blob("training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
return {"answer": response}
else:
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
@app.get("/")
async def get_home():
user_id = str(uuid.uuid4())
html_code = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Chatbot</title>
<style>
body {{
font-family: 'Arial', sans-serif;
background-color: #f4f4f9;
margin: 0;
padding: 0;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
}}
.container {{
background-color: #fff;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
overflow: hidden;
width: 400px;
max-width: 90%;
}}
h1 {{
color: #333;
text-align: center;
padding: 20px;
margin: 0;
background-color: #f8f9fa;
border-bottom: 1px solid #eee;
}}
#chatbox {{
height: 300px;
overflow-y: auto;
padding: 10px;
border-bottom: 1px solid #eee;
}}
.message {{
margin-bottom: 10px;
padding: 10px;
border-radius: 5px;
}}
.message.user {{
background-color: #e1f5fe;
text-align: right;
}}
.message.bot {{
background-color: #f1f1f1;
text-align: left;
}}
#input {{
display: flex;
padding: 10px;
}}
#input textarea {{
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
margin-right: 10px;
}}
#input button {{
padding: 10px 20px;
border: none;
border-radius: 4px;
background-color: #007bff;
color: #fff;
cursor: pointer;
}}
#input button:hover {{
background-color: #0056b3;
}}
</style>
</head>
<body>
<div class="container">
<h1>Chatbot</h1>
<div id="chatbox"></div>
<div id="input">
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea>
<button id="send">Enviar</button>
</div>
</div>
<script>
const chatbox = document.getElementById('chatbox');
const messageInput = document.getElementById('message');
const sendButton = document.getElementById('send');
function appendMessage(text, sender) {{
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', sender);
messageDiv.textContent = text;
chatbox.appendChild(messageDiv);
chatbox.scrollTop = chatbox.scrollHeight;
}}
async function sendMessage() {{
const message = messageInput.value;
if (!message.trim()) return;
appendMessage(message, 'user');
messageInput.value = '';
const response = await fetch('/process', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json'
}},
body: JSON.stringify({{
message: message,
user_id: '{user_id}'
}})
}});
const data = await response.json();
appendMessage(data.answer, 'bot');
}}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', (e) => {{
if (e.key === 'Enter' && !e.shiftKey) {{
e.preventDefault();
sendMessage();
}}
}});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code)
@spaces.GPU
def my_inference_function(input_data, output_data, mode, max_length, max_new_tokens, model_size):
print("xd")
# Add your inference logic here
# ...
def train_unified_model():
global tokenizer, unified_model
model_name = "unified_model"
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
training_args = TrainingArguments(
output_dir=f"gs://{GCS_BUCKET_NAME}/results",
per_device_train_batch_size=8,
num_train_epochs=3,
)
while True:
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
if bucket.blob("training_queue.json").exists():
blob = bucket.blob("training_queue.json")
training_queue_bytes = blob.download_as_bytes()
training_data_list = json.loads(training_queue_bytes)
if training_data_list:
training_data = training_data_list.pop(0)
new_data_bytes = json.dumps(training_data_list).encode("utf-8")
blob = bucket.blob("training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
tokenizer_data = training_data.get("tokenizers")
if tokenizer_data:
tokenizer_name = list(tokenizer_data.keys())[0]
existing_tokens = tokenizer.get_vocab()
new_tokens = tokenizer_data[tokenizer_name]
for token, id in new_tokens.items():
if token not in existing_tokens:
tokenizer.add_tokens([token])
data = training_data.get("data", [])
if data:
dataset = SyntheticDataset(tokenizer, data)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
model_buffer = io.BytesIO()
torch.save(unified_model.state_dict(), model_buffer)
model_buffer.seek(0)
blob = bucket.blob(f"model_{model_name}")
blob.upload_from_file(model_buffer, content_type="application/octet-stream")
new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8")
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
blob.upload_from_string(new_tokenizer_bytes, content_type="application/json")
initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json"
if bucket.blob("initial_data.json").exists():
blob = bucket.blob("initial_data.json")
initial_data_bytes = blob.download_as_bytes()
initial_data = json.loads(initial_data_bytes)
dataset = SyntheticDataset(tokenizer, initial_data)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
model_buffer = io.BytesIO()
torch.save(unified_model.state_dict(), model_buffer)
model_buffer.seek(0)
blob = bucket.blob(f"model_{model_name}")
blob.upload_from_file(model_buffer, content_type="application/octet-stream")
def train_text_model():
global tokenizer, unified_model
model_name = "text_model"
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
training_args = TrainingArguments(
output_dir=f"gs://{GCS_BUCKET_NAME}/results",
per_device_train_batch_size=8,
num_train_epochs=3,
)
while True:
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
if bucket.blob("training_queue.json").exists():
blob = bucket.blob("training_queue.json")
training_queue_bytes = blob.download_as_bytes()
training_data_list = json.loads(training_queue_bytes)
if training_data_list:
training_data = training_data_list.pop(0)
new_data_bytes = json.dumps(training_data_list).encode("utf-8")
blob = bucket.blob("training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
tokenizer_data = training_data.get("tokenizers")
if tokenizer_data:
tokenizer_name = list(tokenizer_data.keys())[0]
existing_tokens = tokenizer.get_vocab()
new_tokens = tokenizer_data[tokenizer_name]
for token, id in new_tokens.items():
if token not in existing_tokens:
tokenizer.add_tokens([token])
data = training_data.get("data", [])
if data:
dataset = SyntheticDataset(tokenizer, data)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
model_buffer = io.BytesIO()
torch.save(unified_model.state_dict(), model_buffer)
model_buffer.seek(0)
blob = bucket.blob(f"model_{model_name}")
blob.upload_from_file(model_buffer, content_type="application/octet-stream")
new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8")
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
blob.upload_from_string(new_tokenizer_bytes, content_type="application/json")
initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json"
if bucket.blob("initial_data.json").exists():
blob = bucket.blob("initial_data.json")
initial_data_bytes = blob.download_as_bytes()
initial_data = json.loads(initial_data_bytes)
dataset = SyntheticDataset(tokenizer, initial_data)
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
trainer.train()
model_buffer = io.BytesIO()
torch.save(unified_model.state_dict(), model_buffer)
model_buffer.seek(0)
blob = bucket.blob(f"model_{model_name}")
blob.upload_from_file(model_buffer, content_type="application/octet-stream")
def train_image_model():
global image_pipeline
while True:
image_training_queue_path = f"gs://{GCS_BUCKET_NAME}/image_training_queue.json"
if bucket.blob("image_training_queue.json").exists():
blob = bucket.blob("image_training_queue.json")
image_training_queue_bytes = blob.download_as_bytes()
image_training_data_list = json.loads(image_training_queue_bytes)
if image_training_data_list:
image_training_data = image_training_data_list.pop(0)
new_data_bytes = json.dumps(image_training_data_list).encode("utf-8")
blob = bucket.blob("image_training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
image_pipeline.model.to("cuda")
image_pipeline.model.train()
optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5)
loss_fn = torch.nn.MSELoss()
for epoch in range(3):
for i in tqdm(range(len(image_training_data)), desc=f"Epoch {epoch+1}"):
image_prompt = image_training_data[i]
image = image_pipeline(
image_prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cuda").manual_seed(0)
).images[0]
image_tensor = torch.tensor(np.array(image)).unsqueeze(0).to("cuda")
target_tensor = torch.zeros_like(image_tensor)
outputs = image_pipeline.model(image_tensor)
loss = loss_fn(outputs, target_tensor)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Step {i+1}/{len(image_training_data)}, Loss: {loss.item()}")
def train_music_model():
global musicgen_tokenizer, musicgen_model
while True:
music_training_queue_path = f"gs://{GCS_BUCKET_NAME}/music_training_queue.json"
if bucket.blob("music_training_queue.json").exists():
blob = bucket.blob("music_training_queue.json")
music_training_queue_bytes = blob.download_as_bytes()
music_training_data_list = json.loads(music_training_queue_bytes)
if music_training_data_list:
music_training_data = music_training_data_list.pop(0)
new_data_bytes = json.dumps(music_training_data_list).encode("utf-8")
blob = bucket.blob("music_training_queue.json")
blob.upload_from_string(new_data_bytes, content_type="application/json")
inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True).to("cuda")
musicgen_model.to("cuda")
musicgen_model.train()
optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(3):
for i in tqdm(range(len(inputs["input_ids"])), desc=f"Epoch {epoch+1}"):
outputs = musicgen_model(**inputs)
loss = loss_fn(outputs.logits, inputs['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Step {i+1}/{len(inputs['input_ids'])}, Loss: {loss.item()}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
print("Iniciando entrenamiento automático del modelo unificado...")
auto_learn_process = multiprocessing.Process(target=train_unified_model)
auto_learn_process.start()
print("Iniciando entrenamiento automático del modelo de texto...")
auto_learn_process_2 = multiprocessing.Process(target=train_text_model)
auto_learn_process_2.start()
print("Iniciando entrenamiento automático del modelo de imagen...")
auto_learn_process_3 = multiprocessing.Process(target=train_image_model)
auto_learn_process_3.start()
print("Iniciando entrenamiento automático del modelo de música...")
auto_learn_process_4 = multiprocessing.Process(target=train_music_model)
auto_learn_process_4.start()