id-align / playground /sgl_llava_inference_multinode.py
zooblastlbz's picture
Upload folder using huggingface_hub
a9e1e1a verified
import argparse
import json
import time
import os
import tqdm
import sglang as sgl
from sglang.test.test_utils import select_sglang_backend
from sglang.utils import dump_state_text
@sgl.function
def image_description(s, image_file):
prompt = "Please generate detailed descriptions of the given image."
s += sgl.user(sgl.image(image_file) + prompt)
s += sgl.assistant(sgl.gen("answer", max_tokens=1024, temperature=0.0))
def load_progress(progress_file):
print(f"Load progress from {progress_file}")
if os.path.exists(progress_file):
with open(progress_file, "r") as f:
return json.load(f)
return {"last_index": -1, "last_chunk": -1, "results": [], "annotations": []}
def save_progress(progress_file, progress_data):
with open(progress_file, "w") as f:
json.dump(progress_data, f, indent=2)
def find_images_in_subfolders(folder_path):
image_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
image_files = []
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith(image_extensions):
image_files.append(os.path.join(root, file))
return image_files
def main(args):
dist_rank = args.dist
dist_size = args.total_dist
base_dir = os.path.dirname(args.result_file)
os.makedirs(base_dir, exist_ok=True) # Ensure the base directory exists
progress_file = f"{base_dir}/progress_{dist_rank}_or_{dist_size}.json"
progress_data = load_progress(progress_file)
with open(args.json_path, "r") as fp:
data = json.load(fp)
image_files = [os.path.join(args.images_root, item["image"]) for item in data]
image_files = image_files[: args.limit] if args.limit > 0 else image_files
# Shard the data
shard_size = len(image_files) // dist_size
start_index = shard_size * dist_rank
end_index = start_index + shard_size if dist_rank < dist_size - 1 else len(image_files)
shard_files = image_files[start_index:end_index]
print(f"Querying {len(shard_files)} images from index {start_index} to {end_index - 1}")
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
tic = time.time()
batch_size = args.parallel
for batch_start in tqdm.tqdm(range(0, len(shard_files), batch_size)):
batch_end = min(batch_start + batch_size, len(shard_files))
if batch_start <= progress_data.get("last_index", -1):
print(f"Skipping already processed batch starting at {batch_start}")
continue
batch_arguments = [{"image_file": image_file} for image_file in shard_files[batch_start:batch_end]]
try:
batch_states = image_description.run_batch(batch_arguments, temperature=0, num_threads=args.parallel, progress_bar=False)
for i, ret in enumerate(batch_states):
image_file = batch_arguments[i]["image_file"]
caption = ret.text().split("ASSISTANT:")[-1].strip()
progress_data["annotations"].append({"image_file": image_file, "caption": caption})
progress_data["last_index"] = batch_start + i # Update last_index relative to this rank's shard
save_progress(progress_file, progress_data)
except Exception as e:
print(f"Error during batch processing: {e}")
save_progress(progress_file, progress_data)
break
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
value = {
"task": "image_captioning",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(shard_files),
"parallel": args.parallel,
"results": progress_data["annotations"],
}
result_file = args.result_file.replace(".json", f"_shard_{dist_rank}_or_{dist_size}.json")
print(f"Write output to {result_file}")
with open(result_file, "w") as fout:
json.dump(value, fout, indent=2)
save_progress(progress_file, progress_data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--images_root", type=str, default="/mnt/bn/vl-research/data/llava_data/cc3m")
parser.add_argument("--json_path", type=str, default="/mnt/bn/vl-research/data/llava_instruct/cc3m_recap_requery_363707.json")
parser.add_argument("--max_tokens", type=int, default=1024)
parser.add_argument("--parallel", type=int, default=32)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--result_file", type=str, default="/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/playground/sgl_llava_inference.json")
parser.add_argument("--limit", type=int, default=-1)
parser.add_argument("--dist", type=int, default=0, help="The rank of the distributed machine")
parser.add_argument("--total_dist", type=int, default=6, help="Total number of distributed machines")
args = parser.parse_args()
main(args)