import os import math import torch import gradio as gr from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification from diffusers import AutoPipelineForImage2Image # ------------------------- # 1) 年龄估计模型(HF 上可用) # 说明:本示例使用 Hugging Face 的 ViT 年龄估计模型。 # 我们把分类标签转换成年龄(若是"0-2"取区间中点;若是"23"就取整)。 # ------------------------- AGE_MODEL_ID = "nateraw/vit-age-classifier" age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID) age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID) age_model.eval() def _label_to_age(label: str) -> float: # 尝试解析类似 "(0-2)"、"0-2"、"3-9" 的标签 label = label.strip().replace("(", "").replace(")", "") if "-" in label: a, b = label.split("-") try: return (float(a) + float(b)) / 2.0 except: pass # 若是单值,如 "23" try: return float(label) except: # 兜底:无法解析就返回 NaN return float("nan") @torch.inference_mode() def estimate_age(image: Image.Image) -> dict: inputs = age_processor(images=image, return_tensors="pt") logits = age_model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0] # 取 top-5 以便展示 id2label = age_model.config.id2label topk = torch.topk(probs, k=min(5, probs.shape[0])) items = [] ages = [] for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): label = id2label[idx] age = _label_to_age(label) ages.append((age, score)) items.append(f"{label}: {score*100:.1f}%") # 期望年龄(加权平均) ages_valid = [(a, p) for a, p in ages if not math.isnan(a)] if ages_valid: num = sum(a * p for a, p in ages_valid) den = sum(p for _, p in ages_valid) expected_age = num / den else: expected_age = float("nan") return { "expected_age": None if math.isnan(expected_age) else round(expected_age, 1), "top5": items } # ------------------------- # 2) 漫画风格生成(img2img) # 说明:使用 "stabilityai/sd-turbo" 的图生图,速度较快,提示词主打漫画/卡通风。 # CPU 也能跑,但较慢;有 GPU(T4/A10)体验最佳。 # ------------------------- IMG2IMG_MODEL_ID = "stabilityai/sd-turbo" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 pipe = AutoPipelineForImage2Image.from_pretrained( IMG2IMG_MODEL_ID, torch_dtype=dtype ) pipe = pipe.to(device) DEFAULT_PROMPT = ( "comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, " "professional illustration, vibrant" ) NEG_PROMPT = "realistic, photorealistic, blurry, noisy, artifacts, watermark, text" @torch.inference_mode() def stylize_to_comic( image: Image.Image, prompt: str = DEFAULT_PROMPT, strength: float = 0.6, guidance_scale: float = 0.0, steps: int = 4, seed: int | None = 42 ) -> Image.Image: if seed is None or seed < 0: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) image = image.convert("RGB") out = pipe( prompt=prompt, negative_prompt=NEG_PROMPT, image=image, strength=strength, # 0.2~0.7:数值越大改动越明显 num_inference_steps=steps, # sd-turbo 推荐极少步数(2~6) guidance_scale=guidance_scale, # sd-turbo 常用 0~1 generator=generator ) return out.images[0] # ------------------------- # 3) Gradio 界面 # ------------------------- def process(image, prompt, strength, guidance, steps, seed): if image is None: return "请先上传图片", None age_result = estimate_age(image) styled = stylize_to_comic( image=image, prompt=prompt, strength=strength, guidance_scale=guidance, steps=int(steps), seed=int(seed) if seed is not None else 42 ) # 结果文字 if age_result["expected_age"] is None: age_text = "年龄估计:解析失败(可能检测不到年龄标签)" else: age_text = f"年龄估计:≈ {age_result['expected_age']} 岁\nTop-5: " + " | ".join(age_result["top5"]) return age_text, styled with gr.Blocks(title="Age & Comicify Agent") as demo: gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → 年龄估计 → 漫画风格生成") with gr.Row(): with gr.Column(scale=1): in_img = gr.Image(label="上传图片", type="pil") prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT) strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)") guidance = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="引导系数(guidance_scale)") steps = gr.Slider(2, 12, value=4, step=1, label="步数(num_inference_steps)") seed = gr.Number(value=42, precision=0, label="随机种子(固定可复现)") run_btn = gr.Button("🚀 运行") with gr.Column(scale=1): age_txt = gr.Textbox(label="年龄估计结果") out_img = gr.Image(label="漫画风格输出") run_btn.click( fn=process, inputs=[in_img, prompt, strength, guidance, steps, seed], outputs=[age_txt, out_img] ) if __name__ == "__main__": demo.launch()