Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- src/app.py +238 -0
- src/requirements.txt +3 -0
src/app.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import pipeline
|
3 |
+
import html
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
# 设置页面
|
7 |
+
st.set_page_config(
|
8 |
+
page_title="OpenMed NER Demo",
|
9 |
+
page_icon="🏥",
|
10 |
+
layout="wide"
|
11 |
+
)
|
12 |
+
|
13 |
+
# 模型映射
|
14 |
+
MODELS = {
|
15 |
+
"Pharmacology": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M",
|
16 |
+
"Oncology Genetics": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M",
|
17 |
+
"Species Detection": "OpenMed/OpenMed-NER-SpeciesDetect-PubMed-335M",
|
18 |
+
"Chemical Detection": "OpenMed/OpenMed-NER-ChemicalDetect-PubMed-335M",
|
19 |
+
"Anatomy Detection": "OpenMed/OpenMed-NER-AnatomyDetect-PubMed-335M",
|
20 |
+
"Blood Cancer Detection": "OpenMed/OpenMed-NER-BloodCancerDetect-TinyMed-82M",
|
21 |
+
"Disease Detection": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M"
|
22 |
+
}
|
23 |
+
|
24 |
+
# 实体类型颜色映射
|
25 |
+
ENTITY_COLORS = {
|
26 |
+
"DRUG": "#FF9999", # 药物 - 浅红色
|
27 |
+
"CHEMICAL": "#FFCC99", # 化学物质 - 浅橙色
|
28 |
+
"DISEASE": "#FF99CC", # 疾病 - 浅粉色
|
29 |
+
"ANATOMY": "#99CCFF", # 解剖结构 - 浅蓝色
|
30 |
+
"SPECIES": "#99FF99", # 物种 - 浅绿色
|
31 |
+
"GENE": "#CC99FF", # 基因 - 浅紫色
|
32 |
+
"PROTEIN": "#FFFF99", # 蛋白质 - 浅黄色
|
33 |
+
"CELL": "#99FFFF", # 细胞 - 浅青色
|
34 |
+
"default": "#DDDDDD" # 默认 - 浅灰色
|
35 |
+
}
|
36 |
+
|
37 |
+
# 初始化会话状态
|
38 |
+
if "text_input" not in st.session_state:
|
39 |
+
st.session_state.text_input = ""
|
40 |
+
if "entities" not in st.session_state:
|
41 |
+
st.session_state.entities = []
|
42 |
+
if "model_loaded" not in st.session_state:
|
43 |
+
st.session_state.model_loaded = None
|
44 |
+
|
45 |
+
# 缓存模型加载
|
46 |
+
@st.cache_resource
|
47 |
+
def load_model(model_name):
|
48 |
+
try:
|
49 |
+
ner_pipeline = pipeline(
|
50 |
+
"token-classification",
|
51 |
+
model=model_name,
|
52 |
+
aggregation_strategy="simple"
|
53 |
+
)
|
54 |
+
return ner_pipeline
|
55 |
+
except Exception as e:
|
56 |
+
st.error(f"Error loading model: {str(e)}")
|
57 |
+
return None
|
58 |
+
|
59 |
+
# 高亮文本中的实体
|
60 |
+
def highlight_entities(text, entities):
|
61 |
+
if not entities:
|
62 |
+
return text
|
63 |
+
|
64 |
+
# 将文本转换为HTML安全格式
|
65 |
+
safe_text = html.escape(text)
|
66 |
+
|
67 |
+
# 按起始位置排序实体
|
68 |
+
sorted_entities = sorted(entities, key=lambda x: x['start'])
|
69 |
+
|
70 |
+
# 构建高亮文本
|
71 |
+
highlighted_parts = []
|
72 |
+
last_end = 0
|
73 |
+
|
74 |
+
for entity in sorted_entities:
|
75 |
+
# 添加实体前的文本
|
76 |
+
if entity['start'] > last_end:
|
77 |
+
highlighted_parts.append(safe_text[last_end:entity['start']])
|
78 |
+
|
79 |
+
# 获取实体颜色
|
80 |
+
entity_type = entity['entity_group']
|
81 |
+
color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default'])
|
82 |
+
|
83 |
+
# 添加高亮的实体
|
84 |
+
entity_text = safe_text[entity['start']:entity['end']]
|
85 |
+
highlighted_parts.append(
|
86 |
+
f'<mark style="background-color: {color}; padding: 2px 4px; border-radius: 3px;" '
|
87 |
+
f'title="{entity_type} (confidence: {entity["score"]:.3f})">'
|
88 |
+
f'{entity_text}'
|
89 |
+
f'</mark>'
|
90 |
+
)
|
91 |
+
|
92 |
+
last_end = entity['end']
|
93 |
+
|
94 |
+
# 添加剩余文本
|
95 |
+
if last_end < len(safe_text):
|
96 |
+
highlighted_parts.append(safe_text[last_end:])
|
97 |
+
|
98 |
+
return ''.join(highlighted_parts)
|
99 |
+
|
100 |
+
# 应用标题
|
101 |
+
st.title("🏥 OpenMed Named Entity Recognition Demo")
|
102 |
+
st.markdown("Using domain-specific pre-trained models for medical text analysis")
|
103 |
+
|
104 |
+
# 侧边栏 - 模型选择
|
105 |
+
st.sidebar.header("Model Selection")
|
106 |
+
selected_domain = st.sidebar.selectbox(
|
107 |
+
"Select Domain",
|
108 |
+
list(MODELS.keys())
|
109 |
+
)
|
110 |
+
|
111 |
+
# 加载选定模型
|
112 |
+
model_name = MODELS[selected_domain]
|
113 |
+
|
114 |
+
# 如果模型改变,清除之前的实体结果
|
115 |
+
if st.session_state.model_loaded != model_name:
|
116 |
+
st.session_state.entities = []
|
117 |
+
st.session_state.model_loaded = model_name
|
118 |
+
|
119 |
+
ner_pipeline = load_model(model_name)
|
120 |
+
|
121 |
+
# 显示模型信息
|
122 |
+
st.sidebar.header("Model Information")
|
123 |
+
st.sidebar.write(f"**Domain**: {selected_domain}")
|
124 |
+
st.sidebar.write(f"**Model**: {model_name.split('/')[-1]}")
|
125 |
+
|
126 |
+
# 示例文本 (英文)
|
127 |
+
example_texts = {
|
128 |
+
"Pharmacology": "The patient was prescribed aspirin and warfarin for anticoagulation therapy.",
|
129 |
+
"Oncology Genetics": "BRCA1 gene mutations are associated with increased risk of breast and ovarian cancer.",
|
130 |
+
"Species Detection": "Researchers tested the new drug in a mouse model and observed significant effects.",
|
131 |
+
"Chemical Detection": "Glucose and oxygen molecules play key roles in cellular respiration processes.",
|
132 |
+
"Anatomy Detection": "The patient reported pain in the right knee joint radiating to the thigh.",
|
133 |
+
"Blood Cancer Detection": "The patient was diagnosed with chronic lymphocytic leukemia and requires regular monitoring of lymphocyte counts.",
|
134 |
+
"Disease Detection": "Patients with diabetes mellitus often have increased risk of hypertension and cardiovascular disease."
|
135 |
+
}
|
136 |
+
|
137 |
+
# 主区域
|
138 |
+
col1, col2 = st.columns([1, 1])
|
139 |
+
|
140 |
+
with col1:
|
141 |
+
st.header("Text Input")
|
142 |
+
|
143 |
+
# 示例文本按钮
|
144 |
+
if st.button("Load Example Text"):
|
145 |
+
st.session_state.text_input = example_texts[selected_domain]
|
146 |
+
st.session_state.entities = [] # 清除之前的实体结果
|
147 |
+
|
148 |
+
# 文本输入区域
|
149 |
+
text = st.text_area(
|
150 |
+
"Enter text to analyze:",
|
151 |
+
value=st.session_state.text_input,
|
152 |
+
height=200,
|
153 |
+
help="Enter medical text for analysis",
|
154 |
+
key="text_input_widget"
|
155 |
+
)
|
156 |
+
|
157 |
+
# 更新会话状态中的文本
|
158 |
+
st.session_state.text_input = text
|
159 |
+
|
160 |
+
# 分析按钮
|
161 |
+
if st.button("Analyze Text", type="primary"):
|
162 |
+
if st.session_state.text_input.strip():
|
163 |
+
with st.spinner("Analyzing..."):
|
164 |
+
try:
|
165 |
+
entities = ner_pipeline(st.session_state.text_input)
|
166 |
+
st.session_state.entities = entities
|
167 |
+
st.success("Analysis completed!")
|
168 |
+
except Exception as e:
|
169 |
+
st.error(f"Error during analysis: {str(e)}")
|
170 |
+
else:
|
171 |
+
st.warning("Please enter text to analyze")
|
172 |
+
|
173 |
+
with col2:
|
174 |
+
st.header("NER Results")
|
175 |
+
|
176 |
+
if st.session_state.entities and st.session_state.text_input:
|
177 |
+
entities = st.session_state.entities
|
178 |
+
|
179 |
+
# 显示高亮文本
|
180 |
+
st.markdown("### Highlighted Text")
|
181 |
+
highlighted_text = highlight_entities(st.session_state.text_input, entities)
|
182 |
+
st.markdown(highlighted_text, unsafe_allow_html=True)
|
183 |
+
|
184 |
+
# 显示实体统计
|
185 |
+
st.markdown("### Entity Statistics")
|
186 |
+
entity_counts = defaultdict(int)
|
187 |
+
for entity in entities:
|
188 |
+
entity_counts[entity['entity_group']] += 1
|
189 |
+
|
190 |
+
if entity_counts:
|
191 |
+
for entity_type, count in entity_counts.items():
|
192 |
+
color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default'])
|
193 |
+
st.markdown(
|
194 |
+
f'<span style="background-color: {color}; padding: 4px 8px; '
|
195 |
+
f'border-radius: 4px; margin-right: 8px; color: black;">'
|
196 |
+
f'{entity_type}: {count}'
|
197 |
+
f'</span>',
|
198 |
+
unsafe_allow_html=True
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
st.info("No entities detected")
|
202 |
+
|
203 |
+
# 显示详细实体列表
|
204 |
+
st.markdown("### Entity Details")
|
205 |
+
if entities:
|
206 |
+
for i, entity in enumerate(entities):
|
207 |
+
color = ENTITY_COLORS.get(entity['entity_group'], ENTITY_COLORS['default'])
|
208 |
+
st.markdown(
|
209 |
+
f"{i+1}. **{entity['word']}** - "
|
210 |
+
f"<span style='color: {color};'>{entity['entity_group']}</span> "
|
211 |
+
f"(confidence: {entity['score']:.3f})",
|
212 |
+
unsafe_allow_html=True
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
st.info("No entities detected")
|
216 |
+
else:
|
217 |
+
st.info("Please enter text and click 'Analyze Text'")
|
218 |
+
|
219 |
+
# 底部信息
|
220 |
+
st.markdown("---")
|
221 |
+
st.markdown(
|
222 |
+
"### Instructions\n"
|
223 |
+
"1. Select a domain-specific NER model from the left sidebar\n"
|
224 |
+
"2. Enter or paste medical text in the input box\n"
|
225 |
+
"3. Click the 'Analyze Text' button to run the model\n"
|
226 |
+
"4. View the entity recognition results on the right\n\n"
|
227 |
+
"Different colored highlights represent different entity types. Hover over entities to see type and confidence."
|
228 |
+
)
|
229 |
+
|
230 |
+
# 隐藏Streamlit默认样式
|
231 |
+
hide_st_style = """
|
232 |
+
<style>
|
233 |
+
#MainMenu {visibility: hidden;}
|
234 |
+
footer {visibility: hidden;}
|
235 |
+
header {visibility: hidden;}
|
236 |
+
</style>
|
237 |
+
"""
|
238 |
+
st.markdown(hide_st_style, unsafe_allow_html=True)
|
src/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.48.1
|
2 |
+
transformers==4.53.3
|
3 |
+
torch==2.7.1
|