Spaces:
Running
Running
| import streamlit as st | |
| from huggingface_hub import HfApi | |
| import os | |
| import subprocess | |
| HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN") | |
| HF_USERNAME = ( | |
| st.secrets.get("HF_USERNAME") | |
| or os.environ.get("HF_USERNAME") | |
| or os.environ.get("SPACE_AUTHOR_NAME") | |
| ) | |
| TRANSFORMERS_REPOSITORY_URL = "https://github.com/xenova/transformers.js" | |
| TRANSFORMERS_REPOSITORY_REVISION = "2.16.0" | |
| TRANSFORMERS_REPOSITORY_PATH = "./transformers.js" | |
| HF_BASE_URL = "https://huggingface.co" | |
| if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH): | |
| os.system(f"git clone {TRANSFORMERS_REPOSITORY_URL} {TRANSFORMERS_REPOSITORY_PATH}") | |
| os.system( | |
| f"cd {TRANSFORMERS_REPOSITORY_PATH} && git checkout {TRANSFORMERS_REPOSITORY_REVISION}" | |
| ) | |
| st.write("## Convert a HuggingFace model to ONNX") | |
| input_model_id = st.text_input( | |
| "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`" | |
| ) | |
| if input_model_id: | |
| model_name = ( | |
| input_model_id.replace(f"{HF_BASE_URL}/", "") | |
| .replace("/", "-") | |
| .replace(f"{HF_USERNAME}-", "") | |
| .strip() | |
| ) | |
| output_model_id = f"{HF_USERNAME}/{model_name}-ONNX" | |
| output_model_url = f"{HF_BASE_URL}/{output_model_id}" | |
| api = HfApi(token=HF_TOKEN) | |
| repo_exists = api.repo_exists(output_model_id) | |
| if repo_exists: | |
| st.write("This model has already been converted! 🎉") | |
| st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") | |
| else: | |
| st.write(f"This model will be converted and uploaded to the following URL:") | |
| st.code(output_model_url, language="plaintext") | |
| start_conversion = st.button(label="Proceed", type="primary") | |
| if start_conversion: | |
| with st.spinner("Converting model..."): | |
| output = subprocess.run( | |
| [ | |
| "python", | |
| "-m", | |
| "scripts.convert", | |
| "--quantize", | |
| "--model_id", | |
| input_model_id, | |
| ], | |
| cwd=TRANSFORMERS_REPOSITORY_PATH, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| model_folder_path = ( | |
| f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}" | |
| ) | |
| os.rename( | |
| f"{model_folder_path}/onnx/model.onnx", | |
| f"{model_folder_path}/onnx/decoder_model_merged.onnx", | |
| ) | |
| os.rename( | |
| f"{model_folder_path}/onnx/model_quantized.onnx", | |
| f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx", | |
| ) | |
| st.success("Conversion successful!") | |
| st.code(output.stderr) | |
| with st.spinner("Uploading model..."): | |
| repository = api.create_repo( | |
| f"{output_model_id}", exist_ok=True, private=False | |
| ) | |
| upload_error_message = None | |
| try: | |
| api.upload_folder( | |
| folder_path=model_folder_path, repo_id=repository.repo_id | |
| ) | |
| except Exception as e: | |
| upload_error_message = str(e) | |
| os.system(f"rm -rf {model_folder_path}") | |
| if upload_error_message: | |
| st.error(f"Upload failed: {upload_error_message}") | |
| else: | |
| st.success(f"Upload successful!") | |
| st.write("You can now go and view the model on HuggingFace!") | |
| st.link_button( | |
| f"Go to {output_model_id}", output_model_url, type="primary" | |
| ) | |