Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| import argparse | |
| import csv | |
| import itertools | |
| import json | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| import llava | |
| from llava import conversation as conversation_lib | |
| from llava.data.builder import DATASETS | |
| from llava.eval.mmmu_utils.eval_utils import parse_choice | |
| from llava.utils import distributed as dist | |
| from llava.utils import io | |
| from llava.utils.logging import logger | |
| def load_existing_ids(output_file): | |
| if not os.path.exists(output_file): | |
| return set(), [] | |
| try: | |
| with open(output_file, "r") as f: | |
| lines = f.readlines() | |
| outputs = [json.loads(line) for line in lines] | |
| processed_ids = {item["id"] for item in outputs} | |
| return processed_ids, outputs | |
| except Exception as e: | |
| print(f"Error loading existing outputs: {e}") | |
| return set(), [] | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model-path", type=str, default=None) | |
| parser.add_argument("--model-base", type=str, default=None) | |
| parser.add_argument("--task", type=str, default=None) | |
| parser.add_argument("--conv-mode", type=str, default="auto") | |
| parser.add_argument("--generation-config", type=json.loads) | |
| parser.add_argument("--output-dir", type=str, default=None) | |
| args = parser.parse_args() | |
| # Set up distributed environment | |
| dist.init() | |
| devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size()) | |
| torch.cuda.set_device(devices[0]) | |
| # Load stage 3 model with line 56 | |
| model = llava.load(args.model_base, model_base=None, devices=devices) | |
| # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode | |
| # model = PeftModel.from_pretrained( | |
| # model, | |
| # args.model_path, | |
| # device_map="auto", | |
| # torch_dtype=torch.float16, | |
| # ) | |
| # Set up generation config | |
| generation_config = model.default_generation_config | |
| if args.generation_config is not None: | |
| generation_config.update(**args.generation_config) | |
| # Load data and chunk it | |
| json_file = DATASETS[args.task]["data_path"] | |
| instances = io.load(json_file) | |
| instances = instances[dist.rank() :: dist.size()] | |
| output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl") | |
| processed_ids, outputs = load_existing_ids(output_path) | |
| count = len(outputs) | |
| # Run inference | |
| new_outputs = [] | |
| for instance in tqdm(instances, disable=not dist.is_main()): | |
| uuid = instance["id"] | |
| sound_path = instance["sound"] | |
| if sound_path in processed_ids: | |
| continue # Skip if already processed | |
| sound = llava.Sound(sound_path) | |
| conversations = instance["conversations"] | |
| question = conversations[0]["value"] | |
| response = model.generate_content([sound, question], generation_config=generation_config) | |
| print("response", response) | |
| output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response} | |
| new_outputs.append(output) | |
| count = count +1 | |
| if count % 20 == 0: | |
| # Gather and save outputs | |
| if dist.size() > 1: | |
| outputs_new = dist.gather(new_outputs, dst=0) | |
| if dist.is_main(): | |
| outputs_new = list(itertools.chain(*outputs_new)) | |
| final_outputs = outputs + outputs_new | |
| io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) | |
| else: | |
| final_outputs = outputs + new_outputs | |
| io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) | |
| if dist.size() > 1: | |
| new_outputs = dist.gather(new_outputs, dst=0) | |
| if not dist.is_main(): | |
| return | |
| new_outputs = list(itertools.chain(*new_outputs)) | |
| final_outputs = outputs + new_outputs | |
| io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs) | |
| if __name__ == "__main__": | |
| main() | |