VirginiaZane commited on
Commit
e3a4271
·
verified ·
1 Parent(s): a907e5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ from diffusers import AutoPipelineForImage2Image
9
+
10
+ # -------------------------
11
+ # 1) 年龄估计模型(HF 上可用)
12
+ # 说明:本示例使用 Hugging Face 的 ViT 年龄估计模型。
13
+ # 我们把分类标签转换成年龄(若是"0-2"取区间中点;若是"23"就取整)。
14
+ # -------------------------
15
+ AGE_MODEL_ID = "nateraw/vit-age-classifier"
16
+
17
+ age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID)
18
+ age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID)
19
+ age_model.eval()
20
+
21
+ def _label_to_age(label: str) -> float:
22
+ # 尝试解析类似 "(0-2)"、"0-2"、"3-9" 的标签
23
+ label = label.strip().replace("(", "").replace(")", "")
24
+ if "-" in label:
25
+ a, b = label.split("-")
26
+ try:
27
+ return (float(a) + float(b)) / 2.0
28
+ except:
29
+ pass
30
+ # 若是单值,如 "23"
31
+ try:
32
+ return float(label)
33
+ except:
34
+ # 兜底:无法解析就返回 NaN
35
+ return float("nan")
36
+
37
+ @torch.inference_mode()
38
+ def estimate_age(image: Image.Image) -> dict:
39
+ inputs = age_processor(images=image, return_tensors="pt")
40
+ logits = age_model(**inputs).logits
41
+ probs = torch.softmax(logits, dim=-1)[0]
42
+
43
+ # 取 top-5 以便展示
44
+ id2label = age_model.config.id2label
45
+ topk = torch.topk(probs, k=min(5, probs.shape[0]))
46
+ items = []
47
+ ages = []
48
+ for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
49
+ label = id2label[idx]
50
+ age = _label_to_age(label)
51
+ ages.append((age, score))
52
+ items.append(f"{label}: {score*100:.1f}%")
53
+
54
+ # 期望年龄(加权平均)
55
+ ages_valid = [(a, p) for a, p in ages if not math.isnan(a)]
56
+ if ages_valid:
57
+ num = sum(a * p for a, p in ages_valid)
58
+ den = sum(p for _, p in ages_valid)
59
+ expected_age = num / den
60
+ else:
61
+ expected_age = float("nan")
62
+
63
+ return {
64
+ "expected_age": None if math.isnan(expected_age) else round(expected_age, 1),
65
+ "top5": items
66
+ }
67
+
68
+ # -------------------------
69
+ # 2) 漫画风格生成(img2img)
70
+ # 说明:使用 "stabilityai/sd-turbo" 的图生图,速度较快,提示词主打漫画/卡通风。
71
+ # CPU 也能跑,但较慢;有 GPU(T4/A10)体验最佳。
72
+ # -------------------------
73
+ IMG2IMG_MODEL_ID = "stabilityai/sd-turbo"
74
+
75
+ device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ dtype = torch.float16 if device == "cuda" else torch.float32
77
+
78
+ pipe = AutoPipelineForImage2Image.from_pretrained(
79
+ IMG2IMG_MODEL_ID,
80
+ torch_dtype=dtype
81
+ )
82
+ pipe = pipe.to(device)
83
+
84
+ DEFAULT_PROMPT = (
85
+ "comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, "
86
+ "professional illustration, vibrant"
87
+ )
88
+ NEG_PROMPT = "realistic, photorealistic, blurry, noisy, artifacts, watermark, text"
89
+
90
+ @torch.inference_mode()
91
+ def stylize_to_comic(
92
+ image: Image.Image,
93
+ prompt: str = DEFAULT_PROMPT,
94
+ strength: float = 0.6,
95
+ guidance_scale: float = 0.0,
96
+ steps: int = 4,
97
+ seed: int | None = 42
98
+ ) -> Image.Image:
99
+ if seed is None or seed < 0:
100
+ generator = None
101
+ else:
102
+ generator = torch.Generator(device=device).manual_seed(seed)
103
+
104
+ image = image.convert("RGB")
105
+ out = pipe(
106
+ prompt=prompt,
107
+ negative_prompt=NEG_PROMPT,
108
+ image=image,
109
+ strength=strength, # 0.2~0.7:数值越大改动越明显
110
+ num_inference_steps=steps, # sd-turbo 推荐极少步数(2~6)
111
+ guidance_scale=guidance_scale, # sd-turbo 常用 0~1
112
+ generator=generator
113
+ )
114
+ return out.images[0]
115
+
116
+ # -------------------------
117
+ # 3) Gradio 界面
118
+ # -------------------------
119
+ def process(image, prompt, strength, guidance, steps, seed):
120
+ if image is None:
121
+ return "请先上传图片", None
122
+
123
+ age_result = estimate_age(image)
124
+ styled = stylize_to_comic(
125
+ image=image,
126
+ prompt=prompt,
127
+ strength=strength,
128
+ guidance_scale=guidance,
129
+ steps=int(steps),
130
+ seed=int(seed) if seed is not None else 42
131
+ )
132
+
133
+ # 结果文字
134
+ if age_result["expected_age"] is None:
135
+ age_text = "年龄估计:解析失败(可能检测不到年龄标签)"
136
+ else:
137
+ age_text = f"年龄估计:≈ {age_result['expected_age']} 岁\nTop-5: " + " | ".join(age_result["top5"])
138
+
139
+ return age_text, styled
140
+
141
+ with gr.Blocks(title="Age & Comicify Agent") as demo:
142
+ gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → 年龄估计 → 漫画风格生成")
143
+ with gr.Row():
144
+ with gr.Column(scale=1):
145
+ in_img = gr.Image(label="上传图片", type="pil")
146
+ prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT)
147
+ strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)")
148
+ guidance = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="引导系数(guidance_scale)")
149
+ steps = gr.Slider(2, 12, value=4, step=1, label="步数(num_inference_steps)")
150
+ seed = gr.Number(value=42, precision=0, label="随机种子(固定可复现)")
151
+ run_btn = gr.Button("🚀 运行")
152
+ with gr.Column(scale=1):
153
+ age_txt = gr.Textbox(label="年龄估计结果")
154
+ out_img = gr.Image(label="漫画风格输出")
155
+
156
+ run_btn.click(
157
+ fn=process,
158
+ inputs=[in_img, prompt, strength, guidance, steps, seed],
159
+ outputs=[age_txt, out_img]
160
+ )
161
+
162
+ if __name__ == "__main__":
163
+ demo.launch()