Spaces:
Paused
Paused
mv batch option to argparse
Browse files- app.py +6 -2
- javascript/app.js +2 -1
app.py
CHANGED
@@ -20,7 +20,6 @@ from midi_model import MIDIModel, MIDIModelConfig
|
|
20 |
from midi_synthesizer import MidiSynthesizer
|
21 |
|
22 |
MAX_SEED = np.iinfo(np.int32).max
|
23 |
-
OUTPUT_BATCH_SIZE = 8
|
24 |
in_space = os.getenv("SYSTEM") == "spaces"
|
25 |
|
26 |
|
@@ -305,7 +304,10 @@ def load_javascript(dir="javascript"):
|
|
305 |
javascript = ""
|
306 |
for path in scripts_list:
|
307 |
with open(path, "r", encoding="utf8") as jsfile:
|
308 |
-
|
|
|
|
|
|
|
309 |
template_response_ori = gr.routes.templates.TemplateResponse
|
310 |
|
311 |
def template_response(*args, **kwargs):
|
@@ -344,8 +346,10 @@ if __name__ == "__main__":
|
|
344 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
345 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
346 |
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
|
|
347 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
348 |
opt = parser.parse_args()
|
|
|
349 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
350 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
351 |
synthesizer = MidiSynthesizer(soundfont_path)
|
|
|
20 |
from midi_synthesizer import MidiSynthesizer
|
21 |
|
22 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
23 |
in_space = os.getenv("SYSTEM") == "spaces"
|
24 |
|
25 |
|
|
|
304 |
javascript = ""
|
305 |
for path in scripts_list:
|
306 |
with open(path, "r", encoding="utf8") as jsfile:
|
307 |
+
js_content = jsfile.read()
|
308 |
+
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
|
309 |
+
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
|
310 |
+
javascript += f"\n<!-- {path} --><script>{js_content}</script>"
|
311 |
template_response_ori = gr.routes.templates.TemplateResponse
|
312 |
|
313 |
def template_response(*args, **kwargs):
|
|
|
346 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
347 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
348 |
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
349 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
350 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
351 |
opt = parser.parse_args()
|
352 |
+
OUTPUT_BATCH_SIZE = opt.batch
|
353 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
354 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
355 |
synthesizer = MidiSynthesizer(soundfont_path)
|
javascript/app.js
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
const MIDI_OUTPUT_BATCH_SIZE=
|
|
|
2 |
|
3 |
/**
|
4 |
* 自动绕过 shadowRoot 的 querySelector
|
|
|
1 |
+
const MIDI_OUTPUT_BATCH_SIZE=4;
|
2 |
+
//Do not change MIDI_OUTPUT_BATCH_SIZE. It will be automatically replaced.
|
3 |
|
4 |
/**
|
5 |
* 自动绕过 shadowRoot 的 querySelector
|