SDC-multi-classifier / local_embedder.py
DocUA's picture
добавлення функціонала для підключення моделей для локального ембедінга
aaec566
import numpy as np
import torch
from typing import List, Union, Dict
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import json
class LocalEmbedder:
def __init__(self, model_name: str, device: str = None, batch_size: int = 32):
"""
Ініціалізація локальної моделі для ембедінгів
Args:
model_name: назва або шлях до моделі (з HuggingFace або локальна)
device: пристрій для обчислень ('cuda', 'cpu' або None - автовибір)
batch_size: розмір батчу для інференсу
"""
self.model_name = model_name
self.batch_size = batch_size
# Визначення пристрою
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
# Завантаження моделі та токенізатора
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
# Максимальна довжина послідовності
self.max_length = self.tokenizer.model_max_length
if self.max_length > 512:
self.max_length = 512
def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""
L2-нормалізація ембедінгів
Args:
embeddings: матриця ембедінгів
Returns:
np.ndarray: нормалізована матриця ембедінгів
"""
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
def get_embeddings(self, texts: Union[str, List[str]]) -> np.ndarray:
"""
Отримання ембедінгів для тексту або списку текстів
Args:
texts: текст або список текстів
Returns:
np.ndarray: матриця нормалізованих ембедінгів
"""
if isinstance(texts, str):
texts = [texts]
all_embeddings = []
with torch.no_grad():
for i in range(0, len(texts), self.batch_size):
batch_texts = texts[i:i + self.batch_size]
# Токенізація
encoded = self.tokenizer.batch_encode_plus(
batch_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
# Переміщуємо тензори на потрібний пристрій
input_ids = encoded['input_ids'].to(self.device)
attention_mask = encoded['attention_mask'].to(self.device)
# Отримуємо ембедінги
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
# Використовуємо [CLS] токен як ембедінг
embeddings = outputs.last_hidden_state[:, 0, :]
all_embeddings.append(embeddings.cpu().numpy())
# Об'єднуємо всі батчі
embeddings = np.vstack(all_embeddings)
# Нормалізуємо ембедінги
normalized_embeddings = self._normalize_embeddings(embeddings)
return normalized_embeddings
def get_model_info(self) -> Dict[str, any]:
"""
Отримання інформації про модель
Returns:
Dict: інформація про модель
"""
return {
'model_name': self.model_name,
'device': self.device,
'embedding_size': self.model.config.hidden_size,
'max_length': self.max_length,
'batch_size': self.batch_size
}