Spaces:
Running
Running
import gradio as gr | |
import open_clip | |
import torch | |
import numpy as np | |
from PIL import Image | |
# --- 1. Model Initialization (保持不变) --- | |
print("Loading HQ-CLIP model...") | |
model_hq, _, preprocess_hq = open_clip.create_model_and_transforms('hf-hub:zhixiangwei/hqclip-openai-large-ft-vlm1b') | |
tokenizer_hq = open_clip.get_tokenizer('hf-hub:zhixiangwei/hqclip-openai-large-ft-vlm1b') | |
print("HQ-CLIP model loaded.") | |
print("Loading standard OpenAI CLIP model...") | |
model_openai, _, preprocess_openai = open_clip.create_model_and_transforms('ViT-L-14-quickgelu', 'openai') | |
tokenizer_openai = open_clip.get_tokenizer('ViT-L-14-quickgelu') | |
print("OpenAI CLIP model loaded.") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device='cpu' # 保持您的强制CPU设置 | |
model_hq.to(device) | |
model_openai.to(device) | |
print(f"Models moved to {device}.") | |
# --- 2. Core Logic: Refactored for Simplicity --- | |
def calculate_similarities(image, texts_str): | |
""" | |
重构后的核心函数: | |
- 接收一个换行分隔的字符串作为文本输入。 | |
- 返回一个简洁的分析文本和两个字典,直接用于驱动 gr.Label 组件。 | |
""" | |
# 将换行分隔的字符串解析为文本列表 | |
texts = [t.strip() for t in texts_str.split('\n') if t.strip()] | |
# 修复:检查 image 是否为 None,而不是 image.any(),以避免在初始加载时出错 | |
if image is None or not texts: | |
return "请上传一张图片并输入至少一个文本描述。", None, None | |
# --- 内部辅助函数 (保持不变) --- | |
def get_scores(model, preprocess, tokenizer, img_input, text_inputs): | |
img = Image.fromarray(img_input.astype('uint8')) | |
img_tensor = preprocess(img).unsqueeze(0).to(device) | |
tokenized_texts = torch.cat([tokenizer(text) for text in text_inputs]).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(img_tensor) | |
text_features = model.encode_text(tokenized_texts) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
similarities = (100.0 * image_features @ text_features.T).squeeze(0) | |
probs = torch.softmax(similarities, dim=-1) | |
return similarities.cpu().numpy(), probs.cpu().numpy() | |
# --- 为两个模型计算分数和概率 --- | |
similarities_hq, probs_hq = get_scores(model_hq, preprocess_hq, tokenizer_hq, image, texts) | |
similarities_openai, probs_openai = get_scores(model_openai, preprocess_openai, tokenizer_openai, image, texts) | |
# --- 准备 gr.Label 的输出 --- | |
hq_results = {text: prob for text, prob in zip(texts, probs_hq)} | |
openai_results = {text: prob for text, prob in zip(texts, probs_openai)} | |
# --- 准备顶部的最佳匹配分析文本 --- | |
best_idx_hq = np.argmax(similarities_hq) | |
best_idx_openai = np.argmax(similarities_openai) | |
best_match_output = f""" | |
### 🏆 Best Match Analysis | |
**HQ-CLIP's Choice:** **'{texts[best_idx_hq]}'** (Probability: {probs_hq[best_idx_hq]:.2%}) | | |
**OpenAI CLIP's Choice:** **'{texts[best_idx_openai]}'** (Probability: {probs_openai[best_idx_openai]:.2%}) | |
""" | |
return hq_results, openai_results | |
# --- 3. Gradio Interface: Rebuilt with Default Loading --- | |
# 步骤 1: 将示例数据提取到变量中 | |
examples_list = [ | |
["examples/mnls.jpeg", "An oil painting of a smiling, long-haired woman\nAn oil painting of a sad, long-haired woman\nA sketch of a smiling, long-haired woman\nA photo of a smiling, long-haired woman"], | |
["examples/su7s.jpg", "a blue car, with black wheels\na blue car, with blue wheels\na black car, with blue wheels\na black car, with black wheels"], | |
] | |
# 步骤 2: 创建一个在应用加载时运行的函数 | |
def load_default_example(): | |
print("Loading default example...") | |
# 获取第一个示例的数据 | |
image_path, texts_str = examples_list[0] | |
# 将图片路径加载为 numpy 数组 | |
image = np.array(Image.open(image_path)) | |
# 调用主函数进行计算 | |
return [image,texts_str]+list(calculate_similarities(image, texts_str)) | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="HQ-CLIP vs OpenAI CLIP") as demo: | |
gr.Markdown( | |
""" | |
<div style="text-align: center;"> | |
<h1>🎨 HQ-CLIP vs. OpenAI CLIP: A Visual Comparison</h1> | |
<p>Upload an image and provide text descriptions (one per line) to compare models.</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(label="🖼️ Upload Your Image", type="numpy") | |
text_input = gr.Textbox( | |
label="📝 Enter Text Descriptions (one per line)", | |
placeholder="e.g., a cat by the window\na dog in the yard\na sleeping kitten", | |
lines=4 # 允许多行输入 | |
) | |
submit_btn = gr.Button("🔍 Analyze & Compare", variant="primary") | |
with gr.Column(): | |
# best_match_output = gr.Markdown() | |
with gr.Row(): | |
# 移除 num_top_classes,gr.Label 会自动显示所有结果 | |
hq_label = gr.Label(label="HQ-CLIP Results") | |
openai_label = gr.Label(label="OpenAI CLIP Results") | |
gr.Markdown("---") | |
gr.Markdown("### ✨ Try Some Examples") | |
gr.Examples( | |
examples=examples_list, # 使用变量 | |
inputs=[image_input, text_input], | |
# 注意:Examples 的 outputs 和 fn 仅在用户点击时触发,不影响 load 事件 | |
outputs=[hq_label, openai_label], | |
fn=calculate_similarities, | |
label="Click an example to run the analysis" | |
) | |
submit_btn.click( | |
fn=calculate_similarities, | |
inputs=[image_input, text_input], | |
outputs=[hq_label, openai_label] | |
) | |
# 步骤 3: 在应用加载时,调用加载函数并更新输出组件 | |
demo.load( | |
fn=load_default_example, | |
inputs=None, # 加载时不需要输入 | |
outputs=[image_input, text_input, hq_label, openai_label] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |