ChatTTS-Forge / modules /models_setup.py
zhzluke96
update
bed01bd
import argparse
import logging
from modules import generate_audio
from modules.devices import devices
from modules.Enhancer.ResembleEnhance import load_enhancer
from modules.models import load_chat_tts
from modules.utils import env
def setup_model_args(parser: argparse.ArgumentParser):
parser.add_argument("--compile", action="store_true", help="Enable model compile")
parser.add_argument(
"--no_half",
action="store_true",
help="Disalbe half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
choices=["all", "chattts", "enhancer", "trainer"],
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--debug_generate",
action="store_true",
help="Enable debug mode for audio generation",
)
parser.add_argument(
"--preload_models",
action="store_true",
help="Preload all models at startup",
)
def process_model_args(args: argparse.Namespace):
lru_size = env.get_and_update_env(args, "lru_size", 64, int)
compile = env.get_and_update_env(args, "compile", False, bool)
device_id = env.get_and_update_env(args, "device_id", None, str)
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
no_half = env.get_and_update_env(args, "no_half", False, bool)
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
preload_models = env.get_and_update_env(args, "preload_models", False, bool)
generate_audio.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()
if debug_generate:
generate_audio.logger.setLevel(logging.DEBUG)
if preload_models:
load_chat_tts()
load_enhancer()