Spaces:
Build error
Build error
| # ---------------------------------------------------------------------------- | |
| # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
| # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
| # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
| # | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # ---------------------------------------------------------------------------- | |
| """ | |
| Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq_cli/generate.py | |
| """ | |
| import ast | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| from argparse import Namespace | |
| from itertools import chain | |
| import numpy as np | |
| import torch | |
| from omegaconf import DictConfig | |
| from fairseq import checkpoint_utils, options, scoring, tasks, utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.logging import progress_bar | |
| from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
| def main(cfg: DictConfig): | |
| if isinstance(cfg, Namespace): | |
| cfg = convert_namespace_to_omegaconf(cfg) | |
| assert cfg.common_eval.path is not None, "--path required for generation!" | |
| assert ( | |
| not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam | |
| ), "--sampling requires --nbest to be equal to --beam" | |
| assert ( | |
| cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" | |
| ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" | |
| if cfg.common_eval.results_path is not None: | |
| os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
| output_path = os.path.join( | |
| cfg.common_eval.results_path, | |
| "generate-{}.txt".format(cfg.dataset.gen_subset), | |
| ) | |
| with open(output_path, "w", buffering=1, encoding="utf-8") as h: | |
| return _main(cfg, h) | |
| else: | |
| return _main(cfg, sys.stdout) | |
| def get_symbols_to_strip_from_output(generator): | |
| if hasattr(generator, "symbols_to_strip_from_output"): | |
| return generator.symbols_to_strip_from_output | |
| else: | |
| return {generator.eos} | |
| def _main(cfg: DictConfig, output_file): | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=output_file, | |
| ) | |
| logger = logging.getLogger("fairseq_cli.generate") | |
| utils.import_user_module(cfg.common) | |
| if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: | |
| cfg.dataset.max_tokens = 12000 | |
| logger.info(cfg) | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
| # Load dataset splits | |
| task = tasks.setup_task(cfg.task) | |
| # Set dictionaries | |
| try: | |
| src_dict = getattr(task, "source_dictionary", None) | |
| except NotImplementedError: | |
| src_dict = None | |
| tgt_dict = task.target_dictionary | |
| overrides = ast.literal_eval(cfg.common_eval.model_overrides) | |
| # Load ensemble | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
| utils.split_paths(cfg.common_eval.path), | |
| arg_overrides=overrides, | |
| task=task, | |
| suffix=cfg.checkpoint.checkpoint_suffix, | |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
| num_shards=cfg.checkpoint.checkpoint_shard_count, | |
| ) | |
| # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
| task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
| if cfg.generation.lm_path is not None: | |
| overrides["data"] = cfg.task.data | |
| try: | |
| lms, _ = checkpoint_utils.load_model_ensemble( | |
| [cfg.generation.lm_path], arg_overrides=overrides, task=None | |
| ) | |
| except: | |
| logger.warning( | |
| f"Failed to load language model! Please make sure that the language model dict is the same " | |
| f"as target dict and is located in the data dir ({cfg.task.data})" | |
| ) | |
| raise | |
| assert len(lms) == 1 | |
| else: | |
| lms = [None] | |
| # Optimize ensemble for generation | |
| for model in chain(models, lms): | |
| if model is None: | |
| continue | |
| if cfg.common.fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| def _fp_convert_sample(sample): | |
| def apply_half(t): | |
| if t.dtype is torch.float32: | |
| return t.to(dtype=torch.half) | |
| return t | |
| def apply_bfloat16(t): | |
| if t.dtype is torch.float32: | |
| return t.to(dtype=torch.bfloat16) | |
| return t | |
| if cfg.common.fp16: | |
| sample = utils.apply_to_sample(apply_half, sample) | |
| if cfg.common.bf16: | |
| sample = utils.apply_to_sample(apply_bfloat16, sample) | |
| return sample | |
| # Load alignment dictionary for unknown word replacement | |
| # (None if no unknown word replacement, empty if no path to align dictionary) | |
| align_dict = utils.load_align_dict(cfg.generation.replace_unk) | |
| # Load dataset (possibly sharded) | |
| itr = task.get_batch_iterator( | |
| dataset=task.dataset(cfg.dataset.gen_subset), | |
| max_tokens=cfg.dataset.max_tokens, | |
| max_sentences=cfg.dataset.batch_size, | |
| max_positions=utils.resolve_max_positions( | |
| task.max_positions(), *[m.max_positions() for m in models] | |
| ), | |
| ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, | |
| required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, | |
| seed=cfg.common.seed, | |
| num_shards=cfg.distributed_training.distributed_world_size, | |
| shard_id=cfg.distributed_training.distributed_rank, | |
| num_workers=cfg.dataset.num_workers, | |
| data_buffer_size=cfg.dataset.data_buffer_size, | |
| ).next_epoch_itr(shuffle=False) | |
| progress = progress_bar.progress_bar( | |
| itr, | |
| log_format=cfg.common.log_format, | |
| log_interval=cfg.common.log_interval, | |
| default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
| ) | |
| # Initialize generator | |
| gen_timer = StopwatchMeter() | |
| extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} | |
| generator = task.build_generator( | |
| models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
| ) | |
| # Handle tokenization and BPE | |
| tokenizer = task.build_tokenizer(cfg.tokenizer) | |
| bpe = task.build_bpe(cfg.bpe) | |
| def decode_fn(x): | |
| if bpe is not None: | |
| x = bpe.decode(x) | |
| if tokenizer is not None: | |
| x = tokenizer.decode(x) | |
| return x | |
| scorer = scoring.build_scorer(cfg.scoring, None) | |
| num_sentences = 0 | |
| has_target = True | |
| wps_meter = TimeMeter() | |
| for sample in progress: | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| sample = _fp_convert_sample(sample) | |
| if "net_input" not in sample: | |
| continue | |
| prefix_tokens = None | |
| if cfg.generation.prefix_size > 0: | |
| prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] | |
| constraints = None | |
| if "constraints" in sample: | |
| constraints = sample["constraints"] | |
| gen_timer.start() | |
| hypos = task.inference_step( | |
| generator, | |
| models[0], | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| constraints=constraints, | |
| ) | |
| num_generated_tokens = sum(len(h["unit"]) for h in hypos) | |
| gen_timer.stop(num_generated_tokens) | |
| for i, sample_id in enumerate(sample["id"].tolist()): | |
| has_target = sample["target"] is not None | |
| # Remove padding | |
| if "src_tokens" in sample["net_input"]: | |
| src_tokens = utils.strip_pad( | |
| sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() | |
| ).cpu() | |
| else: | |
| src_tokens = None | |
| target_tokens = None | |
| if has_target: | |
| target_tokens = ( | |
| utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).cpu() | |
| ) | |
| # Either retrieve the original sentences or regenerate them from tokens. | |
| if align_dict is not None: | |
| src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( | |
| sample_id | |
| ) | |
| target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( | |
| sample_id | |
| ) | |
| else: | |
| if src_dict is not None: | |
| src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) | |
| else: | |
| src_str = "" | |
| if has_target: | |
| target_str = " ".join(map(str, target_tokens.numpy().tolist())) | |
| src_str = decode_fn(src_str) | |
| if not cfg.common_eval.quiet: | |
| if src_dict is not None: | |
| print("S-{}\t{}".format(sample_id, src_str), file=output_file) | |
| if has_target: | |
| print("T-{}\t{}".format(sample_id, target_str), file=output_file) | |
| # Process top predictions | |
| j = 0 | |
| hypo = hypos[i] | |
| hypo_tokens = hypo["unit"].int().cpu() | |
| hypo_str = " ".join(map(str, hypo_tokens.numpy().tolist())) | |
| alignment = None | |
| detok_hypo_str = hypo_str | |
| # add duration prediction | |
| hypo_duration = " ".join(map(str, hypo["duration"].int().cpu().numpy().tolist())) | |
| hypo_fa_src_str = src_dict.string(hypo["fa_src"].cpu().numpy(), cfg.common_eval.post_process) | |
| # hypo_fa_src_str = " ".join(map(str, hypo["fa_src"].int().cpu().numpy() - 4)) | |
| if not cfg.common_eval.quiet: | |
| # score = hypo["score"] / math.log(2) # convert to base 2 | |
| score = 0.00 | |
| # original hypothesis (after tokenization and BPE) | |
| # print( | |
| # "H-{}\t{}\t{}".format(sample_id, score, hypo_str), | |
| # file=output_file, | |
| # ) | |
| # detokenized hypothesis | |
| print( | |
| "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), | |
| file=output_file, | |
| ) | |
| # duration prediction | |
| print( | |
| "L-{}\t{}\t{}".format(sample_id, score, hypo_duration), | |
| file=output_file, | |
| ) | |
| # force-aligned upsampled src-tokens | |
| print( | |
| "U-{}\t{}\t{}".format(sample_id, score, hypo_fa_src_str), | |
| file=output_file, | |
| ) | |
| # print( | |
| # "P-{}\t{}".format( | |
| # sample_id, | |
| # " ".join( | |
| # map( | |
| # lambda x: "{:.4f}".format(x), | |
| # # convert from base e to base 2 | |
| # hypo["positional_scores"] | |
| # .div_(math.log(2)) | |
| # .tolist(), | |
| # ) | |
| # ), | |
| # ), | |
| # file=output_file, | |
| # ) | |
| if cfg.generation.print_alignment == "hard": | |
| print( | |
| "A-{}\t{}".format( | |
| sample_id, | |
| " ".join( | |
| [ | |
| "{}-{}".format(src_idx, tgt_idx) | |
| for src_idx, tgt_idx in alignment | |
| ] | |
| ), | |
| ), | |
| file=output_file, | |
| ) | |
| if cfg.generation.print_alignment == "soft": | |
| print( | |
| "A-{}\t{}".format( | |
| sample_id, | |
| " ".join( | |
| [",".join(src_probs) for src_probs in alignment] | |
| ), | |
| ), | |
| file=output_file, | |
| ) | |
| # Score only the top hypothesis | |
| if has_target and j == 0: | |
| if hasattr(scorer, "add_string"): | |
| scorer.add_string(target_str, detok_hypo_str) | |
| else: | |
| scorer.add(target_tokens, hypo_tokens) | |
| wps_meter.update(num_generated_tokens) | |
| progress.log({"wps": round(wps_meter.avg)}) | |
| num_sentences += ( | |
| sample["nsentences"] if "nsentences" in sample else sample["id"].numel() | |
| ) | |
| logger.info("NOTE: hypothesis and token scores are output in base 2") | |
| logger.info( | |
| "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( | |
| num_sentences, | |
| gen_timer.n, | |
| gen_timer.sum, | |
| num_sentences / gen_timer.sum, | |
| 1.0 / gen_timer.avg, | |
| ) | |
| ) | |
| if has_target: | |
| if cfg.bpe and not cfg.generation.sacrebleu: | |
| if cfg.common_eval.post_process: | |
| logger.warning( | |
| "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" | |
| ) | |
| else: | |
| logger.warning( | |
| "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" | |
| ) | |
| # use print to be consistent with other main outputs: S-, H-, T-, D- and so on | |
| print( | |
| "Generate {} with beam={}: {}".format( | |
| cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() | |
| ), | |
| file=output_file, | |
| ) | |
| return scorer | |
| def cli_main(): | |
| parser = options.get_generation_parser() | |
| # TODO: replace this workaround with refactoring of `AudioPretraining` | |
| parser.add_argument( | |
| "--arch", | |
| "-a", | |
| metavar="ARCH", | |
| default="wav2vec2", | |
| help="Model architecture. For constructing tasks that rely on " | |
| "model args (e.g. `AudioPretraining`)", | |
| ) | |
| args = options.parse_args_and_arch(parser) | |
| main(args) | |
| if __name__ == "__main__": | |
| cli_main() | |