Spaces:
Running
on
A10G
Running
on
A10G
MekkCyber
commited on
Commit
·
5bb569a
1
Parent(s):
7952bf4
add to org option
Browse files
app.py
CHANGED
@@ -18,19 +18,26 @@ def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) ->
|
|
18 |
|
19 |
|
20 |
def check_model_exists(
|
21 |
-
oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name
|
22 |
):
|
23 |
"""Check if a model exists in the user's Hugging Face repository."""
|
24 |
try:
|
25 |
models = list_models(author=username, token=oauth_token.token)
|
|
|
26 |
model_names = [model.id for model in models]
|
27 |
-
|
28 |
-
|
|
|
29 |
else:
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
if repo_name in model_names:
|
33 |
return f"Model '{repo_name}' already exists in your repository."
|
|
|
|
|
34 |
else:
|
35 |
return None # Model does not exist
|
36 |
except Exception as e:
|
@@ -200,6 +207,7 @@ def save_model(
|
|
200 |
auth_token=None,
|
201 |
quantized_model_name=None,
|
202 |
public=False,
|
|
|
203 |
progress=gr.Progress(),
|
204 |
):
|
205 |
progress(0.67, desc="Preparing to push")
|
@@ -214,10 +222,13 @@ def save_model(
|
|
214 |
progress(0.75, desc="Preparing to push")
|
215 |
|
216 |
# Prepare repo name and model card
|
217 |
-
if
|
218 |
-
repo_name = f"{
|
219 |
else:
|
220 |
-
|
|
|
|
|
|
|
221 |
|
222 |
model_card = create_model_card(
|
223 |
model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
|
@@ -291,6 +302,7 @@ def quantize_and_save(
|
|
291 |
quant_storage_4,
|
292 |
quantized_model_name,
|
293 |
public,
|
|
|
294 |
progress=gr.Progress(),
|
295 |
):
|
296 |
if oauth_token is None:
|
@@ -308,7 +320,7 @@ def quantize_and_save(
|
|
308 |
</div>
|
309 |
"""
|
310 |
exists_message = check_model_exists(
|
311 |
-
oauth_token, profile.username, model_name, quantized_model_name
|
312 |
)
|
313 |
if exists_message:
|
314 |
return f"""
|
@@ -341,6 +353,7 @@ def quantize_and_save(
|
|
341 |
oauth_token,
|
342 |
quantized_model_name,
|
343 |
public,
|
|
|
344 |
progress,
|
345 |
)
|
346 |
# Clean up the model to free memory
|
@@ -685,6 +698,28 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
|
|
685 |
interactive=True,
|
686 |
show_label=True,
|
687 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
with gr.Column():
|
690 |
quantize_button = gr.Button(
|
@@ -704,6 +739,7 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
|
|
704 |
quant_storage_4,
|
705 |
quantized_model_name,
|
706 |
public,
|
|
|
707 |
],
|
708 |
outputs=[output_link],
|
709 |
show_progress="full",
|
|
|
18 |
|
19 |
|
20 |
def check_model_exists(
|
21 |
+
oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name, upload_to_community
|
22 |
):
|
23 |
"""Check if a model exists in the user's Hugging Face repository."""
|
24 |
try:
|
25 |
models = list_models(author=username, token=oauth_token.token)
|
26 |
+
community_models = list_models(author="bnb-community", token=oauth_token.token)
|
27 |
model_names = [model.id for model in models]
|
28 |
+
community_model_names = [model.id for model in community_models]
|
29 |
+
if upload_to_community:
|
30 |
+
repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
|
31 |
else:
|
32 |
+
if quantized_model_name:
|
33 |
+
repo_name = f"{username}/{quantized_model_name}"
|
34 |
+
else:
|
35 |
+
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
|
36 |
|
37 |
if repo_name in model_names:
|
38 |
return f"Model '{repo_name}' already exists in your repository."
|
39 |
+
elif repo_name in community_model_names:
|
40 |
+
return f"Model '{repo_name}' already exists in the bnb-community organization."
|
41 |
else:
|
42 |
return None # Model does not exist
|
43 |
except Exception as e:
|
|
|
207 |
auth_token=None,
|
208 |
quantized_model_name=None,
|
209 |
public=False,
|
210 |
+
upload_to_community=False,
|
211 |
progress=gr.Progress(),
|
212 |
):
|
213 |
progress(0.67, desc="Preparing to push")
|
|
|
222 |
progress(0.75, desc="Preparing to push")
|
223 |
|
224 |
# Prepare repo name and model card
|
225 |
+
if upload_to_community:
|
226 |
+
repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
|
227 |
else:
|
228 |
+
if quantized_model_name:
|
229 |
+
repo_name = f"{username}/{quantized_model_name}"
|
230 |
+
else:
|
231 |
+
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
|
232 |
|
233 |
model_card = create_model_card(
|
234 |
model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
|
|
|
302 |
quant_storage_4,
|
303 |
quantized_model_name,
|
304 |
public,
|
305 |
+
upload_to_community,
|
306 |
progress=gr.Progress(),
|
307 |
):
|
308 |
if oauth_token is None:
|
|
|
320 |
</div>
|
321 |
"""
|
322 |
exists_message = check_model_exists(
|
323 |
+
oauth_token, profile.username, model_name, quantized_model_name, upload_to_community
|
324 |
)
|
325 |
if exists_message:
|
326 |
return f"""
|
|
|
353 |
oauth_token,
|
354 |
quantized_model_name,
|
355 |
public,
|
356 |
+
upload_to_community,
|
357 |
progress,
|
358 |
)
|
359 |
# Clean up the model to free memory
|
|
|
698 |
interactive=True,
|
699 |
show_label=True,
|
700 |
)
|
701 |
+
|
702 |
+
with gr.Row():
|
703 |
+
upload_to_community = gr.Checkbox(
|
704 |
+
label="🤗 Upload to bnb-community",
|
705 |
+
info="If checked, the model will be uploaded to the bnb-community organization",
|
706 |
+
value=False,
|
707 |
+
interactive=True,
|
708 |
+
show_label=True,
|
709 |
+
)
|
710 |
+
|
711 |
+
# Add event handler to disable and clear model name when uploading to community
|
712 |
+
def toggle_model_name(upload_to_community_checked):
|
713 |
+
return gr.update(
|
714 |
+
interactive=not upload_to_community_checked,
|
715 |
+
value="Can't change model name when uploading to community" if upload_to_community_checked else quantized_model_name.value
|
716 |
+
)
|
717 |
+
|
718 |
+
upload_to_community.change(
|
719 |
+
fn=toggle_model_name,
|
720 |
+
inputs=[upload_to_community],
|
721 |
+
outputs=quantized_model_name
|
722 |
+
)
|
723 |
|
724 |
with gr.Column():
|
725 |
quantize_button = gr.Button(
|
|
|
739 |
quant_storage_4,
|
740 |
quantized_model_name,
|
741 |
public,
|
742 |
+
upload_to_community,
|
743 |
],
|
744 |
outputs=[output_link],
|
745 |
show_progress="full",
|