import json import sqlite3 import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from PIL import Image import clip import faiss import numpy as np import glob # Đường dẫn lưu trữ VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db' IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index' TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index' # Đường dẫn dữ liệu DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data' MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data') CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions') # Kết nối SQLite conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH) cursor = conn.cursor() # Tạo bảng embeddings cho ảnh và văn bản cursor.execute(''' CREATE TABLE IF NOT EXISTS image_embeddings ( e_index INTEGER PRIMARY KEY, image_path TEXT NOT NULL, caption TEXT NOT NULL, category TEXT NOT NULL, subcategory TEXT NOT NULL ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS text_embeddings ( e_index INTEGER PRIMARY KEY, text TEXT NOT NULL, category TEXT NOT NULL, subcategory TEXT NOT NULL ) ''') def insert_image_embedding(e_index, image_path, caption, category, subcategory): """Thêm embedding ảnh vào SQLite.""" cursor.execute(''' INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory) VALUES (?, ?, ?, ?, ?) ''', (e_index, image_path, caption, category, subcategory)) conn.commit() print(f"Đã thêm embedding ảnh: {image_path}") def insert_text_embedding(e_index, text, category, subcategory): """Thêm embedding văn bản vào SQLite.""" cursor.execute(''' INSERT INTO text_embeddings (e_index, text, category, subcategory) VALUES (?, ?, ?, ?) ''', (e_index, text, category, subcategory)) conn.commit() print(f"Đã thêm embedding văn bản: {text[:50]}...") def save_faiss_index(index, index_file): """Lưu FAISS index vào file.""" faiss.write_index(index, index_file) print(f"Đã lưu FAISS index vào {index_file}") def load_faiss_index(index_file): """Nạp FAISS index từ file.""" if os.path.exists(index_file): index = faiss.read_index(index_file) print(f"Đã nạp FAISS index từ {index_file}") return index return None def compute_embeddings(): """Tính toán embeddings cho ảnh và văn bản sử dụng CLIP.""" print("Loading CLIP model...") device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) print("Model loaded") # Lấy danh sách các thư mục con (categories) categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))] image_paths = [] captions = [] texts = [] categories_list = [] subcategories_list = [] # Chuẩn bị dữ liệu print("Processing data from directories...") for category in categories: # Đường dẫn đến thư mục category category_path = os.path.join(MAIN_DATA_PATH, category) # Lấy danh sách các subcategories subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))] for subcategory in subcategories: # Đường dẫn đến thư mục ảnh và caption của subcategory subcategory_image_path = os.path.join(category_path, subcategory) subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory) # Lấy danh sách ảnh image_files = glob.glob(os.path.join(subcategory_image_path, '*.*')) for img_path in image_files: # Lấy tên file không có phần mở rộng base_name = os.path.splitext(os.path.basename(img_path))[0] caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt") if os.path.exists(caption_file): try: # Đọc caption with open(caption_file, 'r', encoding='utf-8') as f: caption = f.read().strip() # Thêm vào danh sách image_paths.append(img_path) captions.append(caption) texts.append(caption) # Sử dụng caption làm text categories_list.append(category) subcategories_list.append(subcategory) except Exception as e: print(f"Error processing {img_path}: {e}") continue # Tính toán embeddings cho ảnh # if image_paths: # print("Computing image embeddings...") # image_embeddings = [] # for idx, img_path in enumerate(image_paths): # try: # image = preprocess(Image.open(img_path)).unsqueeze(0).to(device) # with torch.no_grad(): # image_features = model.encode_image(image) # image_features = image_features.cpu().numpy() # faiss.normalize_L2(image_features) # image_embeddings.append(image_features[0]) # insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx]) # except Exception as e: # print(f"Error processing image {img_path}: {e}") # continue # if image_embeddings: # image_embeddings = np.array(image_embeddings) # d = image_embeddings.shape[1] # image_index = faiss.IndexFlatIP(d) # image_index.add(image_embeddings) # save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH) # Tính toán embeddings cho văn bản if texts: print("Computing text embeddings...") text_tokens = clip.tokenize(texts, truncate=True).to(device) print("Kích thước của text_tokens:", text_tokens.shape) with torch.no_grad(): text_features = model.encode_text(text_tokens) text_features = text_features.cpu().numpy() faiss.normalize_L2(text_features) d = text_features.shape[1] text_index = faiss.IndexFlatIP(d) text_index.add(text_features) # Lưu text embeddings vào SQLite for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)): insert_text_embedding(idx, text, category, subcategory) save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH) print("Processing completed") return image_index if image_paths else None, text_index if texts else None def predict_image(image_path): device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) image = preprocess(Image.open(image_path)).unsqueeze(0).to(device) with torch.no_grad(): image_features = model.encode_image(image) image_features = image_features.cpu().numpy() faiss.normalize_L2(image_features) index = load_faiss_index(IMAGE_FAISS_INDEX_PATH) distances, indices = index.search(image_features, k=10) return distances, indices if __name__ == '__main__': ## Predict try: image_index, text_index = compute_embeddings() if image_index: print(f"Image index ready with {image_index.ntotal} embeddings") if text_index: print(f"Text index ready with {text_index.ntotal} embeddings") finally: conn.close() print("SQLite connection closed")