Irshadcse2k16's picture
Upload 11 files
2f9781d verified
raw
history blame
1.44 kB
import torch
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from peft import PeftModel, PeftConfig
import gradio as gr
def predict_phishing_url(url_to_predict, model_path="./roberta_classifier"):
config = PeftConfig.from_pretrained(model_path)
inference_model = RobertaForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=2)
inference_model = PeftModel.from_pretrained(inference_model, model_path)
inference_tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
inference_model.to("cpu")
inference_model.eval()
inputs = inference_tokenizer(url_to_predict, padding='max_length', truncation=True, max_length=128, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = inference_model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class_id = torch.argmax(predictions).item()
probability_phishing = predictions[0, 1].item()
result = "Phishing" if predicted_class_id == 1 else "Legitimate"
return f"{result} (Confidence: {probability_phishing:.2f})"
# Gradio interface
iface = gr.Interface(
fn=predict_phishing_url,
inputs=gr.Textbox(lines=1, placeholder="Enter URL..."),
outputs="text",
title="Phishing URL Detector",
description="Enter a URL to classify it as Phishing or Legitimate using RoBERTa-LoRA model."
)
iface.launch()