Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import re | |
models = { | |
"RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4", | |
"RUSpam/spamNS_v1": "RUSpam/spamNS_v1" | |
} | |
tokenizers = {} | |
model_instances = {} | |
for name, path in models.items(): | |
tokenizers[name] = AutoTokenizer.from_pretrained(path) | |
model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval() | |
def clean_text(text): | |
text = re.sub(r'http\S+', '', text) | |
text = re.sub(r'[^А-Яа-я0-9 ]+', ' ', text) | |
text = text.lower().strip() | |
return text | |
def predict_spam_deberta(text): | |
tokenizer = tokenizers["RUSpam/spam_deberta_v4"] | |
model = model_instances["RUSpam/spam_deberta_v4"] | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=1).item() | |
result = "Спам" if predicted_class == 1 else "Не спам" | |
return result | |
def predict_spam_spamns(text): | |
tokenizer = tokenizers["RUSpam/spamNS_v1"] | |
model = model_instances["RUSpam/spamNS_v1"] | |
text = clean_text(text) | |
encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask).logits | |
pred = torch.sigmoid(outputs).cpu().numpy()[0][0] | |
result = "Спам" if pred >= 0.5 else "Не спам" | |
return result | |
def predict_spam(text, model_choice): | |
if model_choice == "RUSpam/spam_deberta_v4": | |
return predict_spam_deberta(text) | |
elif model_choice == "RUSpam/spamNS_v1": | |
return predict_spam_spamns(text) | |
# Создание интерфейса Gradio | |
iface = gr.Interface( | |
fn=predict_spam, | |
inputs=[ | |
gr.Textbox(lines=5, label="Введите текст"), | |
gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4") | |
], | |
outputs=gr.Label(label="Результат"), | |
title="Определение спама в русскоязычных текстах" | |
) | |
iface.launch() | |