Spaces:
Runtime error
Runtime error
File size: 7,997 Bytes
2e4274a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
# ###########################################################################
#
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
# (C) Cloudera, Inc. 2022
# All rights reserved.
#
# Applicable Open Source License: Apache 2.0
#
# NOTE: Cloudera open source products are modular software products
# made up of hundreds of individual components, each of which was
# individually copyrighted. Each Cloudera open source product is a
# collective work under U.S. Copyright Law. Your license to use the
# collective work is as provided in your written agreement with
# Cloudera. Used apart from the collective work, this file is
# licensed for your use pursuant to the open source license
# identified above.
#
# This code is provided to you pursuant a written agreement with
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
# this code. If you do not have a written agreement with Cloudera nor
# with an authorized and properly licensed third party, you do not
# have any rights to access nor to use this code.
#
# Absent a written agreement with Cloudera, Inc. (βClouderaβ) to the
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
# ###########################################################################
from typing import List
import tokenizers
import streamlit as st
from src.style_transfer import StyleTransfer
from src.style_classification import StyleIntensityClassifier
from src.content_preservation import ContentPreservationScorer
from src.transformer_interpretability import InterpretTransformer
from apps.data_utils import StyleAttributeData, string_to_list_string
# CALLBACKS
def increment_page_progress():
st.session_state.page_progress += 1
def reset_page_progress_state():
del st.session_state.st_result
st.session_state.page_progress = 1
# UTILITY CLASSES
class DisableableButton:
"""
Utility class for creating "disable-able" buttons upon click.
We initialize an empty container, then update that container with buttons
upon calling `create_enabled_button` and `disable` methods where clicking
is enabled and then disabled, respectively.
"""
def __init__(self, button_number, button_text):
self.button_number = button_number
self.button_text = button_text
def _init_placeholder_container(self):
self.ph = st.empty()
def create_enabled_button(self):
self._init_placeholder_container()
self.ph.button(
self.button_text,
on_click=increment_page_progress,
key=f"ph{self.button_number}_before",
disabled=False,
)
def disable(self):
self.ph.button(
self.button_text, key=f"ph{self.button_number}_after", disabled=True
)
# CACHED FUNCTIONS
@st.cache(
hash_funcs={tokenizers.Tokenizer: lambda _: None},
allow_output_mutation=True,
show_spinner=False,
)
def get_cached_style_intensity_classifier(
style_data: StyleAttributeData,
) -> StyleIntensityClassifier:
"""
Return a cached style classifier.
This function overwrites the existing model's config values for
`id2label` and `label2id`.
Args:
style_data (StyleAttributeData)
Returns:
StyleIntensityClassifier
"""
sic = StyleIntensityClassifier(style_data.cls_model_path)
# create or overwrite id-label lookup in model config
sic.pipeline.model.config.__dict__["id2label"] = {
i: a
for i, a in enumerate(
[
style_data.source_attribute.capitalize(),
style_data.target_attribute.capitalize(),
]
)
}
sic.pipeline.model.config.__dict__["label2id"] = {
v: k for k, v in sic.pipeline.model.config.__dict__["id2label"].items()
}
return sic
@st.cache(
hash_funcs={tokenizers.Tokenizer: lambda _: None},
allow_output_mutation=True,
show_spinner=False,
)
def get_cached_word_attributions(
text_sample: str, style_data: StyleAttributeData
) -> str:
"""
Calculated word attributions and return HTML visual.
This function overwrites the existing model's config values for
`id2label` and `label2id`.
Args:
text_sample (str)
style_data (StyleAttributeData)
Returns:
str
"""
it = InterpretTransformer(cls_model_identifier=style_data.cls_model_path)
# create or overwrite id-label lookup in model config
it.explainer.id2label = {
i: a
for i, a in enumerate(
[
style_data.source_attribute.capitalize(),
style_data.target_attribute.capitalize(),
]
)
}
it.explainer.label2id = {v: k for k, v in it.explainer.id2label.items()}
return it.visualize_feature_attribution_scores(text_sample).data
@st.cache(
hash_funcs={tokenizers.Tokenizer: lambda _: None},
allow_output_mutation=True,
show_spinner=False,
)
def get_sti_metric(
input_text: str, output_text: str, style_data: StyleAttributeData
) -> List[float]:
"""
Calculate Style Transfer Intensity (STI)
Args:
input_text (str)
output_text (str)
style_data (StyleAttributeData)
Returns:
List[float]
"""
sti = StyleIntensityClassifier(
model_identifier=style_data.cls_model_path,
)
return sti.calculate_transfer_intensity_fraction(
string_to_list_string(input_text), string_to_list_string(output_text)
)
@st.cache(
hash_funcs={tokenizers.Tokenizer: lambda _: None},
allow_output_mutation=True,
show_spinner=False,
)
def get_cps_metric(
input_text: str, output_text: str, style_data: StyleAttributeData
) -> List[float]:
"""
Calculate Content Preservation Score (CPS)
Args:
input_text (str)
output_text (str)
style_data (StyleAttributeData)
Returns:
List[float]
"""
cps = ContentPreservationScorer(
cls_model_identifier=style_data.cls_model_path,
sbert_model_identifier=style_data.sbert_model_path,
)
return cps.calculate_content_preservation_score(
string_to_list_string(input_text),
string_to_list_string(output_text),
mask_type="none",
)
def generate_style_transfer(
text_sample: str,
style_data: StyleAttributeData,
max_gen_length: int,
num_beams: int,
temperature: int,
):
"""
Run inference on seq2seq model and persist result to
`session_state` varaible.
Args:
text_sample (str): _description_
style_data (StyleAttributeData): _description_
max_gen_length (int): _description_
num_beams (int): _description_
temperature (int): _description_
"""
with st.spinner("Transferring style, hang tight!"):
generate_kwargs = {
"max_gen_length": max_gen_length,
"num_beams": num_beams,
"temperature": temperature,
}
st_class = StyleTransfer(
model_identifier=style_data.seq2seq_model_path,
**generate_kwargs,
)
st_result = st_class.transfer(text_sample)
st.session_state.st_result = st_result
|