VirginiaZane commited on
Commit
b7380ce
·
verified ·
1 Parent(s): 5026aa5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -46
app.py CHANGED
@@ -8,18 +8,14 @@ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
  from diffusers import AutoPipelineForImage2Image
9
 
10
  # -------------------------
11
- # 1) 年龄估计模型(HF 上可用)
12
- # 说明:本示例使用 Hugging Face 的 ViT 年龄估计模型。
13
- # 我们把分类标签转换成年龄(若是"0-2"取区间中点;若是"23"就取整)。
14
  # -------------------------
15
  AGE_MODEL_ID = "nateraw/vit-age-classifier"
16
-
17
  age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID)
18
  age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID)
19
  age_model.eval()
20
 
21
  def _label_to_age(label: str) -> float:
22
- # 尝试解析类似 "(0-2)"、"0-2"、"3-9" 的标签
23
  label = label.strip().replace("(", "").replace(")", "")
24
  if "-" in label:
25
  a, b = label.split("-")
@@ -27,11 +23,9 @@ def _label_to_age(label: str) -> float:
27
  return (float(a) + float(b)) / 2.0
28
  except:
29
  pass
30
- # 若是单值,如 "23"
31
  try:
32
  return float(label)
33
  except:
34
- # 兜底:无法解析就返回 NaN
35
  return float("nan")
36
 
37
  @torch.inference_mode()
@@ -39,19 +33,16 @@ def estimate_age(image: Image.Image) -> dict:
39
  inputs = age_processor(images=image, return_tensors="pt")
40
  logits = age_model(**inputs).logits
41
  probs = torch.softmax(logits, dim=-1)[0]
42
-
43
- # 取 top-5 以便展示
44
  id2label = age_model.config.id2label
 
45
  topk = torch.topk(probs, k=min(5, probs.shape[0]))
46
- items = []
47
- ages = []
48
  for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
49
  label = id2label[idx]
50
  age = _label_to_age(label)
51
  ages.append((age, score))
52
  items.append(f"{label}: {score*100:.1f}%")
53
 
54
- # 期望年龄(加权平均)
55
  ages_valid = [(a, p) for a, p in ages if not math.isnan(a)]
56
  if ages_valid:
57
  num = sum(a * p for a, p in ages_valid)
@@ -67,19 +58,15 @@ def estimate_age(image: Image.Image) -> dict:
67
 
68
  # -------------------------
69
  # 2) 漫画风格生成(img2img)
70
- # 说明:使用 "stabilityai/sd-turbo" 的图生图,速度较快,提示词主打漫画/卡通风。
71
- # CPU 也能跑,但较慢;有 GPU(T4/A10)体验最佳。
72
  # -------------------------
73
  IMG2IMG_MODEL_ID = "stabilityai/sd-turbo"
74
-
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  dtype = torch.float16 if device == "cuda" else torch.float32
77
 
78
  pipe = AutoPipelineForImage2Image.from_pretrained(
79
  IMG2IMG_MODEL_ID,
80
  torch_dtype=dtype
81
- )
82
- pipe = pipe.to(device)
83
 
84
  DEFAULT_PROMPT = (
85
  "comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, "
@@ -96,32 +83,34 @@ def stylize_to_comic(
96
  steps: int = 4,
97
  seed: int | None = 42
98
  ) -> Image.Image:
99
- if seed is None or seed < 0:
100
- generator = None
101
- else:
102
- generator = torch.Generator(device=device).manual_seed(seed)
103
-
104
  image = image.convert("RGB")
105
  out = pipe(
106
  prompt=prompt,
107
  negative_prompt=NEG_PROMPT,
108
  image=image,
109
- strength=strength, # 0.2~0.7:数值越大改动越明显
110
- num_inference_steps=steps, # sd-turbo 推荐极少步数(2~6)
111
- guidance_scale=guidance_scale, # sd-turbo 常用 0~1
112
- generator=generator
113
  )
114
  return out.images[0]
115
 
116
  # -------------------------
117
- # 3) Gradio 界面
118
  # -------------------------
119
- def process(image, prompt, strength, guidance, steps, seed):
120
  if image is None:
121
- return "请先上传图片", None
 
 
 
 
122
 
123
- age_result = estimate_age(image)
124
- styled = stylize_to_comic(
 
 
125
  image=image,
126
  prompt=prompt,
127
  strength=strength,
@@ -130,19 +119,16 @@ def process(image, prompt, strength, guidance, steps, seed):
130
  seed=int(seed) if seed is not None else 42
131
  )
132
 
133
- # 结果文字
134
- if age_result["expected_age"] is None:
135
- age_text = "年龄估计:解析失败(可能检测不到年龄标签)"
136
- else:
137
- age_text = f"年龄估计:≈ {age_result['expected_age']} 岁\nTop-5: " + " | ".join(age_result["top5"])
138
 
139
- return age_text, styled
 
 
 
140
 
141
- with gr.Blocks(title="Age & Comicify Agent") as demo:
142
- gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → 年龄估计 → 漫画风格生成")
143
  with gr.Row():
144
  with gr.Column(scale=1):
145
- run_btn = gr.Button("🚀 运行")
146
  in_img = gr.Image(label="上传图片", type="pil")
147
  prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT)
