Uhhy commited on
Commit
3c88fa1
·
verified ·
1 Parent(s): 7f78671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +576 -0
app.py CHANGED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import json
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ AutoModelForCausalLM,
9
+ TrainingArguments,
10
+ Trainer,
11
+ AutoModelForTextToWaveform
12
+ )
13
+ from fastapi import FastAPI, HTTPException, Request
14
+ from fastapi.responses import HTMLResponse
15
+ import multiprocessing
16
+ import uuid
17
+ import numpy as np
18
+ from diffusers import FluxPipeline
19
+ from tqdm import tqdm
20
+ from google.cloud import storage
21
+ import io
22
+ import spaces
23
+
24
+ spaces.GPU(duration=0)
25
+ load_dotenv()
26
+
27
+ app = FastAPI()
28
+
29
+ default_language = "es"
30
+
31
+ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
32
+ if GCS_BUCKET_NAME is None:
33
+ raise ValueError("La variable de entorno GCS_BUCKET_NAME no está definida.")
34
+
35
+ GCS_CREDENTIALS = os.getenv("GCS_CREDENTIALS")
36
+ if GCS_CREDENTIALS is None:
37
+ raise ValueError("La variable de entorno GCS_CREDENTIALS no está definida.")
38
+ gcs_credentials_dict = json.loads(GCS_CREDENTIALS)
39
+ with open('gcs_credentials.json', 'w') as f:
40
+ json.dump(gcs_credentials_dict, f)
41
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "gcs_credentials.json"
42
+
43
+ storage_client = storage.Client()
44
+ bucket = storage_client.bucket(GCS_BUCKET_NAME)
45
+
46
+ AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
47
+ AutoModelForCausalLM.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
48
+ FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
49
+ AutoTokenizer.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
50
+ AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
51
+
52
+ class ChatbotService:
53
+ def __init__(self):
54
+ self.model_name = "response_model"
55
+ self.tokenizer_name = "response_tokenizer"
56
+ self.model = self.load_model()
57
+ self.tokenizer = self.load_tokenizer()
58
+
59
+ def get_response(self, user_id, message, language=default_language):
60
+ if self.model is None or self.tokenizer is None:
61
+ return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
62
+ input_text = f"Usuario: {message} Asistente:"
63
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cuda")
64
+ with torch.no_grad():
65
+ output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2,
66
+ early_stopping=True)
67
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
68
+ response = response.replace(input_text, "").strip()
69
+ return response
70
+
71
+ def load_model(self):
72
+ model_path = f"gs://{GCS_BUCKET_NAME}/model_{self.model_name}"
73
+ if bucket.blob(f"model_{self.model_name}").exists():
74
+ blob = bucket.blob(f"model_{self.model_name}")
75
+ model_bytes = blob.download_as_bytes()
76
+ model_buffer = io.BytesIO(model_bytes)
77
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
78
+ model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda")))
79
+ return model
80
+ return None
81
+
82
+ def load_tokenizer(self):
83
+ tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{self.tokenizer_name}.json"
84
+ if bucket.blob(f"tokenizer_{self.tokenizer_name}.json").exists():
85
+ blob = bucket.blob(f"tokenizer_{self.tokenizer_name}.json")
86
+ tokenizer_bytes = blob.download_as_bytes()
87
+ tokenizer_data = json.loads(tokenizer_bytes)
88
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
89
+ existing_tokens = tokenizer.get_vocab()
90
+ new_tokens = tokenizer_data
91
+ for token, id in new_tokens.items():
92
+ if token not in existing_tokens:
93
+ tokenizer.add_tokens([token])
94
+ tokenizer.pad_token = tokenizer.eos_token
95
+ return tokenizer
96
+ return None
97
+
98
+
99
+ chatbot_service = ChatbotService()
100
+
101
+
102
+ class UnifiedModel(AutoModelForSequenceClassification):
103
+ def __init__(self, config):
104
+ super().__init__(config)
105
+
106
+ @staticmethod
107
+ def load_model():
108
+ model_name = "unified_model"
109
+ model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
110
+ if bucket.blob(f"model_{model_name}").exists():
111
+ blob = bucket.blob(f"model_{model_name}")
112
+ model_bytes = blob.download_as_bytes()
113
+ model_buffer = io.BytesIO(model_bytes)
114
+ model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
115
+ model.load_state_dict(torch.load(model_buffer, map_location=torch.device("cuda")))
116
+ return model
117
+ else:
118
+ model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
119
+ model_buffer = io.BytesIO()
120
+ torch.save(model.state_dict(), model_buffer)
121
+ model_buffer.seek(0)
122
+ blob = bucket.blob(f"model_{model_name}")
123
+ blob.upload_from_file(model_buffer, content_type="application/octet-stream")
124
+ return model
125
+
126
+
127
+ class SyntheticDataset(torch.utils.data.Dataset):
128
+ def __init__(self, tokenizer, data):
129
+ self.tokenizer = tokenizer
130
+ self.data = data
131
+
132
+ def __len__(self):
133
+ return len(self.data)
134
+
135
+ def __getitem__(self, idx):
136
+ item = self.data[idx]
137
+ text = item['text']
138
+ label = item['label']
139
+ tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
140
+ return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(),
141
+ "labels": label}
142
+
143
+
144
+ conversation_history = {}
145
+
146
+ tokenizer_name = "unified_tokenizer"
147
+ tokenizer = None
148
+ unified_model = None
149
+ image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16,
150
+ cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
151
+ image_pipeline.enable_model_cpu_offload()
152
+ musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small")
153
+ musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small")
154
+
155
+ @app.on_event("startup")
156
+ async def startup_event():
157
+ global tokenizer, unified_model
158
+ tokenizer_path = f"gs://{GCS_BUCKET_NAME}/tokenizer_{tokenizer_name}.json"
159
+ if bucket.blob(f"tokenizer_{tokenizer_name}.json").exists():
160
+ blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
161
+ tokenizer_bytes = blob.download_as_bytes()
162
+ tokenizer_data = json.loads(tokenizer_bytes)
163
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
164
+ existing_tokens = tokenizer.get_vocab()
165
+ new_tokens = tokenizer_data
166
+ for token, id in new_tokens.items():
167
+ if token not in existing_tokens:
168
+ tokenizer.add_tokens([token])
169
+ tokenizer.pad_token = tokenizer.eos_token
170
+ else:
171
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir=f"gs://{GCS_BUCKET_NAME}/cache")
172
+ tokenizer.pad_token = tokenizer.eos_token
173
+ unified_model = UnifiedModel.load_model()
174
+ unified_model.to(torch.device("cuda"))
175
+
176
+
177
+ @app.post("/process")
178
+ async def process(request: Request):
179
+ global tokenizer, unified_model
180
+ data = await request.json()
181
+
182
+ if data.get("train"):
183
+ user_data = data.get("user_data", [])
184
+ if not user_data:
185
+ user_data = [
186
+ {"text": "Hola", "label": 1},
187
+ {"text": "Necesito ayuda", "label": 2},
188
+ {"text": "No entiendo", "label": 0}
189
+ ]
190
+ training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
191
+ if bucket.blob("training_queue.json").exists():
192
+ blob = bucket.blob("training_queue.json")
193
+ training_queue_bytes = blob.download_as_bytes()
194
+ existing_data = json.loads(training_queue_bytes)
195
+ else:
196
+ existing_data = []
197
+ new_data = existing_data + [{
198
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
199
+ "data": user_data
200
+ }]
201
+ new_data_bytes = json.dumps(new_data).encode("utf-8")
202
+ blob = bucket.blob("training_queue.json")
203
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
204
+ return {"message": "Training data received. Model will be updated asynchronously."}
205
+ elif data.get("message"):
206
+ user_id = data.get("user_id")
207
+ text = data['message']
208
+ language = data.get("language", default_language)
209
+ if user_id not in conversation_history:
210
+ conversation_history[user_id] = []
211
+ conversation_history[user_id].append(text)
212
+ contextualized_text = " ".join(conversation_history[user_id][-3:])
213
+ tokenized_input = tokenizer(contextualized_text, return_tensors="pt")
214
+ with torch.no_grad():
215
+ logits = unified_model(**tokenized_input).logits
216
+ predicted_class = torch.argmax(logits, dim=-1).item()
217
+ response = chatbot_service.get_response(user_id, contextualized_text, language)
218
+ training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
219
+ if bucket.blob("training_queue.json").exists():
220
+ blob = bucket.blob("training_queue.json")
221
+ training_queue_bytes = blob.download_as_bytes()
222
+ existing_data = json.loads(training_queue_bytes)
223
+ else:
224
+ existing_data = []
225
+ new_data = existing_data + [{
226
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
227
+ "data": [{"text": contextualized_text, "label": predicted_class}]
228
+ }]
229
+ new_data_bytes = json.dumps(new_data).encode("utf-8")
230
+ blob = bucket.blob("training_queue.json")
231
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
232
+ return {"answer": response}
233
+ else:
234
+ raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
235
+
236
+
237
+ @app.get("/")
238
+ async def get_home():
239
+ user_id = str(uuid.uuid4())
240
+ html_code = f"""
241
+ <!DOCTYPE html>
242
+ <html>
243
+ <head>
244
+ <meta charset="UTF-8">
245
+ <title>Chatbot</title>
246
+ <style>
247
+ body {{
248
+ font-family: 'Arial', sans-serif;
249
+ background-color: #f4f4f9;
250
+ margin: 0;
251
+ padding: 0;
252
+ display: flex;
253
+ align-items: center;
254
+ justify-content: center;
255
+ min-height: 100vh;
256
+ }}
257
+ .container {{
258
+ background-color: #fff;
259
+ border-radius: 10px;
260
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
261
+ overflow: hidden;
262
+ width: 400px;
263
+ max-width: 90%;
264
+ }}
265
+ h1 {{
266
+ color: #333;
267
+ text-align: center;
268
+ padding: 20px;
269
+ margin: 0;
270
+ background-color: #f8f9fa;
271
+ border-bottom: 1px solid #eee;
272
+ }}
273
+ #chatbox {{
274
+ height: 300px;
275
+ overflow-y: auto;
276
+ padding: 10px;
277
+ border-bottom: 1px solid #eee;
278
+ }}
279
+ .message {{
280
+ margin-bottom: 10px;
281
+ padding: 10px;
282
+ border-radius: 5px;
283
+ }}
284
+ .message.user {{
285
+ background-color: #e1f5fe;
286
+ text-align: right;
287
+ }}
288
+ .message.bot {{
289
+ background-color: #f1f1f1;
290
+ text-align: left;
291
+ }}
292
+ #input {{
293
+ display: flex;
294
+ padding: 10px;
295
+ }}
296
+ #input textarea {{
297
+ flex: 1;
298
+ padding: 10px;
299
+ border: 1px solid #ddd;
300
+ border-radius: 4px;
301
+ margin-right: 10px;
302
+ }}
303
+ #input button {{
304
+ padding: 10px 20px;
305
+ border: none;
306
+ border-radius: 4px;
307
+ background-color: #007bff;
308
+ color: #fff;
309
+ cursor: pointer;
310
+ }}
311
+ #input button:hover {{
312
+ background-color: #0056b3;
313
+ }}
314
+ </style>
315
+ </head>
316
+ <body>
317
+ <div class="container">
318
+ <h1>Chatbot</h1>
319
+ <div id="chatbox"></div>
320
+ <div id="input">
321
+ <textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea>
322
+ <button id="send">Enviar</button>
323
+ </div>
324
+ </div>
325
+ <script>
326
+ const chatbox = document.getElementById('chatbox');
327
+ const messageInput = document.getElementById('message');
328
+ const sendButton = document.getElementById('send');
329
+
330
+ function appendMessage(text, sender) {{
331
+ const messageDiv = document.createElement('div');
332
+ messageDiv.classList.add('message', sender);
333
+ messageDiv.textContent = text;
334
+ chatbox.appendChild(messageDiv);
335
+ chatbox.scrollTop = chatbox.scrollHeight;
336
+ }}
337
+
338
+ async function sendMessage() {{
339
+ const message = messageInput.value;
340
+ if (!message.trim()) return;
341
+
342
+ appendMessage(message, 'user');
343
+ messageInput.value = '';
344
+
345
+ const response = await fetch('/process', {{
346
+ method: 'POST',
347
+ headers: {{
348
+ 'Content-Type': 'application/json'
349
+ }},
350
+ body: JSON.stringify({{
351
+ message: message,
352
+ user_id: '{user_id}'
353
+ }})
354
+ }});
355
+ const data = await response.json();
356
+ appendMessage(data.answer, 'bot');
357
+ }}
358
+
359
+ sendButton.addEventListener('click', sendMessage);
360
+ messageInput.addEventListener('keypress', (e) => {{
361
+ if (e.key === 'Enter' && !e.shiftKey) {{
362
+ e.preventDefault();
363
+ sendMessage();
364
+ }}
365
+ }});
366
+ </script>
367
+ </body>
368
+ </html>
369
+ """
370
+ return HTMLResponse(content=html_code)
371
+
372
+ @spaces.GPU
373
+ def my_inference_function(input_data, output_data, mode, max_length, max_new_tokens, model_size):
374
+ print("xd")
375
+ # Add your inference logic here
376
+ # ...
377
+
378
+ def train_unified_model():
379
+ global tokenizer, unified_model
380
+ model_name = "unified_model"
381
+ model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
382
+ training_args = TrainingArguments(
383
+ output_dir=f"gs://{GCS_BUCKET_NAME}/results",
384
+ per_device_train_batch_size=8,
385
+ num_train_epochs=3,
386
+ )
387
+ while True:
388
+ training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
389
+ if bucket.blob("training_queue.json").exists():
390
+ blob = bucket.blob("training_queue.json")
391
+ training_queue_bytes = blob.download_as_bytes()
392
+ training_data_list = json.loads(training_queue_bytes)
393
+ if training_data_list:
394
+ training_data = training_data_list.pop(0)
395
+ new_data_bytes = json.dumps(training_data_list).encode("utf-8")
396
+ blob = bucket.blob("training_queue.json")
397
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
398
+
399
+ tokenizer_data = training_data.get("tokenizers")
400
+ if tokenizer_data:
401
+ tokenizer_name = list(tokenizer_data.keys())[0]
402
+ existing_tokens = tokenizer.get_vocab()
403
+ new_tokens = tokenizer_data[tokenizer_name]
404
+ for token, id in new_tokens.items():
405
+ if token not in existing_tokens:
406
+ tokenizer.add_tokens([token])
407
+ data = training_data.get("data", [])
408
+ if data:
409
+ dataset = SyntheticDataset(tokenizer, data)
410
+ trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
411
+ trainer.train()
412
+ model_buffer = io.BytesIO()
413
+ torch.save(unified_model.state_dict(), model_buffer)
414
+ model_buffer.seek(0)
415
+ blob = bucket.blob(f"model_{model_name}")
416
+ blob.upload_from_file(model_buffer, content_type="application/octet-stream")
417
+ new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8")
418
+ blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
419
+ blob.upload_from_string(new_tokenizer_bytes, content_type="application/json")
420
+
421
+ initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json"
422
+ if bucket.blob("initial_data.json").exists():
423
+ blob = bucket.blob("initial_data.json")
424
+ initial_data_bytes = blob.download_as_bytes()
425
+ initial_data = json.loads(initial_data_bytes)
426
+ dataset = SyntheticDataset(tokenizer, initial_data)
427
+ trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
428
+ trainer.train()
429
+ model_buffer = io.BytesIO()
430
+ torch.save(unified_model.state_dict(), model_buffer)
431
+ model_buffer.seek(0)
432
+ blob = bucket.blob(f"model_{model_name}")
433
+ blob.upload_from_file(model_buffer, content_type="application/octet-stream")
434
+
435
+
436
+ def train_text_model():
437
+ global tokenizer, unified_model
438
+ model_name = "text_model"
439
+ model_path = f"gs://{GCS_BUCKET_NAME}/model_{model_name}"
440
+ training_args = TrainingArguments(
441
+ output_dir=f"gs://{GCS_BUCKET_NAME}/results",
442
+ per_device_train_batch_size=8,
443
+ num_train_epochs=3,
444
+ )
445
+ while True:
446
+ training_queue_path = f"gs://{GCS_BUCKET_NAME}/training_queue.json"
447
+ if bucket.blob("training_queue.json").exists():
448
+ blob = bucket.blob("training_queue.json")
449
+ training_queue_bytes = blob.download_as_bytes()
450
+ training_data_list = json.loads(training_queue_bytes)
451
+ if training_data_list:
452
+ training_data = training_data_list.pop(0)
453
+ new_data_bytes = json.dumps(training_data_list).encode("utf-8")
454
+ blob = bucket.blob("training_queue.json")
455
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
456
+
457
+ tokenizer_data = training_data.get("tokenizers")
458
+ if tokenizer_data:
459
+ tokenizer_name = list(tokenizer_data.keys())[0]
460
+ existing_tokens = tokenizer.get_vocab()
461
+ new_tokens = tokenizer_data[tokenizer_name]
462
+ for token, id in new_tokens.items():
463
+ if token not in existing_tokens:
464
+ tokenizer.add_tokens([token])
465
+ data = training_data.get("data", [])
466
+ if data:
467
+ dataset = SyntheticDataset(tokenizer, data)
468
+ trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
469
+ trainer.train()
470
+ model_buffer = io.BytesIO()
471
+ torch.save(unified_model.state_dict(), model_buffer)
472
+ model_buffer.seek(0)
473
+ blob = bucket.blob(f"model_{model_name}")
474
+ blob.upload_from_file(model_buffer, content_type="application/octet-stream")
475
+ new_tokenizer_bytes = json.dumps(tokenizer.get_vocab()).encode("utf-8")
476
+ blob = bucket.blob(f"tokenizer_{tokenizer_name}.json")
477
+ blob.upload_from_string(new_tokenizer_bytes, content_type="application/json")
478
+
479
+ initial_data_path = f"gs://{GCS_BUCKET_NAME}/initial_data.json"
480
+ if bucket.blob("initial_data.json").exists():
481
+ blob = bucket.blob("initial_data.json")
482
+ initial_data_bytes = blob.download_as_bytes()
483
+ initial_data = json.loads(initial_data_bytes)
484
+ dataset = SyntheticDataset(tokenizer, initial_data)
485
+ trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
486
+ trainer.train()
487
+ model_buffer = io.BytesIO()
488
+ torch.save(unified_model.state_dict(), model_buffer)
489
+ model_buffer.seek(0)
490
+ blob = bucket.blob(f"model_{model_name}")
491
+ blob.upload_from_file(model_buffer, content_type="application/octet-stream")
492
+
493
+ def train_image_model():
494
+ global image_pipeline
495
+ while True:
496
+ image_training_queue_path = f"gs://{GCS_BUCKET_NAME}/image_training_queue.json"
497
+ if bucket.blob("image_training_queue.json").exists():
498
+ blob = bucket.blob("image_training_queue.json")
499
+ image_training_queue_bytes = blob.download_as_bytes()
500
+ image_training_data_list = json.loads(image_training_queue_bytes)
501
+ if image_training_data_list:
502
+ image_training_data = image_training_data_list.pop(0)
503
+ new_data_bytes = json.dumps(image_training_data_list).encode("utf-8")
504
+ blob = bucket.blob("image_training_queue.json")
505
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
506
+ image_pipeline.model.to("cuda")
507
+ image_pipeline.model.train()
508
+ optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5)
509
+ loss_fn = torch.nn.MSELoss()
510
+ for epoch in range(3):
511
+ for i in tqdm(range(len(image_training_data)), desc=f"Epoch {epoch+1}"):
512
+ image_prompt = image_training_data[i]
513
+ image = image_pipeline(
514
+ image_prompt,
515
+ guidance_scale=0.0,
516
+ num_inference_steps=4,
517
+ max_sequence_length=256,
518
+ generator=torch.Generator("cuda").manual_seed(0)
519
+ ).images[0]
520
+ image_tensor = torch.tensor(np.array(image)).unsqueeze(0).to("cuda")
521
+ target_tensor = torch.zeros_like(image_tensor)
522
+ outputs = image_pipeline.model(image_tensor)
523
+ loss = loss_fn(outputs, target_tensor)
524
+ optimizer.zero_grad()
525
+ loss.backward()
526
+ optimizer.step()
527
+ print(f"Epoch {epoch+1}, Step {i+1}/{len(image_training_data)}, Loss: {loss.item()}")
528
+
529
+ def train_music_model():
530
+ global musicgen_tokenizer, musicgen_model
531
+ while True:
532
+ music_training_queue_path = f"gs://{GCS_BUCKET_NAME}/music_training_queue.json"
533
+ if bucket.blob("music_training_queue.json").exists():
534
+ blob = bucket.blob("music_training_queue.json")
535
+ music_training_queue_bytes = blob.download_as_bytes()
536
+ music_training_data_list = json.loads(music_training_queue_bytes)
537
+ if music_training_data_list:
538
+ music_training_data = music_training_data_list.pop(0)
539
+ new_data_bytes = json.dumps(music_training_data_list).encode("utf-8")
540
+ blob = bucket.blob("music_training_queue.json")
541
+ blob.upload_from_string(new_data_bytes, content_type="application/json")
542
+
543
+ inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True).to("cuda")
544
+ musicgen_model.to("cuda")
545
+ musicgen_model.train()
546
+ optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5)
547
+ loss_fn = torch.nn.CrossEntropyLoss()
548
+ for epoch in range(3):
549
+ for i in tqdm(range(len(inputs["input_ids"])), desc=f"Epoch {epoch+1}"):
550
+ outputs = musicgen_model(**inputs)
551
+ loss = loss_fn(outputs.logits, inputs['labels'])
552
+ optimizer.zero_grad()
553
+ loss.backward()
554
+ optimizer.step()
555
+ print(f"Epoch {epoch+1}, Step {i+1}/{len(inputs['input_ids'])}, Loss: {loss.item()}")
556
+
557
+
558
+ if __name__ == "__main__":
559
+ import uvicorn
560
+ uvicorn.run(app, host="0.0.0.0", port=7860)
561
+
562
+ print("Iniciando entrenamiento automático del modelo unificado...")
563
+ auto_learn_process = multiprocessing.Process(target=train_unified_model)
564
+ auto_learn_process.start()
565
+
566
+ print("Iniciando entrenamiento automático del modelo de texto...")
567
+ auto_learn_process_2 = multiprocessing.Process(target=train_text_model)
568
+ auto_learn_process_2.start()
569
+
570
+ print("Iniciando entrenamiento automático del modelo de imagen...")
571
+ auto_learn_process_3 = multiprocessing.Process(target=train_image_model)
572
+ auto_learn_process_3.start()
573
+
574
+ print("Iniciando entrenamiento automático del modelo de música...")
575
+ auto_learn_process_4 = multiprocessing.Process(target=train_music_model)
576
+ auto_learn_process_4.start()