File size: 4,824 Bytes
76684fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import unittest
import os
import sys
from pptx import Presentation

# 添加 src 目录到模块搜索路径,以便可以导入 src 目录中的模块
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))

from data_structures import PowerPoint, Slide, SlideContent
from ppt_generator import generate_presentation

class TestPPTGenerator(unittest.TestCase):
    """
    测试 ppt_generator 模块的 generate_presentation 函数,验证生成的 PowerPoint 文件内容是否符合预期。
    """

    def setUp(self):
        """
        设置测试数据和输出路径。
        """
        # 定义输入 PowerPoint 数据结构
        self.powerpoint_data = PowerPoint(
            title="ChatPPT Demo",
            slides=[
                Slide(
                    layout_id=1,
                    layout_name="Title 1",
                    content=SlideContent(title="ChatPPT Demo")
                ),
                Slide(
                    layout_id=2,
                    layout_name="Title, Content 0",
                    content=SlideContent(
                        title="2024 业绩概述",
                        bullet_points=[
                            {"text": "总收入增长15%", "level": 0},
                            {"text": "市场份额扩大至30%", "level": 0}
                        ]
                    )
                ),
                Slide(
                    layout_id=8,
                    layout_name="Title, Content, Picture 2",
                    content=SlideContent(
                        title="业绩图表",
                        bullet_points=[{"text": "OpenAI 利润不断增加", "level": 0}],
                        image_path="images/performance_chart.png"
                    )
                ),
                Slide(
                    layout_id=8,
                    layout_name="Title, Content, Picture 2",
                    content=SlideContent(
                        title="新产品发布",
                        bullet_points=[
                            {"text": "产品A: **特色功能介绍**", "level": 0},
                            {"text": "增长潜力巨大", "level": 1},
                            {"text": "新兴市场", "level": 1},
                            {"text": "**非洲**市场", "level": 2},
                            {"text": "**东南亚**市场", "level": 2},
                            {"text": "产品B: 市场定位", "level": 0}
                        ],
                        image_path="images/forecast.png"
                    )
                )
            ]
        )

        self.template_path = "templates/SimpleTemplate.pptx"  # 假设存在模板文件
        self.output_path = "outputs/test_presentation.pptx"  # 定义输出文件路径

    def test_generate_presentation(self):
        """
        测试 generate_presentation 函数生成的 PowerPoint 文件是否符合预期。
        """
        # 调用函数生成 PowerPoint 演示文稿
        generate_presentation(self.powerpoint_data, self.template_path, self.output_path)

        # 检查输出文件是否存在
        self.assertTrue(os.path.exists(self.output_path), "输出 PowerPoint 文件未找到。")

        # 打开生成的 PowerPoint 文件并验证内容
        prs = Presentation(self.output_path)
        
        # 检查演示文稿标题
        self.assertEqual(prs.core_properties.title, self.powerpoint_data.title)

        # 检查幻灯片数量
        self.assertEqual(len(prs.slides), len(self.powerpoint_data.slides))

        # 验证每张幻灯片的内容
        for idx, slide_data in enumerate(self.powerpoint_data.slides):
            slide = prs.slides[idx]

            # 验证幻灯片标题
            self.assertEqual(slide.shapes.title.text, slide_data.content.title)

            # 验证项目符号列表内容
            bullet_points = [shape.text_frame.text for shape in slide.shapes if shape.has_text_frame and shape != slide.shapes.title]
            expected_bullets = [point["text"].replace("**", "") for point in slide_data.content.bullet_points]
            for bullet, expected in zip(bullet_points, expected_bullets):
                self.assertIn(expected, bullet)

            # 验证图片路径(如果存在)
            if slide_data.content.image_path:
                images = [shape for shape in slide.shapes if shape.shape_type == 13]  # 13 为图片形状类型
                self.assertGreater(len(images), 0, f"幻灯片 {idx + 1} 应该包含图片,但未找到。")

    def tearDown(self):
        """
        清理生成的文件。
        """
        if os.path.exists(self.output_path):
            os.remove(self.output_path)

if __name__ == "__main__":
    unittest.main()