148
  strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)")
@@ -153,11 +139,10 @@ with gr.Blocks(title="Age & Comicify Agent") as demo:
153
  age_txt = gr.Textbox(label="年龄估计结果")
154
  out_img = gr.Image(label="漫画风格输出")
155
 
156
- run_btn.click(
157
- fn=process,
158
- inputs=[in_img, prompt, strength, guidance, steps, seed],
159
- outputs=[age_txt, out_img]
160
- )
161
 
162
  if __name__ == "__main__":
163
- demo.launch()
 
 
8
  from diffusers import AutoPipelineForImage2Image
9
 
10
  # -------------------------
11
+ # 1) 年龄估计模型
 
 
12
  # -------------------------
13
  AGE_MODEL_ID = "nateraw/vit-age-classifier"
 
14
  age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID)
15
  age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID)
16
  age_model.eval()
17
 
18
  def _label_to_age(label: str) -> float:
 
19
  label = label.strip().replace("(", "").replace(")", "")
20
  if "-" in label:
21
  a, b = label.split("-")
 
23
  return (float(a) + float(b)) / 2.0
24
  except:
25
  pass
 
26
  try:
27
  return float(label)
28
  except:
 
29
  return float("nan")
30
 
31
  @torch.inference_mode()
 
33
  inputs = age_processor(images=image, return_tensors="pt")
34
  logits = age_model(**inputs).logits
35
  probs = torch.softmax(logits, dim=-1)[0]
 
 
36
  id2label = age_model.config.id2label
37
+
38
  topk = torch.topk(probs, k=min(5, probs.shape[0]))
39
+ items, ages = [], []
 
40
  for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
41
  label = id2label[idx]
42
  age = _label_to_age(label)
43
  ages.append((age, score))
44
  items.append(f"{label}: {score*100:.1f}%")
45
 
 
46
  ages_valid = [(a, p) for a, p in ages if not math.isnan(a)]
47
  if ages_valid:
48
  num = sum(a * p for a, p in ages_valid)
 
58
 
59
  # -------------------------
60
  # 2) 漫画风格生成(img2img)
 
 
61
  # -------------------------
62
  IMG2IMG_MODEL_ID = "stabilityai/sd-turbo"
 
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
  dtype = torch.float16 if device == "cuda" else torch.float32
65
 
66
  pipe = AutoPipelineForImage2Image.from_pretrained(
67
  IMG2IMG_MODEL_ID,
68
  torch_dtype=dtype
69
+ ).to(device)
 
70
 
71
  DEFAULT_PROMPT = (
72
  "comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, "
 
83
  steps: int = 4,
84
  seed: int | None = 42
85
  ) -> Image.Image:
86
+ generator = None if (seed is None or seed < 0) else torch.Generator(device=device).manual_seed(int(seed))
 
 
 
 
87
  image = image.convert("RGB")
88
  out = pipe(
89
  prompt=prompt,
90
  negative_prompt=NEG_PROMPT,
91
  image=image,
92
+ strength=float(strength),
93
+ num_inference_steps=int(steps),
94
+ guidance_scale=float(guidance_scale),
95
+ generator=generator,
96
  )
97
  return out.images[0]
98
 
99
  # -------------------------
100
+ # 3) Gradio 界面(两个按钮都在最上面)
101
  # -------------------------
102
+ def ui_estimate_age(image):
103
  if image is None:
104
+ return "请先上传图片"
105
+ res = estimate_age(image)
106
+ if res["expected_age"] is None:
107
+ return "年龄估计:解析失败(可能检测不到年龄标签)"
108
+ return f"年龄估计:≈ {res['expected_age']} 岁\nTop-5: " + " | ".join(res["top5"])
109
 
110
+ def ui_stylize(image, prompt, strength, guidance, steps, seed):
111
+ if image is None:
112
+ return None
113
+ return stylize_to_comic(
114
  image=image,
115
  prompt=prompt,
116
  strength=strength,
 
119
  seed=int(seed) if seed is not None else 42
120
  )
121
 
122
+ with gr.Blocks(title="Age & Comicify Agent") as demo:
123
+ gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → ① 估计年龄 ② 生成���画风格图片")
 
 
 
124
 
125
+ # 顶部两个按钮
126
+ with gr.Row():
127
+ btn_est = gr.Button("🧮 估计年龄", variant="primary")
128
+ btn_gen = gr.Button("🎨 生成漫画图片", variant="secondary")
129
 
 
 
130
  with gr.Row():
131
  with gr.Column(scale=1):
 
132
  in_img = gr.Image(label="上传图片", type="pil")
133
  prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT)
134
  strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)")
 
139
  age_txt = gr.Textbox(label="年龄估计结果")
140
  out_img = gr.Image(label="漫画风格输出")
141
 
142
+ # 绑定:按钮各自只触发一个功能
143
+ btn_est.click(fn=ui_estimate_age, inputs=[in_img], outputs=[age_txt])
144
+ btn_gen.click(fn=ui_stylize, inputs=[in_img, prompt, strength, guidance, steps, seed], outputs=[out_img])
 
 
145
 
146
  if __name__ == "__main__":
147
+ # 可选:并发/队列
148
+ demo.queue().launch()