Viju Sudhi commited on
Commit
c4727bc
·
1 Parent(s): e1e06da

adding application file

Browse files
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 4.17.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ python_version: 3.8
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Optional
3
+
4
+ import gradio as gr
5
+ from gradio.components import Markdown
6
+
7
+ from src.dto.dto import ExplanationGranularity, ExplanationDto
8
+ from src.utils.registry import EXPLAINERS, MODELS, PERTURBERS, COMPARATORS
9
+ from src.utils.segregate import PercentileBasedSegregator
10
+ from src.utils.visualizer import Visualizer
11
+
12
+
13
+ class MockExplainerUI:
14
+ def __init__(
15
+ self,
16
+ logo_path: str,
17
+ css_path: str,
18
+ visualizer: Visualizer,
19
+ window_title: str,
20
+ title: str,
21
+ examples: Optional[List[str]] = None,
22
+ ):
23
+ self.__logo_path = logo_path
24
+ self.__css_path = css_path
25
+ self.__examples = examples
26
+ self.__window_title = window_title
27
+ self.__title = title
28
+ self.__visualizer = visualizer
29
+
30
+ self.app: gr.Blocks = self.build_app()
31
+
32
+ def build_app(self):
33
+ with gr.Blocks(
34
+ theme=gr.themes.Monochrome().set(
35
+ button_primary_background_fill="#009374",
36
+ button_primary_background_fill_hover="#009374C4",
37
+ checkbox_label_background_fill_selected="#028A6EFF",
38
+ ),
39
+ css=self.__css_path,
40
+ title=self.__window_title,
41
+ ) as demo:
42
+ self.__build_app_title()
43
+ (
44
+ user_input,
45
+ system_response,
46
+ granularity,
47
+ upper_percentile,
48
+ middle_percentile,
49
+ lower_percentile,
50
+ explainer_name,
51
+ model_name,
52
+ perturber_name,
53
+ comparator_name,
54
+ generator_vis,
55
+ submit_btn,
56
+ ) = self.__build_chat_and_explain()
57
+
58
+ submit_btn.click(
59
+ fn=self.run,
60
+ inputs=[
61
+ user_input,
62
+ granularity,
63
+ upper_percentile,
64
+ middle_percentile,
65
+ lower_percentile,
66
+ explainer_name,
67
+ model_name,
68
+ perturber_name,
69
+ comparator_name,
70
+ ],
71
+ outputs=[system_response, generator_vis],
72
+ )
73
+
74
+ return demo
75
+
76
+ def run(
77
+ self,
78
+ user_input: str,
79
+ granularity: ExplanationGranularity,
80
+ upper_percentile: str,
81
+ middle_percentile: str,
82
+ lower_percentile: str,
83
+ explainer_name: str,
84
+ model_name: str,
85
+ perturber_name: str,
86
+ comparator_name: str,
87
+ ):
88
+ print(user_input)
89
+ with open(
90
+ "data/en_q_001.json",
91
+ "r",
92
+ ) as f:
93
+ data = json.load(f)
94
+ data = data[0]
95
+ explanation_dto = ExplanationDto.parse_obj(data)
96
+ user_input = explanation_dto.input_text
97
+
98
+ system_response = explanation_dto.output_text
99
+ generator_vis = self.__visualize_explanations(
100
+ user_input=user_input,
101
+ system_response=system_response,
102
+ generator_explanations=explanation_dto,
103
+ upper_percentile=int(upper_percentile),
104
+ middle_percentile=int(middle_percentile),
105
+ lower_percentile=int(lower_percentile),
106
+ )
107
+ return system_response, generator_vis
108
+
109
+ def __build_app_title(self):
110
+ with gr.Row():
111
+ with gr.Column(min_width=50, scale=1):
112
+ gr.Image(
113
+ value=self.__logo_path,
114
+ width=50,
115
+ height=50,
116
+ show_download_button=False,
117
+ container=False,
118
+ )
119
+ with gr.Column(scale=2):
120
+ Markdown(
121
+ f'<p style="text-align: left; font-size:200%; font-weight: bold"'
122
+ f">{self.__title}"
123
+ f"</p>"
124
+ )
125
+
126
+ def __build_chat_and_explain(self):
127
+ with gr.Row():
128
+ with gr.Column(scale=2):
129
+ gr.Textbox(
130
+ label="Attention!",
131
+ value="This is a demo version of the tool! For running the full version, please follow the instructions in ...",
132
+ container=False,
133
+ interactive=False,
134
+ )
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=2):
138
+ user_input = gr.Radio(
139
+ # placeholder="Type your question here and press Enter.",
140
+ label="Question",
141
+ container=True,
142
+ choices=["Question 1 EN", "Question 1 DE"],
143
+ )
144
+ with gr.Column(scale=1):
145
+ granularity = gr.Radio(
146
+ choices=[e for e in ExplanationGranularity],
147
+ value=ExplanationGranularity.SENTENCE_LEVEL,
148
+ label="Explanation Granularity",
149
+ )
150
+
151
+ with gr.Accordion(label="Settings", open=False, elem_id="accordion"):
152
+ with gr.Row(variant="compact"):
153
+ explainer_name = gr.Radio(
154
+ label="Explainer",
155
+ choices=list(EXPLAINERS.keys()),
156
+ value=list(EXPLAINERS.keys())[0],
157
+ container=True,
158
+ )
159
+ with gr.Row(variant="compact"):
160
+ upper_percentile = gr.Textbox(label="Upper", value="85", container=True)
161
+ middle_percentile = gr.Textbox(
162
+ label="Middle", value="75", container=True
163
+ )
164
+ lower_percentile = gr.Textbox(label="Lower", value="10", container=True)
165
+
166
+ with gr.Row(variant="compact"):
167
+ model_name = gr.Radio(
168
+ label="Model",
169
+ choices=list(MODELS.keys()),
170
+ value=list(MODELS.keys())[0],
171
+ container=True,
172
+ )
173
+ with gr.Row(variant="compact"):
174
+ perturber_name = gr.Radio(
175
+ label="Perturber",
176
+ choices=list(PERTURBERS.keys()),
177
+ value=list(PERTURBERS.keys())[0],
178
+ container=True,
179
+ )
180
+ with gr.Row(variant="compact"):
181
+ comparator_name = gr.Radio(
182
+ label="Comparator",
183
+ choices=list(COMPARATORS.keys()),
184
+ value=list(COMPARATORS.keys())[0],
185
+ container=True,
186
+ )
187
+ with gr.Row(variant="compact"):
188
+ # passing "elem_id" to use a custom style for the component
189
+ # in the CSS passed.
190
+ submit_btn = gr.Button(
191
+ value="🛠 Submit",
192
+ variant="secondary",
193
+ elem_id="button",
194
+ interactive=True,
195
+ )
196
+
197
+ with gr.Row():
198
+ generator_vis = gr.HTML(label="Explanations")
199
+
200
+ with gr.Row():
201
+ system_response = gr.Textbox(
202
+ label="System Response",
203
+ container=True,
204
+ interactive=False,
205
+ )
206
+
207
+ return (
208
+ user_input,
209
+ system_response,
210
+ granularity,
211
+ upper_percentile,
212
+ middle_percentile,
213
+ lower_percentile,
214
+ explainer_name,
215
+ model_name,
216
+ perturber_name,
217
+ comparator_name,
218
+ generator_vis,
219
+ submit_btn,
220
+ )
221
+
222
+ def __visualize_explanations(
223
+ self,
224
+ user_input: str,
225
+ system_response: Optional[str],
226
+ generator_explanations: ExplanationDto,
227
+ upper_percentile: Optional[int],
228
+ middle_percentile: Optional[int],
229
+ lower_percentile: Optional[int],
230
+ ) -> str:
231
+ segregator = PercentileBasedSegregator(
232
+ upper_bound_percentile=upper_percentile,
233
+ middle_bound_percentile=middle_percentile,
234
+ lower_bound_percentile=lower_percentile,
235
+ )
236
+ return self.__visualizer.visualize(
237
+ segregator=segregator,
238
+ explanations=generator_explanations,
239
+ output_from_explanations=user_input,
240
+ )
data/en_q_001.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==4.17.0
2
+ pydantic~=1.8.2
3
+ numpy~=1.22.4
src/__init__.py ADDED
File without changes
src/dto/__init__.py ADDED
File without changes
src/dto/dto.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class FeatureImportance(BaseModel):
8
+ feature: str
9
+ score: float
10
+ token_field: Optional[str] = None
11
+
12
+
13
+ class ExplanationDto(BaseModel):
14
+ explanations: List[FeatureImportance]
15
+ input_text: str
16
+ output_text: str
17
+
18
+
19
+ class ExplanationGranularity(str, Enum):
20
+ WORD_LEVEL = "word_level_granularity"
21
+ SENTENCE_LEVEL = "sentence_level_granularity"
22
+ PARAGRAPH_LEVEL = "paragraph_level_granularity"
23
+ PHRASE_LEVEL = "phrase_level_granularity"
24
+
25
+
26
+ class SimilarityMetric(Enum):
27
+ COSINE = "cosine"
src/utils/__init__.py ADDED
File without changes
src/utils/registry.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COMPARATORS = {
2
+ "base_llm_based_comparator": ...,
3
+ "sentence_transformers_based_comparator": ...,
4
+ "levenshtein_comparator": ...,
5
+ "jaro_winkler_comparator": ...,
6
+ "n_gram_comparator": ...,
7
+ }
8
+
9
+ EXPLAINERS = {
10
+ "generic_explainer": "generic_explainer",
11
+ }
12
+
13
+ MODELS = {"flan-t5-xxl": "flan-t5-xxl"}
14
+
15
+ PERTURBERS = {
16
+ "leave_one_out": ...,
17
+ "random_word_perturber": ...,
18
+ "reorder_perturber": ...,
19
+ "antonym_perturber": ...,
20
+ "synonym_perturber": ...,
21
+ "entity_perturber": ...,
22
+ }
src/utils/segregate.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple, List
3
+
4
+ import numpy as np
5
+
6
+ from src.dto.dto import ExplanationDto
7
+
8
+
9
+ class Segregator(ABC):
10
+ @abstractmethod
11
+ def segregate(
12
+ self, explanations: ExplanationDto
13
+ ) -> Tuple[List[str], List[str], List[str]]:
14
+ ...
15
+
16
+
17
+ class PercentileBasedSegregator(Segregator):
18
+ def __init__(
19
+ self,
20
+ upper_bound_percentile: int = 85,
21
+ middle_bound_percentile: int = 75,
22
+ lower_bound_percentile: int = 10,
23
+ ):
24
+ self.__upper_bound_percentile = upper_bound_percentile
25
+ self.__middle_bound_percentile = middle_bound_percentile
26
+ self.__lower_bound_percentile = lower_bound_percentile
27
+
28
+ def segregate(
29
+ self,
30
+ explanations: ExplanationDto,
31
+ ) -> Tuple[List[str], List[str], List[str]]:
32
+ scores = [explanation.score for explanation in explanations.explanations]
33
+ scores = np.asarray(scores)
34
+ upper_bound = np.percentile(scores, self.__upper_bound_percentile)
35
+ mid_bound = np.percentile(scores, self.__middle_bound_percentile)
36
+ lower_bound = np.percentile(scores, self.__lower_bound_percentile)
37
+
38
+ pos_features = [
39
+ explanation.feature
40
+ for explanation in explanations.explanations
41
+ if explanation.score >= upper_bound and explanation.score != 0
42
+ ]
43
+ mid_features = [
44
+ explanation.feature
45
+ for explanation in explanations.explanations
46
+ if upper_bound > explanation.score >= mid_bound > 0
47
+ ]
48
+ low_features = [
49
+ explanation.feature
50
+ for explanation in explanations.explanations
51
+ if mid_bound > explanation.score >= lower_bound > 0
52
+ ]
53
+
54
+ return pos_features, mid_features, low_features
src/utils/visualizer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xlm.demo.utils.segregate import Segregator
2
+ from xlm.dto.dto import ExplanationDto
3
+
4
+
5
+ UPPER_COLOR = "#D4EFDF" # green
6
+ MID_COLOR = "#FBFBB8BF" # amber
7
+ LOW_COLOR = "black"
8
+
9
+
10
+ class Visualizer:
11
+ def __init__(self, show_mid_features: bool = True, show_low_features: bool = True):
12
+ self.__show_mid_features = show_mid_features
13
+ self.__show_low_features = show_low_features
14
+
15
+ def visualize(
16
+ self,
17
+ segregator: Segregator,
18
+ explanations: ExplanationDto,
19
+ output_from_explanations: str,
20
+ avoid_exp_label: bool = False,
21
+ ) -> str:
22
+ highlighted_text = output_from_explanations
23
+
24
+ pos_features, mid_features, low_features = segregator.segregate(
25
+ explanations=explanations
26
+ )
27
+
28
+ if not self.__show_mid_features:
29
+ mid_features = []
30
+
31
+ if not self.__show_low_features:
32
+ low_features = []
33
+
34
+ for explanation in explanations.explanations:
35
+ score = round(explanation.score, 2)
36
+
37
+ if explanation.feature in pos_features:
38
+ token_str = (
39
+ '<span title="'
40
+ + str(score)
41
+ + '"style="font-weight:bold;background-color:'
42
+ + UPPER_COLOR
43
+ + '">'
44
+ + explanation.feature
45
+ + "</span>"
46
+ )
47
+ elif explanation.feature in mid_features:
48
+ token_str = (
49
+ '<span title="'
50
+ + str(score)
51
+ + '"style="font-weight:bold;background-color:'
52
+ + MID_COLOR
53
+ + '">'
54
+ + explanation.feature
55
+ + "</span>"
56
+ )
57
+ else:
58
+ token_str = (
59
+ '<span title="'
60
+ + str(score)
61
+ + '"style="color:'
62
+ + LOW_COLOR
63
+ + '">'
64
+ + explanation.feature
65
+ + "</span>"
66
+ )
67
+
68
+ highlighted_text = highlighted_text.replace(explanation.feature, token_str)
69
+
70
+ if avoid_exp_label:
71
+ vis = "<p>" + highlighted_text + "</p>"
72
+ else:
73
+ vis = "<p><b>Explanations:</b><br>" + highlighted_text + "</p>"
74
+ vis = vis.replace("\n", "<br>")
75
+
76
+ legend = "<p align='right'"
77
+
78
+ legend += (
79
+ '<span title="' + '"style="color:' + LOW_COLOR + '">' + "💡" + "</span>"
80
+ )
81
+
82
+ legend += "&emsp;"
83
+
84
+ legend += (
85
+ '<span title="'
86
+ + '"style="color:'
87
+ + LOW_COLOR
88
+ + '">'
89
+ + "not important"
90
+ + "</span>"
91
+ )
92
+
93
+ legend += "&emsp;⇢&emsp;"
94
+
95
+ legend += (
96
+ '<span title="'
97
+ + '"style="font-weight:bold;background-color:'
98
+ + MID_COLOR
99
+ + '">'
100
+ + " important "
101
+ + "</span>"
102
+ )
103
+
104
+ legend += "&emsp;⇢&emsp;"
105
+
106
+ legend += (
107
+ '<span title="'
108
+ + '"style="font-weight:bold;background-color:'
109
+ + UPPER_COLOR
110
+ + '">'
111
+ + " very important "
112
+ + "</span>"
113
+ )
114
+
115
+ legend += "</p>"
116
+
117
+ html_str = legend + vis
118
+
119
+ return html_str