hqclip / app.py
zhixiangwei's picture
x
eea177e
raw
history blame
6.17 kB
import gradio as gr
import open_clip
import torch
import numpy as np
from PIL import Image
# --- 1. Model Initialization (保持不变) ---
print("Loading HQ-CLIP model...")
model_hq, _, preprocess_hq = open_clip.create_model_and_transforms('hf-hub:zhixiangwei/hqclip-openai-large-ft-vlm1b')
tokenizer_hq = open_clip.get_tokenizer('hf-hub:zhixiangwei/hqclip-openai-large-ft-vlm1b')
print("HQ-CLIP model loaded.")
print("Loading standard OpenAI CLIP model...")
model_openai, _, preprocess_openai = open_clip.create_model_and_transforms('ViT-L-14-quickgelu', 'openai')
tokenizer_openai = open_clip.get_tokenizer('ViT-L-14-quickgelu')
print("OpenAI CLIP model loaded.")
device = "cuda" if torch.cuda.is_available() else "cpu"
device='cpu' # 保持您的强制CPU设置
model_hq.to(device)
model_openai.to(device)
print(f"Models moved to {device}.")
# --- 2. Core Logic: Refactored for Simplicity ---
def calculate_similarities(image, texts_str):
"""
重构后的核心函数:
- 接收一个换行分隔的字符串作为文本输入。
- 返回一个简洁的分析文本和两个字典,直接用于驱动 gr.Label 组件。
"""
# 将换行分隔的字符串解析为文本列表
texts = [t.strip() for t in texts_str.split('\n') if t.strip()]
# 修复:检查 image 是否为 None,而不是 image.any(),以避免在初始加载时出错
if image is None or not texts:
return "请上传一张图片并输入至少一个文本描述。", None, None
# --- 内部辅助函数 (保持不变) ---
def get_scores(model, preprocess, tokenizer, img_input, text_inputs):
img = Image.fromarray(img_input.astype('uint8'))
img_tensor = preprocess(img).unsqueeze(0).to(device)
tokenized_texts = torch.cat([tokenizer(text) for text in text_inputs]).to(device)
with torch.no_grad():
image_features = model.encode_image(img_tensor)
text_features = model.encode_text(tokenized_texts)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarities = (100.0 * image_features @ text_features.T).squeeze(0)
probs = torch.softmax(similarities, dim=-1)
return similarities.cpu().numpy(), probs.cpu().numpy()
# --- 为两个模型计算分数和概率 ---
similarities_hq, probs_hq = get_scores(model_hq, preprocess_hq, tokenizer_hq, image, texts)
similarities_openai, probs_openai = get_scores(model_openai, preprocess_openai, tokenizer_openai, image, texts)
# --- 准备 gr.Label 的输出 ---
hq_results = {text: prob for text, prob in zip(texts, probs_hq)}
openai_results = {text: prob for text, prob in zip(texts, probs_openai)}
# --- 准备顶部的最佳匹配分析文本 ---
best_idx_hq = np.argmax(similarities_hq)
best_idx_openai = np.argmax(similarities_openai)
best_match_output = f"""
### 🏆 Best Match Analysis
**HQ-CLIP's Choice:** **'{texts[best_idx_hq]}'** (Probability: {probs_hq[best_idx_hq]:.2%}) |
**OpenAI CLIP's Choice:** **'{texts[best_idx_openai]}'** (Probability: {probs_openai[best_idx_openai]:.2%})
"""
return hq_results, openai_results
# --- 3. Gradio Interface: Rebuilt with Default Loading ---
# 步骤 1: 将示例数据提取到变量中
examples_list = [
["examples/mnls.jpeg", "An oil painting of a smiling, long-haired woman\nAn oil painting of a sad, long-haired woman\nA sketch of a smiling, long-haired woman\nA photo of a smiling, long-haired woman"],
["examples/su7s.jpg", "a blue car, with black wheels\na blue car, with blue wheels\na black car, with blue wheels\na black car, with black wheels"],
]
# 步骤 2: 创建一个在应用加载时运行的函数
def load_default_example():
print("Loading default example...")
# 获取第一个示例的数据
image_path, texts_str = examples_list[0]
# 将图片路径加载为 numpy 数组
image = np.array(Image.open(image_path))
# 调用主函数进行计算
return [image,texts_str]+list(calculate_similarities(image, texts_str))
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="HQ-CLIP vs OpenAI CLIP") as demo:
gr.Markdown(
"""
<div style="text-align: center;">
<h1>🎨 HQ-CLIP vs. OpenAI CLIP: A Visual Comparison</h1>
<p>Upload an image and provide text descriptions (one per line) to compare models.</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="🖼️ Upload Your Image", type="numpy")
text_input = gr.Textbox(
label="📝 Enter Text Descriptions (one per line)",
placeholder="e.g., a cat by the window\na dog in the yard\na sleeping kitten",
lines=4 # 允许多行输入
)
submit_btn = gr.Button("🔍 Analyze & Compare", variant="primary")
with gr.Column():
# best_match_output = gr.Markdown()
with gr.Row():
# 移除 num_top_classes,gr.Label 会自动显示所有结果
hq_label = gr.Label(label="HQ-CLIP Results")
openai_label = gr.Label(label="OpenAI CLIP Results")
gr.Markdown("---")
gr.Markdown("### ✨ Try Some Examples")
gr.Examples(
examples=examples_list, # 使用变量
inputs=[image_input, text_input],
# 注意:Examples 的 outputs 和 fn 仅在用户点击时触发,不影响 load 事件
outputs=[hq_label, openai_label],
fn=calculate_similarities,
label="Click an example to run the analysis"
)
submit_btn.click(
fn=calculate_similarities,
inputs=[image_input, text_input],
outputs=[hq_label, openai_label]
)
# 步骤 3: 在应用加载时,调用加载函数并更新输出组件
demo.load(
fn=load_default_example,
inputs=None, # 加载时不需要输入
outputs=[image_input, text_input, hq_label, openai_label]
)
if __name__ == "__main__":
demo.launch(share=True)