Spaces:
Sleeping
Sleeping
# test_preloaded_model.py | |
import gradio as gr | |
import pickle | |
import os | |
# --- Dependencies needed for the Class Definition --- | |
from transformers import pipeline | |
import torch | |
# --- Define the Class to Hold Both Pipelines --- | |
# IMPORTANT: This exact class definition MUST be present here, | |
# identical to the one in save_combined_model.py, for unpickling to work. | |
class CombinedAnalyzer: | |
""" | |
A class to encapsulate sentiment analysis and AI text detection pipelines. | |
NOTE: This definition must match the one used when saving the .pkl file. | |
""" | |
def __init__(self, sentiment_model_name="distilbert-base-uncased-finetuned-sst-2-english", | |
detector_model_name="Hello-SimpleAI/chatgpt-detector-roberta"): | |
print("Initializing CombinedAnalyzer structure...") | |
self.device = 0 if torch.cuda.is_available() else -1 | |
self.sentiment_model_name = sentiment_model_name | |
self.detector_model_name = detector_model_name | |
self.sentiment_pipeline = None | |
self.detector_pipeline = None | |
print(f"Class structure defined. Expecting pipelines for models: {sentiment_model_name}, {detector_model_name}") | |
def analyze(self, text): | |
""" | |
Analyzes the input text for both sentiment and authenticity. | |
""" | |
if not isinstance(text, str) or not text.strip(): | |
return "Error: Input text cannot be empty." | |
results = [] | |
# 1. Sentiment Analysis | |
if self.sentiment_pipeline and callable(self.sentiment_pipeline): | |
try: | |
sentiment_result = self.sentiment_pipeline(text)[0] | |
sentiment_label = sentiment_result['label'] | |
sentiment_score = round(sentiment_result['score'] * 100, 2) | |
results.append(f"Sentiment: {sentiment_label} (Confidence: {sentiment_score}%)") | |
except Exception as e: | |
results.append(f"Sentiment Analysis Error in loaded model: {e}") | |
else: | |
results.append("Sentiment Analysis: Model not available or not callable in loaded object.") | |
# 2. AI Text Detection (Authenticity) | |
if self.detector_pipeline and callable(self.detector_pipeline): | |
try: | |
detector_result = self.detector_pipeline(text)[0] | |
auth_label_raw = detector_result['label'] | |
auth_score = round(detector_result['score'] * 100, 2) | |
if auth_label_raw.lower() in ['chatgpt', 'ai', 'generated']: | |
auth_label_display = "Likely AI-Generated" | |
elif auth_label_raw.lower() in ['human', 'real']: | |
auth_label_display = "Likely Human-Written" | |
else: | |
auth_label_display = f"Label: {auth_label_raw}" | |
results.append(f"Authenticity: {auth_label_display} (Confidence: {auth_score}%)") | |
except Exception as e: | |
results.append(f"AI Text Detection Error in loaded model: {e}") | |
else: | |
results.append("Authenticity: AI Text Detector model not available or not callable in loaded object.") | |
return "\n".join(results) | |
# --- Load the Model Automatically on Startup --- | |
analyzer = None | |
pickle_filename = "combined_analyzer.pkl" | |
model_dir = "saved_model" | |
pickle_filepath = os.path.join(model_dir, pickle_filename) | |
model_load_error = None # Store potential loading error message | |
print(f"Attempting to load pre-saved model from: {pickle_filepath}") | |
try: | |
print("\n--- SECURITY WARNING ---") | |
print(f"Loading '{pickle_filepath}'. Unpickling data from untrusted sources is a security risk.") | |
print("Ensure this .pkl file was created by you or a trusted source.\n") | |
if not os.path.exists(pickle_filepath): | |
raise FileNotFoundError(f"Model file not found at {pickle_filepath}") | |
with open(pickle_filepath, 'rb') as f: | |
analyzer = pickle.load(f) | |
if not hasattr(analyzer, 'analyze') or not callable(analyzer.analyze): | |
raise TypeError("Loaded object is not a valid analyzer (missing 'analyze' method).") | |
else: | |
print("Model loaded successfully.") | |
sentiment_name = getattr(analyzer, 'sentiment_model_name', 'Unknown') | |
detector_name = getattr(analyzer, 'detector_model_name', 'Unknown') | |
print(f" -> Sentiment Model: {sentiment_name}") | |
print(f" -> Detector Model: {detector_name}") | |
except FileNotFoundError as e: | |
model_load_error = f"ERROR loading model: {e}" | |
print(model_load_error) | |
print("Please ensure 'save_combined_model.py' was run successfully and") | |
print(f"the file '{pickle_filename}' exists in the '{model_dir}' directory.") | |
except (pickle.UnpicklingError, TypeError, AttributeError) as e: | |
model_load_error = f"ERROR loading model: The pickle file might be corrupted, incompatible, or from a different version. Details: {e}" | |
print(model_load_error) | |
except Exception as e: | |
model_load_error = f"An unexpected ERROR occurred during model loading: {e}" | |
print(model_load_error) | |
# --- Define the Gradio Analysis Function --- | |
def analyze_text_interface(text_input): | |
"""Function called by Gradio to perform analysis using the pre-loaded model.""" | |
if analyzer is None: | |
# Use the stored error message if available | |
error_msg = model_load_error or f"ERROR: The analyzer model could not be loaded from '{pickle_filepath}'." | |
return error_msg | |
if not text_input or not text_input.strip(): | |
return "Please enter some text to analyze." | |
print(f"Analyzing text: '{text_input[:60]}...'") | |
try: | |
results = analyzer.analyze(text_input) | |
print("Analysis complete.") | |
return results | |
except Exception as e: | |
print(f"Error during analysis: {e}") | |
return f"An error occurred during analysis:\n{e}" | |
# --- Build the Gradio Interface --- | |
# **CORRECTION HERE:** Define the warning message string separately | |
warning_message_text = "" | |
if analyzer is None: | |
# Use newline characters safely outside the f-string expression | |
warning_message_text = "\n\n***WARNING: MODEL FAILED TO LOAD. ANALYSIS WILL NOT WORK.***" | |
# Construct the full description string | |
description_text = ( | |
f"Enter text to analyze using the pre-loaded model from '{pickle_filepath}'.\n" | |
"Checks for sentiment (Positive/Negative) and predicts if text is Human-Written or AI-Generated." | |
f"{warning_message_text}" # Use the variable here | |
) | |
interface = gr.Interface( | |
fn=analyze_text_interface, | |
inputs=gr.Textbox(lines=7, label="Text to Analyze", placeholder="Enter review text here..."), | |
outputs=gr.Textbox(lines=7, label="Analysis Results", interactive=False), | |
title="Sentiment & Authenticity Analyzer", | |
description=description_text, # Use the constructed description string | |
allow_flagging='never' | |
) | |
# --- Launch the Interface --- | |
if __name__ == "__main__": | |
if analyzer is None: | |
print("\n--- Interface launched, but MODEL IS NOT LOADED. Analysis will fail. ---") | |
else: | |
print("\n--- Launching Gradio Interface with pre-loaded model ---") | |
interface.launch() |