pcuenq's picture
pcuenq HF staff
Initial conversion
e05f54a
raw
history blame
8.64 kB
import gradio as gr
import json
from pathlib import Path
from huggingface_hub import hf_hub_download, HfApi
from coremltools import ComputeUnit
from transformers.onnx.utils import get_preprocessor
from exporters.coreml import export
from exporters.coreml.features import FeaturesManager
from exporters.coreml.validate import validate_model_outputs
compute_units_mapping = {
"All": ComputeUnit.ALL,
"CPU": ComputeUnit.CPU_ONLY,
"CPU + GPU": ComputeUnit.CPU_AND_GPU,
"CPU + NE": ComputeUnit.CPU_AND_NE,
}
compute_units_labels = list(compute_units_mapping.keys())
framework_mapping = {
"PyTorch": "pt",
"TensorFlow": "tf",
}
framework_labels = list(framework_mapping.keys())
precision_mapping = {
"Float32": "float32",
"Float16 quantization": "float16",
}
precision_labels = list(precision_mapping.keys())
tolerance_mapping = {
"Model default": None,
"1e-2": 1e-2,
"1e-3": 1e-3,
"1e-4": 1e-4,
}
tolerance_labels = list(tolerance_mapping.keys())
def error_str(error, title="Error"):
return f"""#### {title}
{error}""" if error else ""
def url_to_model_id(model_id_str):
if not model_id_str.startswith("https://huggingface.co/"): return model_id_str
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1]
def supported_frameworks(model_id):
"""
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
Only PyTorch and Tensorflow are supported.
"""
api = HfApi()
model_info = api.model_info(model_id)
tags = model_info.tags
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks])
def on_model_change(model):
model = url_to_model_id(model)
tasks = None
error = None
try:
config_file = hf_hub_download(model, filename="config.json")
if config_file is None:
raise Exception(f"Model {model} not found")
with open(config_file, "r") as f:
config_json = f.read()
config = json.loads(config_json)
model_type = config["model_type"]
features = FeaturesManager.get_supported_features_for_model_type(model_type)
tasks = list(features.keys())
frameworks = supported_frameworks(model)
selected_framework = frameworks[0] if len(frameworks) > 0 else None
return (
gr.update(visible=bool(model_type)), # Settings column
gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
gr.update(value=error_str(error)), # Error
)
except Exception as e:
error = e
model_type = None
def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None):
coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq)
mlmodel = export(
preprocessor,
model,
coreml_config,
quantize=precision,
compute_units=compute_units,
)
filename = output
if seq2seq == "encoder":
filename = filename.parent / ("encoder_" + filename.name)
elif seq2seq == "decoder":
filename = filename.parent / ("decoder_" + filename.name)
filename = filename.as_posix()
mlmodel.save(filename)
if tolerance is None:
tolerance = coreml_config.atol_for_validation
validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance)
def convert(model, task, compute_units, precision, tolerance, framework):
model = url_to_model_id(model)
compute_units = compute_units_mapping[compute_units]
precision = precision_mapping[precision]
tolerance = tolerance_mapping[tolerance]
framework = framework_mapping[framework]
# TODO: support legacy format
output = Path("exported")/model/"coreml"/task
output.mkdir(parents=True, exist_ok=True)
output = output/f"{precision}_model.mlpackage"
try:
preprocessor = get_preprocessor(model)
model = FeaturesManager.get_model_from_feature(task, model, framework=framework)
_, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task)
if task in ["seq2seq-lm", "speech-seq2seq"]:
# Convert encoder / decoder
convert_model(
preprocessor,
model,
model_coreml_config,
compute_units,
precision,
tolerance,
output,
seq2seq="encoder"
)
convert_model(
preprocessor,
model,
model_coreml_config,
compute_units,
precision,
tolerance,
output,
seq2seq="decoder"
)
else:
convert_model(
preprocessor,
model,
model_coreml_config,
compute_units,
precision,
tolerance,
output,
)
# TODO: push to hub, whatever
return "Done"
except Exception as e:
return error_str(e)
DESCRIPTION = """
## Convert a transformers model to Core ML
With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood.
Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters).
After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights.
"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1. Load model info")
input_model = gr.Textbox(
max_lines=1,
label="Model name or URL, such as apple/mobilevit-small",
placeholder="distilbert-base-uncased",
value="distilbert-base-uncased",
)
btn_get_tasks = gr.Button("Load")
with gr.Column(scale=3):
with gr.Column(visible=False) as group_settings:
gr.Markdown("## 2. Select Task")
radio_tasks = gr.Radio(label="Choose the task for the converted model.")
gr.Markdown("The `default` task is suitable for feature extraction.")
radio_framework = gr.Radio(
visible=False,
label="Framework",
choices=framework_labels,
value=framework_labels[0],
)
radio_compute = gr.Radio(
label="Compute Units",
choices=compute_units_labels,
value=compute_units_labels[0],
)
radio_precision = gr.Radio(
label="Precision",
choices=precision_labels,
value=precision_labels[0],
)
radio_tolerance = gr.Radio(
label="Absolute Tolerance for Validation",
choices=tolerance_labels,
value=tolerance_labels[0],
)
btn_convert = gr.Button("Convert")
gr.Markdown("Conversion will take a few minutes.")
error_output = gr.Markdown(label="Output")
btn_get_tasks.click(
fn=on_model_change,
inputs=input_model,
outputs=[group_settings, radio_tasks, radio_framework, error_output],
queue=False,
scroll_to_output=True
)
btn_convert.click(
fn=convert,
inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework],
outputs=error_output,
scroll_to_output=True
)
# gr.HTML("""
# <div style="border-top: 1px solid #303030;">
# <br>
# <p>Footer</p><br>
# <p><img src="https://visitor-badge.glitch.me/badge?page_id=pcuenq.transformers-to-coreml" alt="visitors"></p>
# </div>
# """)
demo.queue(concurrency_count=1, max_size=10)
demo.launch(debug=True, share=False)