Spaces:
Runtime error
Runtime error
File size: 2,624 Bytes
56670e1 5f5dede 56670e1 2223ef6 5f9bfe0 5f5dede 56670e1 81ceb92 56670e1 a95cf91 ddb6473 56670e1 81ceb92 5f5dede dacd217 5f5dede 56670e1 2223ef6 5f5dede 56670e1 46f08aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import re
models = {
"RUSpam/spam_deberta_v4": "RUSpam/spam_deberta_v4",
"RUSpam/spamNS_v1": "RUSpam/spamNS_v1"
}
tokenizers = {}
model_instances = {}
for name, path in models.items():
tokenizers[name] = AutoTokenizer.from_pretrained(path)
model_instances[name] = AutoModelForSequenceClassification.from_pretrained(path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_instances["RUSpam/spamNS_v1"] = model_instances["RUSpam/spamNS_v1"].to(device).eval()
def clean_text(text):
text = re.sub(r'http\S+', '', text)
text = re.sub(r'[^А-Яа-я0-9 ]+', ' ', text)
text = text.lower().strip()
return text
def predict_spam_deberta(text):
tokenizer = tokenizers["RUSpam/spam_deberta_v4"]
model = model_instances["RUSpam/spam_deberta_v4"]
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
result = "Спам" if predicted_class == 1 else "Не спам"
return result
def predict_spam_spamns(text):
tokenizer = tokenizers["RUSpam/spamNS_v1"]
model = model_instances["RUSpam/spamNS_v1"]
text = clean_text(text)
encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask).logits
pred = torch.sigmoid(outputs).cpu().numpy()[0][0]
result = "Спам" if pred >= 0.5 else "Не спам"
return result
def predict_spam(text, model_choice):
if model_choice == "RUSpam/spam_deberta_v4":
return predict_spam_deberta(text)
elif model_choice == "RUSpam/spamNS_v1":
return predict_spam_spamns(text)
# Создание интерфейса Gradio
iface = gr.Interface(
fn=predict_spam,
inputs=[
gr.Textbox(lines=5, label="Введите текст"),
gr.Radio(choices=list(models.keys()), label="Выберите модель", value="RUSpam/spam_deberta_v4")
],
outputs=gr.Label(label="Результат"),
title="Определение спама в русскоязычных текстах"
)
iface.launch()
|