Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download, HfApi | |
from coremltools import ComputeUnit | |
from transformers.onnx.utils import get_preprocessor | |
from exporters.coreml import export | |
from exporters.coreml.features import FeaturesManager | |
from exporters.coreml.validate import validate_model_outputs | |
compute_units_mapping = { | |
"All": ComputeUnit.ALL, | |
"CPU": ComputeUnit.CPU_ONLY, | |
"CPU + GPU": ComputeUnit.CPU_AND_GPU, | |
"CPU + NE": ComputeUnit.CPU_AND_NE, | |
} | |
compute_units_labels = list(compute_units_mapping.keys()) | |
framework_mapping = { | |
"PyTorch": "pt", | |
"TensorFlow": "tf", | |
} | |
framework_labels = list(framework_mapping.keys()) | |
precision_mapping = { | |
"Float32": "float32", | |
"Float16 quantization": "float16", | |
} | |
precision_labels = list(precision_mapping.keys()) | |
tolerance_mapping = { | |
"Model default": None, | |
"1e-2": 1e-2, | |
"1e-3": 1e-3, | |
"1e-4": 1e-4, | |
} | |
tolerance_labels = list(tolerance_mapping.keys()) | |
def error_str(error, title="Error"): | |
return f"""#### {title} | |
{error}""" if error else "" | |
def url_to_model_id(model_id_str): | |
if not model_id_str.startswith("https://huggingface.co/"): return model_id_str | |
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] | |
def supported_frameworks(model_id): | |
""" | |
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id. | |
Only PyTorch and Tensorflow are supported. | |
""" | |
api = HfApi() | |
model_info = api.model_info(model_id) | |
tags = model_info.tags | |
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]] | |
return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]) | |
def on_model_change(model): | |
model = url_to_model_id(model) | |
tasks = None | |
error = None | |
try: | |
config_file = hf_hub_download(model, filename="config.json") | |
if config_file is None: | |
raise Exception(f"Model {model} not found") | |
with open(config_file, "r") as f: | |
config_json = f.read() | |
config = json.loads(config_json) | |
model_type = config["model_type"] | |
features = FeaturesManager.get_supported_features_for_model_type(model_type) | |
tasks = list(features.keys()) | |
frameworks = supported_frameworks(model) | |
selected_framework = frameworks[0] if len(frameworks) > 0 else None | |
return ( | |
gr.update(visible=bool(model_type)), # Settings column | |
gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks | |
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks | |
gr.update(value=error_str(error)), # Error | |
) | |
except Exception as e: | |
error = e | |
model_type = None | |
def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None): | |
coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq) | |
mlmodel = export( | |
preprocessor, | |
model, | |
coreml_config, | |
quantize=precision, | |
compute_units=compute_units, | |
) | |
filename = output | |
if seq2seq == "encoder": | |
filename = filename.parent / ("encoder_" + filename.name) | |
elif seq2seq == "decoder": | |
filename = filename.parent / ("decoder_" + filename.name) | |
filename = filename.as_posix() | |
mlmodel.save(filename) | |
if tolerance is None: | |
tolerance = coreml_config.atol_for_validation | |
validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance) | |
def convert(model, task, compute_units, precision, tolerance, framework): | |
model = url_to_model_id(model) | |
compute_units = compute_units_mapping[compute_units] | |
precision = precision_mapping[precision] | |
tolerance = tolerance_mapping[tolerance] | |
framework = framework_mapping[framework] | |
# TODO: support legacy format | |
output = Path("exported")/model/"coreml"/task | |
output.mkdir(parents=True, exist_ok=True) | |
output = output/f"{precision}_model.mlpackage" | |
try: | |
preprocessor = get_preprocessor(model) | |
model = FeaturesManager.get_model_from_feature(task, model, framework=framework) | |
_, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task) | |
if task in ["seq2seq-lm", "speech-seq2seq"]: | |
# Convert encoder / decoder | |
convert_model( | |
preprocessor, | |
model, | |
model_coreml_config, | |
compute_units, | |
precision, | |
tolerance, | |
output, | |
seq2seq="encoder" | |
) | |
convert_model( | |
preprocessor, | |
model, | |
model_coreml_config, | |
compute_units, | |
precision, | |
tolerance, | |
output, | |
seq2seq="decoder" | |
) | |
else: | |
convert_model( | |
preprocessor, | |
model, | |
model_coreml_config, | |
compute_units, | |
precision, | |
tolerance, | |
output, | |
) | |
# TODO: push to hub, whatever | |
return "Done" | |
except Exception as e: | |
return error_str(e) | |
DESCRIPTION = """ | |
## Convert a transformers model to Core ML | |
With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood. | |
Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters). | |
After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights. | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("## 1. Load model info") | |
input_model = gr.Textbox( | |
max_lines=1, | |
label="Model name or URL, such as apple/mobilevit-small", | |
placeholder="distilbert-base-uncased", | |
value="distilbert-base-uncased", | |
) | |
btn_get_tasks = gr.Button("Load") | |
with gr.Column(scale=3): | |
with gr.Column(visible=False) as group_settings: | |
gr.Markdown("## 2. Select Task") | |
radio_tasks = gr.Radio(label="Choose the task for the converted model.") | |
gr.Markdown("The `default` task is suitable for feature extraction.") | |
radio_framework = gr.Radio( | |
visible=False, | |
label="Framework", | |
choices=framework_labels, | |
value=framework_labels[0], | |
) | |
radio_compute = gr.Radio( | |
label="Compute Units", | |
choices=compute_units_labels, | |
value=compute_units_labels[0], | |
) | |
radio_precision = gr.Radio( | |
label="Precision", | |
choices=precision_labels, | |
value=precision_labels[0], | |
) | |
radio_tolerance = gr.Radio( | |
label="Absolute Tolerance for Validation", | |
choices=tolerance_labels, | |
value=tolerance_labels[0], | |
) | |
btn_convert = gr.Button("Convert") | |
gr.Markdown("Conversion will take a few minutes.") | |
error_output = gr.Markdown(label="Output") | |
btn_get_tasks.click( | |
fn=on_model_change, | |
inputs=input_model, | |
outputs=[group_settings, radio_tasks, radio_framework, error_output], | |
queue=False, | |
scroll_to_output=True | |
) | |
btn_convert.click( | |
fn=convert, | |
inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework], | |
outputs=error_output, | |
scroll_to_output=True | |
) | |
# gr.HTML(""" | |
# <div style="border-top: 1px solid #303030;"> | |
# <br> | |
# <p>Footer</p><br> | |
# <p><img src="https://visitor-badge.glitch.me/badge?page_id=pcuenq.transformers-to-coreml" alt="visitors"></p> | |
# </div> | |
# """) | |
demo.queue(concurrency_count=1, max_size=10) | |
demo.launch(debug=True, share=False) | |