import gradio as gr import networkx as nx import matplotlib.pyplot as plt import io from PIL import Image import json from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline import os import re from collections import defaultdict import torch os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # Load NER model print("Loading NER model...") ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # Load REBEL model with better error handling and optimization print("Loading REBEL model...") rebel_pipeline = None def load_rebel_model(): global rebel_pipeline models_to_try = [ "Babelscape/rebel-small", "Babelscape/rebel-base" ] for model_name in models_to_try: try: print(f"Trying to load {model_name}...") rebel_pipeline = pipeline( "text2text-generation", model=model_name, tokenizer=model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print(f"✅ Successfully loaded {model_name}") return True except Exception as e: print(f"❌ Failed to load {model_name}: {str(e)[:100]}...") continue print("❌ Could not load any REBEL model") return False # Try to load REBEL model rebel_loaded = load_rebel_model() def extract_entities(text): """Extract named entities using BERT-NER""" entities = ner_pipeline(text) processed_entities = [] for entity in entities: processed_entities.append({ "text": entity["word"].replace("##", ""), # Clean subword tokens "label": entity["entity_group"], "start": int(entity["start"]), "end": int(entity["end"]), "confidence": round(float(entity["score"]), 3) }) return processed_entities def parse_rebel_output(generated_text): """Enhanced REBEL output parser with multiple pattern matching""" triplets = [] print(f"Raw REBEL output: {generated_text}") # Clean the output generated_text = generated_text.strip() # Multiple parsing strategies parsing_patterns = [ # Standard REBEL format r'\s*([^<]+?)\s*\s*([^<]+?)\s*\s*([^<]+?)(?:\s*<\/triplet>|\s*|\s*$)', # Alternative format without closing tags r'\s*([^<]+?)\s*\s*([^<]+?)\s*\s*([^<]+?)(?=\s*|\s*$)', # Simplified format r'([^<\n]+?)\s*\s*([^<\n]+?)\s*\s*([^<\n]+?)(?:\n|$)', ] for i, pattern in enumerate(parsing_patterns): matches = re.findall(pattern, generated_text, re.IGNORECASE | re.MULTILINE) print(f"Pattern {i+1} found {len(matches)} matches") for match in matches: if len(match) >= 3: relation = clean_text(match[0]) subject = clean_text(match[1]) obj = clean_text(match[2]) if validate_triplet(subject, relation, obj): triplets.append({ "subject": subject, "relation": format_relation(relation), "object": obj, "confidence": 0.9, "source": "REBEL" }) # If we found valid triplets, use them if triplets: break # Fallback: try to extract any meaningful patterns if not triplets: print("Trying fallback parsing...") # Look for any pattern that might be relations fallback_pattern = r'([A-Za-z][A-Za-z\s]+?)\s+([a-z_]+)\s+([A-Za-z][A-Za-z\s]+?)(?:\.|$|\n)' matches = re.findall(fallback_pattern, generated_text) for match in matches: subject = clean_text(match[0]) relation = clean_text(match[1]) obj = clean_text(match[2]) if validate_triplet(subject, relation, obj): triplets.append({ "subject": subject, "relation": format_relation(relation), "object": obj, "confidence": 0.7, "source": "REBEL-fallback" }) return triplets def clean_text(text): """Clean extracted text""" if not text: return "" # Remove HTML tags and special tokens text = re.sub(r'<[^>]+>', '', text) # Remove extra whitespace text = ' '.join(text.split()) # Remove leading/trailing punctuation text = text.strip('.,!?;: ') return text def format_relation(relation): """Format relation text for better readability""" if not relation: return "related_to" # Common relation mappings relation_map = { 'ceo': 'CEO_of', 'founder': 'founded_by', 'president': 'president_of', 'member': 'member_of', 'location': 'located_in', 'country': 'country_of', 'spouse': 'married_to', 'parent': 'parent_of', 'child': 'child_of', 'sibling': 'sibling_of', 'employee': 'works_for', 'owner': 'owns', 'creator': 'created_by' } relation_lower = relation.lower().strip() # Check direct mapping if relation_lower in relation_map: return relation_map[relation_lower] # Format underscores and spaces formatted = relation.replace('_', ' ').replace('-', ' ') formatted = ' '.join(word.capitalize() for word in formatted.split()) return formatted def validate_triplet(subject, relation, object_text): """Validate if a triplet makes sense""" if not subject or not relation or not object_text: return False # Check minimum length if len(subject) < 2 or len(object_text) < 2: return False # Check if subject and object are different if subject.lower() == object_text.lower(): return False # Check for reasonable length (not too long) if len(subject) > 50 or len(object_text) > 50 or len(relation) > 30: return False # Check for non-alphabetic content if not re.search(r'[A-Za-z]', subject) or not re.search(r'[A-Za-z]', object_text): return False return True def extract_relations_rebel(text): """Extract relations using REBEL model with optimized parameters""" if not rebel_pipeline: return [] try: # Preprocess text for better REBEL performance text = preprocess_text_for_rebel(text) # Generate with optimized parameters generated_tokens = rebel_pipeline( text, max_length=512, min_length=10, num_beams=3, do_sample=False, early_stopping=True, return_full_text=False, clean_up_tokenization_spaces=True ) generated_text = generated_tokens[0]["generated_text"] # Parse the output triplets = parse_rebel_output(generated_text) print(f"REBEL extracted {len(triplets)} relations") return triplets except Exception as e: print(f"REBEL extraction error: {e}") return [] def preprocess_text_for_rebel(text): """Preprocess text to improve REBEL performance""" # Limit length for better processing sentences = re.split(r'[.!?]+', text) # Take first 2-3 sentences if text is too long if len(' '.join(sentences)) > 200: text = '. '.join(sentences[:3]) + '.' # Clean up the text text = re.sub(r'\s+', ' ', text) # Remove extra whitespace text = text.strip() return text def create_simple_fallback_relations(entities): """Create simple relations when REBEL fails""" relations = [] if len(entities) < 2: return relations # Create relations based on entity types and proximity for i, ent1 in enumerate(entities[:-1]): ent2 = entities[i + 1] relation_type = determine_relation_by_type(ent1["label"], ent2["label"]) relations.append({ "subject": ent1["text"], "relation": relation_type, "object": ent2["text"], "confidence": 0.5, "source": "type-based" }) return relations[:5] # Limit to 5 relations def determine_relation_by_type(type1, type2): """Determine relation type based on entity types""" type_relations = { ("PER", "ORG"): "works_for", ("ORG", "PER"): "employs", ("PER", "LOC"): "lives_in", ("ORG", "LOC"): "located_in", ("ORG", "ORG"): "partners_with", ("PER", "PER"): "knows", ("LOC", "LOC"): "near", ("MISC", "ORG"): "owned_by", ("MISC", "PER"): "used_by" } return type_relations.get((type1, type2), "related_to") def extract_relations(text): """Main relation extraction function""" try: entities = extract_entities(text) print(f"Found {len(entities)} entities") if rebel_loaded: # Try REBEL first relations = extract_relations_rebel(text) if relations: return relations else: print("REBEL didn't return relations, using fallback...") # Fallback to simple relations relations = create_simple_fallback_relations(entities) return relations except Exception as e: print(f"Relation extraction error: {e}") return [] def create_knowledge_graph(triplets): if not triplets: return None, "No relations found. Try entering text with clearer relationships." G = nx.DiGraph() # Add edges with labels edge_labels = {} for triplet in triplets: subject = triplet["subject"] obj = triplet["object"] relation = triplet["relation"] if subject and obj and subject != obj: G.add_edge(subject, obj) edge_labels[(subject, obj)] = relation if len(G.nodes()) == 0: return None, "No valid graph nodes created." # Create visualization plt.figure(figsize=(14, 10)) plt.clf() # Layout if len(G.nodes()) <= 6: pos = nx.spring_layout(G, k=3, iterations=100, seed=42) else: pos = nx.spring_layout(G, k=2, iterations=50, seed=42) # Draw nodes nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=4000, alpha=0.8, linewidths=2, edgecolors='darkblue') # Draw edges nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=25, alpha=0.6, width=2, connectionstyle="arc3,rad=0.1") # Draw labels nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold') nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=8, font_color='red', font_weight='bold') plt.title("Knowledge Graph (REBEL + Fallback)", size=16, weight='bold') plt.axis('off') plt.tight_layout() # Save to buffer img_buffer = io.BytesIO() plt.savefig(img_buffer, format='png', dpi=200, bbox_inches='tight') img_buffer.seek(0) img = Image.open(img_buffer) plt.close() return img, f"Graph created with {len(G.nodes())} nodes and {len(G.edges())} edges." def format_entities_for_display(entities): return [(entity["text"], entity["label"]) for entity in entities] def process_news_text(text): if not text.strip(): return [], "No text provided", None, "Please enter some text to analyze." try: entities = extract_entities(text) entity_display = format_entities_for_display(entities) triplets = extract_relations(text) graph_img, graph_status = create_knowledge_graph(triplets) results = { "entities_found": len(entities), "relations_found": len(triplets), "rebel_model_loaded": rebel_loaded, "entities": entities, "triplets": triplets } status = f"✅ Found {len(entities)} entities, {len(triplets)} relations" if rebel_loaded: status += " (REBEL enabled)" else: status += " (REBEL not available, using fallback)" return entity_display, json.dumps(results, indent=2), graph_img, status except Exception as e: error_msg = f"❌ Error: {str(e)}" return [], "{}", None, error_msg # Examples examples = [ "AI is reshaping corporate fortunes: while Nvidia, Microsoft, and Google surge with AI-driven gains, Apple and Tesla have lagged, revealing a growing split among the 'Magnificent Seven'—with investors watching to see if laggards can catch up or if the group will fracture entirely.", "Elon Musk is steering Tesla toward becoming an AI robotics powerhouse, integrating his startup xAI into Tesla vehicles—he’s also asked shareholders to approve Tesla funding xAI, marking a bold shift away from traditional EV focus toward autonomous driving, humanoid robots, and supercomputing infrastructure.", "OpenAI, valued at $300 billion with over 500 million weekly users, is under pressure from rivals like Meta, Google, Amazon, and xAI—despite strong uptake, it’s battling talent poaching, delayed model launches due to safety reviews, and legal disputes with Microsoft over partnership terms and AGI control.", "Microsoft’s AI chief Mustafa Suleyman stresses a pragmatic, human-centered AI strategy: his focus is on safe, real-world tools like Copilot and Bing, not speculative AGI; he estimates AGI is at least a decade away, reflecting Microsoft’s measured balance with its OpenAI partnership.", "Jony Ive, the legendary designer behind the iPhone, is joining OpenAI after a $6.5 billion acquisition of his hardware startup io; the deal sets OpenAI on course to develop consumer AI devices, signaling a major push beyond software into hardware innovation." ] # Create Gradio interface with gr.Blocks(title="REBEL Knowledge Graph Extractor", theme=gr.themes.Soft()) as demo: gr.HTML(f"""

🤖 News Knowledge Graph Extractor

Optimized for REBEL model relation extraction

Status: REBEL Model {'✅ Loaded' if rebel_loaded else '❌ Not Available'}

""") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input Text", placeholder="Enter your text here...", lines=6 ) process_btn = gr.Button("Extract Relations", variant="primary") status_output = gr.Textbox( label="Status", interactive=False, max_lines=2 ) with gr.Column(): entity_output = gr.HighlightedText( label="Named Entities", color_map={ "PER": "lightblue", "ORG": "lightgreen", "LOC": "orange", "MISC": "lightpink" } ) results_output = gr.JSON(label="Detailed Results") with gr.Row(): graph_output = gr.Image(label="Knowledge Graph", height=600) with gr.Row(): gr.Examples(examples=examples, inputs=[input_text]) # Event handlers process_btn.click( fn=process_news_text, inputs=[input_text], outputs=[entity_output, results_output, graph_output, status_output] ) input_text.submit( fn=process_news_text, inputs=[input_text], outputs=[entity_output, results_output, graph_output, status_output] ) if __name__ == "__main__": demo.launch(server_port=7860, share=False)