Spaces:
Sleeping
Sleeping
import re | |
import requests | |
import os | |
from abc import ABC | |
from bs4 import BeautifulSoup | |
from PIL import Image | |
from io import BytesIO | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from logger import LOG # 导入日志工具 | |
class ImageAdvisor(ABC): | |
""" | |
聊天机器人基类,提供建议配图的功能。 | |
""" | |
def __init__(self, prompt_file="./prompts/image_advisor.txt"): | |
self.prompt_file = prompt_file | |
self.prompt = self.load_prompt() | |
self.create_advisor() | |
def load_prompt(self): | |
""" | |
从文件加载系统提示语。 | |
""" | |
try: | |
with open(self.prompt_file, "r", encoding="utf-8") as file: | |
return file.read().strip() | |
except FileNotFoundError: | |
LOG.error(f"找不到提示文件 {self.prompt_file}!") | |
raise | |
def create_advisor(self): | |
""" | |
初始化聊天机器人,包括系统提示和消息历史记录。 | |
""" | |
chat_prompt = ChatPromptTemplate.from_messages([ | |
("system", self.prompt), # 系统提示部分 | |
("human", "**Content**:\n\n{input}"), # 消息占位符 | |
]) | |
self.model = ChatOpenAI( | |
model="gpt-4o-mini", | |
temperature=0.7, | |
max_tokens=4096, | |
) | |
self.advisor = chat_prompt | self.model | |
def generate_images(self, markdown_content, image_directory="tmps", num_images=3): | |
""" | |
生成图片并嵌入到指定的 PowerPoint 内容中。 | |
参数: | |
markdown_content (str): PowerPoint markdown 原始格式 | |
image_directory (str): 本地保存图片的文件夹名称 | |
num_images (int): 每个幻灯片搜索的图像数量 | |
返回: | |
content_with_images (str): 嵌入图片后的内容 | |
image_pair (dict): 每个幻灯片标题对应的图像路径 | |
""" | |
response = self.advisor.invoke({ | |
"input": markdown_content, | |
}) | |
LOG.debug(f"[Advisor 建议配图]\n{response.content}") | |
keywords = self.get_keywords(response.content) | |
image_pair = {} | |
for slide_title, query in keywords.items(): | |
# 检索图像 | |
images = self.get_bing_images(slide_title, query, num_images, timeout=1, retries=3) | |
if images: | |
for image in images: | |
LOG.debug(f"Name: {image['slide_title']}, Query: {image['query']} 分辨率:{image['width']}x{image['height']}") | |
else: | |
LOG.warning(f"No images found for {slide_title}.") | |
continue | |
# 仅处理分辨率最高的图像 | |
img = images[0] | |
save_directory = f"images/{image_directory}" | |
os.makedirs(save_directory, exist_ok=True) | |
save_path = os.path.join(save_directory, f"{img['slide_title']}_1.jpeg") | |
self.save_image(img["obj"], save_path) | |
image_pair[img["slide_title"]] = save_path | |
content_with_images = self.insert_images(markdown_content, image_pair) | |
return content_with_images, image_pair | |
def get_keywords(self, advice): | |
""" | |
使用正则表达式提取关键词。 | |
参数: | |
advice (str): 提示文本 | |
返回: | |
keywords (dict): 提取的关键词字典 | |
""" | |
pairs = re.findall(r'\[(.+?)\]:\s*(.+)', advice) | |
keywords = {key.strip(): value.strip() for key, value in pairs} | |
LOG.debug(f"[检索关键词 正则提取结果]{keywords}") | |
return keywords | |
def get_bing_images(self, slide_title, query, num_images=5, timeout=1, retries=3): | |
""" | |
从 Bing 检索图像,最多重试3次。 | |
参数: | |
slide_title (str): 幻灯片标题 | |
query (str): 图像搜索关键词 | |
num_images (int): 搜索的图像数量 | |
timeout (int): 每次请求超时时间(秒),默认1秒 | |
retries (int): 最大重试次数,默认3次 | |
返回: | |
sorted_images (list): 符合条件的图像数据列表 | |
""" | |
url = f"https://www.bing.com/images/search?q={query}" | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" | |
} | |
# 尝试请求并设置重试逻辑 | |
for attempt in range(retries): | |
try: | |
response = requests.get(url, headers=headers, timeout=timeout) | |
response.raise_for_status() | |
break # 请求成功,跳出重试循环 | |
except requests.RequestException as e: | |
LOG.warning(f"Attempt {attempt + 1}/{retries} failed for query '{query}': {e}") | |
if attempt == retries - 1: | |
LOG.error(f"Max retries reached for query '{query}'.") | |
return [] | |
soup = BeautifulSoup(response.text, "html.parser") | |
image_elements = soup.select("a.iusc") | |
image_links = [] | |
for img in image_elements: | |
m_data = img.get("m") | |
if m_data: | |
m_json = eval(m_data) | |
if "murl" in m_json: | |
image_links.append(m_json["murl"]) | |
if len(image_links) >= num_images: | |
break | |
image_data = [] | |
for link in image_links: | |
for attempt in range(retries): | |
try: | |
img_data = requests.get(link, headers=headers, timeout=timeout) | |
img = Image.open(BytesIO(img_data.content)) | |
image_info = { | |
"slide_title": slide_title, | |
"query": query, | |
"width": img.width, | |
"height": img.height, | |
"resolution": img.width * img.height, | |
"obj": img, | |
} | |
image_data.append(image_info) | |
break # 成功下载图像,跳出重试循环 | |
except Exception as e: | |
LOG.warning(f"Attempt {attempt + 1}/{retries} failed for image '{link}': {e}") | |
if attempt == retries - 1: | |
LOG.error(f"Max retries reached for image '{link}'. Skipping.") | |
sorted_images = sorted(image_data, key=lambda x: x["resolution"], reverse=True) | |
return sorted_images | |
def save_image(self, img, save_path, format="JPEG", quality=85, max_size=1080): | |
""" | |
保存图像到本地并压缩。 | |
参数: | |
img (Image): 图像对象 | |
save_path (str): 保存路径 | |
format (str): 保存格式,默认 JPEG | |
quality (int): 图像质量,默认 85 | |
max_size (int): 最大边长,默认 1080 | |
""" | |
try: | |
width, height = img.size | |
if max(width, height) > max_size: | |
scaling_factor = max_size / max(width, height) | |
new_width = int(width * scaling_factor) | |
new_height = int(height * scaling_factor) | |
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
if img.mode == "RGBA": | |
format = "PNG" | |
save_options = {"optimize": True} | |
else: | |
save_options = { | |
"quality": quality, | |
"optimize": True, | |
"progressive": True | |
} | |
img.save(save_path, format=format, **save_options) | |
LOG.debug(f"Image saved as {save_path} in {format} format with quality {quality}.") | |
except Exception as e: | |
LOG.error(f"Failed to save image: {e}") | |
def insert_images(self, markdown_content, image_pair): | |
""" | |
将图像嵌入到 Markdown 内容中。 | |
参数: | |
markdown_content (str): Markdown 内容 | |
image_pair (dict): 幻灯片标题到图像路径的映射 | |
返回: | |
new_content (str): 嵌入图像后的内容 | |
""" | |
lines = markdown_content.split('\n') | |
new_lines = [] | |
i = 0 | |
while i < len(lines): | |
line = lines[i] | |
new_lines.append(line) | |
if line.startswith('## '): | |
slide_title = line[3:].strip() | |
if slide_title in image_pair: | |
image_path = image_pair[slide_title] | |
image_markdown = f'' | |
new_lines.append(image_markdown) | |
i += 1 | |
new_content = '\n'.join(new_lines) | |
return new_content | |