genshin_impact_ccip / run_video_ccip.py
svjack's picture
Update run_video_ccip.py
ef60e7a verified
'''
python run_video_ccip.py Beyond_the_Boundary_Videos_sm Beyond_the_Boundary_Videos_sm_named --image_dir named_image_dir
import pandas as pd
import pathlib
import json
def read_j(x):
with open(x, "r") as f:
return json.load(f)
path_s = pd.Series(list(pathlib.Path("Beyond_the_Boundary_Videos_sm_named/").rglob("*.json"))).map(str)
df = pd.DataFrame(path_s.head(int(1e10)).map(
lambda x: (x, read_j(x))
).values.tolist()
).explode(1).applymap(
lambda x: x["results"] if type(x) == type({}) else x
).explode(1)
df
right_df = pd.json_normalize(df[1])
df = pd.concat([df.reset_index().iloc[:, 1:], right_df.reset_index().iloc[:,1:]], axis = 1)
df = df[
df["prediction"] == "Same"
]
###df[0].sort_values().drop_duplicates()
df
!git clone https://huggingface.co/datasets/svjack/Beyond_the_Boundary_Videos_Captioned
import os
from shutil import copy2
s = df[
df["difference"] <= 0.1
][0].sort_values().map(
lambda x: x.replace("_named", "").replace(".json", ".mp4")
)
import pathlib
import numpy as np
all_paths_mp4 = pd.Series(list(pathlib.Path("Beyond_the_Boundary_Videos_Captioned").rglob("*.mp4"))).map(str).map(
lambda x: x if any(map(lambda y: x.endswith(y.split("/")[-1]), s.values.tolist())) else np.nan
).dropna()
all_paths_txt = all_paths_mp4.map(lambda x: x.replace(".mp4", ".txt")).map(lambda x: x if os.path.exists(x) else np.nan).dropna()
os.makedirs("tgt_dir", exist_ok=True)
for ele in all_paths_mp4.values.tolist() + all_paths_txt.values.tolist():
copy2(ele, os.path.join("tgt_dir", ele.split("/")[-1]))
python run_video_ccip.py Beyond_the_Boundary_Videos Beyond_the_Boundary_Videos_named --image_dir named_image_dir
'''
import os
import json
from tqdm import tqdm
from PIL import Image
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
import pathlib
import argparse
from moviepy.editor import VideoFileClip
def load_images_from_directory(image_dir):
"""
从指定目录加载图片,构建字典。
键为图片的文件名(不含扩展名),值为图片的 PIL.Image 对象。
"""
name_image_dict = {}
image_paths = list(pathlib.Path(image_dir).rglob("*.png")) + list(pathlib.Path(image_dir).rglob("*.jpg")) + list(pathlib.Path(image_dir).rglob("*.jpeg")) + list(pathlib.Path(image_dir).rglob("*.webp"))
for image_path in tqdm(image_paths, desc="Loading images"):
image = Image.open(image_path)
name = os.path.splitext(os.path.basename(image_path))[0] # 去掉扩展名
name_image_dict[name] = image
return name_image_dict
def _compare_with_dataset(imagex, model_name, name_image_dict):
threshold = ccip_default_threshold(model_name)
results = []
for name, imagey in name_image_dict.items():
diff = ccip_difference(imagex, imagey)
result = {
"difference": diff,
"prediction": 'Same' if diff <= threshold else 'Not Same',
"name": name
}
results.append(result)
# 按照 diff 值进行排序
results.sort(key=lambda x: x["difference"])
return results
def process_video(video_path, model_name, output_dir, max_frames, name_image_dict):
# 打开视频文件
clip = VideoFileClip(video_path)
duration = clip.duration
fps = clip.fps
total_frames = int(duration * fps)
# 计算帧间隔
frame_interval = max(1, total_frames // max_frames)
# 生成输出文件名
video_name = os.path.splitext(os.path.basename(video_path))[0]
output_file = os.path.join(output_dir, f"{video_name}.json")
results = []
# 采样帧并处理
for i in tqdm(range(0, total_frames, frame_interval), desc="Processing frames"):
frame = clip.get_frame(i / fps)
image = Image.fromarray(frame)
frame_results = _compare_with_dataset(image, model_name, name_image_dict)
results.append({
"frame_time": i / fps,
"results": frame_results
})
# 保存结果到 JSON 文件
with open(output_file, 'w') as f:
json.dump(results, f, indent=4)
def main():
parser = argparse.ArgumentParser(description="Compare videos with a dataset and save results as JSON.")
parser.add_argument("input_path", type=str, help="Path to the input video or directory containing videos.")
parser.add_argument("output_dir", type=str, help="Directory to save the output JSON files.")
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images to compare with.")
parser.add_argument("--model", type=str, default=_DEFAULT_MODEL_NAMES, choices=_VALID_MODEL_NAMES, help="Model to use for comparison.")
parser.add_argument("--max_frames", type=int, default=3, help="Maximum number of frames to process per video.")
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 加载图片数据集
name_image_dict = load_images_from_directory(args.image_dir)
# 判断输入路径是文件还是目录
if os.path.isfile(args.input_path):
video_paths = [args.input_path]
elif os.path.isdir(args.input_path):
video_paths = list(pathlib.Path(args.input_path).rglob("*.mp4")) + list(pathlib.Path(args.input_path).rglob("*.avi"))
else:
raise ValueError("Input path must be a valid file or directory.")
video_paths = list(map(str, video_paths))
# 处理每个视频
for video_path in tqdm(video_paths, desc="Processing videos"):
process_video(video_path, args.model, args.output_dir, args.max_frames, name_image_dict)
if __name__ == '__main__':
main()