renyi211 commited on
Commit
a0c6b45
·
verified ·
1 Parent(s): 0e1b31d

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/app.py +238 -0
  2. src/requirements.txt +3 -0
src/app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import html
4
+ from collections import defaultdict
5
+
6
+ # 设置页面
7
+ st.set_page_config(
8
+ page_title="OpenMed NER Demo",
9
+ page_icon="🏥",
10
+ layout="wide"
11
+ )
12
+
13
+ # 模型映射
14
+ MODELS = {
15
+ "Pharmacology": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M",
16
+ "Oncology Genetics": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M",
17
+ "Species Detection": "OpenMed/OpenMed-NER-SpeciesDetect-PubMed-335M",
18
+ "Chemical Detection": "OpenMed/OpenMed-NER-ChemicalDetect-PubMed-335M",
19
+ "Anatomy Detection": "OpenMed/OpenMed-NER-AnatomyDetect-PubMed-335M",
20
+ "Blood Cancer Detection": "OpenMed/OpenMed-NER-BloodCancerDetect-TinyMed-82M",
21
+ "Disease Detection": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M"
22
+ }
23
+
24
+ # 实体类型颜色映射
25
+ ENTITY_COLORS = {
26
+ "DRUG": "#FF9999", # 药物 - 浅红色
27
+ "CHEMICAL": "#FFCC99", # 化学物质 - 浅橙色
28
+ "DISEASE": "#FF99CC", # 疾病 - 浅粉色
29
+ "ANATOMY": "#99CCFF", # 解剖结构 - 浅蓝色
30
+ "SPECIES": "#99FF99", # 物种 - 浅绿色
31
+ "GENE": "#CC99FF", # 基因 - 浅紫色
32
+ "PROTEIN": "#FFFF99", # 蛋白质 - 浅黄色
33
+ "CELL": "#99FFFF", # 细胞 - 浅青色
34
+ "default": "#DDDDDD" # 默认 - 浅灰色
35
+ }
36
+
37
+ # 初始化会话状态
38
+ if "text_input" not in st.session_state:
39
+ st.session_state.text_input = ""
40
+ if "entities" not in st.session_state:
41
+ st.session_state.entities = []
42
+ if "model_loaded" not in st.session_state:
43
+ st.session_state.model_loaded = None
44
+
45
+ # 缓存模型加载
46
+ @st.cache_resource
47
+ def load_model(model_name):
48
+ try:
49
+ ner_pipeline = pipeline(
50
+ "token-classification",
51
+ model=model_name,
52
+ aggregation_strategy="simple"
53
+ )
54
+ return ner_pipeline
55
+ except Exception as e:
56
+ st.error(f"Error loading model: {str(e)}")
57
+ return None
58
+
59
+ # 高亮文本中的实体
60
+ def highlight_entities(text, entities):
61
+ if not entities:
62
+ return text
63
+
64
+ # 将文本转换为HTML安全格式
65
+ safe_text = html.escape(text)
66
+
67
+ # 按起始位置排序实体
68
+ sorted_entities = sorted(entities, key=lambda x: x['start'])
69
+
70
+ # 构建高亮文本
71
+ highlighted_parts = []
72
+ last_end = 0
73
+
74
+ for entity in sorted_entities:
75
+ # 添加实体前的文本
76
+ if entity['start'] > last_end:
77
+ highlighted_parts.append(safe_text[last_end:entity['start']])
78
+
79
+ # 获取实体颜色
80
+ entity_type = entity['entity_group']
81
+ color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default'])
82
+
83
+ # 添加高亮的实体
84
+ entity_text = safe_text[entity['start']:entity['end']]
85
+ highlighted_parts.append(
86
+ f'<mark style="background-color: {color}; padding: 2px 4px; border-radius: 3px;" '
87
+ f'title="{entity_type} (confidence: {entity["score"]:.3f})">'
88
+ f'{entity_text}'
89
+ f'</mark>'
90
+ )
91
+
92
+ last_end = entity['end']
93
+
94
+ # 添加剩余文本
95
+ if last_end < len(safe_text):
96
+ highlighted_parts.append(safe_text[last_end:])
97
+
98
+ return ''.join(highlighted_parts)
99
+
100
+ # 应用标题
101
+ st.title("🏥 OpenMed Named Entity Recognition Demo")
102
+ st.markdown("Using domain-specific pre-trained models for medical text analysis")
103
+
104
+ # 侧边栏 - 模型选择
105
+ st.sidebar.header("Model Selection")
106
+ selected_domain = st.sidebar.selectbox(
107
+ "Select Domain",
108
+ list(MODELS.keys())
109
+ )
110
+
111
+ # 加载选定模型
112
+ model_name = MODELS[selected_domain]
113
+
114
+ # 如果模型改变,清除之前的实体结果
115
+ if st.session_state.model_loaded != model_name:
116
+ st.session_state.entities = []
117
+ st.session_state.model_loaded = model_name
118
+
119
+ ner_pipeline = load_model(model_name)
120
+
121
+ # 显示模型信息
122
+ st.sidebar.header("Model Information")
123
+ st.sidebar.write(f"**Domain**: {selected_domain}")
124
+ st.sidebar.write(f"**Model**: {model_name.split('/')[-1]}")
125
+
126
+ # 示例文本 (英文)
127
+ example_texts = {
128
+ "Pharmacology": "The patient was prescribed aspirin and warfarin for anticoagulation therapy.",
129
+ "Oncology Genetics": "BRCA1 gene mutations are associated with increased risk of breast and ovarian cancer.",
130
+ "Species Detection": "Researchers tested the new drug in a mouse model and observed significant effects.",
131
+ "Chemical Detection": "Glucose and oxygen molecules play key roles in cellular respiration processes.",
132
+ "Anatomy Detection": "The patient reported pain in the right knee joint radiating to the thigh.",
133
+ "Blood Cancer Detection": "The patient was diagnosed with chronic lymphocytic leukemia and requires regular monitoring of lymphocyte counts.",
134
+ "Disease Detection": "Patients with diabetes mellitus often have increased risk of hypertension and cardiovascular disease."
135
+ }
136
+
137
+ # 主区域
138
+ col1, col2 = st.columns([1, 1])
139
+
140
+ with col1:
141
+ st.header("Text Input")
142
+
143
+ # 示例文本按钮
144
+ if st.button("Load Example Text"):
145
+ st.session_state.text_input = example_texts[selected_domain]
146
+ st.session_state.entities = [] # 清除之前的实体结果
147
+
148
+ # 文本输入区域
149
+ text = st.text_area(
150
+ "Enter text to analyze:",
151
+ value=st.session_state.text_input,
152
+ height=200,
153
+ help="Enter medical text for analysis",
154
+ key="text_input_widget"
155
+ )
156
+
157
+ # 更新会话状态中的文本
158
+ st.session_state.text_input = text
159
+
160
+ # 分析按钮
161
+ if st.button("Analyze Text", type="primary"):
162
+ if st.session_state.text_input.strip():
163
+ with st.spinner("Analyzing..."):
164
+ try:
165
+ entities = ner_pipeline(st.session_state.text_input)
166
+ st.session_state.entities = entities
167
+ st.success("Analysis completed!")
168
+ except Exception as e:
169
+ st.error(f"Error during analysis: {str(e)}")
170
+ else:
171
+ st.warning("Please enter text to analyze")
172
+
173
+ with col2:
174
+ st.header("NER Results")
175
+
176
+ if st.session_state.entities and st.session_state.text_input:
177
+ entities = st.session_state.entities
178
+
179
+ # 显示高亮文本
180
+ st.markdown("### Highlighted Text")
181
+ highlighted_text = highlight_entities(st.session_state.text_input, entities)
182
+ st.markdown(highlighted_text, unsafe_allow_html=True)
183
+
184
+ # 显示实体统计
185
+ st.markdown("### Entity Statistics")
186
+ entity_counts = defaultdict(int)
187
+ for entity in entities:
188
+ entity_counts[entity['entity_group']] += 1
189
+
190
+ if entity_counts:
191
+ for entity_type, count in entity_counts.items():
192
+ color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default'])
193
+ st.markdown(
194
+ f'<span style="background-color: {color}; padding: 4px 8px; '
195
+ f'border-radius: 4px; margin-right: 8px; color: black;">'
196
+ f'{entity_type}: {count}'
197
+ f'</span>',
198
+ unsafe_allow_html=True
199
+ )
200
+ else:
201
+ st.info("No entities detected")
202
+
203
+ # 显示详细实体列表
204
+ st.markdown("### Entity Details")
205
+ if entities:
206
+ for i, entity in enumerate(entities):
207
+ color = ENTITY_COLORS.get(entity['entity_group'], ENTITY_COLORS['default'])
208
+ st.markdown(
209
+ f"{i+1}. **{entity['word']}** - "
210
+ f"<span style='color: {color};'>{entity['entity_group']}</span> "
211
+ f"(confidence: {entity['score']:.3f})",
212
+ unsafe_allow_html=True
213
+ )
214
+ else:
215
+ st.info("No entities detected")
216
+ else:
217
+ st.info("Please enter text and click 'Analyze Text'")
218
+
219
+ # 底部信息
220
+ st.markdown("---")
221
+ st.markdown(
222
+ "### Instructions\n"
223
+ "1. Select a domain-specific NER model from the left sidebar\n"
224
+ "2. Enter or paste medical text in the input box\n"
225
+ "3. Click the 'Analyze Text' button to run the model\n"
226
+ "4. View the entity recognition results on the right\n\n"
227
+ "Different colored highlights represent different entity types. Hover over entities to see type and confidence."
228
+ )
229
+
230
+ # 隐藏Streamlit默认样式
231
+ hide_st_style = """
232
+ <style>
233
+ #MainMenu {visibility: hidden;}
234
+ footer {visibility: hidden;}
235
+ header {visibility: hidden;}
236
+ </style>
237
+ """
238
+ st.markdown(hide_st_style, unsafe_allow_html=True)
src/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit==1.48.1
2
+ transformers==4.53.3
3
+ torch==2.7.1