skytnt commited on
Commit
e9642f9
1 Parent(s): a23f2ef

mv batch option to argparse

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. javascript/app.js +2 -1
app.py CHANGED
@@ -20,7 +20,6 @@ from midi_model import MIDIModel, MIDIModelConfig
20
  from midi_synthesizer import MidiSynthesizer
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
- OUTPUT_BATCH_SIZE = 8
24
  in_space = os.getenv("SYSTEM") == "spaces"
25
 
26
 
@@ -305,7 +304,10 @@ def load_javascript(dir="javascript"):
305
  javascript = ""
306
  for path in scripts_list:
307
  with open(path, "r", encoding="utf8") as jsfile:
308
- javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
 
 
 
309
  template_response_ori = gr.routes.templates.TemplateResponse
310
 
311
  def template_response(*args, **kwargs):
@@ -344,8 +346,10 @@ if __name__ == "__main__":
344
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
345
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
346
  parser.add_argument("--device", type=str, default="cuda", help="device to run model")
 
347
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
348
  opt = parser.parse_args()
 
349
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
350
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
351
  synthesizer = MidiSynthesizer(soundfont_path)
 
20
  from midi_synthesizer import MidiSynthesizer
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
 
23
  in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
 
 
304
  javascript = ""
305
  for path in scripts_list:
306
  with open(path, "r", encoding="utf8") as jsfile:
307
+ js_content = jsfile.read()
308
+ js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
309
+ f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
310
+ javascript += f"\n<!-- {path} --><script>{js_content}</script>"
311
  template_response_ori = gr.routes.templates.TemplateResponse
312
 
313
  def template_response(*args, **kwargs):
 
346
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
347
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
348
  parser.add_argument("--device", type=str, default="cuda", help="device to run model")
349
+ parser.add_argument("--batch", type=int, default=8, help="batch size")
350
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
351
  opt = parser.parse_args()
352
+ OUTPUT_BATCH_SIZE = opt.batch
353
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
354
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
355
  synthesizer = MidiSynthesizer(soundfont_path)
javascript/app.js CHANGED
@@ -1,4 +1,5 @@
1
- const MIDI_OUTPUT_BATCH_SIZE= 8;
 
2
 
3
  /**
4
  * 自动绕过 shadowRoot 的 querySelector
 
1
+ const MIDI_OUTPUT_BATCH_SIZE=4;
2
+ //Do not change MIDI_OUTPUT_BATCH_SIZE. It will be automatically replaced.
3
 
4
  /**
5
  * 自动绕过 shadowRoot 的 querySelector