MekkCyber commited on
Commit
5bb569a
·
1 Parent(s): 7952bf4

add to org option

Browse files
Files changed (1) hide show
  1. app.py +44 -8
app.py CHANGED
@@ -18,19 +18,26 @@ def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) ->
18
 
19
 
20
  def check_model_exists(
21
- oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name
22
  ):
23
  """Check if a model exists in the user's Hugging Face repository."""
24
  try:
25
  models = list_models(author=username, token=oauth_token.token)
 
26
  model_names = [model.id for model in models]
27
- if quantized_model_name:
28
- repo_name = f"{username}/{quantized_model_name}"
 
29
  else:
30
- repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
 
 
 
31
 
32
  if repo_name in model_names:
33
  return f"Model '{repo_name}' already exists in your repository."
 
 
34
  else:
35
  return None # Model does not exist
36
  except Exception as e:
@@ -200,6 +207,7 @@ def save_model(
200
  auth_token=None,
201
  quantized_model_name=None,
202
  public=False,
 
203
  progress=gr.Progress(),
204
  ):
205
  progress(0.67, desc="Preparing to push")
@@ -214,10 +222,13 @@ def save_model(
214
  progress(0.75, desc="Preparing to push")
215
 
216
  # Prepare repo name and model card
217
- if quantized_model_name:
218
- repo_name = f"{username}/{quantized_model_name}"
219
  else:
220
- repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
 
 
 
221
 
222
  model_card = create_model_card(
223
  model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
@@ -291,6 +302,7 @@ def quantize_and_save(
291
  quant_storage_4,
292
  quantized_model_name,
293
  public,
 
294
  progress=gr.Progress(),
295
  ):
296
  if oauth_token is None:
@@ -308,7 +320,7 @@ def quantize_and_save(
308
  </div>
309
  """
310
  exists_message = check_model_exists(
311
- oauth_token, profile.username, model_name, quantized_model_name
312
  )
313
  if exists_message:
314
  return f"""
@@ -341,6 +353,7 @@ def quantize_and_save(
341
  oauth_token,
342
  quantized_model_name,
343
  public,
 
344
  progress,
345
  )
346
  # Clean up the model to free memory
@@ -685,6 +698,28 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
685
  interactive=True,
686
  show_label=True,
687
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
  with gr.Column():
690
  quantize_button = gr.Button(
@@ -704,6 +739,7 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
704
  quant_storage_4,
705
  quantized_model_name,
706
  public,
 
707
  ],
708
  outputs=[output_link],
709
  show_progress="full",
 
18
 
19
 
20
  def check_model_exists(
21
+ oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name, upload_to_community
22
  ):
23
  """Check if a model exists in the user's Hugging Face repository."""
24
  try:
25
  models = list_models(author=username, token=oauth_token.token)
26
+ community_models = list_models(author="bnb-community", token=oauth_token.token)
27
  model_names = [model.id for model in models]
28
+ community_model_names = [model.id for model in community_models]
29
+ if upload_to_community:
30
+ repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
31
  else:
32
+ if quantized_model_name:
33
+ repo_name = f"{username}/{quantized_model_name}"
34
+ else:
35
+ repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
36
 
37
  if repo_name in model_names:
38
  return f"Model '{repo_name}' already exists in your repository."
39
+ elif repo_name in community_model_names:
40
+ return f"Model '{repo_name}' already exists in the bnb-community organization."
41
  else:
42
  return None # Model does not exist
43
  except Exception as e:
 
207
  auth_token=None,
208
  quantized_model_name=None,
209
  public=False,
210
+ upload_to_community=False,
211
  progress=gr.Progress(),
212
  ):
213
  progress(0.67, desc="Preparing to push")
 
222
  progress(0.75, desc="Preparing to push")
223
 
224
  # Prepare repo name and model card
225
+ if upload_to_community:
226
+ repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
227
  else:
228
+ if quantized_model_name:
229
+ repo_name = f"{username}/{quantized_model_name}"
230
+ else:
231
+ repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
232
 
233
  model_card = create_model_card(
234
  model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
 
302
  quant_storage_4,
303
  quantized_model_name,
304
  public,
305
+ upload_to_community,
306
  progress=gr.Progress(),
307
  ):
308
  if oauth_token is None:
 
320
  </div>
321
  """
322
  exists_message = check_model_exists(
323
+ oauth_token, profile.username, model_name, quantized_model_name, upload_to_community
324
  )
325
  if exists_message:
326
  return f"""
 
353
  oauth_token,
354
  quantized_model_name,
355
  public,
356
+ upload_to_community,
357
  progress,
358
  )
359
  # Clean up the model to free memory
 
698
  interactive=True,
699
  show_label=True,
700
  )
701
+
702
+ with gr.Row():
703
+ upload_to_community = gr.Checkbox(
704
+ label="🤗 Upload to bnb-community",
705
+ info="If checked, the model will be uploaded to the bnb-community organization",
706
+ value=False,
707
+ interactive=True,
708
+ show_label=True,
709
+ )
710
+
711
+ # Add event handler to disable and clear model name when uploading to community
712
+ def toggle_model_name(upload_to_community_checked):
713
+ return gr.update(
714
+ interactive=not upload_to_community_checked,
715
+ value="Can't change model name when uploading to community" if upload_to_community_checked else quantized_model_name.value
716
+ )
717
+
718
+ upload_to_community.change(
719
+ fn=toggle_model_name,
720
+ inputs=[upload_to_community],
721
+ outputs=quantized_model_name
722
+ )
723
 
724
  with gr.Column():
725
  quantize_button = gr.Button(
 
739
  quant_storage_4,
740
  quantized_model_name,
741
  public,
742
+ upload_to_community,
743
  ],
744
  outputs=[output_link],
745
  show_progress="full",