Nam Fam
commited on
Commit
·
ea99abb
1
Parent(s):
a1a0e26
update files
Browse files- .dockerignore +58 -0
- .gitignore +62 -0
- Dockerfile +34 -0
- app.py +72 -0
- llms.py +54 -0
- requirements.txt +13 -0
- tasks/__init__.py +0 -0
- tasks/classification.py +59 -0
- tasks/extraction.py +7 -0
- tasks/grammar_checking.py +27 -0
- tasks/intent_detection.py +54 -0
- tasks/knowledge_graph.py +272 -0
- tasks/ner.py +148 -0
- tasks/pos_tagging.py +178 -0
- tasks/retrieval.py +7 -0
- tasks/segmentation.py +7 -0
- tasks/sentiment_analysis.py +48 -0
- tasks/summarization.py +41 -0
- tasks/topic_classification.py +53 -0
- tasks/translation.py +43 -0
- ui/grammar_ui.py +56 -0
- ui/intent_ui.py +83 -0
- ui/kg_ui.py +231 -0
- ui/ner_ui.py +358 -0
- ui/ner_ui.py.new +362 -0
- ui/pos_ui.py +297 -0
- ui/sentiment_ui.py +108 -0
- ui/summarization_ui.py +101 -0
- ui/topic_ui.py +108 -0
- ui/translation_ui.py +122 -0
- utils/ner_helpers.py +88 -0
- utils/pos_helpers.py +38 -0
- utils/remote_client.py +42 -0
- utils/shared.py +1 -0
.dockerignore
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Git
|
2 |
+
.git
|
3 |
+
.gitignore
|
4 |
+
|
5 |
+
# Virtual Environment
|
6 |
+
.venv
|
7 |
+
venv/
|
8 |
+
env/
|
9 |
+
|
10 |
+
# IDE
|
11 |
+
.vscode/
|
12 |
+
.idea/
|
13 |
+
*.swp
|
14 |
+
*.swo
|
15 |
+
*~
|
16 |
+
|
17 |
+
# Python
|
18 |
+
__pycache__/
|
19 |
+
*.py[cod]
|
20 |
+
*$py.class
|
21 |
+
*.so
|
22 |
+
.pytest_cache/
|
23 |
+
.mypy_cache/
|
24 |
+
|
25 |
+
# Build and distribution
|
26 |
+
build/
|
27 |
+
dist/
|
28 |
+
*.egg-info/
|
29 |
+
|
30 |
+
# Local development
|
31 |
+
*.local
|
32 |
+
|
33 |
+
# Environment files (except example)
|
34 |
+
.env
|
35 |
+
!.env.example
|
36 |
+
|
37 |
+
# Logs and databases
|
38 |
+
*.log
|
39 |
+
*.sqlite
|
40 |
+
*.sqlite3
|
41 |
+
|
42 |
+
# OS generated files
|
43 |
+
.DS_Store
|
44 |
+
.DS_Store?
|
45 |
+
._*
|
46 |
+
.Spotlight-V100
|
47 |
+
.Trashes
|
48 |
+
ehthumbs.db
|
49 |
+
Thumbs.db
|
50 |
+
|
51 |
+
|
52 |
+
# Project specific exclusions
|
53 |
+
modal_inference/
|
54 |
+
tests/
|
55 |
+
api/
|
56 |
+
modal_client.py
|
57 |
+
*.ipynb
|
58 |
+
README1.MD
|
.gitignore
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Virtual Environment
|
24 |
+
.env
|
25 |
+
.venv
|
26 |
+
env/
|
27 |
+
venv/
|
28 |
+
ENV/
|
29 |
+
|
30 |
+
# IDE
|
31 |
+
.vscode/
|
32 |
+
.idea/
|
33 |
+
*.swp
|
34 |
+
*.swo
|
35 |
+
*~
|
36 |
+
|
37 |
+
# OS
|
38 |
+
.DS_Store
|
39 |
+
Thumbs.db
|
40 |
+
|
41 |
+
# Logs and databases
|
42 |
+
*.log
|
43 |
+
*.sqlite
|
44 |
+
|
45 |
+
# Local development
|
46 |
+
*.local
|
47 |
+
|
48 |
+
# Docker
|
49 |
+
data/
|
50 |
+
Dockerfile.dev
|
51 |
+
|
52 |
+
# Environment files (except example)
|
53 |
+
.env
|
54 |
+
.env.example
|
55 |
+
|
56 |
+
# Project specific
|
57 |
+
modal_inference/
|
58 |
+
utils/modal_client.py
|
59 |
+
api/
|
60 |
+
tests/
|
61 |
+
*.ipynb
|
62 |
+
README1.MD
|
Dockerfile
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use official Python image as base
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set environment variables
|
5 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
6 |
+
PYTHONUNBUFFERED=1 \
|
7 |
+
PIP_NO_CACHE_DIR=off \
|
8 |
+
PIP_DISABLE_PIP_VERSION_CHECK=on
|
9 |
+
|
10 |
+
# Set working directory
|
11 |
+
WORKDIR /app
|
12 |
+
|
13 |
+
# Install system dependencies
|
14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
15 |
+
build-essential \
|
16 |
+
&& rm -rf /var/lib/apt/lists/*
|
17 |
+
|
18 |
+
# Copy requirements first to leverage Docker cache
|
19 |
+
COPY requirements.txt .
|
20 |
+
|
21 |
+
# Install Python dependencies
|
22 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
23 |
+
|
24 |
+
# Copy the rest of the application
|
25 |
+
COPY . .
|
26 |
+
|
27 |
+
# Expose the port the app runs on
|
28 |
+
EXPOSE 7860
|
29 |
+
|
30 |
+
# Set environment variables for Gradio
|
31 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
32 |
+
|
33 |
+
# Command to run the application
|
34 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict, List, Any, Optional
|
3 |
+
from tasks.knowledge_graph import build_knowledge_graph
|
4 |
+
|
5 |
+
# Import all UI components
|
6 |
+
from ui.summarization_ui import summarization_ui
|
7 |
+
from ui.translation_ui import translation_ui
|
8 |
+
from ui.sentiment_ui import sentiment_ui
|
9 |
+
from ui.topic_ui import topic_ui
|
10 |
+
from ui.ner_ui import ner_ui
|
11 |
+
from ui.pos_ui import pos_ui
|
12 |
+
from ui.kg_ui import kg_ui
|
13 |
+
from ui.intent_ui import intent_ui
|
14 |
+
from ui.grammar_ui import grammar_ui
|
15 |
+
|
16 |
+
# UI function wrappers
|
17 |
+
def summarization_ui_wrapper():
|
18 |
+
return summarization_ui()
|
19 |
+
|
20 |
+
def translation_ui_wrapper():
|
21 |
+
return translation_ui()
|
22 |
+
|
23 |
+
def sentiment_analysis_ui_wrapper():
|
24 |
+
return sentiment_ui()
|
25 |
+
|
26 |
+
def topic_classification_ui_wrapper():
|
27 |
+
return topic_ui()
|
28 |
+
|
29 |
+
def named_entity_recognition_ui_wrapper():
|
30 |
+
return ner_ui()
|
31 |
+
|
32 |
+
def pos_tagging_ui_wrapper():
|
33 |
+
return pos_ui()
|
34 |
+
|
35 |
+
def extraction_ui():
|
36 |
+
return gr.Markdown("Information Extraction is currently under development.")
|
37 |
+
|
38 |
+
def retrieval_ui():
|
39 |
+
return gr.Markdown("Text Retrieval is currently under development.")
|
40 |
+
|
41 |
+
def grammar_ui_wrapper():
|
42 |
+
return grammar_ui()
|
43 |
+
|
44 |
+
with gr.Blocks(theme=gr.themes.Ocean(), title="Ling - Text Intelligence") as demo:
|
45 |
+
gr.HTML('''
|
46 |
+
<div style="text-align:center; padding: 24px 0 12px 0;">
|
47 |
+
<h1 style="font-size:2.5em; margin-bottom:0.2em; color:#0e7490; letter-spacing:2px; font-family:sans-serif;">Ling</h1>
|
48 |
+
<p style="font-size:1.3em; color:#444; margin-bottom:0.2em;">Text Intelligence Platform for Smart Insights</p>
|
49 |
+
</div>
|
50 |
+
''')
|
51 |
+
with gr.Tab("Summarization"):
|
52 |
+
summarization_ui_wrapper()
|
53 |
+
with gr.Tab("Translation"):
|
54 |
+
translation_ui_wrapper()
|
55 |
+
with gr.Tab("Sentiment Analysis"):
|
56 |
+
sentiment_analysis_ui_wrapper()
|
57 |
+
with gr.Tab("Topic Classification"):
|
58 |
+
topic_classification_ui_wrapper()
|
59 |
+
with gr.Tab("NER"):
|
60 |
+
named_entity_recognition_ui_wrapper()
|
61 |
+
with gr.Tab("POS Tagging"):
|
62 |
+
pos_tagging_ui_wrapper()
|
63 |
+
with gr.Tab("Intent Detection"):
|
64 |
+
intent_ui()
|
65 |
+
with gr.Tab("Grammar Checking"):
|
66 |
+
grammar_ui_wrapper()
|
67 |
+
with gr.Tab("Knowledge Graph"):
|
68 |
+
kg_ui()
|
69 |
+
with gr.Tab("Retrieval"):
|
70 |
+
retrieval_ui()
|
71 |
+
|
72 |
+
demo.launch()
|
llms.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import init_chat_model
|
2 |
+
from langchain_core.messages import HumanMessage
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from typing import List
|
5 |
+
from langchain.tools import BaseTool
|
6 |
+
from langchain.agents import initialize_agent, AgentType
|
7 |
+
|
8 |
+
_ = load_dotenv()
|
9 |
+
|
10 |
+
class LLM:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model: str = "gemini-2.0-flash",
|
14 |
+
model_provider: str = "google_genai",
|
15 |
+
temperature: float = 0.0,
|
16 |
+
max_tokens: int = 1000
|
17 |
+
):
|
18 |
+
self.chat_model = init_chat_model(
|
19 |
+
model=model,
|
20 |
+
model_provider=model_provider,
|
21 |
+
temperature=temperature,
|
22 |
+
max_tokens=max_tokens,
|
23 |
+
)
|
24 |
+
|
25 |
+
def generate(self, prompt: str) -> str:
|
26 |
+
message = HumanMessage(content=prompt)
|
27 |
+
response = self.chat_model.invoke([message])
|
28 |
+
return response.content
|
29 |
+
|
30 |
+
def bind_tools(self, tools: List[BaseTool], agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION):
|
31 |
+
"""
|
32 |
+
Bind LangChain tools to this model and return an AgentExecutor.
|
33 |
+
"""
|
34 |
+
return initialize_agent(
|
35 |
+
tools,
|
36 |
+
self.chat_model,
|
37 |
+
agent=agent_type,
|
38 |
+
verbose=False
|
39 |
+
)
|
40 |
+
|
41 |
+
def set_temperature(self, temperature: float):
|
42 |
+
"""
|
43 |
+
Set the temperature for the chat model.
|
44 |
+
"""
|
45 |
+
self.chat_model.temperature = temperature
|
46 |
+
|
47 |
+
def set_max_tokens(self, max_tokens: int):
|
48 |
+
"""
|
49 |
+
Set the maximum number of tokens for the chat model.
|
50 |
+
"""
|
51 |
+
self.chat_model.max_tokens = max_tokens
|
52 |
+
|
53 |
+
|
54 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.21.0
|
2 |
+
spacy==3.7.5
|
3 |
+
networkx==3.2.1
|
4 |
+
matplotlib==3.8.3
|
5 |
+
fastapi==0.115.11
|
6 |
+
uvicorn==0.27.0.post1
|
7 |
+
pydantic==2.9.2
|
8 |
+
langchain-core ==0.3.58
|
9 |
+
langchain-community==0.3.7
|
10 |
+
google-generativeai==0.8.3
|
11 |
+
python-dotenv==1.0.1
|
12 |
+
pyvis==0.3.2
|
13 |
+
ipython
|
tasks/__init__.py
ADDED
File without changes
|
tasks/classification.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llms import LLM
|
2 |
+
from utils.remote_client import execute_remote_task
|
3 |
+
|
4 |
+
def text_classification(text: str, model: str, task: str = "topic", candidate_labels=None, custom_instructions: str = "", use_llm: bool = True) -> str:
|
5 |
+
"""
|
6 |
+
Classify text using either LLM or traditional (Modal API) method.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
text: The text to classify
|
10 |
+
model: The model to use
|
11 |
+
task: Either "sentiment" or "topic"
|
12 |
+
candidate_labels: For topic classification, the list of candidate labels
|
13 |
+
custom_instructions: Optional instructions for LLM
|
14 |
+
use_llm: Whether to use LLM or traditional method
|
15 |
+
"""
|
16 |
+
if not text.strip():
|
17 |
+
return ""
|
18 |
+
if use_llm:
|
19 |
+
return _classification_with_llm(text, model, task, candidate_labels, custom_instructions)
|
20 |
+
else:
|
21 |
+
return _classification_with_traditional(text, model, candidate_labels)
|
22 |
+
|
23 |
+
def _classification_with_llm(text: str, model: str, task: str, candidate_labels=None, custom_instructions: str = "") -> str:
|
24 |
+
try:
|
25 |
+
llm = LLM(model=model)
|
26 |
+
|
27 |
+
if task == "sentiment":
|
28 |
+
prompt = (
|
29 |
+
f"Analyze the sentiment of the following text. Return ONLY one value: 'positive', 'negative', or 'neutral'.\n" +
|
30 |
+
(f"{custom_instructions}\n" if custom_instructions else "") +
|
31 |
+
f"Text: {text}\nSentiment:"
|
32 |
+
)
|
33 |
+
else: # topic classification
|
34 |
+
labels_str = ", ".join(candidate_labels) if candidate_labels else "any appropriate topic"
|
35 |
+
prompt = (
|
36 |
+
f"Classify the following text into ONE of these categories: {labels_str}.\n" +
|
37 |
+
f"Return ONLY the most appropriate category name.\n" +
|
38 |
+
(f"{custom_instructions}\n" if custom_instructions else "") +
|
39 |
+
f"Text: {text}\nCategory:"
|
40 |
+
)
|
41 |
+
|
42 |
+
result = llm.generate(prompt)
|
43 |
+
return result.strip()
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error in LLM classification: {str(e)}")
|
46 |
+
return "Oops! Something went wrong. Please try again later."
|
47 |
+
|
48 |
+
def _classification_with_traditional(text: str, model: str, labels=None) -> str:
|
49 |
+
try:
|
50 |
+
payload = {"text": text, "model": model}
|
51 |
+
if labels is not None:
|
52 |
+
payload["labels"] = labels
|
53 |
+
resp = execute_remote_task("classification", payload)
|
54 |
+
if "error" in resp:
|
55 |
+
return "Oops! Something went wrong. Please try again later."
|
56 |
+
return resp.get("labels", "")
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error in traditional classification: {str(e)}")
|
59 |
+
return "Oops! Something went wrong. Please try again later."
|
tasks/extraction.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.remote_client import execute_remote_task
|
2 |
+
|
3 |
+
def information_extraction(text: str, model: str) -> str:
|
4 |
+
resp = execute_remote_task("extraction", {"text": text, "model": model})
|
5 |
+
if "error" in resp:
|
6 |
+
return "Oops! Something went wrong. Please try again later."
|
7 |
+
return resp.get("entities", "")
|
tasks/grammar_checking.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from llms import LLM
|
3 |
+
|
4 |
+
def grammar_checking(text: str, model: str, custom_instructions: Optional[str] = None, use_llm: bool = True) -> str:
|
5 |
+
"""Grammar and spelling correction using LLM or traditional method"""
|
6 |
+
if not text or not text.strip():
|
7 |
+
return "Please enter input text."
|
8 |
+
if use_llm:
|
9 |
+
return _grammar_checking_with_llm(text, model, custom_instructions)
|
10 |
+
else:
|
11 |
+
return _grammar_checking_with_traditional(text, model)
|
12 |
+
|
13 |
+
def _grammar_checking_with_llm(text: str, model: str, custom_instructions: Optional[str]) -> str:
|
14 |
+
try:
|
15 |
+
llm = LLM(model=model)
|
16 |
+
prompt = (
|
17 |
+
(custom_instructions + "\n") if custom_instructions else ""
|
18 |
+
) + f"Check and correct grammar and spelling for the following text.\nText: {text}\nCorrected:"
|
19 |
+
result = llm.generate(prompt)
|
20 |
+
return result.strip()
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error in LLM grammar checking: {str(e)}")
|
23 |
+
return "Oops! Something went wrong. Please try again later."
|
24 |
+
|
25 |
+
def _grammar_checking_with_traditional(text: str, model: str) -> str:
|
26 |
+
# Placeholder for traditional grammar checking (could use LanguageTool or similar)
|
27 |
+
return "[Traditional grammar checking is not implemented. Please use LLM mode.]"
|
tasks/intent_detection.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
def intent_detection(
|
4 |
+
text: str,
|
5 |
+
model: str,
|
6 |
+
candidate_intents: Optional[List[str]] = None,
|
7 |
+
custom_instructions: str = "",
|
8 |
+
use_llm: bool = True
|
9 |
+
) -> str:
|
10 |
+
if not text or not text.strip():
|
11 |
+
return "Please enter input text."
|
12 |
+
if use_llm:
|
13 |
+
return _intent_detection_with_llm(text, model, candidate_intents, custom_instructions)
|
14 |
+
else:
|
15 |
+
return _intent_detection_with_traditional(text, model, candidate_intents)
|
16 |
+
|
17 |
+
from llms import LLM
|
18 |
+
|
19 |
+
def _intent_detection_with_llm(
|
20 |
+
text: str,
|
21 |
+
model: str,
|
22 |
+
candidate_intents: Optional[List[str]],
|
23 |
+
custom_instructions: str
|
24 |
+
) -> str:
|
25 |
+
try:
|
26 |
+
llm = LLM(model=model)
|
27 |
+
if candidate_intents:
|
28 |
+
prompt = (
|
29 |
+
f"Classify the intent of the following text from this list: {', '.join(candidate_intents)}.\n"
|
30 |
+
f"Return ONLY the best intent name.\n"
|
31 |
+
+ (f"{custom_instructions}\n" if custom_instructions else "")
|
32 |
+
+ f"Text: {text}\nIntent:"
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
prompt = (
|
36 |
+
f"Detect the intent of the following text.\n"
|
37 |
+
f"Return ONLY the intent name, do not explain.\n"
|
38 |
+
+ (f"{custom_instructions}\n" if custom_instructions else "")
|
39 |
+
+ f"Text: {text}\nIntent:"
|
40 |
+
)
|
41 |
+
result = llm.generate(prompt)
|
42 |
+
return result.strip()
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error in LLM intent detection: {str(e)}")
|
45 |
+
return "Oops! Something went wrong. Please try again later."
|
46 |
+
|
47 |
+
|
48 |
+
def _intent_detection_with_traditional(
|
49 |
+
text: str,
|
50 |
+
model: str,
|
51 |
+
candidate_intents: Optional[List[str]]
|
52 |
+
) -> str:
|
53 |
+
# TODO: Implement traditional model inference
|
54 |
+
return "[Traditional model intent detection not implemented yet]"
|
tasks/knowledge_graph.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Tuple, Optional
|
2 |
+
import spacy
|
3 |
+
import networkx as nx
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from io import BytesIO
|
6 |
+
import base64
|
7 |
+
import re
|
8 |
+
import json
|
9 |
+
from langchain_core.messages import HumanMessage
|
10 |
+
from langchain.chat_models import init_chat_model
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
import os
|
13 |
+
# Interactive visualization
|
14 |
+
from pyvis.network import Network
|
15 |
+
|
16 |
+
# Load environment variables
|
17 |
+
_ = load_dotenv()
|
18 |
+
|
19 |
+
class LLMKnowledgeGraph:
|
20 |
+
def __init__(self, model: str = "gemini-2.0-flash", model_provider: str = "google_genai"):
|
21 |
+
"""Initialize the LLM for knowledge graph generation."""
|
22 |
+
self.llm = init_chat_model(
|
23 |
+
model=model,
|
24 |
+
model_provider=model_provider,
|
25 |
+
temperature=0.1, # Lower temperature for more deterministic results
|
26 |
+
max_tokens=2000
|
27 |
+
)
|
28 |
+
self.entity_prompt = """
|
29 |
+
Extract all named entities from the following text and categorize them into the following types:
|
30 |
+
- PERSON: People, including fictional
|
31 |
+
- ORG: Companies, agencies, institutions, etc.
|
32 |
+
- GPE: Countries, cities, states
|
33 |
+
- DATE: Absolute or relative dates or periods
|
34 |
+
- MONEY: Monetary values
|
35 |
+
- PERCENT: Percentage values
|
36 |
+
- QUANTITY: Measurements, weights, distances
|
37 |
+
- EVENT: Named hurricanes, battles, wars, sports events, etc.
|
38 |
+
- WORK_OF_ART: Titles of books, songs, etc.
|
39 |
+
- LAW: Legal document titles
|
40 |
+
- LANGUAGE: Any named language
|
41 |
+
|
42 |
+
Return the entities in JSON format with the following structure:
|
43 |
+
[
|
44 |
+
{"text": "entity text", "label": "ENTITY_TYPE", "start": character_start, "end": character_end}
|
45 |
+
]
|
46 |
+
|
47 |
+
Text: """
|
48 |
+
|
49 |
+
self.relation_prompt = """
|
50 |
+
Analyze the following text and extract relationships between entities in the form of subject-relation-object triples.
|
51 |
+
For each relation, provide:
|
52 |
+
- The subject (entity that is the source of the relation)
|
53 |
+
- The relation type (e.g., 'works at', 'located in', 'part of')
|
54 |
+
- The object (entity that is the target of the relation)
|
55 |
+
|
56 |
+
Return the relations in JSON format with the following structure:
|
57 |
+
[
|
58 |
+
{"subject": "subject text", "relation": "relation type", "object": "object text"}
|
59 |
+
]
|
60 |
+
|
61 |
+
Text: """
|
62 |
+
|
63 |
+
def extract_entities_with_llm(self, text: str) -> List[Dict[str, Any]]:
|
64 |
+
"""Extract entities from text using LLM."""
|
65 |
+
try:
|
66 |
+
response = self.llm.invoke([HumanMessage(content=self.entity_prompt + text)])
|
67 |
+
# Handle case where response might be a string or a message object
|
68 |
+
if hasattr(response, 'content'):
|
69 |
+
content = response.content
|
70 |
+
else:
|
71 |
+
content = str(response)
|
72 |
+
|
73 |
+
# Clean the response to ensure it's valid JSON
|
74 |
+
content = content.strip()
|
75 |
+
if content.startswith('```json'):
|
76 |
+
content = content[content.find('['):content.rfind(']')+1]
|
77 |
+
elif content.startswith('['):
|
78 |
+
content = content[:content.rfind(']')+1]
|
79 |
+
|
80 |
+
entities = json.loads(content)
|
81 |
+
return entities
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Error extracting entities with LLM: {str(e)}")
|
84 |
+
print(f"Response content: {getattr(response, 'content', str(response))}")
|
85 |
+
return []
|
86 |
+
|
87 |
+
def extract_relations_with_llm(self, text: str) -> List[Dict[str, str]]:
|
88 |
+
"""Extract relations between entities using LLM."""
|
89 |
+
try:
|
90 |
+
response = self.llm.invoke([HumanMessage(content=self.relation_prompt + text)])
|
91 |
+
# Handle case where response might be a string or a message object
|
92 |
+
if hasattr(response, 'content'):
|
93 |
+
content = response.content
|
94 |
+
else:
|
95 |
+
content = str(response)
|
96 |
+
|
97 |
+
# Clean the response to ensure it's valid JSON
|
98 |
+
content = content.strip()
|
99 |
+
if content.startswith('```json'):
|
100 |
+
content = content[content.find('['):content.rfind(']')+1]
|
101 |
+
elif content.startswith('['):
|
102 |
+
content = content[:content.rfind(']')+1]
|
103 |
+
|
104 |
+
relations = json.loads(content)
|
105 |
+
return relations
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error extracting relations with LLM: {str(e)}")
|
108 |
+
print(f"Response content: {getattr(response, 'content', str(response))}")
|
109 |
+
return []
|
110 |
+
|
111 |
+
def extract_relations(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]:
|
112 |
+
"""
|
113 |
+
Extract entities and their relations from text to build a knowledge graph.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
text: Input text to process
|
117 |
+
model_name: Name of the model to use (spaCy model or LLM)
|
118 |
+
use_llm: Whether to use LLM for relation extraction (default: True)
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Dictionary containing nodes and edges for the knowledge graph
|
122 |
+
"""
|
123 |
+
if use_llm:
|
124 |
+
# Use LLM for both entity and relation extraction
|
125 |
+
kg_extractor = LLMKnowledgeGraph(model=model_name)
|
126 |
+
|
127 |
+
# Extract entities using LLM
|
128 |
+
entities = kg_extractor.extract_entities_with_llm(text)
|
129 |
+
|
130 |
+
# Extract relations using LLM
|
131 |
+
relations = kg_extractor.extract_relations_with_llm(text)
|
132 |
+
else:
|
133 |
+
# Fallback to spaCy for entity and relation extraction
|
134 |
+
try:
|
135 |
+
nlp = spacy.load(model_name)
|
136 |
+
except OSError:
|
137 |
+
# If model is not found, download it
|
138 |
+
import subprocess
|
139 |
+
import sys
|
140 |
+
subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name])
|
141 |
+
nlp = spacy.load(model_name)
|
142 |
+
|
143 |
+
# Process the text
|
144 |
+
doc = nlp(text)
|
145 |
+
|
146 |
+
# Extract entities
|
147 |
+
entities = [{"text": ent.text, "label": ent.label_, "start": ent.start_char, "end": ent.end_char}
|
148 |
+
for ent in doc.ents]
|
149 |
+
|
150 |
+
# Extract relations (subject-verb-object)
|
151 |
+
relations = []
|
152 |
+
for sent in doc.sents:
|
153 |
+
for token in sent:
|
154 |
+
if token.dep_ in ("ROOT", "nsubj", "dobj"):
|
155 |
+
subj = ""
|
156 |
+
obj = ""
|
157 |
+
relation = ""
|
158 |
+
|
159 |
+
# Find subject
|
160 |
+
if token.dep_ == "nsubj" and token.head.pos_ == "VERB":
|
161 |
+
subj = token.text
|
162 |
+
relation = token.head.lemma_
|
163 |
+
# Find object
|
164 |
+
for child in token.head.children:
|
165 |
+
if child.dep_ == "dobj":
|
166 |
+
obj = child.text
|
167 |
+
break
|
168 |
+
|
169 |
+
if subj and obj and relation:
|
170 |
+
relations.append({
|
171 |
+
"subject": subj,
|
172 |
+
"relation": relation,
|
173 |
+
"object": obj
|
174 |
+
})
|
175 |
+
|
176 |
+
return {
|
177 |
+
"entities": entities,
|
178 |
+
"relations": relations
|
179 |
+
}
|
180 |
+
|
181 |
+
def build_nx_graph(entities: List[Dict], relations: List[Dict]) -> nx.DiGraph:
|
182 |
+
"""Build a NetworkX DiGraph from entities and relations. Ensure all nodes have a 'label'."""
|
183 |
+
G = nx.DiGraph()
|
184 |
+
# Add entities as nodes
|
185 |
+
for entity in entities:
|
186 |
+
label = entity.get("label") or entity.get("type") or "ENTITY"
|
187 |
+
text = entity.get("text") or entity.get("word")
|
188 |
+
G.add_node(text, label=label, type="entity")
|
189 |
+
# Add edges and ensure nodes exist with label
|
190 |
+
for rel in relations:
|
191 |
+
subj = rel.get("subject")
|
192 |
+
obj = rel.get("object")
|
193 |
+
rel_label = rel.get("relation", "related_to")
|
194 |
+
if subj is not None and subj not in G:
|
195 |
+
G.add_node(subj, label="ENTITY", type="entity")
|
196 |
+
if obj is not None and obj not in G:
|
197 |
+
G.add_node(obj, label="ENTITY", type="entity")
|
198 |
+
G.add_edge(subj, obj, label=rel_label)
|
199 |
+
return G
|
200 |
+
|
201 |
+
def visualize_knowledge_graph(entities: List[Dict], relations: List[Dict]) -> str:
|
202 |
+
"""
|
203 |
+
Generate a static PNG visualization of the knowledge graph, returned as base64 string for HTML embedding.
|
204 |
+
"""
|
205 |
+
G = build_nx_graph(entities, relations)
|
206 |
+
plt.figure(figsize=(12, 8))
|
207 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
208 |
+
# Color nodes by entity type
|
209 |
+
entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)]))
|
210 |
+
color_map = {etype: plt.cm.tab20(i % 20) for i, etype in enumerate(entity_types)}
|
211 |
+
node_colors = [color_map[d.get('label', 'ENTITY')] for n, d in G.nodes(data=True)]
|
212 |
+
nx.draw_networkx_nodes(G, pos, node_size=2000, node_color=node_colors, alpha=0.8)
|
213 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=20)
|
214 |
+
nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')
|
215 |
+
edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)}
|
216 |
+
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
|
217 |
+
buf = BytesIO()
|
218 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
219 |
+
plt.close()
|
220 |
+
img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
|
221 |
+
return f"data:image/png;base64,{img_str}"
|
222 |
+
|
223 |
+
def visualize_knowledge_graph_interactive(entities: List[Dict], relations: List[Dict]) -> str:
|
224 |
+
"""
|
225 |
+
Generate an interactive HTML visualization of the knowledge graph using pyvis.
|
226 |
+
Returns HTML as a string for embedding in Gradio or web UI.
|
227 |
+
"""
|
228 |
+
G = build_nx_graph(entities, relations)
|
229 |
+
net = Network(height="600px", width="100%", directed=True, notebook=False)
|
230 |
+
# Color map for entity types
|
231 |
+
entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)]))
|
232 |
+
color_palette = ["#e3f2fd", "#e8f5e9", "#fff8e1", "#f3e5f5", "#e8eaf6", "#e0f7fa", "#f1f8e9", "#fce4ec", "#e8f5e9", "#f5f5f5", "#fafafa", "#e1f5fe", "#fff3e0", "#d7ccc8", "#f9fbe7", "#fbe9e7", "#ede7f6", "#e0f2f1"]
|
233 |
+
color_map = {etype: color_palette[i % len(color_palette)] for i, etype in enumerate(entity_types)}
|
234 |
+
for n, d in G.nodes(data=True):
|
235 |
+
label = d.get('label', 'ENTITY')
|
236 |
+
net.add_node(n, label=n, title=f"{n}<br>Type: {label}", color=color_map[label])
|
237 |
+
for u, v, d in G.edges(data=True):
|
238 |
+
net.add_edge(u, v, label=d['label'], title=d['label'])
|
239 |
+
net.set_options('''var options = { "edges": { "arrows": {"to": {"enabled": true}}, "color": {"color": "#888"} }, "nodes": { "font": {"size": 18} }, "physics": { "enabled": true } };''')
|
240 |
+
html_buf = BytesIO()
|
241 |
+
net.write_html(html_buf)
|
242 |
+
html_buf.seek(0)
|
243 |
+
html = html_buf.read().decode('utf-8')
|
244 |
+
# Remove <html>, <body> wrappers to allow embedding in Gradio
|
245 |
+
body_start = html.find('<body>') + len('<body>')
|
246 |
+
body_end = html.find('</body>')
|
247 |
+
body_content = html[body_start:body_end]
|
248 |
+
return body_content
|
249 |
+
|
250 |
+
def build_knowledge_graph(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]:
|
251 |
+
"""
|
252 |
+
Main function to build a knowledge graph from text.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
text: Input text to process
|
256 |
+
model_name: Name of the model to use (spaCy model or LLM)
|
257 |
+
use_llm: Whether to use LLM for relation extraction (default: True)
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
Dictionary containing the knowledge graph data and visualization
|
261 |
+
"""
|
262 |
+
# Extract entities and relations
|
263 |
+
result = extract_relations(text, model_name, use_llm)
|
264 |
+
|
265 |
+
# Generate visualization
|
266 |
+
if result.get("entities") and result.get("relations"):
|
267 |
+
visualization = visualize_knowledge_graph(result["entities"], result["relations"])
|
268 |
+
result["visualization"] = visualization
|
269 |
+
else:
|
270 |
+
result["visualization"] = None
|
271 |
+
|
272 |
+
return result
|
tasks/ner.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union, Optional
|
2 |
+
from llms import LLM
|
3 |
+
from dataclasses import dataclass, asdict
|
4 |
+
import json
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Entity:
|
8 |
+
text: str
|
9 |
+
type: str
|
10 |
+
start: int
|
11 |
+
end: int
|
12 |
+
confidence: Optional[float] = None
|
13 |
+
description: Optional[str] = None
|
14 |
+
|
15 |
+
def named_entity_recognition(
|
16 |
+
text: str,
|
17 |
+
model: str = "gemini-2.0-flash",
|
18 |
+
use_llm: bool = True,
|
19 |
+
entity_types: Optional[List[str]] = None
|
20 |
+
) -> Union[str, List[Dict]]:
|
21 |
+
"""
|
22 |
+
Perform named entity recognition using either LLM or traditional NER models.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text: Input text to analyze
|
26 |
+
model: Model to use for NER
|
27 |
+
use_llm: Whether to use LLM for more accurate but slower NER
|
28 |
+
entity_types: List of entity types to extract (only used with LLM)
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
List of entities with their types and positions
|
32 |
+
"""
|
33 |
+
if not text.strip():
|
34 |
+
return []
|
35 |
+
|
36 |
+
if use_llm:
|
37 |
+
return _ner_with_llm(text, model, entity_types)
|
38 |
+
else:
|
39 |
+
return _ner_with_traditional(text, model)
|
40 |
+
|
41 |
+
def _ner_with_llm(
|
42 |
+
text: str,
|
43 |
+
model_name: str,
|
44 |
+
entity_types: Optional[List[str]] = None
|
45 |
+
) -> List[Dict]:
|
46 |
+
"""Use LLM for more accurate and flexible NER."""
|
47 |
+
# Default entity types if none provided
|
48 |
+
if entity_types is None:
|
49 |
+
entity_types = [
|
50 |
+
"PERSON", "ORG", "GPE", "LOC", "PRODUCT", "EVENT",
|
51 |
+
"WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME",
|
52 |
+
"PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"
|
53 |
+
]
|
54 |
+
|
55 |
+
# Create the prompt
|
56 |
+
entity_types_str = ", ".join(entity_types)
|
57 |
+
prompt = f"""
|
58 |
+
Extract named entities from the following text and categorize them into these types: {entity_types_str}.
|
59 |
+
For each entity, provide:
|
60 |
+
- The entity text
|
61 |
+
- The entity type (from the list above)
|
62 |
+
- The start and end character positions
|
63 |
+
- (Optional) A brief description of the entity
|
64 |
+
- (Optional) Confidence score (0-1)
|
65 |
+
|
66 |
+
Return the entities as a JSON array of objects with these fields:
|
67 |
+
- text: The entity text
|
68 |
+
- type: The entity type
|
69 |
+
- start: Start character position
|
70 |
+
- end: End character position
|
71 |
+
- description: (Optional) Brief description
|
72 |
+
- confidence: (Optional) Confidence score (0-1)
|
73 |
+
|
74 |
+
Text: """ + text + """
|
75 |
+
|
76 |
+
JSON response (only the array, no other text):
|
77 |
+
["""
|
78 |
+
|
79 |
+
try:
|
80 |
+
# Initialize LLM
|
81 |
+
llm = LLM(model=model_name, temperature=0.1)
|
82 |
+
|
83 |
+
# Get response from LLM
|
84 |
+
response = llm.generate(prompt)
|
85 |
+
|
86 |
+
# Clean and parse the response
|
87 |
+
response = response.strip()
|
88 |
+
if response.startswith('```json'):
|
89 |
+
response = response[response.find('['):response.rfind(']')+1]
|
90 |
+
elif response.startswith('['):
|
91 |
+
response = response[:response.rfind(']')+1]
|
92 |
+
|
93 |
+
entities = json.loads(response)
|
94 |
+
|
95 |
+
# Convert to Entity objects and validate
|
96 |
+
valid_entities = []
|
97 |
+
for ent in entities:
|
98 |
+
try:
|
99 |
+
entity = Entity(
|
100 |
+
text=ent['text'],
|
101 |
+
type=ent['type'],
|
102 |
+
start=int(ent['start']),
|
103 |
+
end=int(ent['end']),
|
104 |
+
confidence=ent.get('confidence'),
|
105 |
+
description=ent.get('description')
|
106 |
+
)
|
107 |
+
valid_entities.append(asdict(entity))
|
108 |
+
except (KeyError, ValueError) as e:
|
109 |
+
print(f"Error parsing entity: {e}")
|
110 |
+
continue
|
111 |
+
|
112 |
+
return valid_entities
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error in LLM NER: {str(e)}")
|
116 |
+
# Fall back to traditional NER if LLM fails
|
117 |
+
return _ner_with_traditional(text, "en_core_web_md")
|
118 |
+
|
119 |
+
def _ner_with_traditional(text: str, model: str) -> List[Dict]:
|
120 |
+
"""Fallback to traditional NER models."""
|
121 |
+
try:
|
122 |
+
import spacy
|
123 |
+
|
124 |
+
# Load the appropriate model
|
125 |
+
if model == "en_core_web_sm" or model == "en_core_web_md" or model == "en_core_web_lg":
|
126 |
+
nlp = spacy.load(model)
|
127 |
+
else:
|
128 |
+
nlp = spacy.load("en_core_web_md")
|
129 |
+
|
130 |
+
# Process the text
|
131 |
+
doc = nlp(text)
|
132 |
+
|
133 |
+
# Convert to our entity format
|
134 |
+
entities = []
|
135 |
+
for ent in doc.ents:
|
136 |
+
entities.append({
|
137 |
+
'text': ent.text,
|
138 |
+
'type': ent.label_,
|
139 |
+
'start': ent.start_char,
|
140 |
+
'end': ent.end_char,
|
141 |
+
'confidence': 1.0 # Traditional NER doesn't provide confidence
|
142 |
+
})
|
143 |
+
|
144 |
+
return entities
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Error in traditional NER: {str(e)}")
|
148 |
+
return []
|
tasks/pos_tagging.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union, Optional
|
2 |
+
from llms import LLM
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
def pos_tagging(
|
7 |
+
text: str,
|
8 |
+
model: str = "en_core_web_sm",
|
9 |
+
use_llm: bool = False,
|
10 |
+
custom_instructions: str = ""
|
11 |
+
) -> Dict[str, List[Union[str, List[str]]]]:
|
12 |
+
"""
|
13 |
+
Perform Part-of-Speech tagging on the input text using either LLM or traditional models.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
text: The input text to tag
|
17 |
+
model: The model to use for tagging (e.g., 'en_core_web_sm', 'gpt-4', 'gemini-pro')
|
18 |
+
use_llm: Whether to use LLM for more accurate but slower POS tagging
|
19 |
+
custom_instructions: Custom instructions for LLM-based tagging
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
A dictionary containing 'tokens' and 'tags' lists
|
23 |
+
"""
|
24 |
+
if not text.strip():
|
25 |
+
return {"tokens": [], "tags": []}
|
26 |
+
|
27 |
+
if use_llm:
|
28 |
+
return _pos_tagging_with_llm(text, model, custom_instructions)
|
29 |
+
else:
|
30 |
+
return _pos_tagging_traditional(text, model)
|
31 |
+
|
32 |
+
def _extract_json_array(text: str) -> str:
|
33 |
+
"""Extract JSON array from text, handling various formats."""
|
34 |
+
import re
|
35 |
+
|
36 |
+
# Try to find JSON array pattern
|
37 |
+
json_match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL)
|
38 |
+
if json_match:
|
39 |
+
return json_match.group(0)
|
40 |
+
|
41 |
+
# If not found, try to find array between square brackets
|
42 |
+
start = text.find('[')
|
43 |
+
end = text.rfind(']')
|
44 |
+
if start >= 0 and end > start:
|
45 |
+
return text[start:end+1]
|
46 |
+
|
47 |
+
return text
|
48 |
+
|
49 |
+
def _pos_tagging_with_llm(
|
50 |
+
text: str,
|
51 |
+
model_name: str,
|
52 |
+
custom_instructions: str = ""
|
53 |
+
) -> Dict[str, List[str]]:
|
54 |
+
"""Use LLM for more accurate and flexible POS tagging."""
|
55 |
+
# Create the prompt with clear instructions
|
56 |
+
prompt = """Analyze the following text and provide Part-of-Speech (POS) tags for each token.
|
57 |
+
Return the result as a JSON array of objects with 'token' and 'tag' keys.
|
58 |
+
|
59 |
+
Use standard Universal Dependencies POS tags:
|
60 |
+
- ADJ: adjective
|
61 |
+
- ADP: adposition
|
62 |
+
- ADV: adverb
|
63 |
+
- AUX: auxiliary verb
|
64 |
+
- CONJ: coordinating conjunction
|
65 |
+
- DET: determiner
|
66 |
+
- INTJ: interjection
|
67 |
+
- NOUN: noun
|
68 |
+
- NUM: numeral
|
69 |
+
- PART: particle
|
70 |
+
- PRON: pronoun
|
71 |
+
- PROPN: proper noun
|
72 |
+
- PUNCT: punctuation
|
73 |
+
- SCONJ: subordinating conjunction
|
74 |
+
- SYM: symbol
|
75 |
+
- VERB: verb
|
76 |
+
- X: other
|
77 |
+
|
78 |
+
Example output format:
|
79 |
+
[
|
80 |
+
{"token": "Hello", "tag": "INTJ"},
|
81 |
+
{"token": "world", "tag": "NOUN"},
|
82 |
+
{"token": ".", "tag": "PUNCT"}
|
83 |
+
]
|
84 |
+
|
85 |
+
Text to analyze:
|
86 |
+
"""
|
87 |
+
|
88 |
+
if custom_instructions:
|
89 |
+
prompt = f"{custom_instructions}\n\n{prompt}"
|
90 |
+
|
91 |
+
prompt += f'"{text}"'
|
92 |
+
|
93 |
+
try:
|
94 |
+
# Initialize LLM with lower temperature for more deterministic output
|
95 |
+
llm = LLM(model=model_name, temperature=0.1, max_tokens=2000)
|
96 |
+
|
97 |
+
# Get response from LLM
|
98 |
+
response = llm.generate(prompt)
|
99 |
+
print(f"LLM Raw Response: {response[:500]}...") # Log first 500 chars
|
100 |
+
|
101 |
+
if not response.strip():
|
102 |
+
raise ValueError("Empty response from LLM")
|
103 |
+
|
104 |
+
# Extract JSON array from response
|
105 |
+
json_str = _extract_json_array(response)
|
106 |
+
if not json_str:
|
107 |
+
raise ValueError("No JSON array found in response")
|
108 |
+
|
109 |
+
# Parse the JSON
|
110 |
+
try:
|
111 |
+
pos_tags = json.loads(json_str)
|
112 |
+
except json.JSONDecodeError as e:
|
113 |
+
# Try to fix common JSON issues
|
114 |
+
json_str = json_str.replace("'", '"')
|
115 |
+
json_str = re.sub(r'(\w+):', r'"\1":', json_str) # Add quotes around keys
|
116 |
+
pos_tags = json.loads(json_str)
|
117 |
+
|
118 |
+
# Validate and extract tokens and tags
|
119 |
+
if not isinstance(pos_tags, list):
|
120 |
+
raise ValueError(f"Expected list, got {type(pos_tags).__name__}")
|
121 |
+
|
122 |
+
tokens = []
|
123 |
+
tags = []
|
124 |
+
|
125 |
+
for item in pos_tags:
|
126 |
+
if not isinstance(item, dict):
|
127 |
+
continue
|
128 |
+
|
129 |
+
token = item.get('token', '')
|
130 |
+
tag = item.get('tag', '')
|
131 |
+
|
132 |
+
if token and tag: # Only add if both token and tag are non-empty
|
133 |
+
tokens.append(str(token).strip())
|
134 |
+
tags.append(str(tag).strip())
|
135 |
+
|
136 |
+
if not tokens or not tags:
|
137 |
+
raise ValueError("No valid tokens and tags found in response")
|
138 |
+
|
139 |
+
return {
|
140 |
+
'tokens': tokens,
|
141 |
+
'tags': tags
|
142 |
+
}
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error in LLM POS tagging: {str(e)}")
|
146 |
+
print(f"Falling back to traditional POS tagging...")
|
147 |
+
return _pos_tagging_traditional(text, "en_core_web_sm")
|
148 |
+
|
149 |
+
def _pos_tagging_traditional(text: str, model: str) -> Dict[str, List[str]]:
|
150 |
+
"""Use traditional POS tagging models."""
|
151 |
+
try:
|
152 |
+
import spacy
|
153 |
+
|
154 |
+
# Load the appropriate model
|
155 |
+
try:
|
156 |
+
nlp = spacy.load(model)
|
157 |
+
except OSError:
|
158 |
+
# Fallback to small English model if specified model is not found
|
159 |
+
nlp = spacy.load("en_core_web_sm")
|
160 |
+
|
161 |
+
# Process the text
|
162 |
+
doc = nlp(text)
|
163 |
+
|
164 |
+
# Extract tokens and POS tags
|
165 |
+
tokens = []
|
166 |
+
tags = []
|
167 |
+
for token in doc:
|
168 |
+
tokens.append(token.text)
|
169 |
+
tags.append(token.pos_)
|
170 |
+
|
171 |
+
return {
|
172 |
+
'tokens': tokens,
|
173 |
+
'tags': tags
|
174 |
+
}
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
print(f"Error in traditional POS tagging: {str(e)}")
|
178 |
+
return {"tokens": [], "tags": []}
|
tasks/retrieval.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.remote_client import execute_remote_task
|
2 |
+
|
3 |
+
def text_retrieval(query: str, model: str, documents: list) -> str:
|
4 |
+
resp = execute_remote_task("retrieval", {"query": query, "model": model, "documents": documents})
|
5 |
+
if "error" in resp:
|
6 |
+
return "Oops! Something went wrong. Please try again later."
|
7 |
+
return resp.get("matches", "")
|
tasks/segmentation.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.remote_client import execute_remote_task
|
2 |
+
|
3 |
+
def text_segmentation(text: str, model: str = "") -> str:
|
4 |
+
resp = execute_remote_task("segmentation", {"text": text, "model": model})
|
5 |
+
if "error" in resp:
|
6 |
+
return "Oops! Something went wrong. Please try again later."
|
7 |
+
return resp.get("segments", "")
|
tasks/sentiment_analysis.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llms import LLM
|
2 |
+
from utils.remote_client import execute_remote_task
|
3 |
+
|
4 |
+
def sentiment_analysis(text: str, model: str, custom_instructions: str = "", use_llm: bool = True) -> str:
|
5 |
+
"""
|
6 |
+
Analyze sentiment of text using either LLM or traditional (Modal API) method.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
text: The text to analyze
|
10 |
+
model: The model to use
|
11 |
+
custom_instructions: Optional instructions for LLM
|
12 |
+
use_llm: Whether to use LLM or traditional method
|
13 |
+
"""
|
14 |
+
if not text.strip():
|
15 |
+
return ""
|
16 |
+
if use_llm:
|
17 |
+
return _sentiment_with_llm(text, model, custom_instructions)
|
18 |
+
else:
|
19 |
+
return _sentiment_with_traditional(text, model)
|
20 |
+
|
21 |
+
def _sentiment_with_llm(text: str, model: str, custom_instructions: str = "") -> str:
|
22 |
+
try:
|
23 |
+
llm = LLM(model=model)
|
24 |
+
prompt = (
|
25 |
+
f"Analyze the sentiment of the following text. Return ONLY one value: 'positive', 'negative', or 'neutral'.\n" +
|
26 |
+
(f"{custom_instructions}\n" if custom_instructions else "") +
|
27 |
+
f"Text: {text}\nSentiment:"
|
28 |
+
)
|
29 |
+
result = llm.generate(prompt)
|
30 |
+
return result.strip()
|
31 |
+
except Exception as e:
|
32 |
+
print(f"Error in LLM sentiment analysis: {str(e)}")
|
33 |
+
return "Oops! Something went wrong. Please try again later."
|
34 |
+
|
35 |
+
def _sentiment_with_traditional(text: str, model: str) -> str:
|
36 |
+
try:
|
37 |
+
payload = {
|
38 |
+
"text": text,
|
39 |
+
"model": model,
|
40 |
+
"task": "sentiment"
|
41 |
+
}
|
42 |
+
resp = execute_remote_task("classification", payload)
|
43 |
+
if "error" in resp:
|
44 |
+
return "Oops! Something went wrong. Please try again later."
|
45 |
+
return resp.get("labels", "")
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error in traditional sentiment analysis: {str(e)}")
|
48 |
+
return "Oops! Something went wrong. Please try again later."
|
tasks/summarization.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llms import LLM
|
2 |
+
from utils.remote_client import execute_remote_task
|
3 |
+
|
4 |
+
def text_summarization(text: str, model: str, summary_length: str, use_llm: bool = True) -> str:
|
5 |
+
"""
|
6 |
+
Summarize the input text using either LLM or traditional (Modal API) method.
|
7 |
+
"""
|
8 |
+
if not text.strip():
|
9 |
+
return ""
|
10 |
+
if use_llm:
|
11 |
+
return _summarization_with_llm(text, model, summary_length)
|
12 |
+
else:
|
13 |
+
return _summarization_with_traditional(text, model, summary_length)
|
14 |
+
|
15 |
+
def _summarization_with_llm(text: str, model: str, summary_length: str) -> str:
|
16 |
+
try:
|
17 |
+
llm = LLM(model=model)
|
18 |
+
prompt = (
|
19 |
+
f"Summarize the following text in {summary_length} detail. "
|
20 |
+
f"Text: {text}\nSummary:"
|
21 |
+
)
|
22 |
+
summary = llm.generate(prompt)
|
23 |
+
return summary.strip()
|
24 |
+
except Exception as e:
|
25 |
+
print(f"Error in LLM summarization: {str(e)}")
|
26 |
+
return "Oops! Something went wrong. Please try again later."
|
27 |
+
|
28 |
+
def _summarization_with_traditional(text: str, model: str, summary_length: str) -> str:
|
29 |
+
try:
|
30 |
+
payload = {
|
31 |
+
"text": text,
|
32 |
+
"model": model,
|
33 |
+
"summary_length": summary_length,
|
34 |
+
}
|
35 |
+
resp = execute_remote_task("summarization", payload)
|
36 |
+
if "error" in resp:
|
37 |
+
return "Oops! Something went wrong. Please try again later."
|
38 |
+
return resp.get("summary", "")
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Error in traditional summarization: {str(e)}")
|
41 |
+
return "Oops! Something went wrong. Please try again later."
|
tasks/topic_classification.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llms import LLM
|
2 |
+
from utils.remote_client import execute_remote_task
|
3 |
+
|
4 |
+
def topic_classification(text: str, model: str, candidate_labels=None, custom_instructions: str = "", use_llm: bool = True) -> str:
|
5 |
+
"""
|
6 |
+
Classify text into topics using either LLM or traditional (Modal API) method.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
text: The text to classify
|
10 |
+
model: The model to use
|
11 |
+
candidate_labels: List of candidate topics/categories
|
12 |
+
custom_instructions: Optional instructions for LLM
|
13 |
+
use_llm: Whether to use LLM or traditional method
|
14 |
+
"""
|
15 |
+
if not text.strip():
|
16 |
+
return ""
|
17 |
+
if use_llm:
|
18 |
+
return _topic_classification_with_llm(text, model, candidate_labels, custom_instructions)
|
19 |
+
else:
|
20 |
+
return _topic_classification_with_traditional(text, model, candidate_labels)
|
21 |
+
|
22 |
+
def _topic_classification_with_llm(text: str, model: str, candidate_labels=None, custom_instructions: str = "") -> str:
|
23 |
+
try:
|
24 |
+
llm = LLM(model=model)
|
25 |
+
labels_str = ", ".join(candidate_labels) if candidate_labels else "any appropriate topic"
|
26 |
+
prompt = (
|
27 |
+
f"Classify the following text into ONE of these categories: {labels_str}.\n" +
|
28 |
+
f"Return ONLY the most appropriate category name.\n" +
|
29 |
+
(f"{custom_instructions}\n" if custom_instructions else "") +
|
30 |
+
f"Text: {text}\nCategory:"
|
31 |
+
)
|
32 |
+
result = llm.generate(prompt)
|
33 |
+
return result.strip()
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error in LLM topic classification: {str(e)}")
|
36 |
+
return "Oops! Something went wrong. Please try again later."
|
37 |
+
|
38 |
+
def _topic_classification_with_traditional(text: str, model: str, labels=None) -> str:
|
39 |
+
try:
|
40 |
+
payload = {
|
41 |
+
"text": text,
|
42 |
+
"model": model,
|
43 |
+
"task": "topic"
|
44 |
+
}
|
45 |
+
if labels is not None:
|
46 |
+
payload["labels"] = labels
|
47 |
+
resp = execute_remote_task("classification", payload)
|
48 |
+
if "error" in resp:
|
49 |
+
return "Oops! Something went wrong. Please try again later."
|
50 |
+
return resp.get("labels", "")
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error in traditional topic classification: {str(e)}")
|
53 |
+
return "Oops! Something went wrong. Please try again later."
|
tasks/translation.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llms import LLM
|
2 |
+
from utils.remote_client import execute_remote_task
|
3 |
+
|
4 |
+
def text_translation(text: str, model: str, src_lang: str, tgt_lang: str, custom_instructions: str = "", use_llm: bool = True) -> str:
|
5 |
+
"""
|
6 |
+
Translate the input text using either LLM or traditional (Modal API) method.
|
7 |
+
"""
|
8 |
+
if not text.strip():
|
9 |
+
return ""
|
10 |
+
if use_llm:
|
11 |
+
return _translation_with_llm(text, model, src_lang, tgt_lang, custom_instructions)
|
12 |
+
else:
|
13 |
+
return _translation_with_traditional(text, model, src_lang, tgt_lang)
|
14 |
+
|
15 |
+
def _translation_with_llm(text: str, model: str, src_lang: str, tgt_lang: str, custom_instructions: str = "") -> str:
|
16 |
+
try:
|
17 |
+
llm = LLM(model=model)
|
18 |
+
prompt = (
|
19 |
+
f"Translate the following text from {src_lang} to {tgt_lang}.\n" +
|
20 |
+
(f"{custom_instructions}\n" if custom_instructions else "") +
|
21 |
+
f"Text: {text}\nTranslation:"
|
22 |
+
)
|
23 |
+
translation = llm.generate(prompt)
|
24 |
+
return translation.strip()
|
25 |
+
except Exception as e:
|
26 |
+
print(f"Error in LLM translation: {str(e)}")
|
27 |
+
return "Oops! Something went wrong. Please try again later."
|
28 |
+
|
29 |
+
def _translation_with_traditional(text: str, model: str, src_lang: str, tgt_lang: str) -> str:
|
30 |
+
try:
|
31 |
+
payload = {
|
32 |
+
"text": text,
|
33 |
+
"model": model,
|
34 |
+
"src_lang": src_lang,
|
35 |
+
"tgt_lang": tgt_lang,
|
36 |
+
}
|
37 |
+
resp = execute_remote_task("translation", payload)
|
38 |
+
if "error" in resp:
|
39 |
+
return "Oops! Something went wrong. Please try again later."
|
40 |
+
return resp.get("translation", "")
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Error in traditional translation: {str(e)}")
|
43 |
+
return "Oops! Something went wrong. Please try again later."
|
ui/grammar_ui.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tasks.grammar_checking import grammar_checking
|
3 |
+
|
4 |
+
GRAMMAR_MODELS = ["gemini-2.0-flash"]
|
5 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
6 |
+
|
7 |
+
def grammar_ui():
|
8 |
+
with gr.Row():
|
9 |
+
with gr.Column(scale=1):
|
10 |
+
input_text = gr.Textbox(
|
11 |
+
label="Input Text",
|
12 |
+
lines=6,
|
13 |
+
placeholder="Enter text to check grammar and spelling...",
|
14 |
+
elem_id="grammar-input-text"
|
15 |
+
)
|
16 |
+
gr.Examples(
|
17 |
+
examples=[
|
18 |
+
["This is a smple sentence with errrors."],
|
19 |
+
["I has went to the store yesterday."]
|
20 |
+
],
|
21 |
+
inputs=[input_text],
|
22 |
+
label="Examples"
|
23 |
+
)
|
24 |
+
check_btn = gr.Button("Check Grammar & Spelling", variant="primary")
|
25 |
+
model_dropdown = gr.Dropdown(
|
26 |
+
GRAMMAR_MODELS,
|
27 |
+
value=DEFAULT_MODEL,
|
28 |
+
label="Model",
|
29 |
+
interactive=True,
|
30 |
+
elem_id="grammar-model-dropdown"
|
31 |
+
)
|
32 |
+
custom_instructions = gr.Textbox(
|
33 |
+
label="Custom Instructions (optional)",
|
34 |
+
lines=2,
|
35 |
+
placeholder="Add any custom instructions for the model...",
|
36 |
+
elem_id="grammar-custom-instructions"
|
37 |
+
)
|
38 |
+
with gr.Column(scale=1):
|
39 |
+
output_box = gr.Textbox(
|
40 |
+
label="Corrected Text",
|
41 |
+
lines=3,
|
42 |
+
interactive=False,
|
43 |
+
elem_id="grammar-output"
|
44 |
+
)
|
45 |
+
def run_grammar_checking(text, model, custom_instructions):
|
46 |
+
return grammar_checking(
|
47 |
+
text=text,
|
48 |
+
model=model,
|
49 |
+
custom_instructions=custom_instructions,
|
50 |
+
use_llm=True
|
51 |
+
)
|
52 |
+
check_btn.click(
|
53 |
+
run_grammar_checking,
|
54 |
+
inputs=[input_text, model_dropdown, custom_instructions],
|
55 |
+
outputs=output_box
|
56 |
+
)
|
ui/intent_ui.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tasks.intent_detection import intent_detection
|
3 |
+
|
4 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
5 |
+
INTENT_MODELS = [DEFAULT_MODEL]
|
6 |
+
DEFAULT_INTENTS = ["book_flight", "order_food", "check_weather", "greeting", "goodbye"]
|
7 |
+
|
8 |
+
def intent_ui():
|
9 |
+
with gr.Row():
|
10 |
+
# Left column: input/config
|
11 |
+
with gr.Column(scale=1):
|
12 |
+
input_text = gr.Textbox(
|
13 |
+
label="Input Text",
|
14 |
+
lines=6,
|
15 |
+
placeholder="Enter text to detect intent...",
|
16 |
+
elem_id="intent-input-text"
|
17 |
+
)
|
18 |
+
gr.Examples(
|
19 |
+
examples=[
|
20 |
+
["I want to book a flight to Paris next week."],
|
21 |
+
["Can you tell me what the weather is like in Hanoi?"]
|
22 |
+
],
|
23 |
+
inputs=[input_text],
|
24 |
+
label="Examples"
|
25 |
+
)
|
26 |
+
use_custom_intents = gr.Checkbox(
|
27 |
+
label="Use custom intents",
|
28 |
+
value=True,
|
29 |
+
elem_id="intent-use-custom-intents"
|
30 |
+
)
|
31 |
+
intents_area = gr.TextArea(
|
32 |
+
label="Candidate Intents (one per line)",
|
33 |
+
value='\n'.join(DEFAULT_INTENTS),
|
34 |
+
lines=5,
|
35 |
+
visible=True,
|
36 |
+
elem_id="intent-candidate-intents"
|
37 |
+
)
|
38 |
+
def toggle_intent_area(use_custom):
|
39 |
+
return gr.update(visible=use_custom)
|
40 |
+
use_custom_intents.change(toggle_intent_area, inputs=use_custom_intents, outputs=intents_area)
|
41 |
+
detect_btn = gr.Button("Detect Intent", variant="primary")
|
42 |
+
model_dropdown = gr.Dropdown(
|
43 |
+
INTENT_MODELS,
|
44 |
+
value=DEFAULT_MODEL,
|
45 |
+
label="Model",
|
46 |
+
interactive=True,
|
47 |
+
elem_id="intent-model-dropdown"
|
48 |
+
)
|
49 |
+
custom_instructions = gr.Textbox(
|
50 |
+
label="Custom Instructions (optional)",
|
51 |
+
lines=2,
|
52 |
+
placeholder="Add any custom instructions for the model...",
|
53 |
+
elem_id="intent-custom-instructions"
|
54 |
+
)
|
55 |
+
# Right column: output/result
|
56 |
+
with gr.Column(scale=1):
|
57 |
+
output_box = gr.Textbox(
|
58 |
+
label="Detected Intent",
|
59 |
+
lines=1,
|
60 |
+
interactive=False,
|
61 |
+
elem_id="intent-output"
|
62 |
+
)
|
63 |
+
# gr.Markdown("""
|
64 |
+
# **Instructions:**
|
65 |
+
# - Enter your text and (optionally) custom intent labels.
|
66 |
+
# - Use the checkbox to switch between custom intent list or LLM auto-detect mode.
|
67 |
+
# - Add any custom instructions for the LLM if needed.
|
68 |
+
# """)
|
69 |
+
# Logic for button
|
70 |
+
def run_intent_detection(text, model, use_custom, intents, custom_instructions):
|
71 |
+
candidate_intents = [s.strip() for s in intents.split("\n") if s.strip()] if use_custom else None
|
72 |
+
return intent_detection(
|
73 |
+
text=text,
|
74 |
+
model=model,
|
75 |
+
candidate_intents=candidate_intents,
|
76 |
+
custom_instructions=custom_instructions,
|
77 |
+
use_llm=True
|
78 |
+
)
|
79 |
+
detect_btn.click(
|
80 |
+
run_intent_detection,
|
81 |
+
inputs=[input_text, model_dropdown, use_custom_intents, intents_area, custom_instructions],
|
82 |
+
outputs=output_box
|
83 |
+
)
|
ui/kg_ui.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from utils.ner_helpers import is_llm_model
|
4 |
+
from typing import Dict, List, Any, Tuple
|
5 |
+
from tasks.knowledge_graph import build_knowledge_graph, visualize_knowledge_graph_interactive
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
def kg_ui():
|
10 |
+
"""Knowledge Graph UI component"""
|
11 |
+
|
12 |
+
# Define models
|
13 |
+
KG_MODELS = [
|
14 |
+
"gemini-2.0-flash",
|
15 |
+
"gpt-4",
|
16 |
+
"claude-2",
|
17 |
+
"en_core_web_sm",
|
18 |
+
"en_core_web_md"
|
19 |
+
]
|
20 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
21 |
+
|
22 |
+
def build_kg(text, model, custom_instructions, interactive=False):
|
23 |
+
"""Process text for knowledge graph generation"""
|
24 |
+
import gradio as gr
|
25 |
+
if not text.strip():
|
26 |
+
# Trả về các giá trị rỗng cho tất cả các tab
|
27 |
+
return (
|
28 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>No text provided</div>",
|
29 |
+
pd.DataFrame(),
|
30 |
+
pd.DataFrame(),
|
31 |
+
False, True, False, True, False, True
|
32 |
+
)
|
33 |
+
use_llm = is_llm_model(model)
|
34 |
+
result = build_knowledge_graph(
|
35 |
+
text=text,
|
36 |
+
model_name=model,
|
37 |
+
use_llm=use_llm
|
38 |
+
)
|
39 |
+
entities = result.get("entities", [])
|
40 |
+
relations = result.get("relations", [])
|
41 |
+
visualization = result.get("visualization")
|
42 |
+
# DataFrames
|
43 |
+
if entities:
|
44 |
+
entities_df = pd.DataFrame(entities)
|
45 |
+
entities_df = entities_df.rename(columns={
|
46 |
+
"text": "Entity",
|
47 |
+
"label": "Type",
|
48 |
+
"start": "Start Position",
|
49 |
+
"end": "End Position"
|
50 |
+
})
|
51 |
+
else:
|
52 |
+
entities_df = pd.DataFrame()
|
53 |
+
if relations:
|
54 |
+
relations_df = pd.DataFrame(relations)
|
55 |
+
relations_df = relations_df.rename(columns={
|
56 |
+
"subject": "Subject",
|
57 |
+
"relation": "Relation",
|
58 |
+
"object": "Object"
|
59 |
+
})
|
60 |
+
else:
|
61 |
+
relations_df = pd.DataFrame()
|
62 |
+
# Visualization
|
63 |
+
if interactive and entities and relations:
|
64 |
+
try:
|
65 |
+
interactive_html = visualize_knowledge_graph_interactive(entities, relations)
|
66 |
+
visualization_html = f"<div style='width:100%;overflow-x:auto'>{interactive_html}</div>"
|
67 |
+
viz_vis = True
|
68 |
+
no_viz_vis = False
|
69 |
+
except Exception as e:
|
70 |
+
visualization_html = f"<div style='color:#d32f2f;padding:20px;'>Error rendering interactive graph: {e}</div>"
|
71 |
+
viz_vis = True
|
72 |
+
no_viz_vis = False
|
73 |
+
elif visualization:
|
74 |
+
visualization_html = f"<img src='data:image/png;base64,{visualization}' style='max-width:100%;height:auto;'/>"
|
75 |
+
viz_vis = True
|
76 |
+
no_viz_vis = False
|
77 |
+
else:
|
78 |
+
visualization_html = ""
|
79 |
+
viz_vis = False
|
80 |
+
no_viz_vis = True
|
81 |
+
# Visibility flags
|
82 |
+
entities_vis = not entities_df.empty
|
83 |
+
no_entities_vis = not entities_vis
|
84 |
+
relations_vis = not relations_df.empty
|
85 |
+
no_relations_vis = not relations_vis
|
86 |
+
# Return
|
87 |
+
return (
|
88 |
+
visualization_html,
|
89 |
+
entities_df,
|
90 |
+
relations_df,
|
91 |
+
viz_vis,
|
92 |
+
no_viz_vis,
|
93 |
+
entities_vis,
|
94 |
+
no_entities_vis,
|
95 |
+
relations_vis,
|
96 |
+
no_relations_vis
|
97 |
+
)
|
98 |
+
|
99 |
+
# UI Components
|
100 |
+
with gr.Row():
|
101 |
+
with gr.Column(scale=2):
|
102 |
+
input_text = gr.Textbox(
|
103 |
+
label="Input Text",
|
104 |
+
lines=8,
|
105 |
+
placeholder="Enter text to extract knowledge graph...",
|
106 |
+
elem_id="kg-input-text"
|
107 |
+
)
|
108 |
+
gr.Examples(
|
109 |
+
examples=[
|
110 |
+
["Elon Musk founded SpaceX and Tesla in the United States."],
|
111 |
+
["Amazon acquired Whole Foods in 2017."]
|
112 |
+
],
|
113 |
+
inputs=[input_text],
|
114 |
+
label="Examples"
|
115 |
+
)
|
116 |
+
# Model selection
|
117 |
+
model = gr.Dropdown(
|
118 |
+
KG_MODELS,
|
119 |
+
value=DEFAULT_MODEL,
|
120 |
+
label="Model",
|
121 |
+
interactive=True,
|
122 |
+
elem_id="kg-model-dropdown"
|
123 |
+
)
|
124 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="kg-advanced-options"):
|
125 |
+
custom_instructions = gr.Textbox(
|
126 |
+
label="Custom Instructions",
|
127 |
+
lines=2,
|
128 |
+
placeholder="(Optional) Add specific instructions for knowledge graph generation...",
|
129 |
+
elem_id="kg-custom-instructions"
|
130 |
+
)
|
131 |
+
btn = gr.Button("Generate Knowledge Graph", elem_id="kg-btn")
|
132 |
+
|
133 |
+
with gr.Column(scale=3):
|
134 |
+
# Results container with tabs
|
135 |
+
with gr.Tabs() as output_tabs:
|
136 |
+
with gr.Tab("Graph Visualization", id="kg-viz-tab"):
|
137 |
+
no_viz_html = gr.HTML(
|
138 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
139 |
+
"Generate a knowledge graph to visualize relationships.</div>",
|
140 |
+
visible=True,
|
141 |
+
elem_id="kg-no-viz"
|
142 |
+
)
|
143 |
+
viz_html = gr.HTML(
|
144 |
+
label="Knowledge Graph Visualization",
|
145 |
+
visible=False,
|
146 |
+
elem_id="kg-viz-html"
|
147 |
+
)
|
148 |
+
|
149 |
+
with gr.Tab("Entities", id="kg-entities-tab"):
|
150 |
+
no_entities_html = gr.HTML(
|
151 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
152 |
+
"No entities found. Try generating a knowledge graph first.</div>",
|
153 |
+
visible=True,
|
154 |
+
elem_id="kg-no-entities"
|
155 |
+
)
|
156 |
+
entities_table = gr.DataFrame(
|
157 |
+
headers=["Entity", "Type", "Start Position", "End Position"],
|
158 |
+
datatype=["str", "str", "number", "number"],
|
159 |
+
visible=False,
|
160 |
+
elem_id="kg-entities-table"
|
161 |
+
)
|
162 |
+
|
163 |
+
with gr.Tab("Relationships", id="kg-relations-tab"):
|
164 |
+
no_relations_html = gr.HTML(
|
165 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
166 |
+
"No relationships found. Try generating a knowledge graph first.</div>",
|
167 |
+
visible=True,
|
168 |
+
elem_id="kg-no-relations"
|
169 |
+
)
|
170 |
+
relations_table = gr.DataFrame(
|
171 |
+
headers=["Subject", "Relation", "Object"],
|
172 |
+
datatype=["str", "str", "str"],
|
173 |
+
visible=False,
|
174 |
+
elem_id="kg-relations-table"
|
175 |
+
)
|
176 |
+
|
177 |
+
with gr.Accordion("About Knowledge Graphs", open=False):
|
178 |
+
gr.Markdown("""
|
179 |
+
## Knowledge Graphs
|
180 |
+
|
181 |
+
Knowledge graphs represent relationships between entities in text as a network. This tool:
|
182 |
+
|
183 |
+
- **Extracts entities**: Identifies people, places, organizations, and concepts
|
184 |
+
- **Maps relationships**: Shows how entities are connected to each other
|
185 |
+
- **Visualizes connections**: Creates an interactive graph you can explore
|
186 |
+
|
187 |
+
### How it works
|
188 |
+
|
189 |
+
- **LLM models** can understand complex relationships in text
|
190 |
+
- **Traditional models** use pattern matching and syntactic parsing
|
191 |
+
|
192 |
+
Knowledge graphs are particularly useful for:
|
193 |
+
- Research and analysis
|
194 |
+
- Content exploration
|
195 |
+
- Understanding complex narratives
|
196 |
+
- Discovering hidden connections
|
197 |
+
|
198 |
+
Try it with news articles, scientific papers, or story excerpts to see different types of relationships.
|
199 |
+
""")
|
200 |
+
|
201 |
+
# Toggle for interactive/static visualization
|
202 |
+
with gr.Row():
|
203 |
+
interactive_toggle = gr.Checkbox(
|
204 |
+
label="Interactive Graph (pyvis)",
|
205 |
+
value=True,
|
206 |
+
elem_id="kg-interactive-toggle"
|
207 |
+
)
|
208 |
+
# Event handler: use build_kg for all outputs
|
209 |
+
def process_and_update_ui(text, model, custom_instructions, interactive):
|
210 |
+
return build_kg(text, model, custom_instructions, interactive)
|
211 |
+
|
212 |
+
# Wire button to unified handler
|
213 |
+
def gradio_output_adapter(visualization_html, entities_df, relations_df, viz_vis, no_viz_vis, entities_vis, no_entities_vis, relations_vis, no_relations_vis):
|
214 |
+
return [
|
215 |
+
gr.update(value=visualization_html, visible=viz_vis),
|
216 |
+
gr.update(value=entities_df, visible=entities_vis),
|
217 |
+
gr.update(value=relations_df, visible=relations_vis),
|
218 |
+
gr.update(visible=no_viz_vis),
|
219 |
+
gr.update(visible=no_entities_vis),
|
220 |
+
gr.update(visible=no_relations_vis),
|
221 |
+
]
|
222 |
+
|
223 |
+
btn.click(
|
224 |
+
fn=lambda text, model, custom_instructions, interactive: gradio_output_adapter(*build_kg(text, model, custom_instructions, interactive)),
|
225 |
+
inputs=[input_text, model, custom_instructions, interactive_toggle],
|
226 |
+
outputs=[
|
227 |
+
viz_html, entities_table, relations_table,
|
228 |
+
no_viz_html, no_entities_html, no_relations_html,
|
229 |
+
]
|
230 |
+
)
|
231 |
+
return None
|
ui/ner_ui.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import html as html_lib
|
7 |
+
from tasks.ner import named_entity_recognition
|
8 |
+
from utils.ner_helpers import NER_ENTITY_TYPES, DEFAULT_SELECTED_ENTITIES, is_llm_model
|
9 |
+
|
10 |
+
# The ner_ui function and related logic moved from app.py
|
11 |
+
|
12 |
+
def ner_ui():
|
13 |
+
# Default entity types for the multi-select
|
14 |
+
DEFAULT_ENTITY_TYPES = list(NER_ENTITY_TYPES.keys())
|
15 |
+
|
16 |
+
def ner(text: str, model: str, entity_types: List[str]) -> Dict[str, Any]:
|
17 |
+
"""Extract named entities, automatically using LLM for supported models."""
|
18 |
+
if not text.strip():
|
19 |
+
return {"text": "", "entities": []}
|
20 |
+
|
21 |
+
try:
|
22 |
+
use_llm = is_llm_model(model)
|
23 |
+
# Call the enhanced NER function
|
24 |
+
entities = named_entity_recognition(
|
25 |
+
text=text,
|
26 |
+
model=model,
|
27 |
+
use_llm=use_llm,
|
28 |
+
entity_types=entity_types if use_llm else None
|
29 |
+
)
|
30 |
+
|
31 |
+
# Convert to the format expected by the UI
|
32 |
+
if not isinstance(entities, list):
|
33 |
+
entities = []
|
34 |
+
|
35 |
+
if not use_llm and entity_types:
|
36 |
+
entities = [e for e in entities if e.get("type", "") in entity_types or e.get("entity", "") in entity_types]
|
37 |
+
|
38 |
+
return {
|
39 |
+
"entities": [
|
40 |
+
{
|
41 |
+
"entity": e.get("type", ""),
|
42 |
+
"word": e.get("text", ""),
|
43 |
+
"start": e.get("start", 0),
|
44 |
+
"end": e.get("end", 0),
|
45 |
+
"score": e.get("confidence", 1.0),
|
46 |
+
"description": e.get("description", "")
|
47 |
+
}
|
48 |
+
for e in entities
|
49 |
+
]
|
50 |
+
}
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error in NER: {str(e)}")
|
54 |
+
return {"entities": []}
|
55 |
+
|
56 |
+
def render_ner_html(text, entities):
|
57 |
+
# COMPLETELY REVISED APPROACH: Clean inline display of entities with proper positioning
|
58 |
+
if not text.strip() or not entities:
|
59 |
+
return "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>"
|
60 |
+
|
61 |
+
COLORS = [
|
62 |
+
'#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa',
|
63 |
+
'#f1f8e9', '#fce4ec', '#e8f5e9', '#f5f5f5', '#fafafa', '#e1f5fe',
|
64 |
+
'#fff3e0', '#d7ccc8', '#f9fbe7', '#fbe9e7', '#ede7f6', '#e0f2f1'
|
65 |
+
]
|
66 |
+
|
67 |
+
# Clean up entities and extract necessary data
|
68 |
+
clean_entities = []
|
69 |
+
label_colors = {}
|
70 |
+
|
71 |
+
for ent in entities:
|
72 |
+
# Extract label
|
73 |
+
label = ent.get('type') or ent.get('entity')
|
74 |
+
if not label:
|
75 |
+
continue # Skip entities without label
|
76 |
+
|
77 |
+
# Extract text
|
78 |
+
entity_text = ent.get('text') or ent.get('word')
|
79 |
+
if not entity_text:
|
80 |
+
continue # Skip entities without text
|
81 |
+
|
82 |
+
# Get positions if available
|
83 |
+
start = ent.get('start', -1)
|
84 |
+
end = ent.get('end', -1)
|
85 |
+
|
86 |
+
# Verify that entity text matches the span in the original text
|
87 |
+
# This ensures positions are correct
|
88 |
+
if start >= 0 and end > start and end <= len(text):
|
89 |
+
span_text = text[start:end]
|
90 |
+
if entity_text != span_text and not text[start:end].strip().startswith(entity_text):
|
91 |
+
# Try to find the entity in the text if position doesn't match
|
92 |
+
found = False
|
93 |
+
for match in re.finditer(re.escape(entity_text), text):
|
94 |
+
if not found:
|
95 |
+
start = match.start()
|
96 |
+
end = match.end()
|
97 |
+
found = True
|
98 |
+
else:
|
99 |
+
# Try to find the entity in the text if no position information
|
100 |
+
found = False
|
101 |
+
for match in re.finditer(re.escape(entity_text), text):
|
102 |
+
if not found:
|
103 |
+
start = match.start()
|
104 |
+
end = match.end()
|
105 |
+
found = True
|
106 |
+
|
107 |
+
# Assign color based on label
|
108 |
+
if label not in label_colors:
|
109 |
+
label_colors[label] = COLORS[len(label_colors) % len(COLORS)]
|
110 |
+
|
111 |
+
clean_entities.append({
|
112 |
+
'text': entity_text,
|
113 |
+
'label': label,
|
114 |
+
'color': label_colors[label],
|
115 |
+
'start': start,
|
116 |
+
'end': end
|
117 |
+
})
|
118 |
+
|
119 |
+
# Sort entities by position (important for proper rendering)
|
120 |
+
clean_entities.sort(key=lambda x: x['start'])
|
121 |
+
|
122 |
+
# Check for overlapping entities and resolve conflicts
|
123 |
+
non_overlapping = []
|
124 |
+
if clean_entities:
|
125 |
+
non_overlapping.append(clean_entities[0])
|
126 |
+
for i in range(1, len(clean_entities)):
|
127 |
+
current = clean_entities[i]
|
128 |
+
prev = non_overlapping[-1]
|
129 |
+
|
130 |
+
# Check if current entity overlaps with previous one
|
131 |
+
if current['start'] < prev['end']:
|
132 |
+
# Skip overlapping entity to avoid confusion
|
133 |
+
continue
|
134 |
+
else:
|
135 |
+
non_overlapping.append(current)
|
136 |
+
|
137 |
+
# Generate HTML with proper inline highlighting
|
138 |
+
html = ["<div class='ner-highlight' style='line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;'>"]
|
139 |
+
|
140 |
+
# Process text sequentially with entity markers
|
141 |
+
last_pos = 0
|
142 |
+
for entity in non_overlapping:
|
143 |
+
start = entity['start']
|
144 |
+
end = entity['end']
|
145 |
+
|
146 |
+
# Add text before entity
|
147 |
+
if start > last_pos:
|
148 |
+
html.append(html_lib.escape(text[last_pos:start]))
|
149 |
+
|
150 |
+
# Add the entity with its label (with spacing between entity and label)
|
151 |
+
html.append(f"<span style='background:{entity['color']};border-radius:3px;padding:2px 4px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);'>")
|
152 |
+
html.append(f"{html_lib.escape(entity['text'])} ")
|
153 |
+
html.append(f"<span style='font-size:0.8em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);'>{html_lib.escape(entity['label'])}</span>")
|
154 |
+
html.append("</span>")
|
155 |
+
|
156 |
+
# Update position
|
157 |
+
last_pos = end
|
158 |
+
|
159 |
+
# Add any remaining text
|
160 |
+
if last_pos < len(text):
|
161 |
+
html.append(html_lib.escape(text[last_pos:]))
|
162 |
+
|
163 |
+
html.append("</div>")
|
164 |
+
return "".join(html)
|
165 |
+
|
166 |
+
def update_ui(model_id: str) -> Dict:
|
167 |
+
"""Update the UI based on the selected model."""
|
168 |
+
use_llm = is_llm_model(model_id)
|
169 |
+
return {
|
170 |
+
entity_types_group: gr.Group(visible=use_llm)
|
171 |
+
}
|
172 |
+
|
173 |
+
with gr.Row():
|
174 |
+
with gr.Column(scale=2):
|
175 |
+
input_text = gr.Textbox(
|
176 |
+
label="Input Text",
|
177 |
+
lines=8,
|
178 |
+
placeholder="Enter text to analyze for named entities..."
|
179 |
+
)
|
180 |
+
|
181 |
+
gr.Examples(
|
182 |
+
examples=[
|
183 |
+
["Barack Obama was born in Hawaii and became the 44th President of the United States."],
|
184 |
+
["Google is headquartered in Mountain View, California."]
|
185 |
+
],
|
186 |
+
inputs=[input_text],
|
187 |
+
label="Examples"
|
188 |
+
)
|
189 |
+
model_dropdown = gr.Dropdown(
|
190 |
+
["gemini-2.0-flash"], # Only allow gemini-2.0-flash for now
|
191 |
+
value="gemini-2.0-flash",
|
192 |
+
label="Model"
|
193 |
+
)
|
194 |
+
|
195 |
+
with gr.Group() as entity_types_group:
|
196 |
+
entity_types = gr.CheckboxGroup(
|
197 |
+
label="Entity Types to Extract",
|
198 |
+
choices=DEFAULT_ENTITY_TYPES,
|
199 |
+
value=DEFAULT_SELECTED_ENTITIES,
|
200 |
+
interactive=True
|
201 |
+
)
|
202 |
+
with gr.Row():
|
203 |
+
select_all_btn = gr.Button("Select All", size="sm")
|
204 |
+
clear_all_btn = gr.Button("Clear All", size="sm")
|
205 |
+
|
206 |
+
btn = gr.Button("Extract Entities", variant="primary")
|
207 |
+
|
208 |
+
# Button handlers for entity selection
|
209 |
+
def select_all_entities():
|
210 |
+
return gr.CheckboxGroup(value=DEFAULT_ENTITY_TYPES)
|
211 |
+
|
212 |
+
def clear_all_entities():
|
213 |
+
return gr.CheckboxGroup(value=[])
|
214 |
+
|
215 |
+
select_all_btn.click(
|
216 |
+
fn=select_all_entities,
|
217 |
+
outputs=[entity_types]
|
218 |
+
)
|
219 |
+
|
220 |
+
clear_all_btn.click(
|
221 |
+
fn=clear_all_entities,
|
222 |
+
outputs=[entity_types]
|
223 |
+
)
|
224 |
+
|
225 |
+
with gr.Column(scale=3):
|
226 |
+
# Output with tabs
|
227 |
+
with gr.Tabs() as output_tabs:
|
228 |
+
with gr.Tab("Tagged View", id="tagged-view-ner"):
|
229 |
+
no_results_html = gr.HTML(
|
230 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
231 |
+
"Enter text and click 'Extract Entities' to get results.</div>",
|
232 |
+
visible=True
|
233 |
+
)
|
234 |
+
output_html = gr.HTML(
|
235 |
+
label="NER Highlighted",
|
236 |
+
elem_id="ner-output-html",
|
237 |
+
visible=False
|
238 |
+
)
|
239 |
+
# Add CSS for NER tags (scoped to this component)
|
240 |
+
gr.HTML("""
|
241 |
+
<style>
|
242 |
+
#ner-output-html .pos-highlight {
|
243 |
+
white-space: pre-wrap;
|
244 |
+
line-height: 1.8;
|
245 |
+
font-size: 14px;
|
246 |
+
padding: 15px;
|
247 |
+
border: 1px solid #e0e0e0;
|
248 |
+
border-radius: 4px;
|
249 |
+
background: #f9f9f9;
|
250 |
+
}
|
251 |
+
#ner-output-html .pos-token {
|
252 |
+
display: inline-block;
|
253 |
+
margin: 0 2px 4px 0;
|
254 |
+
vertical-align: top;
|
255 |
+
text-align: center;
|
256 |
+
}
|
257 |
+
#ner-output-html .token-text {
|
258 |
+
display: block;
|
259 |
+
padding: 2px 8px;
|
260 |
+
background: #f0f4f8;
|
261 |
+
border-radius: 4px 4px 0 0;
|
262 |
+
border: 1px solid #dbe4ed;
|
263 |
+
border-bottom: none;
|
264 |
+
font-size: 0.9em;
|
265 |
+
}
|
266 |
+
#ner-output-html .pos-tag {
|
267 |
+
display: block;
|
268 |
+
padding: 2px 8px;
|
269 |
+
border-radius: 0 0 4px 4px;
|
270 |
+
#ner-output-html .WORK_OF_ART { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; }
|
271 |
+
#ner-output-html .LAW { background-color: #fce4ec; border-color: #f8bbd0; color: #880e4f; }
|
272 |
+
#ner-output-html .LANGUAGE { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; font-weight: bold; }
|
273 |
+
#ner-output-html .DATE { background-color: #f5f5f5; border-color: #e0e0e0; color: #424242; }
|
274 |
+
#ner-output-html .TIME { background-color: #fafafa; border-color: #f5f5f5; color: #616161; }
|
275 |
+
#ner-output-html .PERCENT { background-color: #e1f5fe; border-color: #b3e5fc; color: #01579b; font-weight: bold; }
|
276 |
+
#ner-output-html .MONEY { background-color: #f3e5f5; border-color: #e1bee7; color: #6a1b9a; }
|
277 |
+
#ner-output-html .QUANTITY { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; font-style: italic; }
|
278 |
+
#ner-output-html .ORDINAL { background-color: #fff3e0; border-color: #ffe0b2; color: #e65100; }
|
279 |
+
#ner-output-html .CARDINAL { background-color: #ede7f6; border-color: #d1c4e9; color: #4527a0; }
|
280 |
+
</style>
|
281 |
+
""")
|
282 |
+
with gr.Tab("Table View", id="table-view-ner"):
|
283 |
+
no_results_table = gr.HTML(
|
284 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
285 |
+
"Enter text and click 'Extract Entities' to get results.</div>",
|
286 |
+
visible=True
|
287 |
+
)
|
288 |
+
output_table = gr.Dataframe(
|
289 |
+
label="Extracted Entities",
|
290 |
+
headers=["Type", "Text", "Confidence", "Description"],
|
291 |
+
datatype=["str", "str", "number", "str"],
|
292 |
+
interactive=False,
|
293 |
+
wrap=True,
|
294 |
+
visible=False
|
295 |
+
)
|
296 |
+
|
297 |
+
# Update the UI when the model changes
|
298 |
+
model_dropdown.change(
|
299 |
+
fn=update_ui,
|
300 |
+
inputs=[model_dropdown],
|
301 |
+
outputs=[entity_types_group]
|
302 |
+
)
|
303 |
+
|
304 |
+
def process_and_show_results(text: str, model: str, entity_types: List[str]):
|
305 |
+
"""Process NER and return both the results and UI state"""
|
306 |
+
if not text.strip():
|
307 |
+
msg = "<div style='text-align: center; color: #f44336; padding: 20px;'>Please enter some text to analyze.</div>"
|
308 |
+
return [
|
309 |
+
gr.HTML(visible=False), # output_html
|
310 |
+
gr.HTML(msg, visible=True), # no_results_html
|
311 |
+
gr.DataFrame(visible=False), # output_table
|
312 |
+
gr.HTML(msg, visible=True) # no_results_table
|
313 |
+
]
|
314 |
+
if not entity_types:
|
315 |
+
entity_types = list(NER_ENTITY_TYPES.keys())
|
316 |
+
result = ner(text, model, entity_types)
|
317 |
+
entities = result["entities"] if result and "entities" in result else []
|
318 |
+
# DataFrame for table view
|
319 |
+
if entities:
|
320 |
+
df = pd.DataFrame(entities)
|
321 |
+
if not df.empty:
|
322 |
+
df = df.rename(columns={
|
323 |
+
"entity": "Type",
|
324 |
+
"word": "Text",
|
325 |
+
"score": "Confidence",
|
326 |
+
"description": "Description"
|
327 |
+
})
|
328 |
+
display_columns = ["Type", "Text", "Confidence", "Description"]
|
329 |
+
df = df[[col for col in display_columns if col in df.columns]]
|
330 |
+
if 'start' in df.columns:
|
331 |
+
df = df.sort_values('start')
|
332 |
+
html = render_ner_html(text, entities)
|
333 |
+
return [
|
334 |
+
gr.HTML(html, visible=True), # output_html
|
335 |
+
gr.HTML(visible=False), # no_results_html
|
336 |
+
gr.DataFrame(value=df, visible=True), # output_table
|
337 |
+
gr.HTML(visible=False) # no_results_table
|
338 |
+
]
|
339 |
+
# No entities found
|
340 |
+
msg = "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>"
|
341 |
+
return [
|
342 |
+
gr.HTML(msg, visible=True), # output_html
|
343 |
+
gr.HTML(visible=False), # no_results_html
|
344 |
+
gr.DataFrame(visible=False), # output_table
|
345 |
+
gr.HTML(msg, visible=True) # no_results_table
|
346 |
+
]
|
347 |
+
|
348 |
+
# Set up the button click handler
|
349 |
+
btn.click(
|
350 |
+
fn=process_and_show_results,
|
351 |
+
inputs=[input_text, model_dropdown, entity_types],
|
352 |
+
outputs=[output_html, no_results_html, output_table, no_results_table]
|
353 |
+
)
|
354 |
+
|
355 |
+
# Initial UI update
|
356 |
+
update_ui(model_dropdown.value)
|
357 |
+
|
358 |
+
return None
|
ui/ner_ui.py.new
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import html as html_lib
|
7 |
+
from tasks.ner import named_entity_recognition
|
8 |
+
from utils.ner_helpers import NER_ENTITY_TYPES, DEFAULT_SELECTED_ENTITIES, is_llm_model
|
9 |
+
|
10 |
+
# The ner_ui function and related logic moved from app.py
|
11 |
+
|
12 |
+
def ner_ui():
|
13 |
+
# Default entity types for the multi-select
|
14 |
+
DEFAULT_ENTITY_TYPES = list(NER_ENTITY_TYPES.keys())
|
15 |
+
|
16 |
+
def ner(text: str, model: str, entity_types: List[str]) -> Dict[str, Any]:
|
17 |
+
"""Extract named entities, automatically using LLM for supported models."""
|
18 |
+
if not text.strip():
|
19 |
+
return {"text": "", "entities": []}
|
20 |
+
|
21 |
+
try:
|
22 |
+
use_llm = is_llm_model(model)
|
23 |
+
# Call the enhanced NER function
|
24 |
+
entities = named_entity_recognition(
|
25 |
+
text=text,
|
26 |
+
model=model,
|
27 |
+
use_llm=use_llm,
|
28 |
+
entity_types=entity_types if use_llm else None
|
29 |
+
)
|
30 |
+
|
31 |
+
# Convert to the format expected by the UI
|
32 |
+
if not isinstance(entities, list):
|
33 |
+
entities = []
|
34 |
+
|
35 |
+
if not use_llm and entity_types:
|
36 |
+
entities = [e for e in entities if e.get("type", "") in entity_types or e.get("entity", "") in entity_types]
|
37 |
+
|
38 |
+
return {
|
39 |
+
"entities": [
|
40 |
+
{
|
41 |
+
"entity": e.get("type", ""),
|
42 |
+
"word": e.get("text", ""),
|
43 |
+
"start": e.get("start", 0),
|
44 |
+
"end": e.get("end", 0),
|
45 |
+
"score": e.get("confidence", 1.0),
|
46 |
+
"description": e.get("description", "")
|
47 |
+
}
|
48 |
+
for e in entities
|
49 |
+
]
|
50 |
+
}
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error in NER: {str(e)}")
|
54 |
+
return {"entities": []}
|
55 |
+
|
56 |
+
def render_ner_html(text, entities):
|
57 |
+
# COMPLETELY REVISED APPROACH: Clean inline display of entities with proper positioning
|
58 |
+
if not text.strip() or not entities:
|
59 |
+
return "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>"
|
60 |
+
|
61 |
+
COLORS = [
|
62 |
+
'#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa',
|
63 |
+
'#f1f8e9', '#fce4ec', '#e8f5e9', '#f5f5f5', '#fafafa', '#e1f5fe',
|
64 |
+
'#fff3e0', '#d7ccc8', '#f9fbe7', '#fbe9e7', '#ede7f6', '#e0f2f1'
|
65 |
+
]
|
66 |
+
|
67 |
+
# Clean up entities and extract necessary data
|
68 |
+
clean_entities = []
|
69 |
+
label_colors = {}
|
70 |
+
|
71 |
+
for ent in entities:
|
72 |
+
# Extract label
|
73 |
+
label = ent.get('type') or ent.get('entity')
|
74 |
+
if not label:
|
75 |
+
continue # Skip entities without label
|
76 |
+
|
77 |
+
# Extract text
|
78 |
+
entity_text = ent.get('text') or ent.get('word')
|
79 |
+
if not entity_text:
|
80 |
+
continue # Skip entities without text
|
81 |
+
|
82 |
+
# Get positions if available
|
83 |
+
start = ent.get('start', -1)
|
84 |
+
end = ent.get('end', -1)
|
85 |
+
|
86 |
+
# Verify that entity text matches the span in the original text
|
87 |
+
# This ensures positions are correct
|
88 |
+
if start >= 0 and end > start and end <= len(text):
|
89 |
+
span_text = text[start:end]
|
90 |
+
if entity_text != span_text and not text[start:end].strip().startswith(entity_text):
|
91 |
+
# Try to find the entity in the text if position doesn't match
|
92 |
+
found = False
|
93 |
+
for match in re.finditer(re.escape(entity_text), text):
|
94 |
+
if not found:
|
95 |
+
start = match.start()
|
96 |
+
end = match.end()
|
97 |
+
found = True
|
98 |
+
else:
|
99 |
+
# Try to find the entity in the text if no position information
|
100 |
+
found = False
|
101 |
+
for match in re.finditer(re.escape(entity_text), text):
|
102 |
+
if not found:
|
103 |
+
start = match.start()
|
104 |
+
end = match.end()
|
105 |
+
found = True
|
106 |
+
|
107 |
+
# Assign color based on label
|
108 |
+
if label not in label_colors:
|
109 |
+
label_colors[label] = COLORS[len(label_colors) % len(COLORS)]
|
110 |
+
|
111 |
+
clean_entities.append({
|
112 |
+
'text': entity_text,
|
113 |
+
'label': label,
|
114 |
+
'color': label_colors[label],
|
115 |
+
'start': start,
|
116 |
+
'end': end
|
117 |
+
})
|
118 |
+
|
119 |
+
# Sort entities by position (important for proper rendering)
|
120 |
+
clean_entities.sort(key=lambda x: x['start'])
|
121 |
+
|
122 |
+
# Check for overlapping entities and resolve conflicts
|
123 |
+
non_overlapping = []
|
124 |
+
if clean_entities:
|
125 |
+
non_overlapping.append(clean_entities[0])
|
126 |
+
for i in range(1, len(clean_entities)):
|
127 |
+
current = clean_entities[i]
|
128 |
+
prev = non_overlapping[-1]
|
129 |
+
|
130 |
+
# Check if current entity overlaps with previous one
|
131 |
+
if current['start'] < prev['end']:
|
132 |
+
# Skip overlapping entity to avoid confusion
|
133 |
+
continue
|
134 |
+
else:
|
135 |
+
non_overlapping.append(current)
|
136 |
+
|
137 |
+
# Generate HTML with proper inline highlighting
|
138 |
+
html = ["<div class='ner-highlight' style='line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;'>"]
|
139 |
+
|
140 |
+
# Process text sequentially with entity markers
|
141 |
+
last_pos = 0
|
142 |
+
for entity in non_overlapping:
|
143 |
+
start = entity['start']
|
144 |
+
end = entity['end']
|
145 |
+
|
146 |
+
# Add text before entity
|
147 |
+
if start > last_pos:
|
148 |
+
html.append(html_lib.escape(text[last_pos:start]))
|
149 |
+
|
150 |
+
# Add the entity with its label (with spacing between entity and label)
|
151 |
+
html.append(f"<span style='background:{entity['color']};border-radius:3px;padding:2px 4px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);'>")
|
152 |
+
html.append(f"{html_lib.escape(entity['text'])} ")
|
153 |
+
html.append(f"<span style='font-size:0.8em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);'>{html_lib.escape(entity['label'])}</span>")
|
154 |
+
html.append("</span>")
|
155 |
+
|
156 |
+
# Update position
|
157 |
+
last_pos = end
|
158 |
+
|
159 |
+
# Add any remaining text
|
160 |
+
if last_pos < len(text):
|
161 |
+
html.append(html_lib.escape(text[last_pos:]))
|
162 |
+
|
163 |
+
html.append("</div>")
|
164 |
+
return "".join(html)
|
165 |
+
|
166 |
+
def update_ui(model_id: str) -> Dict:
|
167 |
+
"""Update the UI based on the selected model."""
|
168 |
+
use_llm = is_llm_model(model_id)
|
169 |
+
return {
|
170 |
+
entity_types_group: gr.Group(visible=use_llm)
|
171 |
+
}
|
172 |
+
|
173 |
+
with gr.Row():
|
174 |
+
with gr.Column(scale=2):
|
175 |
+
input_text = gr.Textbox(
|
176 |
+
label="Input Text",
|
177 |
+
lines=8,
|
178 |
+
placeholder="Enter text to analyze for named entities..."
|
179 |
+
)
|
180 |
+
|
181 |
+
model_dropdown = gr.Dropdown(
|
182 |
+
["gemini-2.0-flash", "gpt-4", "claude-2", "en_core_web_sm", "en_core_web_md", "en_core_web_lg"],
|
183 |
+
value="gemini-2.0-flash",
|
184 |
+
label="Model"
|
185 |
+
)
|
186 |
+
|
187 |
+
with gr.Group() as entity_types_group:
|
188 |
+
entity_types = gr.CheckboxGroup(
|
189 |
+
label="Entity Types to Extract",
|
190 |
+
choices=DEFAULT_ENTITY_TYPES,
|
191 |
+
value=DEFAULT_SELECTED_ENTITIES,
|
192 |
+
interactive=True
|
193 |
+
)
|
194 |
+
with gr.Row():
|
195 |
+
select_all_btn = gr.Button("Select All", size="sm")
|
196 |
+
clear_all_btn = gr.Button("Clear All", size="sm")
|
197 |
+
|
198 |
+
btn = gr.Button("Extract Entities", variant="primary")
|
199 |
+
|
200 |
+
# Button handlers for entity selection
|
201 |
+
def select_all_entities():
|
202 |
+
return gr.CheckboxGroup(value=DEFAULT_ENTITY_TYPES)
|
203 |
+
|
204 |
+
def clear_all_entities():
|
205 |
+
return gr.CheckboxGroup(value=[])
|
206 |
+
|
207 |
+
select_all_btn.click(
|
208 |
+
fn=select_all_entities,
|
209 |
+
outputs=[entity_types]
|
210 |
+
)
|
211 |
+
|
212 |
+
clear_all_btn.click(
|
213 |
+
fn=clear_all_entities,
|
214 |
+
outputs=[entity_types]
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Column(scale=3):
|
218 |
+
# Output with tabs
|
219 |
+
with gr.Tabs() as output_tabs:
|
220 |
+
with gr.Tab("Tagged View", id="tagged-view-ner"):
|
221 |
+
no_results_html = gr.HTML(
|
222 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
223 |
+
"Enter text and click 'Extract Entities' to get results.</div>",
|
224 |
+
visible=True
|
225 |
+
)
|
226 |
+
output_html = gr.HTML(
|
227 |
+
label="NER Highlighted",
|
228 |
+
elem_id="ner-output-html",
|
229 |
+
visible=False
|
230 |
+
)
|
231 |
+
# Add CSS for NER tags (scoped to this component)
|
232 |
+
gr.HTML("""
|
233 |
+
<style>
|
234 |
+
#ner-output-html .pos-highlight {
|
235 |
+
white-space: pre-wrap;
|
236 |
+
line-height: 1.8;
|
237 |
+
font-size: 14px;
|
238 |
+
padding: 15px;
|
239 |
+
border: 1px solid #e0e0e0;
|
240 |
+
border-radius: 4px;
|
241 |
+
background: #f9f9f9;
|
242 |
+
}
|
243 |
+
#ner-output-html .pos-token {
|
244 |
+
display: inline-block;
|
245 |
+
margin: 0 2px 4px 0;
|
246 |
+
vertical-align: top;
|
247 |
+
text-align: center;
|
248 |
+
}
|
249 |
+
#ner-output-html .token-text {
|
250 |
+
display: block;
|
251 |
+
padding: 2px 8px;
|
252 |
+
background: #f0f4f8;
|
253 |
+
border-radius: 4px 4px 0 0;
|
254 |
+
border: 1px solid #dbe4ed;
|
255 |
+
border-bottom: none;
|
256 |
+
font-size: 0.9em;
|
257 |
+
}
|
258 |
+
#ner-output-html .pos-tag {
|
259 |
+
display: block;
|
260 |
+
padding: 2px 8px;
|
261 |
+
border-radius: 0 0 4px 4px;
|
262 |
+
font-size: 0.8em;
|
263 |
+
font-family: 'Courier New', monospace;
|
264 |
+
border: 1px solid;
|
265 |
+
border-top: none;
|
266 |
+
}
|
267 |
+
/* Example color coding for common NER labels (customize as needed) */
|
268 |
+
#ner-output-html .PERSON { background-color: #e3f2fd; border-color: #bbdefb; color: #0d47a1; }
|
269 |
+
#ner-output-html .ORG { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; }
|
270 |
+
#ner-output-html .GPE { background-color: #fff8e1; border-color: #ffecb3; color: #ff6f00; }
|
271 |
+
#ner-output-html .LOC { background-color: #f3e5f5; border-color: #e1bee7; color: #4a148c; }
|
272 |
+
#ner-output-html .PRODUCT { background-color: #e8eaf6; border-color: #c5cae9; color: #1a237e; }
|
273 |
+
#ner-output-html .EVENT { background-color: #e0f7fa; border-color: #b2ebf2; color: #006064; }
|
274 |
+
#ner-output-html .WORK_OF_ART { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; }
|
275 |
+
#ner-output-html .LAW { background-color: #fce4ec; border-color: #f8bbd0; color: #880e4f; }
|
276 |
+
#ner-output-html .LANGUAGE { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; font-weight: bold; }
|
277 |
+
#ner-output-html .DATE { background-color: #f5f5f5; border-color: #e0e0e0; color: #424242; }
|
278 |
+
#ner-output-html .TIME { background-color: #fafafa; border-color: #f5f5f5; color: #616161; }
|
279 |
+
#ner-output-html .PERCENT { background-color: #e1f5fe; border-color: #b3e5fc; color: #01579b; font-weight: bold; }
|
280 |
+
#ner-output-html .MONEY { background-color: #f3e5f5; border-color: #e1bee7; color: #6a1b9a; }
|
281 |
+
#ner-output-html .QUANTITY { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; font-style: italic; }
|
282 |
+
#ner-output-html .ORDINAL { background-color: #fff3e0; border-color: #ffe0b2; color: #e65100; }
|
283 |
+
#ner-output-html .CARDINAL { background-color: #ede7f6; border-color: #d1c4e9; color: #4527a0; }
|
284 |
+
</style>
|
285 |
+
""")
|
286 |
+
with gr.Tab("Table View", id="table-view-ner"):
|
287 |
+
no_results_table = gr.HTML(
|
288 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
289 |
+
"Enter text and click 'Extract Entities' to get results.</div>",
|
290 |
+
visible=True
|
291 |
+
)
|
292 |
+
output_table = gr.Dataframe(
|
293 |
+
label="Extracted Entities",
|
294 |
+
headers=["Type", "Text", "Confidence", "Description"],
|
295 |
+
datatype=["str", "str", "number", "str"],
|
296 |
+
interactive=False,
|
297 |
+
wrap=True,
|
298 |
+
visible=False
|
299 |
+
)
|
300 |
+
|
301 |
+
# Update the UI when the model changes
|
302 |
+
model_dropdown.change(
|
303 |
+
fn=update_ui,
|
304 |
+
inputs=[model_dropdown],
|
305 |
+
outputs=[entity_types_group]
|
306 |
+
)
|
307 |
+
|
308 |
+
def process_and_show_results(text: str, model: str, entity_types: List[str]):
|
309 |
+
"""Process NER and return both the results and UI state"""
|
310 |
+
if not text.strip():
|
311 |
+
msg = "<div style='text-align: center; color: #f44336; padding: 20px;'>Please enter some text to analyze.</div>"
|
312 |
+
return [
|
313 |
+
gr.HTML(visible=False), # output_html
|
314 |
+
gr.HTML(msg, visible=True), # no_results_html
|
315 |
+
gr.DataFrame(visible=False), # output_table
|
316 |
+
gr.HTML(msg, visible=True) # no_results_table
|
317 |
+
]
|
318 |
+
if not entity_types:
|
319 |
+
entity_types = list(NER_ENTITY_TYPES.keys())
|
320 |
+
result = ner(text, model, entity_types)
|
321 |
+
entities = result["entities"] if result and "entities" in result else []
|
322 |
+
# DataFrame for table view
|
323 |
+
if entities:
|
324 |
+
df = pd.DataFrame(entities)
|
325 |
+
if not df.empty:
|
326 |
+
df = df.rename(columns={
|
327 |
+
"entity": "Type",
|
328 |
+
"word": "Text",
|
329 |
+
"score": "Confidence",
|
330 |
+
"description": "Description"
|
331 |
+
})
|
332 |
+
display_columns = ["Type", "Text", "Confidence", "Description"]
|
333 |
+
df = df[[col for col in display_columns if col in df.columns]]
|
334 |
+
if 'start' in df.columns:
|
335 |
+
df = df.sort_values('start')
|
336 |
+
html = render_ner_html(text, entities)
|
337 |
+
return [
|
338 |
+
gr.HTML(html, visible=True), # output_html
|
339 |
+
gr.HTML(visible=False), # no_results_html
|
340 |
+
gr.DataFrame(value=df, visible=True), # output_table
|
341 |
+
gr.HTML(visible=False) # no_results_table
|
342 |
+
]
|
343 |
+
# No entities found
|
344 |
+
msg = "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>"
|
345 |
+
return [
|
346 |
+
gr.HTML(msg, visible=True), # output_html
|
347 |
+
gr.HTML(visible=False), # no_results_html
|
348 |
+
gr.DataFrame(visible=False), # output_table
|
349 |
+
gr.HTML(msg, visible=True) # no_results_table
|
350 |
+
]
|
351 |
+
|
352 |
+
# Set up the button click handler
|
353 |
+
btn.click(
|
354 |
+
fn=process_and_show_results,
|
355 |
+
inputs=[input_text, model_dropdown, entity_types],
|
356 |
+
outputs=[output_html, no_results_html, output_table, no_results_table]
|
357 |
+
)
|
358 |
+
|
359 |
+
# Initial UI update
|
360 |
+
update_ui(model_dropdown.value)
|
361 |
+
|
362 |
+
return None
|
ui/pos_ui.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.ner_helpers import is_llm_model
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
from typing import Dict, List
|
6 |
+
from tasks.pos_tagging import pos_tagging
|
7 |
+
from utils.pos_helpers import *
|
8 |
+
|
9 |
+
# POS UI
|
10 |
+
|
11 |
+
def pos_ui():
|
12 |
+
|
13 |
+
# UI Components
|
14 |
+
with gr.Row():
|
15 |
+
with gr.Column(scale=2):
|
16 |
+
input_text = gr.Textbox(
|
17 |
+
label="Input Text",
|
18 |
+
lines=8,
|
19 |
+
placeholder="Enter text to analyze for part-of-speech tags...",
|
20 |
+
elem_id="pos-input-text"
|
21 |
+
)
|
22 |
+
gr.Examples(
|
23 |
+
examples=[
|
24 |
+
["The cat is sitting on the mat."],
|
25 |
+
["She quickly finished her homework before dinner."]
|
26 |
+
],
|
27 |
+
inputs=[input_text],
|
28 |
+
label="Examples"
|
29 |
+
)
|
30 |
+
# Tag selection
|
31 |
+
with gr.Group():
|
32 |
+
tag_selection = gr.CheckboxGroup(
|
33 |
+
label="POS Tags to Display",
|
34 |
+
# choices=[(f"{tag} - {desc}", tag) for tag, desc in POS_TAG_DESCRIPTIONS.items()],
|
35 |
+
choices=[tag for tag in POS_TAG_DESCRIPTIONS.keys()],
|
36 |
+
value=DEFAULT_SELECTED_TAGS,
|
37 |
+
interactive=True
|
38 |
+
)
|
39 |
+
with gr.Row():
|
40 |
+
select_all_btn = gr.Button("Select All", size="sm")
|
41 |
+
clear_all_btn = gr.Button("Clear All", size="sm")
|
42 |
+
# Model selection at the bottom
|
43 |
+
with gr.Row():
|
44 |
+
model_dropdown = gr.Dropdown(
|
45 |
+
POS_MODELS,
|
46 |
+
value=DEFAULT_MODEL,
|
47 |
+
label="Model",
|
48 |
+
interactive=True,
|
49 |
+
elem_id="pos-model-dropdown"
|
50 |
+
)
|
51 |
+
custom_instructions = gr.Textbox(
|
52 |
+
label="Custom Instructions (optional)",
|
53 |
+
lines=2,
|
54 |
+
placeholder="Add any custom instructions for the model...",
|
55 |
+
elem_id="pos-custom-instructions"
|
56 |
+
)
|
57 |
+
# Submit button
|
58 |
+
submit_btn = gr.Button("Tag Text", variant="primary", elem_id="pos-submit-btn")
|
59 |
+
# Button event handlers
|
60 |
+
def select_all_tags():
|
61 |
+
return gr.CheckboxGroup(value=DEFAULT_SELECTED_TAGS)
|
62 |
+
def clear_all_tags():
|
63 |
+
return gr.CheckboxGroup(value=[])
|
64 |
+
select_all_btn.click(
|
65 |
+
fn=select_all_tags,
|
66 |
+
outputs=[tag_selection]
|
67 |
+
)
|
68 |
+
clear_all_btn.click(
|
69 |
+
fn=clear_all_tags,
|
70 |
+
outputs=[tag_selection]
|
71 |
+
)
|
72 |
+
with gr.Column(scale=3):
|
73 |
+
# Results container with tabs
|
74 |
+
with gr.Tabs() as output_tabs:
|
75 |
+
with gr.Tab("Tagged View", id="tagged-view"):
|
76 |
+
no_results_html = gr.HTML(
|
77 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
78 |
+
"Enter text and click 'Tag Text' to analyze.</div>",
|
79 |
+
visible=True
|
80 |
+
)
|
81 |
+
output_html = gr.HTML(
|
82 |
+
label="POS Tags",
|
83 |
+
elem_id="pos-output-html",
|
84 |
+
visible=False
|
85 |
+
)
|
86 |
+
with gr.Tab("Table View", id="table-view"):
|
87 |
+
no_results_table = gr.HTML(
|
88 |
+
"<div style='text-align: center; color: #666; padding: 20px;'>"
|
89 |
+
"Enter text and click 'Tag Text' to analyze.</div>",
|
90 |
+
visible=True
|
91 |
+
)
|
92 |
+
output_table = gr.Dataframe(
|
93 |
+
label="POS Tags",
|
94 |
+
headers=["Token", "POS Tag"],
|
95 |
+
datatype=["str", "str"],
|
96 |
+
interactive=False,
|
97 |
+
wrap=True,
|
98 |
+
elem_id="pos-output-table",
|
99 |
+
visible=False
|
100 |
+
)
|
101 |
+
# Add CSS for the POS tags (scoped to this component)
|
102 |
+
gr.HTML("""
|
103 |
+
<style>
|
104 |
+
#pos-output-html .pos-highlight {
|
105 |
+
white-space: pre-wrap;
|
106 |
+
line-height: 1.8;
|
107 |
+
font-size: 14px;
|
108 |
+
padding: 15px;
|
109 |
+
border: 1px solid #e0e0e0;
|
110 |
+
border-radius: 4px;
|
111 |
+
background: #f9f9f9;
|
112 |
+
}
|
113 |
+
#pos-output-html .pos-token {
|
114 |
+
display: inline-block;
|
115 |
+
margin: 0 2px 4px 0;
|
116 |
+
vertical-align: top;
|
117 |
+
text-align: center;
|
118 |
+
}
|
119 |
+
#pos-output-html .token-text {
|
120 |
+
display: block;
|
121 |
+
padding: 2px 8px;
|
122 |
+
background: #f0f4f8;
|
123 |
+
border-radius: 4px 4px 0 0;
|
124 |
+
border: 1px solid #dbe4ed;
|
125 |
+
border-bottom: none;
|
126 |
+
font-size: 0.9em;
|
127 |
+
}
|
128 |
+
#pos-output-html .pos-tag {
|
129 |
+
display: block;
|
130 |
+
padding: 2px 8px;
|
131 |
+
border-radius: 0 0 4px 4px;
|
132 |
+
font-size: 0.8em;
|
133 |
+
font-family: 'Courier New', monospace;
|
134 |
+
border: 1px solid;
|
135 |
+
border-top: none;
|
136 |
+
}
|
137 |
+
/* Color coding for common POS tags */
|
138 |
+
#pos-output-html .NOUN { background-color: #e3f2fd; border-color: #bbdefb; color: #0d47a1; }
|
139 |
+
#pos-output-html .VERB { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; }
|
140 |
+
#pos-output-html .ADJ { background-color: #fff8e1; border-color: #ffecb3; color: #ff6f00; }
|
141 |
+
#pos-output-html .ADV { background-color: #f3e5f5; border-color: #e1bee7; color: #4a148c; }
|
142 |
+
#pos-output-html .PRON { background-color: #e8eaf6; border-color: #c5cae9; color: #1a237e; }
|
143 |
+
#pos-output-html .DET { background-color: #e0f7fa; border-color: #b2ebf2; color: #006064; }
|
144 |
+
#pos-output-html .ADP { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; }
|
145 |
+
#pos-output-html .CONJ, #pos-output-html .CCONJ, #pos-output-html .SCONJ { background-color: #fce4ec; border-color: #f8bbd0; color: #880e4f; }
|
146 |
+
#pos-output-html .NUM { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; font-weight: bold; }
|
147 |
+
#pos-output-html .PUNCT { background-color: #f5f5f5; border-color: #e0e0e0; color: #424242; }
|
148 |
+
#pos-output-html .X, #pos-output-html .SYM { background-color: #fafafa; border-color: #f5f5f5; color: #616161; }
|
149 |
+
#pos-output-html .PROPN { background-color: #e1f5fe; border-color: #b3e5fc; color: #01579b; font-weight: bold; }
|
150 |
+
#pos-output-html .AUX { background-color: #f3e5f5; border-color: #e1bee7; color: #6a1b9a; }
|
151 |
+
#pos-output-html .PART { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; font-style: italic; }
|
152 |
+
#pos-output-html .INTJ { background-color: #fff3e0; border-color: #ffe0b2; color: #e65100; }
|
153 |
+
</style>
|
154 |
+
""")
|
155 |
+
def format_pos_result(result, selected_tags=None):
|
156 |
+
import html
|
157 |
+
if not result or "tokens" not in result or "tags" not in result:
|
158 |
+
return "<div style='text-align: center; color: #666; padding: 20px;'>No POS tags found or invalid result format.</div>", pd.DataFrame(columns=["Token", "POS Tag"])
|
159 |
+
if selected_tags is None:
|
160 |
+
selected_tags = list(POS_TAG_DESCRIPTIONS.keys())
|
161 |
+
pos_colors = {
|
162 |
+
"NOUN": "#e3f2fd", "VERB": "#e8f5e9", "ADJ": "#fff8e1",
|
163 |
+
"ADV": "#f3e5f5", "PRON": "#e8eaf6", "DET": "#e0f7fa",
|
164 |
+
"ADP": "#f1f8e9", "CONJ": "#fce4ec", "CCONJ": "#fce4ec",
|
165 |
+
"SCONJ": "#fce4ec", "NUM": "#e8f5e9", "PUNCT": "#f5f5f5",
|
166 |
+
"X": "#fafafa", "SYM": "#fafafa", "PROPN": "#e1f5fe",
|
167 |
+
"AUX": "#f3e5f5", "PART": "#f1f8e9", "INTJ": "#fff3e0"
|
168 |
+
}
|
169 |
+
html_parts = ['<div style="line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;">']
|
170 |
+
df_data = []
|
171 |
+
for word, tag in zip(result["tokens"], result["tags"]):
|
172 |
+
clean_tag = tag.split('-')[0].split('_')[0].upper()
|
173 |
+
if clean_tag not in STANDARD_POS_TAGS:
|
174 |
+
clean_tag = "X"
|
175 |
+
df_data.append({"Token": word, "POS Tag": clean_tag})
|
176 |
+
if clean_tag not in selected_tags:
|
177 |
+
html_parts.append(f'{html.escape(word)} ')
|
178 |
+
continue
|
179 |
+
color = pos_colors.get(clean_tag, "#f0f0f0")
|
180 |
+
html_parts.append(f'<span style="background:{color};border-radius:3px;padding:0 2px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);">')
|
181 |
+
html_parts.append(f'{html.escape(word)} ')
|
182 |
+
html_parts.append(f'<span style="font-size:0.7em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);">{clean_tag}</span>')
|
183 |
+
html_parts.append('</span>')
|
184 |
+
html_parts.append('</div>')
|
185 |
+
import pandas as pd
|
186 |
+
df = pd.DataFrame(df_data)
|
187 |
+
if selected_tags is not None:
|
188 |
+
df = df[df["POS Tag"].isin(selected_tags)]
|
189 |
+
df = df.reset_index(drop=True)
|
190 |
+
return "".join(html_parts), df
|
191 |
+
def process_pos(text: str, model: str, custom_instructions: str, selected_tags: list):
|
192 |
+
if not text.strip():
|
193 |
+
return [
|
194 |
+
gr.HTML("<div style='color: #f44336; padding: 20px;'>Please enter some text to analyze.</div>", visible=True),
|
195 |
+
gr.HTML(visible=False), # no_results_html
|
196 |
+
gr.DataFrame(visible=False), # output_table
|
197 |
+
gr.HTML(visible=False) # no_results_table
|
198 |
+
]
|
199 |
+
use_llm = is_llm_model(model)
|
200 |
+
if not selected_tags:
|
201 |
+
selected_tags = list(POS_TAG_DESCRIPTIONS.keys())
|
202 |
+
try:
|
203 |
+
yield [
|
204 |
+
gr.HTML("<div class='pos-highlight'>Processing... This may take a moment for large texts.</div>", visible=True),
|
205 |
+
gr.HTML(visible=False), # no_results_html
|
206 |
+
gr.DataFrame(visible=False), # output_table
|
207 |
+
gr.HTML(visible=False) # no_results_table
|
208 |
+
]
|
209 |
+
result = pos_tagging(
|
210 |
+
text=text,
|
211 |
+
model=model,
|
212 |
+
custom_instructions=custom_instructions if use_llm else "",
|
213 |
+
use_llm=use_llm
|
214 |
+
)
|
215 |
+
if "error" in result:
|
216 |
+
error_msg = result['error']
|
217 |
+
if "API key" in error_msg or "authentication" in error_msg.lower():
|
218 |
+
error_msg += " Please check your API key configuration."
|
219 |
+
yield [
|
220 |
+
gr.HTML(f"<div style='color: #d32f2f; padding: 20px;'>{error_msg}</div>", visible=True),
|
221 |
+
gr.HTML(visible=False), # no_results_html
|
222 |
+
gr.DataFrame(visible=False), # output_table
|
223 |
+
gr.HTML(visible=False) # no_results_table
|
224 |
+
]
|
225 |
+
return
|
226 |
+
html, table = format_pos_result(result, selected_tags)
|
227 |
+
if not table.empty:
|
228 |
+
yield [
|
229 |
+
gr.HTML(html, visible=True), # output_html
|
230 |
+
gr.HTML(visible=False), # no_results_html
|
231 |
+
gr.DataFrame(value=table, visible=True), # output_table
|
232 |
+
gr.HTML(visible=False) # no_results_table
|
233 |
+
]
|
234 |
+
else:
|
235 |
+
empty_msg = "<div class='pos-highlight' style='text-align: center; color: #666; padding: 20px;'>No POS tags could be extracted from the text.</div>"
|
236 |
+
yield [
|
237 |
+
gr.HTML(empty_msg, visible=True), # output_html
|
238 |
+
gr.HTML(visible=False), # no_results_html
|
239 |
+
gr.DataFrame(visible=False), # output_table
|
240 |
+
gr.HTML(empty_msg, visible=True) # no_results_table
|
241 |
+
]
|
242 |
+
except Exception as e:
|
243 |
+
import traceback
|
244 |
+
error_msg = f"Error processing request: {str(e)}\n\n{traceback.format_exc()}"
|
245 |
+
print(error_msg) # Log the full error
|
246 |
+
yield [
|
247 |
+
gr.HTML("<div class='pos-highlight' style='color: #d32f2f; padding: 20px;'>An error occurred while processing your request. Please try again.</div>", visible=True),
|
248 |
+
gr.HTML(visible=False), # no_results_html
|
249 |
+
gr.DataFrame(visible=False), # output_table
|
250 |
+
gr.HTML(visible=False) # no_results_table
|
251 |
+
]
|
252 |
+
def update_ui(model_name: str) -> Dict:
|
253 |
+
use_llm = is_llm_model(model_name)
|
254 |
+
return {
|
255 |
+
custom_instructions: gr.Textbox(visible=use_llm)
|
256 |
+
}
|
257 |
+
def clear_inputs():
|
258 |
+
return "", "", ""
|
259 |
+
model_dropdown.change(
|
260 |
+
fn=update_ui,
|
261 |
+
inputs=[model_dropdown],
|
262 |
+
outputs=[custom_instructions]
|
263 |
+
)
|
264 |
+
submit_btn.click(
|
265 |
+
fn=process_pos,
|
266 |
+
inputs=[input_text, model_dropdown, custom_instructions, tag_selection],
|
267 |
+
outputs=[output_html, no_results_html, output_table, no_results_table],
|
268 |
+
show_progress=True
|
269 |
+
)
|
270 |
+
gr.HTML("""
|
271 |
+
<style>
|
272 |
+
/* Style for the tabs */
|
273 |
+
#tagged-view, #table-view {
|
274 |
+
padding: 15px;
|
275 |
+
}
|
276 |
+
/* Make the tabs more visible */
|
277 |
+
.tab-nav {
|
278 |
+
margin-bottom: 10px;
|
279 |
+
border-bottom: 1px solid #e0e0e0;
|
280 |
+
}
|
281 |
+
.tab-nav button {
|
282 |
+
padding: 8px 16px;
|
283 |
+
margin-right: 5px;
|
284 |
+
border: 1px solid #e0e0e0;
|
285 |
+
background: #f5f5f5;
|
286 |
+
border-radius: 4px 4px 0 0;
|
287 |
+
cursor: pointer;
|
288 |
+
}
|
289 |
+
.tab-nav button.selected {
|
290 |
+
background: #ffffff;
|
291 |
+
border-bottom: 2px solid #0e7490;
|
292 |
+
font-weight: bold;
|
293 |
+
}
|
294 |
+
</style>
|
295 |
+
""")
|
296 |
+
custom_instructions.visible = is_llm_model(DEFAULT_MODEL)
|
297 |
+
return None
|
ui/sentiment_ui.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.ner_helpers import is_llm_model
|
3 |
+
from typing import Dict, List, Any
|
4 |
+
from tasks.sentiment_analysis import sentiment_analysis
|
5 |
+
|
6 |
+
def sentiment_ui():
|
7 |
+
"""Sentiment analysis UI component"""
|
8 |
+
|
9 |
+
# Define models
|
10 |
+
SENTIMENT_MODELS = [
|
11 |
+
"gemini-2.0-flash" # Only allow gemini-2.0-flash for now
|
12 |
+
# "gpt-4",
|
13 |
+
# "claude-2",
|
14 |
+
# "distilbert-base-uncased-finetuned-sst-2-english",
|
15 |
+
# "finiteautomata/bertweet-base-sentiment-analysis"
|
16 |
+
]
|
17 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
18 |
+
|
19 |
+
def analyze_sentiment(text, model, custom_instructions):
|
20 |
+
"""Process text for sentiment analysis"""
|
21 |
+
if not text.strip():
|
22 |
+
return "No text provided"
|
23 |
+
|
24 |
+
use_llm = is_llm_model(model)
|
25 |
+
result = sentiment_analysis(
|
26 |
+
text=text,
|
27 |
+
model=model,
|
28 |
+
custom_instructions=custom_instructions,
|
29 |
+
use_llm=use_llm
|
30 |
+
)
|
31 |
+
|
32 |
+
# Try to normalize the result
|
33 |
+
result = result.lower().strip()
|
34 |
+
if "positive" in result:
|
35 |
+
return "Positive"
|
36 |
+
elif "negative" in result:
|
37 |
+
return "Negative"
|
38 |
+
elif "neutral" in result:
|
39 |
+
return "Neutral"
|
40 |
+
else:
|
41 |
+
# Return as is for other results
|
42 |
+
return result
|
43 |
+
|
44 |
+
# UI Components
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column():
|
47 |
+
input_text = gr.Textbox(
|
48 |
+
label="Input Text",
|
49 |
+
lines=6,
|
50 |
+
placeholder="Enter text to analyze sentiment...",
|
51 |
+
elem_id="sentiment-input-text"
|
52 |
+
)
|
53 |
+
gr.Examples(
|
54 |
+
examples=[
|
55 |
+
["I am very satisfied with the customer service of this company."],
|
56 |
+
["The product did not meet my expectations and I am disappointed."]
|
57 |
+
],
|
58 |
+
inputs=[input_text],
|
59 |
+
label="Examples"
|
60 |
+
)
|
61 |
+
model = gr.Dropdown(
|
62 |
+
SENTIMENT_MODELS,
|
63 |
+
value=DEFAULT_MODEL,
|
64 |
+
label="Model",
|
65 |
+
interactive=True,
|
66 |
+
elem_id="sentiment-model-dropdown"
|
67 |
+
)
|
68 |
+
custom_instructions = gr.Textbox(
|
69 |
+
label="Custom Instructions (optional)",
|
70 |
+
lines=2,
|
71 |
+
placeholder="Add any custom instructions for the model...",
|
72 |
+
elem_id="sentiment-custom-instructions"
|
73 |
+
)
|
74 |
+
|
75 |
+
btn = gr.Button("Analyze Sentiment", variant="primary", elem_id="sentiment-analyze-btn")
|
76 |
+
|
77 |
+
with gr.Column():
|
78 |
+
output = gr.Textbox(
|
79 |
+
label="Sentiment Analysis",
|
80 |
+
elem_id="sentiment-output"
|
81 |
+
)
|
82 |
+
|
83 |
+
# with gr.Accordion("About Sentiment Analysis", open=False):
|
84 |
+
# gr.Markdown("""
|
85 |
+
# ## Sentiment Analysis
|
86 |
+
|
87 |
+
# Sentiment analysis identifies the emotional tone behind text. The model analyzes your input text and classifies it as:
|
88 |
+
|
89 |
+
# - **Positive**: Text expresses positive emotions, approval, or optimism
|
90 |
+
# - **Negative**: Text expresses negative emotions, criticism, or pessimism
|
91 |
+
# - **Neutral**: Text is factual or does not express strong sentiment
|
92 |
+
|
93 |
+
# ### Model Types
|
94 |
+
|
95 |
+
# - **LLM Models** (Gemini, GPT, Claude): Provide sophisticated analysis with better understanding of context
|
96 |
+
# - **Traditional Models**: Specialized models trained specifically for sentiment analysis tasks
|
97 |
+
|
98 |
+
# Use the advanced options to customize how the model analyzes your text.
|
99 |
+
# """)
|
100 |
+
|
101 |
+
# Event handlers
|
102 |
+
btn.click(
|
103 |
+
analyze_sentiment,
|
104 |
+
inputs=[input_text, model, custom_instructions],
|
105 |
+
outputs=output
|
106 |
+
)
|
107 |
+
|
108 |
+
return None
|
ui/summarization_ui.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.ner_helpers import is_llm_model
|
3 |
+
from typing import Dict, List, Any
|
4 |
+
from tasks.summarization import text_summarization
|
5 |
+
|
6 |
+
def summarization_ui():
|
7 |
+
"""Summarization UI component"""
|
8 |
+
|
9 |
+
# Define models
|
10 |
+
SUMMARY_MODELS = [
|
11 |
+
"gemini-2.0-flash" # Only allow gemini-2.0-flash for now
|
12 |
+
# "gpt-4",
|
13 |
+
# "claude-2",
|
14 |
+
# "facebook/bart-large-cnn",
|
15 |
+
# "t5-small",
|
16 |
+
# "qwen/Qwen2.5-3B-Instruct"
|
17 |
+
]
|
18 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
19 |
+
|
20 |
+
def summarize(text, model, summary_length, custom_instructions):
|
21 |
+
"""Process text for summarization"""
|
22 |
+
if not text.strip():
|
23 |
+
return "No text provided"
|
24 |
+
|
25 |
+
use_llm = is_llm_model(model)
|
26 |
+
result = text_summarization(
|
27 |
+
text=text,
|
28 |
+
model=model,
|
29 |
+
summary_length=summary_length,
|
30 |
+
use_llm=use_llm
|
31 |
+
)
|
32 |
+
|
33 |
+
# Lưu ý: custom_instructions sẽ được sử dụng trong tương lai khi API hỗ trợ
|
34 |
+
|
35 |
+
return result
|
36 |
+
|
37 |
+
# UI Components
|
38 |
+
with gr.Row():
|
39 |
+
with gr.Column():
|
40 |
+
input_text = gr.Textbox(
|
41 |
+
label="Input Text",
|
42 |
+
lines=8,
|
43 |
+
placeholder="Enter text to summarize...",
|
44 |
+
elem_id="summary-input-text"
|
45 |
+
)
|
46 |
+
|
47 |
+
summary_length = gr.Radio(
|
48 |
+
["Short", "Medium", "Long"],
|
49 |
+
value="Medium",
|
50 |
+
label="Summary Length",
|
51 |
+
elem_id="summary-length-radio"
|
52 |
+
)
|
53 |
+
model = gr.Dropdown(
|
54 |
+
SUMMARY_MODELS,
|
55 |
+
value=DEFAULT_MODEL,
|
56 |
+
label="Model",
|
57 |
+
interactive=True,
|
58 |
+
elem_id="summary-model-dropdown"
|
59 |
+
)
|
60 |
+
custom_instructions = gr.Textbox(
|
61 |
+
label="Custom Instructions (optional)",
|
62 |
+
lines=2,
|
63 |
+
placeholder="Add any custom instructions for the model...",
|
64 |
+
elem_id="summary-custom-instructions"
|
65 |
+
)
|
66 |
+
|
67 |
+
btn = gr.Button("Summarize", variant="primary", elem_id="summary-btn")
|
68 |
+
|
69 |
+
with gr.Column():
|
70 |
+
output = gr.Textbox(
|
71 |
+
label="Summary",
|
72 |
+
lines=10,
|
73 |
+
elem_id="summary-output"
|
74 |
+
)
|
75 |
+
|
76 |
+
# with gr.Accordion("About Summarization", open=False):
|
77 |
+
# gr.Markdown("""
|
78 |
+
# ## Text Summarization
|
79 |
+
|
80 |
+
# Text summarization condenses a document while preserving key information. This tool offers:
|
81 |
+
|
82 |
+
# - **Length control**: Choose between short, medium, or long summaries
|
83 |
+
# - **Multiple models**: Select from LLMs (like Gemini and GPT) or traditional models
|
84 |
+
# - **Custom instructions**: Tailor the summarization to your specific needs
|
85 |
+
|
86 |
+
# ### How it works
|
87 |
+
|
88 |
+
# - **LLM models** process your text using natural language understanding
|
89 |
+
# - **Traditional models** use extractive or abstractive techniques to identify and condense key information
|
90 |
+
|
91 |
+
# For best results with long texts, try different summary lengths to find the right balance between brevity and detail.
|
92 |
+
# """)
|
93 |
+
|
94 |
+
# Event handlers
|
95 |
+
btn.click(
|
96 |
+
summarize,
|
97 |
+
inputs=[input_text, model, summary_length, custom_instructions],
|
98 |
+
outputs=output
|
99 |
+
)
|
100 |
+
|
101 |
+
return None
|
ui/topic_ui.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.ner_helpers import is_llm_model
|
3 |
+
from typing import Dict, List, Any
|
4 |
+
from tasks.topic_classification import topic_classification
|
5 |
+
|
6 |
+
def topic_ui():
|
7 |
+
"""Topic classification UI component"""
|
8 |
+
|
9 |
+
# Define models and default labels
|
10 |
+
TOPIC_MODELS = [
|
11 |
+
"gemini-2.0-flash" # Only allow gemini-2.0-flash for now
|
12 |
+
# "gpt-4",
|
13 |
+
# "claude-2",
|
14 |
+
# "facebook/bart-large-mnli",
|
15 |
+
# "joeddav/xlm-roberta-large-xnli"
|
16 |
+
]
|
17 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
18 |
+
DEFAULT_LABELS = [
|
19 |
+
"Sports", "Economy", "Politics", "Entertainment", "Technology", "Education", "Law"
|
20 |
+
]
|
21 |
+
|
22 |
+
def classify(text, model, use_custom, labels, custom_instructions):
|
23 |
+
"""Process text for topic classification"""
|
24 |
+
if not text.strip():
|
25 |
+
return "No text provided"
|
26 |
+
use_llm = is_llm_model(model)
|
27 |
+
label_list = [l.strip() for l in labels.split('\n') if l.strip()] if use_custom else None
|
28 |
+
if use_custom and (not label_list or len(label_list) == 0):
|
29 |
+
return "Please provide at least one category"
|
30 |
+
result = topic_classification(
|
31 |
+
text=text,
|
32 |
+
model=model,
|
33 |
+
candidate_labels=label_list,
|
34 |
+
custom_instructions=custom_instructions,
|
35 |
+
use_llm=use_llm
|
36 |
+
)
|
37 |
+
return result.strip()
|
38 |
+
|
39 |
+
# UI Components
|
40 |
+
with gr.Row():
|
41 |
+
with gr.Column():
|
42 |
+
input_text = gr.Textbox(
|
43 |
+
label="Input Text",
|
44 |
+
lines=6,
|
45 |
+
placeholder="Enter text to classify...",
|
46 |
+
elem_id="topic-input-text"
|
47 |
+
)
|
48 |
+
gr.Examples(
|
49 |
+
examples=[
|
50 |
+
["Apple has announced the release of a new iPhone model this fall."],
|
51 |
+
["The United Nations held a climate summit to discuss global warming solutions."]
|
52 |
+
],
|
53 |
+
inputs=[input_text],
|
54 |
+
label="Examples"
|
55 |
+
)
|
56 |
+
use_custom_topics = gr.Checkbox(
|
57 |
+
label="Use custom topics",
|
58 |
+
value=True,
|
59 |
+
elem_id="topic-use-custom-topics"
|
60 |
+
)
|
61 |
+
topics_area = gr.TextArea(
|
62 |
+
label="Candidate Topics (one per line)",
|
63 |
+
value='\n'.join(DEFAULT_LABELS),
|
64 |
+
lines=5,
|
65 |
+
visible=True,
|
66 |
+
elem_id="topic-candidate-topics"
|
67 |
+
)
|
68 |
+
def toggle_topics_area(use_custom):
|
69 |
+
return gr.update(visible=use_custom)
|
70 |
+
use_custom_topics.change(toggle_topics_area, inputs=use_custom_topics, outputs=topics_area)
|
71 |
+
model = gr.Dropdown(
|
72 |
+
TOPIC_MODELS,
|
73 |
+
value=DEFAULT_MODEL,
|
74 |
+
label="Model",
|
75 |
+
interactive=True,
|
76 |
+
elem_id="topic-model-dropdown"
|
77 |
+
)
|
78 |
+
custom_instructions = gr.Textbox(
|
79 |
+
label="Custom Instructions (optional)",
|
80 |
+
lines=2,
|
81 |
+
placeholder="Add any custom instructions for the model...",
|
82 |
+
elem_id="topic-custom-instructions"
|
83 |
+
)
|
84 |
+
classify_btn = gr.Button("Classify Topic", variant="primary", elem_id="topic-classify-btn")
|
85 |
+
with gr.Column():
|
86 |
+
output_box = gr.Textbox(
|
87 |
+
label="Classification Result",
|
88 |
+
lines=2,
|
89 |
+
elem_id="topic-output"
|
90 |
+
)
|
91 |
+
def run_topic_classification(text, model, use_custom, topics, custom_instructions):
|
92 |
+
return classify(text, model, use_custom, topics, custom_instructions)
|
93 |
+
classify_btn.click(
|
94 |
+
run_topic_classification,
|
95 |
+
inputs=[input_text, model, use_custom_topics, topics_area, custom_instructions],
|
96 |
+
outputs=output_box
|
97 |
+
)
|
98 |
+
# 4. Click "Classify" to analyze
|
99 |
+
|
100 |
+
# ### Model Types
|
101 |
+
|
102 |
+
# - **LLM Models** (Gemini, GPT, Claude): Provide sophisticated classification with better understanding of context and nuance
|
103 |
+
# - **Traditional Models**: Specialized models trained specifically for zero-shot classification tasks
|
104 |
+
|
105 |
+
# Use the advanced options to customize how the model classifies your text.
|
106 |
+
# """)
|
107 |
+
|
108 |
+
return None
|
ui/translation_ui.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.ner_helpers import is_llm_model
|
3 |
+
from typing import Dict, List, Any
|
4 |
+
from tasks.translation import text_translation
|
5 |
+
|
6 |
+
def translation_ui():
|
7 |
+
"""Translation UI component"""
|
8 |
+
|
9 |
+
# Define models
|
10 |
+
TRANSLATION_MODELS = [
|
11 |
+
"gemini-2.0-flash" # Only allow gemini-2.0-flash for now
|
12 |
+
# "gpt-4",
|
13 |
+
# "claude-2",
|
14 |
+
# "Helsinki-NLP/opus-mt-en-vi",
|
15 |
+
# "Helsinki-NLP/opus-mt-vi-en"
|
16 |
+
]
|
17 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
18 |
+
|
19 |
+
def translate(text, model, src_lang, tgt_lang, custom_instructions):
|
20 |
+
"""Process text for translation"""
|
21 |
+
if not text.strip():
|
22 |
+
return "No text provided"
|
23 |
+
|
24 |
+
use_llm = is_llm_model(model)
|
25 |
+
result = text_translation(
|
26 |
+
text=text,
|
27 |
+
model=model,
|
28 |
+
src_lang=src_lang,
|
29 |
+
tgt_lang=tgt_lang,
|
30 |
+
custom_instructions=custom_instructions,
|
31 |
+
use_llm=use_llm
|
32 |
+
)
|
33 |
+
|
34 |
+
return result
|
35 |
+
|
36 |
+
# UI Components
|
37 |
+
with gr.Row():
|
38 |
+
with gr.Column():
|
39 |
+
input_text = gr.Textbox(
|
40 |
+
label="Input Text",
|
41 |
+
lines=8,
|
42 |
+
placeholder="Enter text to translate...",
|
43 |
+
elem_id="translation-input-text"
|
44 |
+
)
|
45 |
+
gr.Examples(
|
46 |
+
examples=[
|
47 |
+
["Vietnam's economy has grown rapidly in the past decade."],
|
48 |
+
["The football match between Manchester United and Chelsea was very exciting."]
|
49 |
+
],
|
50 |
+
inputs=[input_text],
|
51 |
+
label="Examples"
|
52 |
+
)
|
53 |
+
with gr.Row():
|
54 |
+
pass
|
55 |
+
src_lang = gr.Textbox(
|
56 |
+
label="Source Language (e.g., en, vi, ja)",
|
57 |
+
value="en",
|
58 |
+
elem_id="translation-src-lang"
|
59 |
+
)
|
60 |
+
tgt_lang = gr.Textbox(
|
61 |
+
label="Target Language (e.g., en, vi, ja)",
|
62 |
+
value="vi",
|
63 |
+
elem_id="translation-tgt-lang"
|
64 |
+
)
|
65 |
+
model = gr.Dropdown(
|
66 |
+
TRANSLATION_MODELS,
|
67 |
+
value=DEFAULT_MODEL,
|
68 |
+
label="Model",
|
69 |
+
interactive=True,
|
70 |
+
elem_id="translation-model-dropdown"
|
71 |
+
)
|
72 |
+
custom_instructions = gr.Textbox(
|
73 |
+
label="Custom Instructions (optional)",
|
74 |
+
lines=2,
|
75 |
+
placeholder="Add any custom instructions for the model...",
|
76 |
+
elem_id="translation-custom-instructions"
|
77 |
+
)
|
78 |
+
|
79 |
+
btn = gr.Button("Translate", variant="primary", elem_id="translation-btn")
|
80 |
+
|
81 |
+
with gr.Column():
|
82 |
+
output = gr.Textbox(
|
83 |
+
label="Translation",
|
84 |
+
lines=10,
|
85 |
+
elem_id="translation-output"
|
86 |
+
)
|
87 |
+
|
88 |
+
# with gr.Accordion("About Translation", open=False):
|
89 |
+
# gr.Markdown("""
|
90 |
+
# ## Text Translation
|
91 |
+
|
92 |
+
# Text translation converts text from one language to another. This tool offers:
|
93 |
+
|
94 |
+
# - **Multiple languages**: Translate between any language pair
|
95 |
+
# - **Multiple models**: Select from LLMs (like Gemini and GPT) or specialized translation models
|
96 |
+
# - **Custom instructions**: Tailor the translation to specific domains or styles
|
97 |
+
|
98 |
+
# ### Language Codes
|
99 |
+
|
100 |
+
# Use standard language codes like:
|
101 |
+
# - `en` for English
|
102 |
+
# - `vi` for Vietnamese
|
103 |
+
# - `ja` for Japanese
|
104 |
+
# - `fr` for French
|
105 |
+
# - `es` for Spanish
|
106 |
+
# - `ko` for Korean
|
107 |
+
|
108 |
+
# ### Tips
|
109 |
+
|
110 |
+
# - LLM models perform better on complex or nuanced translations
|
111 |
+
# - Specialized models might be faster for common language pairs
|
112 |
+
# - Use custom instructions to specify tones (formal/informal) or domains (technical/literary)
|
113 |
+
# """)
|
114 |
+
|
115 |
+
# Event handlers
|
116 |
+
btn.click(
|
117 |
+
translate,
|
118 |
+
inputs=[input_text, model, src_lang, tgt_lang, custom_instructions],
|
119 |
+
outputs=output
|
120 |
+
)
|
121 |
+
|
122 |
+
return None
|
utils/ner_helpers.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NER helpers and constants
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
# Standard NER entity types with descriptions
|
5 |
+
NER_ENTITY_TYPES = {
|
6 |
+
"PERSON": "People, including fictional",
|
7 |
+
"ORG": "Companies, agencies, institutions, etc.",
|
8 |
+
"GPE": "Countries, cities, states",
|
9 |
+
"LOC": "Non-GPE locations, mountain ranges, bodies of water",
|
10 |
+
"PRODUCT": "Objects, vehicles, foods, etc. (not services)",
|
11 |
+
"EVENT": "Named hurricanes, battles, wars, sports events, etc.",
|
12 |
+
"WORK_OF_ART": "Titles of books, songs, etc.",
|
13 |
+
"LAW": "Named documents made into laws",
|
14 |
+
"LANGUAGE": "Any named language",
|
15 |
+
"DATE": "Absolute or relative dates or periods",
|
16 |
+
"TIME": "Times smaller than a day",
|
17 |
+
"PERCENT": "Percentage (including '%')",
|
18 |
+
"MONEY": "Monetary values, including unit",
|
19 |
+
"QUANTITY": "Measurements, as of weight or distance",
|
20 |
+
"ORDINAL": "'first', 'second', etc.",
|
21 |
+
"CARDINAL": "Numerals that do not fall under another type",
|
22 |
+
"NORP": "Nationalities or religious or political groups",
|
23 |
+
"FAC": "Buildings, airports, highways, bridges, etc.",
|
24 |
+
"PRODUCT": "Objects, vehicles, foods, etc. (not services)",
|
25 |
+
"EVENT": "Named hurricanes, battles, wars, sports events, etc.",
|
26 |
+
"WORK_OF_ART": "Titles of books, songs, etc.",
|
27 |
+
"LAW": "Named documents made into laws",
|
28 |
+
"LANGUAGE": "Any named language"
|
29 |
+
}
|
30 |
+
|
31 |
+
# Default selected entity types (first 5 by default)
|
32 |
+
DEFAULT_SELECTED_ENTITIES = list(NER_ENTITY_TYPES.keys())[:5]
|
33 |
+
|
34 |
+
LLM_MODELS = ["gemini", "gpt", "claude"]
|
35 |
+
|
36 |
+
def is_llm_model(model_id: str) -> bool:
|
37 |
+
"""Check if the model is an LLM-based model."""
|
38 |
+
return any(llm_model in model_id.lower() for llm_model in LLM_MODELS)
|
39 |
+
|
40 |
+
# Render NER HTML for tagged view
|
41 |
+
def render_ner_html(text, entities, selected_entity_types=None):
|
42 |
+
import html as html_lib
|
43 |
+
import re
|
44 |
+
if not text.strip() or not entities:
|
45 |
+
return "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>"
|
46 |
+
if selected_entity_types is None:
|
47 |
+
selected_entity_types = list(NER_ENTITY_TYPES.keys())
|
48 |
+
COLORS = [
|
49 |
+
'#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa',
|
50 |
+
'#f1f8e9', '#fce4ec', '#f5f5f5', '#fafafa', '#e1f5fe', '#f3e5f5', '#f1f8e9'
|
51 |
+
]
|
52 |
+
# Sort and filter entities by start position and selected types
|
53 |
+
entities = sorted(entities, key=lambda e: e.get('start', 0))
|
54 |
+
non_overlapping = []
|
55 |
+
for e in entities:
|
56 |
+
if e.get('type', '') in selected_entity_types or e.get('entity', '') in selected_entity_types:
|
57 |
+
if not non_overlapping or e['start'] >= non_overlapping[-1]['end']:
|
58 |
+
label = e.get('type', e.get('entity', ''))
|
59 |
+
color = COLORS[hash(label) % len(COLORS)]
|
60 |
+
non_overlapping.append({
|
61 |
+
'start': e['start'],
|
62 |
+
'end': e['end'],
|
63 |
+
'label': label,
|
64 |
+
'text': e.get('word', e.get('text', '')),
|
65 |
+
'color': color
|
66 |
+
})
|
67 |
+
filtered_entities = [entity for entity in non_overlapping if entity['label'] in selected_entity_types]
|
68 |
+
html = ["<div class='ner-highlight' style='line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;'>"]
|
69 |
+
if not filtered_entities:
|
70 |
+
html.append("<div style='text-align: center; color: #666; padding: 20px;'>")
|
71 |
+
html.append("No entities of the selected types found in the text.")
|
72 |
+
html.append("</div>")
|
73 |
+
else:
|
74 |
+
last_pos = 0
|
75 |
+
for entity in filtered_entities:
|
76 |
+
start = entity['start']
|
77 |
+
end = entity['end']
|
78 |
+
if start > last_pos:
|
79 |
+
html.append(html_lib.escape(text[last_pos:start]))
|
80 |
+
html.append(f"<span style='background:{entity['color']};border-radius:3px;padding:2px 4px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);'>")
|
81 |
+
html.append(f"{html_lib.escape(entity['text'])} ")
|
82 |
+
html.append(f"<span style='font-size:0.8em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);'>{html_lib.escape(entity['label'])}</span>")
|
83 |
+
html.append("</span>")
|
84 |
+
last_pos = end
|
85 |
+
if last_pos < len(text):
|
86 |
+
html.append(html_lib.escape(text[last_pos:]))
|
87 |
+
html.append("</div>")
|
88 |
+
return "".join(html)
|
utils/pos_helpers.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# POS helpers and constants
|
2 |
+
|
3 |
+
POS_MODELS = [
|
4 |
+
"gemini-2.0-flash" # Only allow gemini-2.0-flash for now
|
5 |
+
# "gpt-4",
|
6 |
+
# "claude-2",
|
7 |
+
# "vblagoje/bert-english-uncased-finetuned-pos",
|
8 |
+
# "QCRI/bert-base-multilingual-cased-pos-english"
|
9 |
+
]
|
10 |
+
|
11 |
+
DEFAULT_MODEL = "gemini-2.0-flash"
|
12 |
+
|
13 |
+
STANDARD_POS_TAGS = [
|
14 |
+
"ADJ", "ADP", "ADV", "AUX", "CONJ", "CCONJ", "DET", "INTJ", "NOUN",
|
15 |
+
"NUM", "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
|
16 |
+
]
|
17 |
+
|
18 |
+
POS_TAG_DESCRIPTIONS = {
|
19 |
+
"ADJ": "Adjective (big, old, green, interesting)",
|
20 |
+
"ADP": "Adposition (in, to, during)",
|
21 |
+
"ADV": "Adverb (very, well, there, tomorrow)",
|
22 |
+
"AUX": "Auxiliary verb (is, has (done), will (do), should (do))",
|
23 |
+
"CCONJ": "Coordinating conjunction (and, or, but)",
|
24 |
+
"DET": "Determiner (a, an, the, this, those)",
|
25 |
+
"INTJ": "Interjection (oh, hey, oops, hmm)",
|
26 |
+
"NOUN": "Noun (dog, cat, man, house, idea)",
|
27 |
+
"NUM": "Numeral (one, two, 3, 55, 2019)",
|
28 |
+
"PART": "Particle (not, 's, let's)",
|
29 |
+
"PRON": "Pronoun (I, you, he, she, it, we, they, me, him, her, us, them)",
|
30 |
+
"PROPN": "Proper noun (John, Mary, London, Microsoft)",
|
31 |
+
"PUNCT": "Punctuation (.,!?;:)",
|
32 |
+
"SCONJ": "Subordinating conjunction (if, because, as, that)",
|
33 |
+
"SYM": "Symbol (%, $, §, ©)",
|
34 |
+
"VERB": "Verb (run, runs, running, eat, ate, eaten)",
|
35 |
+
"X": "Other (foreign words, typos, etc.)"
|
36 |
+
}
|
37 |
+
|
38 |
+
DEFAULT_SELECTED_TAGS = list(POS_TAG_DESCRIPTIONS.keys())
|
utils/remote_client.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from typing import Dict, Any
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
# Timeout in seconds
|
9 |
+
TIMEOUT = int(os.getenv("REMOTE_SERVICE_TIMEOUT", "300"))
|
10 |
+
|
11 |
+
def execute_remote_task(task_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
12 |
+
"""
|
13 |
+
Execute a remote task using the configured remote service.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
task_name: Name of the task to execute (e.g., 'summarization', 'translation')
|
17 |
+
payload: Dictionary containing task-specific parameters
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Dictionary containing the task result
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
requests.RequestException: If there's an error with the remote request
|
24 |
+
"""
|
25 |
+
# Get the endpoint from environment variables
|
26 |
+
endpoint = os.getenv(f"REMOTE_ENDPOINT_{task_name.upper()}")
|
27 |
+
if not endpoint:
|
28 |
+
raise ValueError(f"No endpoint configured for task: {task_name}")
|
29 |
+
|
30 |
+
try:
|
31 |
+
response = requests.post(
|
32 |
+
url=endpoint,
|
33 |
+
json={"task": task_name, **payload},
|
34 |
+
timeout=TIMEOUT
|
35 |
+
)
|
36 |
+
response.raise_for_status()
|
37 |
+
return response.json()
|
38 |
+
except requests.RequestException as e:
|
39 |
+
error_msg = f"Error calling remote service for {task_name}: {str(e)}"
|
40 |
+
if hasattr(e, 'response') and e.response is not None:
|
41 |
+
error_msg += f" - {e.response.text}"
|
42 |
+
return {"error": error_msg}
|
utils/shared.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Shared helpers and utilities
|