|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
AutoModelForCausalLM, |
|
TrainingArguments, |
|
Trainer, |
|
AutoModelForTextToWaveform |
|
) |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import HTMLResponse |
|
import multiprocessing |
|
import uuid |
|
import numpy as np |
|
from diffusers import FluxPipeline |
|
from tqdm import tqdm |
|
from google.cloud import storage |
|
import io |
|
import spaces |
|
|
|
spaces.GPU(duration=0) |
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
default_language = "es" |
|
|
|
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") |
|
if GCS_BUCKET_NAME is None: |
|
raise ValueError("La variable de entorno GCS_BUCKET_NAME no está definida.") |
|
|
|
GCS_CREDENTIALS = os.getenv("GCS_CREDENTIALS") |
|
if GCS_CREDENTIALS is None: |
|
raise ValueError("La variable de entorno GCS_CREDENTIALS no está definida.") |
|
gcs_credentials_dict = json.loads(GCS_CREDENTIALS) |
|
with open('gcs_credentials.json', 'w') as f: |
|
json.dump(gcs_credentials_dict, f) |
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gcs_credentials.json" |
|
|
|
storage_client = storage.Client() |
|
bucket = storage_client.bucket(GCS_BUCKET_NAME) |
|
|
|
AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
AutoModelForCausalLM.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
AutoTokenizer.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
|
|
class ChatbotService: |
|
def __init__(self): |
|
self.model_name = "response_model" |
|
self.tokenizer_name = "response_tokenizer" |
|
self.model = self.load_model() |
|
self.tokenizer = self.load_tokenizer() |
|
|
|
def get_response(self, user_id, message, language=default_language): |
|
if self.model is None or self.tokenizer is None: |
|
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde." |
|
input_text = f"Usuario: {message} Asistente:" |
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cuda") |
|
with torch.no_grad(): |
|
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, |
|
early_stopping=True) |
|
response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = response.replace(input_text, "").strip() |
|
return response |
|
|
|
def load_model(self): |
|
model_path = f"gs://{GCS_BUCKET_NAME}/model_{self.model_name}" |
|
if bucket.blob(f"model_{self.model_name}").exists(): |
|
blob = bucket.blob(f"model_{self.model_name}") |
|
model_bytes = blob.download_as_bytes() |
|
model_buffer = io.BytesIO(model_bytes) |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda"))) |
|
return model |
|
return None |
|
|
|
def load_tokenizer(self): |
|
tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{self.tokenizer_name}.json" |
|
if bucket.blob(f"tokenizer_{self.tokenizer_name}.json").exists(): |
|
blob = bucket.blob(f"tokenizer_{self.tokenizer_name}.json") |
|
tokenizer_bytes = blob.download_as_bytes() |
|
tokenizer_data = json.loads(tokenizer_bytes) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
existing_tokens = tokenizer.get_vocab() |
|
new_tokens = tokenizer_data |
|
for token, id in new_tokens.items(): |
|
if token not in existing_tokens: |
|
tokenizer.add_tokens([token]) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
return tokenizer |
|
return None |
|
|
|
|
|
chatbot_service = ChatbotService() |
|
|
|
|
|
class UnifiedModel(AutoModelForSequenceClassification): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
@staticmethod |
|
def load_model(): |
|
model_name = "unified_model" |
|
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}" |
|
if bucket.blob(f"model_{model_name}").exists(): |
|
blob = bucket.blob(f"model_{model_name}") |
|
model_bytes = blob.download_as_bytes() |
|
model_buffer = io.BytesIO(model_bytes) |
|
model = UnifiedModel.from_pretrained("gpt2", num_labels=3) |
|
model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda"))) |
|
return model |
|
else: |
|
model = UnifiedModel.from_pretrained("gpt2", num_labels=3) |
|
model_buffer = io.BytesIO() |
|
torch.save(model.state_dict(), model_buffer) |
|
model_buffer.seek(0) |
|
blob = bucket.blob(f"model_{model_name}") |
|
blob.upload_from_file(model_buffer, content_type="application/octet-stream") |
|
return model |
|
|
|
|
|
class SyntheticDataset(torch.utils.data.Dataset): |
|
def __init__(self, tokenizer, data): |
|
self.tokenizer = tokenizer |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
text = item['text'] |
|
label = item['label'] |
|
tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") |
|
return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), |
|
"labels": label} |
|
|
|
|
|
conversation_history = {} |
|
|
|
tokenizer_name = "unified_tokenizer" |
|
tokenizer = None |
|
unified_model = None |
|
image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, |
|
cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
image_pipeline.enable_model_cpu_offload() |
|
musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small") |
|
musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small") |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global tokenizer, unified_model |
|
tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{tokenizer_name}.json" |
|
if bucket.blob(f"tokenizer_{tokenizer_name}.json").exists(): |
|
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json") |
|
tokenizer_bytes = blob.download_as_bytes() |
|
tokenizer_data = json.loads(tokenizer_bytes) |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
existing_tokens = tokenizer.get_vocab() |
|
new_tokens = tokenizer_data |
|
for token, id in new_tokens.items(): |
|
if token not in existing_tokens: |
|
tokenizer.add_tokens([token]) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
unified_model = UnifiedModel.load_model() |
|
unified_model.to(torch.device("cuda")) |
|
|
|
|
|
@app.post("/process") |
|
async def process(request: Request): |
|
global tokenizer, unified_model |
|
data = await request.json() |
|
|
|
if data.get("train"): |
|
user_data = data.get("user_data", []) |
|
if not user_data: |
|
user_data = [ |
|
{"text": "Hola", "label": 1}, |
|
{"text": "Necesito ayuda", "label": 2}, |
|
{"text": "No entiendo", "label": 0} |
|
] |
|
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json" |
|
if bucket.blob("training_queue.json").exists(): |
|
blob = bucket.blob("training_queue.json") |
|
training_queue_bytes = blob.download_as_bytes() |
|
existing_data = json.loads(training_queue_bytes) |
|
else: |
|
existing_data = [] |
|
new_data = existing_data + [{ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": user_data |
|
}] |
|
new_data_bytes = json.dumps(new_data).encode("utf-8") |
|
blob = bucket.blob("training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
return {"message": "Training data received. Model will be updated asynchronously."} |
|
elif data.get("message"): |
|
user_id = data.get("user_id") |
|
text = data['message'] |
|
language = data.get("language", default_language) |
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(text) |
|
contextualized_text = " ".join(conversation_history[user_id][-3:]) |
|
tokenized_input = tokenizer(contextualized_text, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = unified_model(**tokenized_input).logits |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
response = chatbot_service.get_response(user_id, contextualized_text, language) |
|
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json" |
|
if bucket.blob("training_queue.json").exists(): |
|
blob = bucket.blob("training_queue.json") |
|
training_queue_bytes = blob.download_as_bytes() |
|
existing_data = json.loads(training_queue_bytes) |
|
else: |
|
existing_data = [] |
|
new_data = existing_data + [{ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": [{"text": contextualized_text, "label": predicted_class}] |
|
}] |
|
new_data_bytes = json.dumps(new_data).encode("utf-8") |
|
blob = bucket.blob("training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
return {"answer": response} |
|
else: |
|
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.") |
|
|
|
|
|
@app.get("/") |
|
async def get_home(): |
|
user_id = str(uuid.uuid4()) |
|
html_code = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>Chatbot</title> |
|
<style> |
|
body {{ |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f4f4f9; |
|
margin: 0; |
|
padding: 0; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
min-height: 100vh; |
|
}} |
|
.container {{ |
|
background-color: #fff; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); |
|
overflow: hidden; |
|
width: 400px; |
|
max-width: 90%; |
|
}} |
|
h1 {{ |
|
color: #333; |
|
text-align: center; |
|
padding: 20px; |
|
margin: 0; |
|
background-color: #f8f9fa; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
#chatbox {{ |
|
height: 300px; |
|
overflow-y: auto; |
|
padding: 10px; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
.message {{ |
|
margin-bottom: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
}} |
|
.message.user {{ |
|
background-color: #e1f5fe; |
|
text-align: right; |
|
}} |
|
.message.bot {{ |
|
background-color: #f1f1f1; |
|
text-align: left; |
|
}} |
|
#input {{ |
|
display: flex; |
|
padding: 10px; |
|
}} |
|
#input textarea {{ |
|
flex: 1; |
|
padding: 10px; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
margin-right: 10px; |
|
}} |
|
#input button {{ |
|
padding: 10px 20px; |
|
border: none; |
|
border-radius: 4px; |
|
background-color: #007bff; |
|
color: #fff; |
|
cursor: pointer; |
|
}} |
|
#input button:hover {{ |
|
background-color: #0056b3; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<div id="chatbox"></div> |
|
<div id="input"> |
|
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea> |
|
<button id="send">Enviar</button> |
|
</div> |
|
</div> |
|
<script> |
|
const chatbox = document.getElementById('chatbox'); |
|
const messageInput = document.getElementById('message'); |
|
const sendButton = document.getElementById('send'); |
|
|
|
function appendMessage(text, sender) {{ |
|
const messageDiv = document.createElement('div'); |
|
messageDiv.classList.add('message', sender); |
|
messageDiv.textContent = text; |
|
chatbox.appendChild(messageDiv); |
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
}} |
|
|
|
async function sendMessage() {{ |
|
const message = messageInput.value; |
|
if (!message.trim()) return; |
|
|
|
appendMessage(message, 'user'); |
|
messageInput.value = ''; |
|
|
|
const response = await fetch('/process', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json' |
|
}}, |
|
body: JSON.stringify({{ |
|
message: message, |
|
user_id: '{user_id}' |
|
}}) |
|
}}); |
|
const data = await response.json(); |
|
appendMessage(data.answer, 'bot'); |
|
}} |
|
|
|
sendButton.addEventListener('click', sendMessage); |
|
messageInput.addEventListener('keypress', (e) => {{ |
|
if (e.key === 'Enter' && !e.shiftKey) {{ |
|
e.preventDefault(); |
|
sendMessage(); |
|
}} |
|
}}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code) |
|
|
|
@spaces.GPU |
|
def my_inference_function(input_data, output_data, mode, max_length, max_new_tokens, model_size): |
|
print("xd") |
|
|
|
|
|
|
|
def train_unified_model(): |
|
global tokenizer, unified_model |
|
model_name = "unified_model" |
|
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}" |
|
training_args = TrainingArguments( |
|
output_dir=f"gs://{GCS_BUCKET_NAME}/results", |
|
per_device_train_batch_size=8, |
|
num_train_epochs=3, |
|
) |
|
while True: |
|
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json" |
|
if bucket.blob("training_queue.json").exists(): |
|
blob = bucket.blob("training_queue.json") |
|
training_queue_bytes = blob.download_as_bytes() |
|
training_data_list = json.loads(training_queue_bytes) |
|
if training_data_list: |
|
training_data = training_data_list.pop(0) |
|
new_data_bytes = json.dumps(training_data_list).encode("utf-8") |
|
blob = bucket.blob("training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
|
|
tokenizer_data = training_data.get("tokenizers") |
|
if tokenizer_data: |
|
tokenizer_name = list(tokenizer_data.keys())[0] |
|
existing_tokens = tokenizer.get_vocab() |
|
new_tokens = tokenizer_data[tokenizer_name] |
|
for token, id in new_tokens.items(): |
|
if token not in existing_tokens: |
|
tokenizer.add_tokens([token]) |
|
data = training_data.get("data", []) |
|
if data: |
|
dataset = SyntheticDataset(tokenizer, data) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
model_buffer = io.BytesIO() |
|
torch.save(unified_model.state_dict(), model_buffer) |
|
model_buffer.seek(0) |
|
blob = bucket.blob(f"model_{model_name}") |
|
blob.upload_from_file(model_buffer, content_type="application/octet-stream") |
|
new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8") |
|
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json") |
|
blob.upload_from_string(new_tokenizer_bytes, content_type="application/json") |
|
|
|
initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json" |
|
if bucket.blob("initial_data.json").exists(): |
|
blob = bucket.blob("initial_data.json") |
|
initial_data_bytes = blob.download_as_bytes() |
|
initial_data = json.loads(initial_data_bytes) |
|
dataset = SyntheticDataset(tokenizer, initial_data) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
model_buffer = io.BytesIO() |
|
torch.save(unified_model.state_dict(), model_buffer) |
|
model_buffer.seek(0) |
|
blob = bucket.blob(f"model_{model_name}") |
|
blob.upload_from_file(model_buffer, content_type="application/octet-stream") |
|
|
|
|
|
def train_text_model(): |
|
global tokenizer, unified_model |
|
model_name = "text_model" |
|
model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}" |
|
training_args = TrainingArguments( |
|
output_dir=f"gs://{GCS_BUCKET_NAME}/results", |
|
per_device_train_batch_size=8, |
|
num_train_epochs=3, |
|
) |
|
while True: |
|
training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json" |
|
if bucket.blob("training_queue.json").exists(): |
|
blob = bucket.blob("training_queue.json") |
|
training_queue_bytes = blob.download_as_bytes() |
|
training_data_list = json.loads(training_queue_bytes) |
|
if training_data_list: |
|
training_data = training_data_list.pop(0) |
|
new_data_bytes = json.dumps(training_data_list).encode("utf-8") |
|
blob = bucket.blob("training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
|
|
tokenizer_data = training_data.get("tokenizers") |
|
if tokenizer_data: |
|
tokenizer_name = list(tokenizer_data.keys())[0] |
|
existing_tokens = tokenizer.get_vocab() |
|
new_tokens = tokenizer_data[tokenizer_name] |
|
for token, id in new_tokens.items(): |
|
if token not in existing_tokens: |
|
tokenizer.add_tokens([token]) |
|
data = training_data.get("data", []) |
|
if data: |
|
dataset = SyntheticDataset(tokenizer, data) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
model_buffer = io.BytesIO() |
|
torch.save(unified_model.state_dict(), model_buffer) |
|
model_buffer.seek(0) |
|
blob = bucket.blob(f"model_{model_name}") |
|
blob.upload_from_file(model_buffer, content_type="application/octet-stream") |
|
new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8") |
|
blob = bucket.blob(f"tokenizer_{tokenizer_name}.json") |
|
blob.upload_from_string(new_tokenizer_bytes, content_type="application/json") |
|
|
|
initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json" |
|
if bucket.blob("initial_data.json").exists(): |
|
blob = bucket.blob("initial_data.json") |
|
initial_data_bytes = blob.download_as_bytes() |
|
initial_data = json.loads(initial_data_bytes) |
|
dataset = SyntheticDataset(tokenizer, initial_data) |
|
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) |
|
trainer.train() |
|
model_buffer = io.BytesIO() |
|
torch.save(unified_model.state_dict(), model_buffer) |
|
model_buffer.seek(0) |
|
blob = bucket.blob(f"model_{model_name}") |
|
blob.upload_from_file(model_buffer, content_type="application/octet-stream") |
|
|
|
def train_image_model(): |
|
global image_pipeline |
|
while True: |
|
image_training_queue_path = f"gs://{GCS_BUCKET_NAME}/image_training_queue.json" |
|
if bucket.blob("image_training_queue.json").exists(): |
|
blob = bucket.blob("image_training_queue.json") |
|
image_training_queue_bytes = blob.download_as_bytes() |
|
image_training_data_list = json.loads(image_training_queue_bytes) |
|
if image_training_data_list: |
|
image_training_data = image_training_data_list.pop(0) |
|
new_data_bytes = json.dumps(image_training_data_list).encode("utf-8") |
|
blob = bucket.blob("image_training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
image_pipeline.model.to("cuda") |
|
image_pipeline.model.train() |
|
optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5) |
|
loss_fn = torch.nn.MSELoss() |
|
for epoch in range(3): |
|
for i in tqdm(range(len(image_training_data)), desc=f"Epoch {epoch+1}"): |
|
image_prompt = image_training_data[i] |
|
image = image_pipeline( |
|
image_prompt, |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
max_sequence_length=256, |
|
generator=torch.Generator("cuda").manual_seed(0) |
|
).images[0] |
|
image_tensor = torch.tensor(np.array(image)).unsqueeze(0).to("cuda") |
|
target_tensor = torch.zeros_like(image_tensor) |
|
outputs = image_pipeline.model(image_tensor) |
|
loss = loss_fn(outputs, target_tensor) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
print(f"Epoch {epoch+1}, Step {i+1}/{len(image_training_data)}, Loss: {loss.item()}") |
|
|
|
def train_music_model(): |
|
global musicgen_tokenizer, musicgen_model |
|
while True: |
|
music_training_queue_path = f"gs://{GCS_BUCKET_NAME}/music_training_queue.json" |
|
if bucket.blob("music_training_queue.json").exists(): |
|
blob = bucket.blob("music_training_queue.json") |
|
music_training_queue_bytes = blob.download_as_bytes() |
|
music_training_data_list = json.loads(music_training_queue_bytes) |
|
if music_training_data_list: |
|
music_training_data = music_training_data_list.pop(0) |
|
new_data_bytes = json.dumps(music_training_data_list).encode("utf-8") |
|
blob = bucket.blob("music_training_queue.json") |
|
blob.upload_from_string(new_data_bytes, content_type="application/json") |
|
|
|
inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True).to("cuda") |
|
musicgen_model.to("cuda") |
|
musicgen_model.train() |
|
optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5) |
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
for epoch in range(3): |
|
for i in tqdm(range(len(inputs["input_ids"])), desc=f"Epoch {epoch+1}"): |
|
outputs = musicgen_model(**inputs) |
|
loss = loss_fn(outputs.logits, inputs['labels']) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
print(f"Epoch {epoch+1}, Step {i+1}/{len(inputs['input_ids'])}, Loss: {loss.item()}") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
print("Iniciando entrenamiento automático del modelo unificado...") |
|
auto_learn_process = multiprocessing.Process(target=train_unified_model) |
|
auto_learn_process.start() |
|
|
|
print("Iniciando entrenamiento automático del modelo de texto...") |
|
auto_learn_process_2 = multiprocessing.Process(target=train_text_model) |
|
auto_learn_process_2.start() |
|
|
|
print("Iniciando entrenamiento automático del modelo de imagen...") |
|
auto_learn_process_3 = multiprocessing.Process(target=train_image_model) |
|
auto_learn_process_3.start() |
|
|
|
print("Iniciando entrenamiento automático del modelo de música...") |
|
auto_learn_process_4 = multiprocessing.Process(target=train_music_model) |
|
auto_learn_process_4.start() |