CrenCren commited on
Commit
88aba71
·
verified ·
1 Parent(s): eb9e6bc

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .cursor/rules/weclone-rules.mdc +23 -0
  2. .gitattributes +2 -0
  3. .github/issue-labeler.yml +47 -0
  4. .github/workflows/issue-labeler.yml +30 -0
  5. .github/workflows/update_space.yml +28 -0
  6. .gitignore +165 -0
  7. LICENSE +661 -0
  8. README.md +277 -12
  9. dataset/res_csv/pt/dataset_info.json +6 -0
  10. dataset/res_csv/sft/dataset_info.json +19 -0
  11. dataset/test_data-privacy.json +224 -0
  12. dataset/test_data.json +157 -0
  13. ds_config.json +28 -0
  14. pyproject.toml +125 -0
  15. requirements.txt +0 -0
  16. settings.template.jsonc +95 -0
  17. spaces_app.py +53 -0
  18. tests/__init__.py +0 -0
  19. tests/full_pipe.jsonc +89 -0
  20. tests/test_full_pipe.py +154 -0
  21. torchvision.whl +3 -0
  22. weclone-audio/README.md +134 -0
  23. weclone-audio/src/Llasa/infer.py +12 -0
  24. weclone-audio/src/Llasa/text_to_speech.py +131 -0
  25. weclone-audio/src/SparkTTS.py +223 -0
  26. weclone-audio/src/__init__.py +0 -0
  27. weclone-audio/src/get_sample_audio.py +35 -0
  28. weclone-audio/src/infer.py +17 -0
  29. weclone-audio/src/sample.wav +3 -0
  30. weclone-audio/src/server未完工/.env.example +14 -0
  31. weclone-audio/src/server未完工/handle_text.py +62 -0
  32. weclone-audio/src/server未完工/requirements.txt +5 -0
  33. weclone-audio/src/server未完工/server.py +167 -0
  34. weclone-audio/src/server未完工/tts_handler.py +133 -0
  35. weclone-audio/src/server未完工/utils.py +38 -0
  36. weclone/__init__.py +0 -0
  37. weclone/cli.py +207 -0
  38. weclone/core/inference/offline_infer.py +120 -0
  39. weclone/core/inference/online_infer.py +40 -0
  40. weclone/data/__init__.py +0 -0
  41. weclone/data/chat_parsers/wechat_parser.py +8 -0
  42. weclone/data/clean/__init__.py +0 -0
  43. weclone/data/clean/get_score.py +64 -0
  44. weclone/data/clean/strategies.py +144 -0
  45. weclone/data/clean/strategies_online.py +155 -0
  46. weclone/data/models.py +74 -0
  47. weclone/data/qa_generator.py +506 -0
  48. weclone/data/strategies.py +60 -0
  49. weclone/eval/__init__.py +0 -0
  50. weclone/eval/cli_demo.py +48 -0
.cursor/rules/weclone-rules.mdc ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ description:
3
+ globs:
4
+ alwaysApply: true
5
+ ---
6
+ ---
7
+ description:
8
+ globs:
9
+ alwaysApply: true
10
+ ---
11
+
12
+ # Your rule content
13
+ - You can @ files here
14
+ - The project uses uv as the package manager and pyproject.toml as the project configuration file.
15
+ - Unless I ask you to, code comments don't need to be excessive.
16
+ - Prefer using the encapsulated logger `from weclone.utils.log import logger` for printing.
17
+ - When retrieving values from a parameter dictionary read from a configuration file, the `get` method should be preferred whenever possible.
18
+
19
+
20
+
21
+
22
+
23
+
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ torchvision.whl filter=lfs diff=lfs merge=lfs -text
37
+ weclone-audio/src/sample.wav filter=lfs diff=lfs merge=lfs -text
.github/issue-labeler.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 添加 Discussion 标签
2
+ Discussion:
3
+ - '(讨论|交流|分享|意见|建议|思考|探讨|交换意见|brainstorm|discussion)'
4
+
5
+ # 添加 bug 标签
6
+ bug:
7
+ - '(bug|错误|问题|失败|崩溃|异常|报错|不工作|无法运行|broken|crash|error|exception|fails)'
8
+
9
+ # 添加 chatbot 标签
10
+ chatbot:
11
+ - '(聊天机器人|chatbot|chat bot|对话机器人|聊天助手|AI助手|机器人对话|bot|assistant)'
12
+
13
+ # 添加 documentation 标签
14
+ documentation:
15
+ - '(文档|说明|使用指南|指导|手册|教程|文档更新|documentation|docs|guide|tutorial|readme)'
16
+
17
+ # 添加 duplicate 标签
18
+ duplicate:
19
+ - '(重复|已有|duplicate|已经存在|已提交过|重复问题|重复报告|dup)'
20
+
21
+ # 添加 feature 标签
22
+ feature:
23
+ - '(功能|特性|新增|增加|添加|实现|feature|enhancement|新功能|功能请求|feature request)'
24
+
25
+ # 添加 good first issue 标签
26
+ good first issue:
27
+ - '(入门|简单|容易|新手|初学者|开始|first|beginner|starter|easy|简单任务|good first issue)'
28
+
29
+ # 添加 help wanted 标签
30
+ help wanted:
31
+ - '(需要帮助|寻求帮助|请求协助|help|求助|协助|帮忙|help wanted|need help|assistance)'
32
+
33
+ # 添加 invalid 标签
34
+ invalid:
35
+ - '(无效|不适用|不相关|无关|错误提交|invalid|not relevant|irrelevant|not applicable)'
36
+
37
+ # 添加 Mac 标签
38
+ Mac:
39
+ - '(Mac|MacOS|macOS|OSX|Mac系统|苹果系统|苹果电脑|MacBook)'
40
+
41
+ # 添加 question 标签
42
+ question:
43
+ - '(问题|疑问|如何|怎么|请问|是否|能否|可以吗|question|how to|what is|why)'
44
+
45
+ # 添加 Windows 标签
46
+ Windows:
47
+ - '(Windows|微软|Win10|Win11|Windows系统|微软系统|win)'
.github/workflows/issue-labeler.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: add labels to Issues
2
+
3
+ on:
4
+ issues:
5
+ types: [opened, edited]
6
+
7
+
8
+ jobs:
9
+ label_issues:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ issues: write
13
+ contents: read
14
+ steps:
15
+ - name: checkout
16
+ uses: actions/checkout@v3
17
+
18
+ - name: get_last_run_time
19
+ id: last_run
20
+ run: |
21
+ # 获取当前日期减去 1 天作为默认值(处理最近一天的 issues)
22
+ echo "date=$(date -d '1 day ago' -u +"%Y-%m-%dT%H:%M:%SZ")" >> $GITHUB_OUTPUT
23
+
24
+ - name: RegEx Issue Labeler
25
+ uses: github/[email protected]
26
+ with:
27
+ repo-token: "${{ secrets.GITHUB_TOKEN }}"
28
+ configuration-path: .github/issue-labeler.yml
29
+ enable-versioned-regex: 0
30
+ not-before: ${{ steps.last_run.outputs.date }}
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - n
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb/
2
+ weclone_archive-my/
3
+ **/pycache/
4
+ events.out.tfevents.*
5
+ 归档/
6
+ *.pt
7
+ *.npz
8
+ *nohup.out
9
+ *log.txt
10
+ *cookie.bin
11
+ *.gradio/
12
+
13
+ # Byte-compiled / optimized / DLL files
14
+ __pycache__/
15
+ *.py[cod]
16
+ *$py.class
17
+
18
+ # C extensions
19
+ *.so
20
+
21
+ # Distribution / packaging
22
+ .Python
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ pip-wheel-metadata/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .nox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ *.py,cover
63
+ .hypothesis/
64
+ .pytest_cache/
65
+
66
+ # Translations
67
+ *.mo
68
+ *.pot
69
+
70
+ # Django stuff:
71
+ *.log
72
+ local_settings.py
73
+ db.sqlite3
74
+ db.sqlite3-journal
75
+
76
+ # Flask stuff:
77
+ instance/
78
+ .webassets-cache
79
+
80
+ # Scrapy stuff:
81
+ .scrapy
82
+
83
+ # Sphinx documentation
84
+ docs/_build/
85
+
86
+ # PyBuilder
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107
+ __pypackages__/
108
+
109
+ # Celery stuff
110
+ celerybeat-schedule
111
+ celerybeat.pid
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+
143
+
144
+ *.zip
145
+ LLaMA-Factory
146
+ chatglm3-6b
147
+ cache
148
+ archive
149
+ model_output*
150
+ data/test
151
+ .vscode
152
+ *-my*.*
153
+ *.csv
154
+ *test.*
155
+ *users.json
156
+ Spark-TTS-0.5B/
157
+ uv.lock
158
+ output*
159
+ *.out
160
+
161
+ Qwen*/
162
+ settings.jsonc
163
+ settings.json
164
+ dataset/blocked_words.json
165
+ dataset/wechat/*
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
README.md CHANGED
@@ -1,12 +1,277 @@
1
- ---
2
- title: Cren
3
- emoji: 🏢
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.32.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: cren
3
+ app_file: spaces_app.py
4
+ sdk: gradio
5
+ sdk_version: 5.21.0
6
+ ---
7
+ ![download](https://github.com/user-attachments/assets/5842e84e-004f-4afd-9373-af64e9575b78)
8
+ <h3 align="center">🚀 One-stop solution for creating your digital avatar from chat history 💡</h3>
9
+ <h3 align="center">🚀从聊天记录创造数字分身的一站式解决方案💡</h3>
10
+
11
+
12
+ <div align="center">
13
+
14
+ [![GitHub stars](https://img.shields.io/github/stars/xming521/WeClone?style=for-the-badge&logo=github&label=Stars&logoColor=white&color=ffda65)](https://github.com/xming521/WeClone/stargazers)
15
+ [![GitHub release](https://img.shields.io/github/v/release/xming521/WeClone?style=for-the-badge&logo=github&label=Release&logoColor=white&color=06d094)](https://github.com/xming521/WeClone/releases)
16
+ <a href="https://qm.qq.com/cgi-bin/qm/qr?k=QXMsXJ_eqeabS0cck0PGjEMyKjcq7J5d&jump_from=webapi&authKey=KHdy31VbSxj34VQVwXtEOYVi1K7SND45vJcNnm1Z5iCCR6IbGiyWEs9UbPqFI8Jc" target="_blank" style="text-decoration: none;">
17
+ <img src="https://img.shields.io/badge/QQ群-650118277-12B7F5?style=for-the-badge&logo=qq&logoColor=white" alt="WeClone①" title="WeClone①">
18
+ </a>
19
+ [![Twitter](https://img.shields.io/badge/Twitter-@weclone567-000000?style=for-the-badge&logo=x&logoColor=white)](https://x.com/weclone567)
20
+ [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/+JEdak4m0XEQ3NGNl)
21
+
22
+ <a href="https://hellogithub.com/repository/12ab209b56cb4cfd885c8cfd4cfdd53e" target="_blank"><img src="https://abroad.hellogithub.com/v1/widgets/recommend.svg?rid=12ab209b56cb4cfd885c8cfd4cfdd53e&claim_uid=RThlPDoGrFvdMY5" alt="Featured|HelloGitHub" style="width: 150px; height: 28px;" /></a>
23
+ <a href="https://trendshift.io/repositories/13759" target="_blank"><img src="https://trendshift.io/api/badge/repositories/13759" alt="xming521%2FWeClone | Trendshift" style="width: 220px; height: 50px;" /></a>
24
+ <a href="https://deepwiki.com/xming521/WeClone"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki" style="width: 134px; height: 23px;margin-bottom: 3px;"></a>
25
+ </div>
26
+
27
+ <p align="center">
28
+ <a href="https://www.weclone.love/" target="_blank"> 项目主页 </a> |
29
+ <a href="https://www.weclone.love/what-is-weclone.html" target="_blank"> 项目文档 </a> |
30
+ <a href="https://blog.051088.xyz/2025/05/14/WeClone-%E7%94%A8%E5%BE%AE%E4%BF%A1%E8%81%8A%E5%A4%A9%E8%AE%B0%E5%BD%95%E6%89%93%E9%80%A0%E8%87%AA%E5%B7%B1%E7%9A%84AI%E6%95%B0%E5%AD%97%E5%88%86%E8%BA%AB/" target="_blank">Windows部署指南</a> |
31
+ <a href="https://blog.051088.xyz/posts/weclone-linux-tutorial/" target="_blank"> Linux部署指南【保姆级】</a>
32
+ </p>
33
+
34
+ > [!IMPORTANT]
35
+ > <h3> WhatsApp and Telegram chat logs integration for digital avatar creation is coming ! </h3>
36
+
37
+ ## ✨核心功能
38
+ - 💫 涵盖打造数字分身的全链路方案,包括聊天数据导出、预处理、模型训练、部署
39
+ - 💬 使用微信聊天记录微调LLM,让大模型有"那味儿"
40
+ - 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身
41
+ - 🛡️ 隐私信息过滤,本地化微调部署,数据安全可控
42
+
43
+ ## 📋特性与说明
44
+
45
+ > [!IMPORTANT]
46
+ > - WeClone仍在快速迭代期,当前效果不代表最终效果。
47
+ > - 微调LLM效果很大程度取决于模型大小、聊天数据的数量和质量,理论上模型越大,数据越多,效果越好。
48
+ > - Windows环境未进行严格测试,可以使用WSL作为运行环境。详细教程可点击[Windows部署指南](https://blog.051088.xyz/2025/05/14/WeClone-%E7%94%A8%E5%BE%AE%E4%BF%A1%E8%81%8A%E5%A4%A9%E8%AE%B0%E5%BD%95%E6%89%93%E9%80%A0%E8%87%AA%E5%B7%B1%E7%9A%84AI%E6%95%B0%E5%AD%97%E5%88%86%E8%BA%AB/)查看。
49
+
50
+ ### 硬件要求
51
+
52
+ 项目默认使用Qwen2.5-7B-Instruct模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法。
53
+
54
+ 需要显存的估算值:
55
+ | 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B |
56
+ | ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
57
+ | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
58
+ | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
59
+ | Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
60
+ | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
61
+ | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
62
+ | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
63
+
64
+
65
+ ## 环境搭建
66
+ 1.cuda安装(已安装可跳过,**要求版本12.4及以上**):[LLaMA Factory](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html#cuda)
67
+
68
+ 2.建议使用 [uv](https://docs.astral.sh/uv/)安装依赖,这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含音频克隆功能的依赖:
69
+ ```bash
70
+ git clone https://github.com/xming521/WeClone.git
71
+ cd WeClone
72
+ uv venv .venv --python=3.10
73
+ source .venv/bin/activate # windows下执行 .venv\Scripts\activate
74
+ uv pip install --group main -e .
75
+ ```
76
+ > [!TIP]
77
+ > 如果要使用最新的模型进行微调,需要手动安装最新版LLaMA Factory:`uv pip install --upgrade git+https://github.com/hiyouga/LLaMA-Factory.git`,同时其他依赖版本也可能需要修改,例如vllm pytorch transforms
78
+
79
+ 3.将配置文件模板复制一份并重命名为`settings.jsonc`,后续配置修改在此文件进行:
80
+ ```bash
81
+ cp settings.template.jsonc settings.jsonc
82
+ ```
83
+ > [!NOTE]
84
+ > 训练以及推理相关配置统一在文件`settings.jsonc`
85
+
86
+ 4.使用以下命令测试CUDA环境是否正确配置并可被PyTorch识别,Mac不需要:
87
+ ```bash
88
+ python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available());"
89
+ ```
90
+
91
+ 5.(可选)安装FlashAttention,加速训练和推理:`uv pip install flash-attn --no-build-isolation`
92
+
93
+ ## 模型下载
94
+ ```bash
95
+ git lfs install
96
+ git clone https://www.modelscope.cn/Qwen/Qwen2.5-7B-Instruct.git
97
+ ```
98
+ 下载有问题使用其他方式下载:[模型的下载](https://www.modelscope.cn/docs/models/download)
99
+
100
+
101
+ ## 数据准备
102
+
103
+ 请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录(不支持4.0版本微信)。可以先将手机的聊天记录迁移(备份)到电脑,数据量更多一些。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人(不建议使用群聊记录),然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./dataset`目录即可,也就是不同人聊天记录的文件夹一起放在 `./dataset/csv`。
104
+
105
+ ## 数据预处理
106
+
107
+ - 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还在`settings.jsonc`中提供了一个禁用词词库`blocked_words`,可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。
108
+ > [!IMPORTANT]
109
+ > 🚨 请一定注意保护个人隐私,不要泄露个人信息!
110
+
111
+ - 执行以下命令对数据进行处理,可以根据自己的聊天风格修改settings.jsonc的`make_dataset_args`。
112
+ ```bash
113
+ weclone-cli make-dataset
114
+ ```
115
+ - 目前仅支持时间窗口策略,根据`single_combine_time_window`将单人连续消息通过逗号连接合并为一句,根据`qa_match_time_window`匹配问答对。
116
+ - 可以启用`clean_dataset`中的`enable_clean`选项,对数据进行清洗,以达到更好效果。* 当前系统支持使用 `llm judge` 对聊天记录进行打分,提供 **vllm 离线推理** 和 **API 在线推理** 两种方式。可通过将 `settings.jsonc` 文件中的 `"online_llm_clear": false` 修改为 `true` 来启用 API 在线推理模式,并配置相应的 `base_url`、`llm_api_key`、`model_name` 等参数。所有兼容 OpenAI 接口的模型均可接入。
117
+ - 在获得 `llm 打分分数分布情况` 后,可通过设置 `accept_score` 参数筛选可接受的分数区间,同时可适当降低 `train_sft_args` 中的 `lora_dropout` 参数,以提升模型的拟合效果。
118
+
119
+ ## 配置参数并微调模型
120
+
121
+ - (可选)修改 `settings.jsonc` 的 `model_name_or_path` 和 `template` 选择本地下载好的其他模型。
122
+ - 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。
123
+ - 可以根据自己数据集的数量和质量修改`train_sft_args`的`num_train_epochs`、`lora_rank`、`lora_dropout`等参数。
124
+
125
+ ### 单卡训练
126
+ ```bash
127
+ weclone-cli train-sft
128
+ ```
129
+ 多卡环境单卡训练,需要先执行 `export CUDA_VISIBLE_DEVICES=0`
130
+
131
+ ### 多卡训练
132
+ 取消`settings.jsonc`中`deepspeed`行代码注释,使用以下命令多卡训练:
133
+ ```bash
134
+ uv pip install deepspeed
135
+ deepspeed --num_gpus=使用显卡数量 weclone/train/train_sft.py
136
+ ```
137
+
138
+ ### 使用浏览器demo简单推理
139
+ 可以在这一步测试出合适的temperature、top_p值,修改settings.jsonc的`infer_args`后,供后续推理时使用。
140
+ ```bash
141
+ weclone-cli webchat-demo
142
+ ```
143
+
144
+ ### 使用接口进行推理
145
+
146
+ ```bash
147
+ weclone-cli server
148
+ ```
149
+
150
+ ### 使用常见聊天问题测试
151
+ 不包含询问个人信息的问题,仅有日常聊天。测试结果在test_result-my.txt。
152
+ ```bash
153
+ weclone-cli server
154
+ weclone-cli test-model
155
+ ```
156
+
157
+ ## 🖼️ 微调效果
158
+ 使用Qwen2.5-14B-Instruct模型,大概3万条处理后的有效数据,loss降到了3.5左右的效果。
159
+ <details>
160
+ <summary>截图</summary>
161
+ <div style="display: flex; flex-wrap: wrap; gap: 10px;">
162
+ <img src="https://github.com/user-attachments/assets/0775ec52-452b-485f-9785-c6eb7b277132" alt="alt text" style="width: 48%; min-width: 150px;">
163
+ <img src="https://github.com/user-attachments/assets/8c7628b5-da70-4c37-9e51-fdfb0eadd2df" alt="alt text" style="width: 48%; min-width: 150px;">
164
+ <img src="https://github.com/user-attachments/assets/523aa742-2aa3-40e9-bd67-b98b336e83a8" alt="alt text" style="width: 48%; min-width: 150px;">
165
+ <img src="https://github.com/user-attachments/assets/dabf0603-dcc4-4a47-b5c3-2bbc036820d9" alt="alt text" style="width: 48%; min-width: 150px;">
166
+ </div>
167
+ </details>
168
+
169
+
170
+ ## 🤖 部署到聊天机器人
171
+
172
+ ### AstrBot
173
+
174
+ [AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。
175
+
176
+ 使用步骤:
177
+ 1. 部署 AstrBot
178
+ 2. 在 AstrBot 中部署消息平台
179
+ 3. 执行 `weclone-cli server` 启动api服务
180
+ 4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo,API Key随意填写一个
181
+ 5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off all`,否则会没有微调后的效果。
182
+ 6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。
183
+ ![5](https://github.com/user-attachments/assets/19de7072-076a-4cdf-8ae6-46b9b89f536a)
184
+ > [!IMPORTANT]
185
+ > 检查api_service的日志,尽量保证大模型服务请求的参数和微调时一致,tool插件能力都关掉。
186
+ 7. 调整采样参数,例如temperature、top_p、top_k等
187
+ [配置自定义的模型参数](https://astrbot.app/config/model-config.html#%E9%85%8D%E7%BD%AE%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0)
188
+
189
+ ### LangBot
190
+
191
+ [LangBot](https://github.com/RockChinQ/LangBot) 是一个开源的接入全球多种即时通信平台的 LLM 机器人平台,适合各种场景使用。
192
+
193
+ 1. [部署 LangBot](https://github.com/RockChinQ/LangBot#-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8)
194
+ 2. 在 LangBot 中添加一个机器人
195
+ 4. 在模型页添加新模型,名称`gpt-3.5-turbo`,供应商选择 OpenAI,填写 请求 URL 为 WeClone 的地址,详细连接方式可以参考[文档](https://docs.langbot.app/zh/workshop/network-details.html),API Key 任意填写。
196
+
197
+ <img width="400px" alt="image" src="https://github.com/user-attachments/assets/fc167dea-7c93-4d94-9c5f-db709d0320ba" />
198
+
199
+ 6. 在流水线配置中选择刚才添加的模型,或修改提示词配置
200
+
201
+ <img width="400px" alt="image" src="https://github.com/user-attachments/assets/dbb0fd0a-f760-42db-acd0-bb99c859b52e" />
202
+
203
+ ## 📌 路线图
204
+ - [ ] 更丰富的上下文:包括上下文对话、聊天对象信息、时间等 + 思考
205
+ - [ ] Memory 支持
206
+ - [ ] 支持多模态
207
+ - [ ] 数据增强
208
+ - [ ] 支持GUI
209
+
210
+ ## 问题解决
211
+ - 微调问题:[LLaMA-Factory| FAQs | 常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614) 或者更方便的 [![更方便的Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/hiyouga/LLaMA-Factory)
212
+
213
+ ## ❤️ 贡献代码
214
+
215
+ 欢迎任何 Issues/Pull Requests!
216
+
217
+ 你可以通过查看Issues或帮助审核 PR(拉取请求)来贡献。对于新功能的添加,请先通过 Issue 讨论。
218
+ 运行`uv pip install --group dev -e .`安装开发依赖。
219
+ 项目使用`pytest`测试(测试脚本待完善),`pyright`检查类型,`ruff`检查代码格式。
220
+
221
+
222
+ ## ⚠️ 免责声明
223
+ > [!CAUTION]
224
+ > 请勿用于非法用途,否则后果自负。
225
+ <details>
226
+ <summary>1. 使用目的</summary>
227
+
228
+ * 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。
229
+ * 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。
230
+
231
+ 2. 使用期限
232
+
233
+ * 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。
234
+
235
+ 3. 操作规范
236
+
237
+ * 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。
238
+ * 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。
239
+
240
+ 4. 免责声明接受
241
+
242
+ * 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它;
243
+
244
+ 5. 禁止用于非法测试或渗透
245
+
246
+ * 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。
247
+ * 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。
248
+
249
+ 6. 免责声明修改
250
+
251
+ * 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。
252
+
253
+ 7. 其他
254
+
255
+ * 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。
256
+
257
+ * 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。
258
+
259
+ </details>
260
+ 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。
261
+
262
+ <br>
263
+ <br>
264
+ <br>
265
+
266
+ ## ⭐ Star History
267
+ > [!TIP]
268
+ > 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢
269
+
270
+ <div align="center">
271
+
272
+ [![Star History Chart](https://api.star-history.com/svg?repos=xming521/WeClone&type=Date)](https://www.star-history.com/#xming521/WeClone&Date)
273
+
274
+ </div>
275
+
276
+
277
+ <div align="center"> 克隆我们,保留灵魂的芬芳 </div>
dataset/res_csv/pt/dataset_info.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"wechat-pt":{
2
+ "file_name": "./pt-my.json",
3
+ "columns": {
4
+ "prompt": "c"
5
+ }
6
+ }}
dataset/res_csv/sft/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "wechat-sft": {
3
+ "file_name": "sft-my-l.json",
4
+ "columns": {
5
+ "prompt": "instruction",
6
+ "response": "output",
7
+ "system": "system"
8
+ }
9
+ },
10
+ "wechat-sft-with-history": {
11
+ "file_name": "sft-my-l.json",
12
+ "columns": {
13
+ "prompt": "instruction",
14
+ "response": "output",
15
+ "system": "system",
16
+ "history": "history"
17
+ }
18
+ }
19
+ }
dataset/test_data-privacy.json ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "questions": [
3
+ [
4
+ "你多大了?"
5
+ ],
6
+ [
7
+ "你有什么爱好吗?"
8
+ ],
9
+ [
10
+ "你的理想是什么?",
11
+ "你觉得你离你的理想还有多远?"
12
+ ],
13
+ [
14
+ "你最近在忙什么?",
15
+ "工作/学习顺利吗?",
16
+ "有什么有趣的事情发生吗?"
17
+ ],
18
+ [
19
+ "你喜欢看什么类型的电影?",
20
+ "最近看过什么好看的电影吗?",
21
+ "你最喜欢的电影是什么?"
22
+ ],
23
+ [
24
+ "你平时喜欢听什么音乐?",
25
+ "有推荐的歌手或乐队吗?",
26
+ "最近有喜欢的歌曲吗?"
27
+ ],
28
+ [
29
+ "你喜欢旅游吗?",
30
+ "去过哪些地方?",
31
+ "最喜欢的旅游地是哪里?"
32
+ ],
33
+ [
34
+ "你喜欢读书吗?",
35
+ "最近在读什么书?",
36
+ "最喜欢的书是哪本?"
37
+ ],
38
+ [
39
+ "你平时喜欢运动吗?",
40
+ "喜欢做哪些运动?",
41
+ "有固定去锻炼吗?"
42
+ ],
43
+ [
44
+ "周末一般都做些什么?",
45
+ "有没有什么特别的计划?",
46
+ "周末喜欢宅在家还是出去玩?"
47
+ ],
48
+ [
49
+ "你喜欢宠物吗?",
50
+ "有养宠物吗?",
51
+ "最喜欢什么动物?"
52
+ ],
53
+ [
54
+ "你喜欢吃什么类型的食物?",
55
+ "有推荐的餐厅吗?",
56
+ "最喜欢的菜是什么?"
57
+ ],
58
+ [
59
+ "你喜欢什么样的天气?",
60
+ "最喜欢的季节是哪一个?",
61
+ "你觉得今天的天气怎么样?"
62
+ ],
63
+ [
64
+ "你有看电视剧的习惯吗?",
65
+ "最近在追哪部剧?",
66
+ "最喜欢的电视剧是哪部?"
67
+ ],
68
+ [
69
+ "你喜欢玩游戏吗?",
70
+ "最近在玩什么游戏?",
71
+ "有推荐的好玩的游戏吗?"
72
+ ],
73
+ [
74
+ "你会做饭吗?",
75
+ "平时喜欢做哪些菜?",
76
+ "有没有特别拿手的菜?"
77
+ ],
78
+ [
79
+ "你喜欢购物吗?",
80
+ "最近买了什么新东西?",
81
+ "有推荐的购物网站或店铺吗?"
82
+ ],
83
+ [
84
+ "你平时怎么放松自己?",
85
+ "有特别的解压方式吗?",
86
+ "最喜欢的放松活动是什么?"
87
+ ],
88
+ [
89
+ "你喜欢和朋友出去玩吗?",
90
+ "平时会和朋友去哪玩?",
91
+ "最近有没有和朋友聚会的计划?"
92
+ ],
93
+ [
94
+ "你喜欢喝咖啡还是茶?",
95
+ "有没有特别喜欢的咖啡馆或茶馆?",
96
+ "最喜欢的饮品是什么?"
97
+ ],
98
+ [
99
+ "你有兄弟姐妹吗?",
100
+ "和他们关系怎么样?",
101
+ "经常联系吗?"
102
+ ],
103
+ [
104
+ "你喜欢读什么类型的杂志?",
105
+ "最近有看什么有趣的文章吗?",
106
+ "有订阅的杂志吗?"
107
+ ],
108
+ [
109
+ "你喜欢看体育比赛吗?",
110
+ "最喜欢的运动项目是什么?",
111
+ "有没有特别支持的球队或运动员?"
112
+ ],
113
+ [
114
+ "你会说其他语言吗?",
115
+ "最想学的语言是什么?",
116
+ "学习语言有什么技巧吗?"
117
+ ],
118
+ [
119
+ "你对科技产品感兴趣吗?",
120
+ "最近有没有关注什么新科技?",
121
+ "最喜欢的电子产品是什么?"
122
+ ],
123
+ [
124
+ "你喜欢喝什么样的饮料?",
125
+ "有没有自己调饮料的习惯?",
126
+ "最喜欢的饮品品牌是什么?"
127
+ ],
128
+ [
129
+ "你平时用社交媒体吗?",
130
+ "常用哪些平台?",
131
+ "在社交媒体上做什么?"
132
+ ],
133
+ [
134
+ "你对艺术感兴趣吗?",
135
+ "最喜欢的艺术家是谁?",
136
+ "有去过哪些艺术展览?"
137
+ ],
138
+ [
139
+ "你喜欢DIY吗?",
140
+ "平时做些什么手工?",
141
+ "有没有完成的作品可以分享?"
142
+ ],
143
+ [
144
+ "你喜欢种植植物吗?",
145
+ "有养什么植物?",
146
+ "最喜欢的植物是什么?"
147
+ ],
148
+ [
149
+ "你喜欢拍照吗?",
150
+ "喜欢拍什么样的照片?",
151
+ "有没有用什么特别的摄影设备?"
152
+ ],
153
+ [
154
+ "你喜欢听播客吗?",
155
+ "常听哪些主题的播客?",
156
+ "有没有推荐���播客?"
157
+ ],
158
+ [
159
+ "你对历史感兴趣吗?",
160
+ "最喜欢哪个历史时期?",
161
+ "有没有特别喜欢的历史人物?"
162
+ ],
163
+ [
164
+ "你喜欢画画吗?",
165
+ "平时画什么类型的画?",
166
+ "有参加过画展吗?"
167
+ ],
168
+ [
169
+ "你喜欢写作吗?",
170
+ "平时写什么类型的文章?",
171
+ "有没有发表过作品?"
172
+ ],
173
+ [
174
+ "你喜欢钓鱼吗?",
175
+ "平时去哪里钓鱼?",
176
+ "有没有钓到过什么大鱼?"
177
+ ],
178
+ [
179
+ "你喜欢露营吗?",
180
+ "平时会去哪里露营?",
181
+ "有没有什么难忘的露营经历?"
182
+ ],
183
+ [
184
+ "你喜欢摄影吗?",
185
+ "最喜欢拍什么题材?",
186
+ "有没有特别喜欢的摄影师?"
187
+ ],
188
+ [
189
+ "你喜欢喝酒吗?",
190
+ "喜欢什么类型的酒?",
191
+ "有没有推荐的酒吧或品牌?"
192
+ ],
193
+ [
194
+ "你喜欢滑雪吗?",
195
+ "平时去哪里滑雪?",
196
+ "有没有什么滑雪技巧分享?"
197
+ ],
198
+ [
199
+ "你喜欢海边还是山里?",
200
+ "最喜欢去哪个地方度假?",
201
+ "有没有什么特别推荐的景点?"
202
+ ],
203
+ [
204
+ "你喜欢参加音乐节吗?",
205
+ "参加过哪些音乐节?",
206
+ "最喜欢的音乐节是哪一个?"
207
+ ],
208
+ [
209
+ "你喜欢跑步吗?",
210
+ "平时跑多长距离?",
211
+ "有没有参加过马拉松?"
212
+ ],
213
+ [
214
+ "你喜欢参加聚会吗?",
215
+ "平时和朋友聚会做什么?",
216
+ "有没有什么有趣的聚会游戏?"
217
+ ],
218
+ [
219
+ "你喜欢收集东西吗?",
220
+ "收集什么类型的物品?",
221
+ "有没有什么特别的收藏?"
222
+ ]
223
+ ]
224
+ }
dataset/test_data.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "questions": [
3
+ [
4
+ "吃了吗?",
5
+ "吃的什么啊",
6
+ "好吃吗",
7
+ "多少钱啊",
8
+ "可以请我吃吗"
9
+ ],
10
+ [
11
+ "干嘛呢?",
12
+ "等会准备干什么去"
13
+ ],
14
+ [
15
+ "在忙什么呢?",
16
+ "今天有什么特别的安排吗?",
17
+ "感觉怎么样?"
18
+ ],
19
+ [
20
+ "最近有什么新鲜事发生吗?",
21
+ "有没有什么有趣的故事可以分享?"
22
+ ],
23
+ [
24
+ "周末过得怎么样?",
25
+ "做了什么好玩的?"
26
+ ],
27
+ [
28
+ "最近看了什么好看的电影或电视剧吗?",
29
+ "有什么推荐的吗?",
30
+ "大概讲了什么内容呀?"
31
+ ],
32
+ [
33
+ "今天天气怎么样?",
34
+ "你那里呢?"
35
+ ],
36
+ [
37
+ "最近工作/学习顺利吗?",
38
+ "有没有遇到什么挑战?"
39
+ ],
40
+ [
41
+ "嗨,这会儿在忙啥呢?",
42
+ "今天有什么特别的安排不?",
43
+ "一切都还顺利吧?"
44
+ ],
45
+ [
46
+ "你那边现在天气咋样啊?",
47
+ "是大晴天还是有点阴沉沉的?",
48
+ "冷不冷,或者热不热呀?"
49
+ ],
50
+ [
51
+ "到饭点儿了没呀?",
52
+ "今天打算犒劳一下自己,吃点啥好吃的?",
53
+ "有没有啥特别想吃的,或者想去哪家馆子尝尝鲜?"
54
+ ],
55
+ [
56
+ "最近网上有啥好玩儿的新闻或者梗吗?",
57
+ "刷到啥有意思的视频或者段子没?分享一下呗!"
58
+ ],
59
+ [
60
+ "待会儿有啥打算呀?",
61
+ "今天剩下的时间准备怎么过呢?"
62
+ ],
63
+ [
64
+ "今天有没有碰到啥让你眼前一亮的小事儿?",
65
+ "随便聊聊呗,有啥轻松点的话题不?"
66
+ ],
67
+ [
68
+ "今天有啥新发现或者小感悟没?",
69
+ "感觉今天过得快不快?节奏怎么样?"
70
+ ],
71
+ [
72
+ "你现在周围环境咋样,吵不吵?",
73
+ "今天出门溜达了没,外面人多不多呀?",
74
+ "瞅瞅窗外,有啥特别的景儿不?"
75
+ ],
76
+ [
77
+ "吃饭了没啊?",
78
+ "吃的啥呀?合胃口不?"
79
+ ],
80
+ [
81
+ "今天怎么样啊?累不累?",
82
+ "有啥事儿不?"
83
+ ],
84
+ [
85
+ "最近身体还好吧?",
86
+ "没什么不舒服的地方吧?"
87
+ ],
88
+ [
89
+ "今天忙不忙啊?",
90
+ "都干啥了呀?"
91
+ ],
92
+ [
93
+ "家里都挺好的吧?",
94
+ "有啥需要帮忙的不?"
95
+ ],
96
+ [
97
+ "今天出门了没?",
98
+ "外面冷不冷/热不热啊?多穿点/注意防暑。"
99
+ ],
100
+ [
101
+ "最近有啥开心的事儿不?说来听听!",
102
+ "或者有啥烦心事儿,跟我说说?"
103
+ ],
104
+ [
105
+ "晚上早点休息啊,别熬太晚。",
106
+ "睡得好不好啊最近?"
107
+ ],
108
+ [
109
+ "缺啥东西不?跟我说。",
110
+ "钱够不够花呀?"
111
+ ],
112
+ [
113
+ "今天看到啥有意思的了没?",
114
+ "或者有啥想跟我分享的?"
115
+ ],
116
+ [
117
+ "周末有啥安排啊?",
118
+ "要不要一起吃个饭/出去转转?"
119
+ ],
120
+ [
121
+ "最近常联系的那些朋友都还好不?",
122
+ "有空多聚聚。"
123
+ ],
124
+ [
125
+ "工作/学习上还顺利吧?",
126
+ "别太给自己压力啊。"
127
+ ],
128
+ [
129
+ "今天做了啥好吃的呀?",
130
+ "下次也给我尝尝呗!"
131
+ ],
132
+ [
133
+ "有啥新闻没有啊最近?",
134
+ "跟我讲讲。"
135
+ ],
136
+ [
137
+ "那谁谁谁最近怎么样了?",
138
+ "好久没听到他/她消息了。"
139
+ ],
140
+ [
141
+ "今天心情好不好呀?",
142
+ "看你气色不错/有点疲惫。"
143
+ ],
144
+ [
145
+ "有啥想吃的没?下次给你做/带。",
146
+ "或者想去哪儿玩,我陪你。"
147
+ ],
148
+ [
149
+ "最近有没有看啥电视剧/电影啊?",
150
+ "有啥好看的推荐给我呗。"
151
+ ],
152
+ [
153
+ "没事儿就早点回家/休息。",
154
+ "注意安全啊。"
155
+ ]
156
+ ]
157
+ }
ds_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "zero_optimization": {
14
+ "stage": 2,
15
+ "allgather_partitions": true,
16
+ "allgather_bucket_size": 5e8,
17
+ "overlap_comm": true,
18
+ "reduce_scatter": true,
19
+ "reduce_bucket_size": 5e8,
20
+ "contiguous_gradients": true
21
+ },
22
+ "gradient_accumulation_steps": "auto",
23
+ "gradient_clipping": "auto",
24
+ "steps_per_print": 2000,
25
+ "train_batch_size": "auto",
26
+ "train_micro_batch_size_per_gpu": "auto",
27
+ "wall_clock_breakdown": false
28
+ }
pyproject.toml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "WeClone"
3
+ version = "0.2.21"
4
+ description = "从聊天记录创造数字分身的一站式解决方案"
5
+ authors = [{ name = "xming521" }]
6
+ readme = "README.md"
7
+ requires-python = ">=3.10,<3.11"
8
+
9
+ dependencies = [
10
+ "pandas",
11
+ "commentjson",
12
+ "click",
13
+ "pydantic==2.10.6",
14
+ "setuptools>=78.1.0",
15
+ "loguru>=0.7.3",
16
+ "torch>=2.6.0",
17
+ "transformers==4.49.0",
18
+ "tomli; python_version < '3.11'",
19
+ "langchain",
20
+ ]
21
+
22
+ [tool.weclone]
23
+ # 配置文件的版本号,当配置文件结构或重要默认值发生变化时,应增加此版本号
24
+ config_version = "0.2.21"
25
+
26
+ # 配置文件更新日志
27
+ config_changelog = """
28
+ [0.2.1] - 2025-04-29 - 初始配置版本。
29
+ [0.2.2] - 2025-05-01 - 增加llm清洗数据配置,blocked_words迁移到settings.jsonc统一配置文件。
30
+ [0.2.21] - 2025-05-01 - 增加在线llm清洗数据配置,兼容openai风格接口。
31
+ """
32
+
33
+ [dependency-groups]
34
+ # xcodec = ["xcodec2==0.1.3"]
35
+ sparktts = [
36
+ "einops>=0.8.1",
37
+ "einx>=0.3.0",
38
+ "numpy==1.26.4",
39
+ "omegaconf>=2.3.0",
40
+ "packaging>=24.2",
41
+ "safetensors>=0.5.2",
42
+ "soundfile>=0.12.1",
43
+ "soxr>=0.5.0.post1",
44
+ "torchaudio>=2.6.0",
45
+ "tqdm>=4.66.5",
46
+ ]
47
+ main = [
48
+ "llamafactory>=0.9.2",
49
+ "openai==1.76.0",
50
+ "vllm==0.8.2; platform_system == 'Linux'",
51
+ ]
52
+ dev = ["pytest", "pytest-order", "pyright", "ruff"]
53
+
54
+ [project.scripts]
55
+ weclone-cli = "weclone.cli:cli"
56
+
57
+ [tool.uv]
58
+ conflicts = [
59
+ # [{ group = "wx" }, { group = "xcodec" }],
60
+ ]
61
+
62
+ [tool.uv.sources]
63
+ torch = [
64
+ { index = "pytorch-cu124", marker = "platform_system == 'Windows'" },
65
+ { index = "pytorch-cu124", marker = "platform_system == 'Linux'" },
66
+ ]
67
+ torchaudio = [
68
+ { index = "pytorch-cu124", marker = "platform_system == 'Windows'" },
69
+ { index = "pytorch-cu124", marker = "platform_system == 'Linux'" },
70
+ ]
71
+ torchvision = [
72
+ { index = "pytorch-cu124", marker = "platform_system == 'Windows'" },
73
+ { index = "pytorch-cu124", marker = "platform_system == 'Linux'" },
74
+ ]
75
+
76
+
77
+ [[tool.uv.index]]
78
+ url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
79
+ default = true
80
+
81
+ [[tool.uv.index]]
82
+ name = "pytorch-cu124"
83
+ url = "https://download.pytorch.org/whl/cu124"
84
+ explicit = true
85
+
86
+ [tool.setuptools.packages.find]
87
+ where = ["."] # 表示在项目根目录开始查找
88
+ include = ["weclone*"] # 只包含名为 weclone 的目录及其子包
89
+ exclude = ["*tests*", "*archive*"] # 可以选择性排除其他模式,比如测试目录
90
+
91
+
92
+ [tool.pyright]
93
+ typeCheckingMode = "basic"
94
+ include = ["weclone/data"]
95
+ exclude = ["**/archive", "**/tests"]
96
+ ignore = ["**/archive"]
97
+
98
+ reportMissingImports = "error"
99
+ reportMissingTypeStubs = false
100
+
101
+ pythonVersion = "3.10"
102
+ pythonPlatform = "Linux"
103
+
104
+ [tool.ruff]
105
+ exclude = [
106
+ "**/archive",
107
+ "**/tests",
108
+ "weclone-audio/src/server未完工",
109
+ "weclone-audio/src/Spark-TTS",
110
+ ]
111
+ line-length = 120
112
+
113
+ lint.ignore = ["F403", "F405", "E501", "E402"]
114
+ lint.select = [
115
+ "F", # Pyflakes
116
+ "W", # pycodestyle warnings
117
+ "E", # pycodestyle errors
118
+ "ASYNC", # flake8-async
119
+ "C4", # flake8-comprehensions
120
+ "Q", # flake8-quotes
121
+ ]
122
+ target-version = "py310"
123
+
124
+ [tool.pytest.ini_options]
125
+ addopts = "-x"
requirements.txt ADDED
Binary file (2.11 kB). View file
 
settings.template.jsonc ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.21",
3
+ "common_args": {
4
+ "model_name_or_path": "./Qwen2.5-7B-Instruct",
5
+ "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir
6
+ "template": "qwen",
7
+ "default_system": "请你扮演一名人类,不要说自己是人工智能",
8
+ "finetuning_type": "lora",
9
+ "trust_remote_code": true
10
+ },
11
+ "cli_args": {
12
+ "full_log": false
13
+ },
14
+ "make_dataset_args": {
15
+ //数据处理配置
16
+ "include_type": [
17
+ "text",
18
+ // "image"
19
+ ],
20
+ "blocked_words": [ // 禁用词
21
+ "例如 姓名",
22
+ "例如 密码",
23
+ "//....."
24
+ ],
25
+ "single_combine_strategy": "time_window", // 单人组成单句策略
26
+ "qa_match_strategy": "time_window", // 组成qa策略
27
+ "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟),
28
+ "qa_match_time_window": 5, // 组成qa时间窗口(分钟),
29
+ "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用
30
+ "prompt_with_history": false, // 是否在prompt中包含历史对话
31
+ "clean_dataset": {
32
+ "enable_clean": true,
33
+ "clean_strategy": "llm",
34
+ "llm": {
35
+ "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练
36
+ }
37
+ },
38
+ "online_llm_clear": false,
39
+ "base_url": "https://xxx/v1",
40
+ "llm_api_key": "xxxxx",
41
+ "model_name": "xxx", //建议使用参数较大的模型,例如DeepSeek-V3
42
+ "clean_batch_size": 10
43
+ },
44
+ "train_pt_args": {
45
+ //预训练微调配置
46
+ "stage": "pt",
47
+ "dataset": "wechat-pt",
48
+ "dataset_dir": "./dataset/res_csv/pt",
49
+ "lora_target": "q_proj,v_proj",
50
+ "lora_rank": 2,
51
+ "lora_dropout": 0.1,
52
+ "output_dir": "model_output",
53
+ "overwrite_cache": true,
54
+ "per_device_train_batch_size": 1,
55
+ "gradient_accumulation_steps": 1,
56
+ "lr_scheduler_type": "cosine",
57
+ "logging_steps": 10,
58
+ "save_steps": 1000,
59
+ "learning_rate": 0.001,
60
+ "num_train_epochs": 30,
61
+ "plot_loss": true,
62
+ "fp16": true
63
+ },
64
+ "train_sft_args": {
65
+ //微调配置
66
+ "stage": "sft",
67
+ "dataset": "wechat-sft",
68
+ "dataset_dir": "./dataset/res_csv/sft",
69
+ "use_fast_tokenizer": true,
70
+ "lora_target": "q_proj,v_proj",
71
+ "lora_rank": 4,
72
+ "lora_dropout": 0.3,
73
+ "weight_decay": 0.1,
74
+ "overwrite_cache": true,
75
+ "per_device_train_batch_size": 8,
76
+ "gradient_accumulation_steps": 4,
77
+ "lr_scheduler_type": "cosine",
78
+ "cutoff_len": 256,
79
+ "logging_steps": 10,
80
+ "save_steps": 100,
81
+ "learning_rate": 1e-4,
82
+ "warmup_ratio": 0.1,
83
+ "num_train_epochs": 2,
84
+ "plot_loss": true,
85
+ "fp16": true,
86
+ "flash_attn": "fa2",
87
+ // "deepspeed": "ds_config.json" //多卡训练
88
+ },
89
+ "infer_args": {
90
+ "repetition_penalty": 1.2,
91
+ "temperature": 0.5,
92
+ "max_length": 50,
93
+ "top_p": 0.65
94
+ }
95
+ }
spaces_app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # 配置模型路径 - 使用您本地的模型目录
6
+ MODEL_PATH = "./Qwen2.5-7B-Instruct"
7
+
8
+ # 加载模型和分词器
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_PATH,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto",
14
+ trust_remote_code=True
15
+ )
16
+
17
+ # 聊天函数
18
+ def chat(message, history):
19
+ history = history or []
20
+ chat_history = ""
21
+ for human, assistant in history:
22
+ chat_history += f"<|im_start|>user\n{human}<|im_end|>\n"
23
+ chat_history += f"<|im_start|>assistant\n{assistant}<|im_end|>\n"
24
+
25
+ prompt = f"{chat_history}<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
26
+
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
+ outputs = model.generate(
29
+ **inputs,
30
+ max_new_tokens=512,
31
+ do_sample=True,
32
+ temperature=0.7,
33
+ top_p=0.9,
34
+ repetition_penalty=1.1,
35
+ eos_token_id=tokenizer.eos_token_id
36
+ )
37
+ response = tokenizer.decode(
38
+ outputs[0][inputs.input_ids.shape[1]:],
39
+ skip_special_tokens=True
40
+ )
41
+ return response
42
+
43
+ # 创建界面
44
+ demo = gr.ChatInterface(
45
+ chat,
46
+ title="WeClone AI 助手",
47
+ description="基于 Qwen2.5-7B 的聊天演示",
48
+ theme="soft",
49
+ examples=["你好", "介绍一下你自己", "你能做什么?"]
50
+ )
51
+
52
+ # 导出为可部署对象
53
+ app = demo
tests/__init__.py ADDED
File without changes
tests/full_pipe.jsonc ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.2",
3
+ "common_args": {
4
+ "model_name_or_path": "./Qwen2.5-3B-Instruct",
5
+ "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir
6
+ "template": "qwen",
7
+ "default_system": "请你扮演一名人类,不要说自己是人工智能",
8
+ "finetuning_type": "lora",
9
+ "trust_remote_code": true
10
+ },
11
+ "cli_args": {
12
+ "full_log": false
13
+ },
14
+ "make_dataset_args": {
15
+ //数据处理配置
16
+ "include_type": [
17
+ "文本"
18
+ ],
19
+ "blocked_words": [ // 禁用词
20
+ "例如 姓名",
21
+ "例如 密码",
22
+ "//....."
23
+ ],
24
+ "single_combine_strategy": "time_window", // 单人组成单句策略
25
+ "qa_match_strategy": "time_window", // 组成qa策略
26
+ "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟),
27
+ "qa_match_time_window": 5, // 组成qa时间窗口(分钟),
28
+ "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用
29
+ "prompt_with_history": false, // 是否在prompt中包含历史对话
30
+ "clean_dataset": {
31
+ "enable_clean": true,
32
+ "clean_strategy": "llm",
33
+ "llm": {
34
+ "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练
35
+ }
36
+ }
37
+ },
38
+ "train_pt_args": {
39
+ //预训练微调配置
40
+ "stage": "pt",
41
+ "dataset": "wechat-pt",
42
+ "dataset_dir": "./dataset/res_csv/pt",
43
+ "lora_target": "q_proj,v_proj",
44
+ "lora_rank": 2,
45
+ "lora_dropout": 0.1,
46
+ "output_dir": "model_output",
47
+ "overwrite_cache": true,
48
+ "per_device_train_batch_size": 1,
49
+ "gradient_accumulation_steps": 1,
50
+ "lr_scheduler_type": "cosine",
51
+ "logging_steps": 10,
52
+ "save_steps": 1000,
53
+ "learning_rate": 0.001,
54
+ "num_train_epochs": 30,
55
+ "plot_loss": true,
56
+ "fp16": true
57
+ },
58
+ "train_sft_args": {
59
+ //微调配置
60
+ "stage": "sft",
61
+ "dataset": "wechat-sft",
62
+ "dataset_dir": "./dataset/res_csv/sft",
63
+ "use_fast_tokenizer": true,
64
+ "lora_target": "q_proj,v_proj",
65
+ "lora_rank": 4,
66
+ "lora_dropout": 0.3,
67
+ "weight_decay": 0.1,
68
+ "overwrite_cache": true,
69
+ "per_device_train_batch_size": 8,
70
+ "gradient_accumulation_steps": 4,
71
+ "lr_scheduler_type": "cosine",
72
+ "cutoff_len": 256,
73
+ "logging_steps": 5,
74
+ "save_steps": 10,
75
+ "learning_rate": 1e-4,
76
+ "warmup_ratio": 0.1,
77
+ "num_train_epochs": 1,
78
+ "plot_loss": true,
79
+ "fp16": true,
80
+ "flash_attn": "fa2",
81
+ // "deepspeed": "ds_config.json" //多卡训练
82
+ },
83
+ "infer_args": {
84
+ "repetition_penalty": 1.2,
85
+ "temperature": 0.5,
86
+ "max_length": 50,
87
+ "top_p": 0.65
88
+ }
89
+ }
tests/test_full_pipe.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest import mock
3
+ import sys
4
+ import os
5
+ import shutil
6
+ import functools
7
+ import subprocess
8
+ import time
9
+ from typing import Union, Optional, cast
10
+ from weclone.utils.log import logger
11
+
12
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
+ PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
14
+ server_process: Optional[subprocess.Popen] = None
15
+
16
+ test_logger = logger.bind()
17
+ test_logger.remove()
18
+ test_logger.add(
19
+ sys.stderr,
20
+ format="<yellow><b>{message}</b></yellow>",
21
+ colorize=True,
22
+ level="INFO",
23
+ )
24
+
25
+ def print_test_header(test_name: str):
26
+ line_length = 100
27
+ test_logger.info("\n" + "─" * line_length)
28
+ title = f" Testing Phase: {test_name} "
29
+ padding_total = line_length - len(title)
30
+ padding_left = padding_total // 2
31
+ padding_right = padding_total - padding_left
32
+ test_logger.info(" " * padding_left + title + " " * padding_right)
33
+ test_logger.info("─" * line_length)
34
+
35
+ def setup_make_dataset_test_data():
36
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
37
+ DATASET_CSV_DIR = os.path.join(PROJECT_ROOT, "dataset", "csv")
38
+
39
+ TESTS_DIR = os.path.dirname(__file__)
40
+ TEST_DATA_PERSON_DIR = os.path.join(TESTS_DIR, "tests_data", "test_person")
41
+
42
+ os.makedirs(DATASET_CSV_DIR, exist_ok=True)
43
+
44
+ if os.path.exists(DATASET_CSV_DIR) and os.listdir(DATASET_CSV_DIR):
45
+ if all(f.startswith('.') or f.lower() == 'readme.md' for f in os.listdir(DATASET_CSV_DIR)):
46
+ for item_name in os.listdir(TEST_DATA_PERSON_DIR):
47
+ source_item_path = os.path.join(TEST_DATA_PERSON_DIR, item_name)
48
+ if os.path.isfile(source_item_path) and item_name.lower().endswith('.csv'):
49
+ destination_item_path = os.path.join(DATASET_CSV_DIR, item_name)
50
+ shutil.copy2(source_item_path, destination_item_path)
51
+
52
+
53
+ def run_cli_command(command: list[str], timeout: int | None = None, background: bool = False) -> Union[subprocess.CompletedProcess, subprocess.Popen]:
54
+ """Execute a CLI command and return the result.
55
+
56
+ Args:
57
+ command: List of commands to execute.
58
+ timeout: Timeout in seconds.
59
+ background: Whether to run in the background.
60
+
61
+ Returns:
62
+ If background=True, returns a Popen object; otherwise, returns a CompletedProcess object.
63
+ """
64
+ env = os.environ.copy()
65
+ env["WECLONE_CONFIG_PATH"] = "tests/full_pipe.jsonc" # Set environment variable
66
+
67
+ if background:
68
+ process = subprocess.Popen(
69
+ [sys.executable, "-m", "weclone.cli"] + command,
70
+ stderr=subprocess.PIPE,
71
+ stdout=subprocess.PIPE,
72
+ text=True,
73
+ cwd=PROJECT_ROOT_DIR,
74
+ env=env
75
+ )
76
+ time.sleep(2)
77
+ return process
78
+ else:
79
+ process = subprocess.run(
80
+ [sys.executable, "-m", "weclone.cli"] + command,
81
+ stderr=None,
82
+ stdout=None,
83
+ text=True,
84
+ cwd=PROJECT_ROOT_DIR, # Execute in the project root directory
85
+ timeout=timeout,
86
+ env=env # Pass the modified environment variables
87
+ )
88
+ return process
89
+
90
+ @pytest.mark.order(1)
91
+ def test_cli_make_dataset():
92
+ """Test the make-dataset command."""
93
+ print_test_header("make-dataset")
94
+ setup_make_dataset_test_data()
95
+ result = run_cli_command(["make-dataset"])
96
+ assert result.returncode == 0, "make-dataset command execution failed"
97
+
98
+ @pytest.mark.order(2)
99
+ def test_cli_train_sft():
100
+ """Test the train-sft command."""
101
+ print_test_header("train-sft")
102
+ try:
103
+ result = run_cli_command(["train-sft"])
104
+ assert result.returncode == 0, "train-sft command failed or did not fail fast as expected"
105
+ except subprocess.TimeoutExpired:
106
+ test_logger.info("train-sft command terminated due to timeout, which is acceptable in testing, indicating the command has started execution.")
107
+ pass
108
+ except Exception as e:
109
+ pytest.fail(f"An unexpected error occurred during train-sft command execution: {e}")
110
+
111
+ @pytest.mark.order(3)
112
+ def test_cli_webchat_demo():
113
+ """Test the webchat-demo command."""
114
+ print_test_header("webchat-demo")
115
+
116
+ with mock.patch("weclone.eval.web_demo.main") as mock_main:
117
+ mock_main.return_value = None
118
+ try:
119
+ result = run_cli_command(["webchat-demo"], timeout=5)
120
+ assert result.returncode == 0, "webchat-demo command execution failed"
121
+ except subprocess.TimeoutExpired:
122
+ pass
123
+
124
+ @pytest.mark.order(4)
125
+ def test_cli_server():
126
+ """Test the server command.
127
+
128
+ Start the server in the background, without blocking subsequent tests.
129
+ """
130
+ print_test_header("server (background)")
131
+ global server_process
132
+ server_process = cast(subprocess.Popen, run_cli_command(["server"], background=True))
133
+ assert server_process.poll() is None, "Server startup failed"
134
+ test_logger.info("服务器已在后台启动")
135
+
136
+ @pytest.mark.order(5)
137
+ def test_cli_test_model():
138
+ """Test the test-model command.
139
+
140
+ Use the server for testing, and shut down the server after the test is complete.
141
+ """
142
+ print_test_header("test-model")
143
+ try:
144
+ result = run_cli_command(["test-model"])
145
+ assert result.returncode == 0, "test-model command execution failed"
146
+ finally:
147
+ global server_process
148
+ if server_process is not None and server_process.poll() is None:
149
+ test_logger.info("测试完成,正在关闭服务器...")
150
+ server_process.terminate()
151
+ server_process.wait(timeout=5)
152
+ if server_process.poll() is None:
153
+ server_process.kill() # Force kill if the process hasn't terminated
154
+ test_logger.info("服务器已关闭")
torchvision.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:307e52c2887c1d2b50cc3581cf5f4c169130b8352462e361e71eeda19e0dd263
3
+ size 5660713
weclone-audio/README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WeClone-audio 模块
2
+
3
+ WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。
4
+ ### 显存需求
5
+ **Spark-TTS** 推荐
6
+ - **0.5B 模型**: 约 4GB 显存
7
+
8
+ **Llasa** (已弃用)
9
+ - **3B 模型**: 约 16GB 显存
10
+ - **1B 模型**: 约 9GB 显存
11
+
12
+
13
+
14
+
15
+ ## 1. 导出微信语音数据
16
+
17
+ ### 1.1 准备工作
18
+ - 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录
19
+ - 下载软件并解密数据库
20
+ - 点击聊天备份,导出类型选择"解密文件"
21
+
22
+ ### 1.2 环境配置
23
+ 语音导出仅支持Windows环境
24
+ WeClone Audio使用uv作为包管理器。
25
+ ```bash
26
+ # 为 PyWxDump 创建 Python 环境和安装依赖
27
+ #
28
+ uv venv .venv-wx --python=3.10
29
+ .venv-wx\Scripts\activate
30
+ uv pip install pywxdump
31
+ ```
32
+
33
+ ### 1.3 导出语音文件
34
+ ```bash
35
+ python weclone-audio/src/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段"
36
+ ```
37
+
38
+ ## 2. 语音合成推理
39
+ ### Spark-TTS模型
40
+
41
+ **环境安装**
42
+ 可不创建新环境,直接安装`sparktts`依赖组到WeClone共主环境
43
+
44
+ ```bash
45
+ uv venv .venv-sparktts --python=3.10
46
+ source .venv-sparktts/bin/activate
47
+ uv pip install --group sparktts -e .
48
+
49
+ git clone https://github.com/SparkAudio/Spark-TTS.git weclone-audio/src/Spark-TTS
50
+ ```
51
+
52
+
53
+ **模型下载**
54
+
55
+ 通过python下载:
56
+ ```python
57
+ from huggingface_hub import snapshot_download
58
+
59
+ # 假设此 Python 代码在 weclone-audio 目录下运行 模型将下载到 weclone-audio/pretrained_models/Spark-TTS-0.5B
60
+ snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
61
+ ```
62
+
63
+ 或通过git下载:
64
+ ```bash
65
+ # 假设当前在 weclone-audio 目录
66
+ mkdir -p pretrained_models
67
+
68
+ # Make sure you have git-lfs installed (https://git-lfs.com)
69
+ git lfs install
70
+ git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
71
+ ```
72
+ 使用代码推理
73
+ ```python
74
+ import os
75
+ import SparkTTS
76
+ import soundfile as sf
77
+ import torch
78
+
79
+ from SparkTTS import SparkTTS
80
+
81
+ # 假设此 Python 代码在 weclone-audio 目录下运行
82
+ # 模型路径相对于当前目录
83
+ model_path = "pretrained_models/Spark-TTS-0.5B"
84
+ sample_audio = "sample.wav"
85
+ output_audio = "output.wav"
86
+
87
+ model = SparkTTS(model_path, "cuda")
88
+
89
+ with torch.no_grad():
90
+ wav = model.inference(
91
+ text="晚上好啊,小可爱们,该睡觉了哦",
92
+ prompt_speech_path=sample_audio, # 使用相对路径
93
+ prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。",
94
+ )
95
+ sf.write(output_audio, wav, samplerate=16000) # 使用相对路径
96
+ ```
97
+ ### Llasa模型 (已弃用)
98
+ ### 2.1 环境配置
99
+ ```bash
100
+ # 创建并配置推理环境
101
+ ## 可不创建新环境,与LLaMA-Factory环境共用
102
+ uv venv .venv-xcodec --python=3.9
103
+ source .venv-xcodec/bin/activate
104
+ uv pip install --group xcodec -e .
105
+ # 退出环境
106
+ deactivate
107
+
108
+ # 系统依赖安装(如果需要)
109
+ sudo apt install python3-dev
110
+ sudo apt install build-essential
111
+ ```
112
+
113
+ ### 2.2 使用代码推理
114
+ 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
115
+ ```python
116
+ import os
117
+ import soundfile as sf
118
+ # 假设 text_to_speech.py 位于 src/ 或其他可导入的位置
119
+ from text_to_speech import TextToSpeech
120
+
121
+
122
+ sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本
123
+ # 假设此 Python 代码在 weclone-audio 目录下运行
124
+ # 示例音频路径相对于当前目录
125
+ sample_audio_path = "sample.wav"
126
+ output_audio = "output.wav"
127
+
128
+
129
+ tts = TextToSpeech(sample_audio_path, sample_audio_text)
130
+ target_text = "晚上好啊" # 生成目标文本
131
+ result = tts.infer(target_text)
132
+ sf.write(output_audio, result[1], result[0]) # 使用相对路径
133
+ ```
134
+
weclone-audio/src/Llasa/infer.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile as sf
3
+ from text_to_speech import TextToSpeech
4
+
5
+
6
+ sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本
7
+ sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径
8
+ tts = TextToSpeech(sample_audio_path, sample_audio_text)
9
+ target_text = "晚上好啊" # 生成目标文本
10
+ result = tts.infer(target_text)
11
+ sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频
12
+
weclone-audio/src/Llasa/text_to_speech.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import soundfile as sf
5
+ from xcodec2.modeling_xcodec2 import XCodec2Model
6
+ import torchaudio
7
+
8
+
9
+ class TextToSpeech:
10
+ def __init__(self, sample_audio_path, sample_audio_text):
11
+ self.sample_audio_text = sample_audio_text
12
+ # 初始化模型
13
+ llasa_3b = "HKUSTAudio/Llasa-3B"
14
+ xcodec2 = "HKUSTAudio/xcodec2"
15
+
16
+ self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
17
+ self.llasa_3b_model = AutoModelForCausalLM.from_pretrained(
18
+ llasa_3b,
19
+ trust_remote_code=True,
20
+ device_map="auto",
21
+ )
22
+ self.llasa_3b_model.eval()
23
+
24
+ self.xcodec_model = XCodec2Model.from_pretrained(xcodec2)
25
+ self.xcodec_model.eval().cuda()
26
+
27
+ # 处理音频
28
+ waveform, sample_rate = torchaudio.load(sample_audio_path)
29
+ if len(waveform[0]) / sample_rate > 15:
30
+ print("已将音频裁剪至前15秒。")
31
+ waveform = waveform[:, : sample_rate * 15]
32
+
33
+ # 检查音频是否为立体声
34
+ if waveform.size(0) > 1:
35
+ waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
36
+ else:
37
+ waveform_mono = waveform
38
+
39
+ self.prompt_wav = torchaudio.transforms.Resample(
40
+ orig_freq=sample_rate, new_freq=16000
41
+ )(waveform_mono)
42
+
43
+ # Encode the prompt wav
44
+ vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav)
45
+ vq_code_prompt = vq_code_prompt[0, 0, :]
46
+ self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
47
+ self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
48
+
49
+ def ids_to_speech_tokens(self, speech_ids):
50
+ speech_tokens_str = []
51
+ for speech_id in speech_ids:
52
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
53
+ return speech_tokens_str
54
+
55
+ def extract_speech_ids(self, speech_tokens_str):
56
+ speech_ids = []
57
+ for token_str in speech_tokens_str:
58
+ if token_str.startswith("<|s_") and token_str.endswith("|>"):
59
+ num_str = token_str[4:-2]
60
+ num = int(num_str)
61
+ speech_ids.append(num)
62
+ else:
63
+ print(f"Unexpected token: {token_str}")
64
+ return speech_ids
65
+
66
+ @torch.inference_mode()
67
+ def infer(self, target_text):
68
+ if len(target_text) == 0:
69
+ return None
70
+ elif len(target_text) > 300:
71
+ print("文本过长,请保持在300字符以内。")
72
+ target_text = target_text[:300]
73
+
74
+ input_text = self.sample_audio_text + " " + target_text
75
+
76
+ formatted_text = (
77
+ f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
78
+ )
79
+
80
+ chat = [
81
+ {
82
+ "role": "user",
83
+ "content": "Convert the text to speech:" + formatted_text,
84
+ },
85
+ {
86
+ "role": "assistant",
87
+ "content": "<|SPEECH_GENERATION_START|>"
88
+ + "".join(self.speech_ids_prefix),
89
+ },
90
+ ]
91
+
92
+ input_ids = self.tokenizer.apply_chat_template(
93
+ chat, tokenize=True, return_tensors="pt", continue_final_message=True
94
+ )
95
+ input_ids = input_ids.to("cuda")
96
+
97
+ outputs = self.llasa_3b_model.generate(
98
+ input_ids,
99
+ max_length=2048,
100
+ eos_token_id=self.speech_end_id,
101
+ do_sample=True,
102
+ top_p=1,
103
+ temperature=0.8,
104
+ )
105
+ generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1]
106
+
107
+ speech_tokens = self.tokenizer.batch_decode(
108
+ generated_ids, skip_special_tokens=True
109
+ )
110
+
111
+ speech_tokens = self.extract_speech_ids(speech_tokens)
112
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
113
+
114
+ gen_wav = self.xcodec_model.decode_code(speech_tokens)
115
+ gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:]
116
+
117
+ return (16000, gen_wav[0, 0, :].cpu().numpy())
118
+
119
+
120
+ if __name__ == "__main__":
121
+ # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
122
+ sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
123
+ sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav")
124
+
125
+ tts = TextToSpeech(sample_audio_path, sample_audio_text)
126
+ target_text = "晚上好啊,吃了吗您"
127
+ result = tts.infer(target_text)
128
+ sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0])
129
+ target_text = "我是老北京正黄旗!"
130
+ result = tts.infer(target_text)
131
+ sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0])
weclone-audio/src/SparkTTS.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from typing import Tuple
4
+ from pathlib import Path
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS")))
9
+ from sparktts.utils.file import load_config
10
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
11
+ from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
12
+
13
+
14
+ class SparkTTS:
15
+ """
16
+ Spark-TTS for text-to-speech generation.
17
+ """
18
+
19
+ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
20
+ """
21
+ Initializes the SparkTTS model with the provided configurations and device.
22
+
23
+ Args:
24
+ model_dir (Path): Directory containing the model and config files.
25
+ device (torch.device): The device (CPU/GPU) to run the model on.
26
+ """
27
+ self.device = device
28
+ self.model_dir = model_dir
29
+ self.configs = load_config(f"{model_dir}/config.yaml")
30
+ self.sample_rate = self.configs["sample_rate"]
31
+ self._initialize_inference()
32
+
33
+ def _initialize_inference(self):
34
+ """Initializes the tokenizer, model, and audio tokenizer for inference."""
35
+ self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
36
+ self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
37
+ self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
38
+ self.model.to(self.device)
39
+
40
+ def process_prompt(
41
+ self,
42
+ text: str,
43
+ prompt_speech_path: Path,
44
+ prompt_text: str = None,
45
+ ) -> Tuple[str, torch.Tensor]:
46
+ """
47
+ Process input for voice cloning.
48
+
49
+ Args:
50
+ text (str): The text input to be converted to speech.
51
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
52
+ prompt_text (str, optional): Transcript of the prompt audio.
53
+
54
+ Return:
55
+ Tuple[str, torch.Tensor]: Input prompt; global tokens
56
+ """
57
+
58
+ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
59
+ prompt_speech_path
60
+ )
61
+ global_tokens = "".join(
62
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
63
+ )
64
+
65
+ # Prepare the input tokens for the model
66
+ if prompt_text is not None:
67
+ semantic_tokens = "".join(
68
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
69
+ )
70
+ inputs = [
71
+ TASK_TOKEN_MAP["tts"],
72
+ "<|start_content|>",
73
+ prompt_text,
74
+ text,
75
+ "<|end_content|>",
76
+ "<|start_global_token|>",
77
+ global_tokens,
78
+ "<|end_global_token|>",
79
+ "<|start_semantic_token|>",
80
+ semantic_tokens,
81
+ ]
82
+ else:
83
+ inputs = [
84
+ TASK_TOKEN_MAP["tts"],
85
+ "<|start_content|>",
86
+ text,
87
+ "<|end_content|>",
88
+ "<|start_global_token|>",
89
+ global_tokens,
90
+ "<|end_global_token|>",
91
+ ]
92
+
93
+ inputs = "".join(inputs)
94
+
95
+ return inputs, global_token_ids
96
+
97
+ def process_prompt_control(
98
+ self,
99
+ gender: str,
100
+ pitch: str,
101
+ speed: str,
102
+ text: str,
103
+ ):
104
+ """
105
+ Process input for voice creation.
106
+
107
+ Args:
108
+ gender (str): female | male.
109
+ pitch (str): very_low | low | moderate | high | very_high
110
+ speed (str): very_low | low | moderate | high | very_high
111
+ text (str): The text input to be converted to speech.
112
+
113
+ Return:
114
+ str: Input prompt
115
+ """
116
+ assert gender in GENDER_MAP.keys()
117
+ assert pitch in LEVELS_MAP.keys()
118
+ assert speed in LEVELS_MAP.keys()
119
+
120
+ gender_id = GENDER_MAP[gender]
121
+ pitch_level_id = LEVELS_MAP[pitch]
122
+ speed_level_id = LEVELS_MAP[speed]
123
+
124
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
125
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
126
+ gender_tokens = f"<|gender_{gender_id}|>"
127
+
128
+ attribte_tokens = "".join(
129
+ [gender_tokens, pitch_label_tokens, speed_label_tokens]
130
+ )
131
+
132
+ control_tts_inputs = [
133
+ TASK_TOKEN_MAP["controllable_tts"],
134
+ "<|start_content|>",
135
+ text,
136
+ "<|end_content|>",
137
+ "<|start_style_label|>",
138
+ attribte_tokens,
139
+ "<|end_style_label|>",
140
+ ]
141
+
142
+ return "".join(control_tts_inputs)
143
+
144
+ @torch.no_grad()
145
+ def inference(
146
+ self,
147
+ text: str,
148
+ prompt_speech_path: Path = None,
149
+ prompt_text: str = None,
150
+ gender: str = None,
151
+ pitch: str = None,
152
+ speed: str = None,
153
+ temperature: float = 0.8,
154
+ top_k: float = 50,
155
+ top_p: float = 0.95,
156
+ ) -> torch.Tensor:
157
+ """
158
+ Performs inference to generate speech from text, incorporating prompt audio and/or text.
159
+
160
+ Args:
161
+ text (str): The text input to be converted to speech.
162
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
163
+ prompt_text (str, optional): Transcript of the prompt audio.
164
+ gender (str): female | male.
165
+ pitch (str): very_low | low | moderate | high | very_high
166
+ speed (str): very_low | low | moderate | high | very_high
167
+ temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
168
+ top_k (float, optional): Top-k sampling parameter. Default is 50.
169
+ top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
170
+
171
+ Returns:
172
+ torch.Tensor: Generated waveform as a tensor.
173
+ """
174
+ if gender is not None:
175
+ prompt = self.process_prompt_control(gender, pitch, speed, text)
176
+
177
+ else:
178
+ prompt, global_token_ids = self.process_prompt(
179
+ text, prompt_speech_path, prompt_text
180
+ )
181
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
182
+
183
+ # Generate speech using the model
184
+ generated_ids = self.model.generate(
185
+ **model_inputs,
186
+ max_new_tokens=3000,
187
+ do_sample=True,
188
+ top_k=top_k,
189
+ top_p=top_p,
190
+ temperature=temperature,
191
+ )
192
+
193
+ # Trim the output tokens to remove the input tokens
194
+ generated_ids = [
195
+ output_ids[len(input_ids):]
196
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
197
+ ]
198
+
199
+ # Decode the generated tokens into text
200
+ predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
201
+
202
+ # Extract semantic token IDs from the generated text
203
+ pred_semantic_ids = (
204
+ torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
205
+ .long()
206
+ .unsqueeze(0)
207
+ )
208
+
209
+ if gender is not None:
210
+ global_token_ids = (
211
+ torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
212
+ .long()
213
+ .unsqueeze(0)
214
+ .unsqueeze(0)
215
+ )
216
+
217
+ # Convert semantic tokens back to waveform
218
+ wav = self.audio_tokenizer.detokenize(
219
+ global_token_ids.to(self.device).squeeze(0),
220
+ pred_semantic_ids.to(self.device),
221
+ )
222
+
223
+ return wav
weclone-audio/src/__init__.py ADDED
File without changes
weclone-audio/src/get_sample_audio.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pywxdump.db import MediaHandler
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="Extract audio from WeChat database")
7
+ parser.add_argument("--db-path", type=str, required=True,
8
+ help="Path to WeChat database file")
9
+ parser.add_argument("--MsgSvrID", type=str, required=True,
10
+ help="Message server ID of the audio")
11
+ parser.add_argument("--save-path", type=str,
12
+ default=os.path.join(os.path.dirname(__file__), "sample.wav"),
13
+ help="Path to save the audio file (default: sample.wav in script directory)")
14
+ parser.add_argument("--rate", type=int, default=24000,
15
+ help="Sample rate for audio conversion (default: 24000)")
16
+
17
+ args = parser.parse_args()
18
+
19
+ config = {
20
+ "key": "test1",
21
+ "type": "sqlite",
22
+ "path": args.db_path,
23
+ }
24
+
25
+ t1 = MediaHandler(config)
26
+ t1.get_audio(
27
+ MsgSvrID=args.MsgSvrID,
28
+ is_play=True,
29
+ is_wave=True,
30
+ save_path=args.save_path,
31
+ rate=args.rate,
32
+ )
33
+
34
+ if __name__ == "__main__":
35
+ main()
weclone-audio/src/infer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile as sf
3
+ import torch
4
+
5
+ from SparkTTS import SparkTTS
6
+
7
+ model = SparkTTS("weclone-audio/pretrained_models/Spark-TTS-0.5B", "cuda")
8
+
9
+
10
+ with torch.no_grad():
11
+ wav = model.inference(
12
+ text="晚上好啊,小可爱们,该睡觉了哦",
13
+ prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"),
14
+ prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。",
15
+ )
16
+ sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000)
17
+ print("生成成功!")
weclone-audio/src/sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:014954ddd00ec481f993ca65c904b8f3ff426df1be05ca260e2b03b3e892fc1b
3
+ size 412402
weclone-audio/src/server未完工/.env.example ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ API_KEY=your_api_key_here
2
+ PORT=5050
3
+
4
+ DEFAULT_VOICE=en-US-AvaNeural
5
+ DEFAULT_RESPONSE_FORMAT=mp3
6
+ DEFAULT_SPEED=1.0
7
+
8
+ DEFAULT_LANGUAGE=en-US
9
+
10
+ REQUIRE_API_KEY=True
11
+
12
+ REMOVE_FILTER=False
13
+
14
+ EXPAND_API=True
weclone-audio/src/server未完工/handle_text.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import emoji
3
+
4
+ def prepare_tts_input_with_context(text: str) -> str:
5
+ """
6
+ Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints
7
+ for certain Markdown elements like headers. Preserves paragraph separation.
8
+
9
+ Args:
10
+ text (str): The raw text containing Markdown or other formatting.
11
+
12
+ Returns:
13
+ str: Cleaned text with contextual hints suitable for TTS input.
14
+ """
15
+
16
+ # Remove emojis
17
+ text = emoji.replace_emoji(text, replace='')
18
+
19
+ # Add context for headers
20
+ def header_replacer(match):
21
+ level = len(match.group(1)) # Number of '#' symbols
22
+ header_text = match.group(2).strip()
23
+ if level == 1:
24
+ return f"Title — {header_text}\n"
25
+ elif level == 2:
26
+ return f"Section — {header_text}\n"
27
+ else:
28
+ return f"Subsection — {header_text}\n"
29
+
30
+ text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE)
31
+
32
+ # Announce links (currently commented out for potential future use)
33
+ # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text)
34
+
35
+ # Remove links while keeping the link text
36
+ text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text)
37
+
38
+ # Describe inline code
39
+ text = re.sub(r"`([^`]+)`", r"code snippet: \1", text)
40
+
41
+ # Remove bold/italic symbols but keep the content
42
+ text = re.sub(r"(\*\*|__|\*|_)", '', text)
43
+
44
+ # Remove code blocks (multi-line) with a description
45
+ text = re.sub(r"```([\s\S]+?)```", r"(code block omitted)", text)
46
+
47
+ # Remove image syntax but add alt text if available
48
+ text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text)
49
+
50
+ # Remove HTML tags
51
+ text = re.sub(r"</?[^>]+(>|$)", '', text)
52
+
53
+ # Normalize line breaks
54
+ text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation
55
+
56
+ # Replace multiple spaces within lines
57
+ text = re.sub(r" {2,}", ' ', text)
58
+
59
+ # Trim leading and trailing whitespace from the whole text
60
+ text = text.strip()
61
+
62
+ return text
weclone-audio/src/server未完工/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask
2
+ gevent
3
+ python-dotenv
4
+ edge-tts
5
+ emoji
weclone-audio/src/server未完工/server.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py
2
+
3
+ from flask import Flask, request, send_file, jsonify
4
+ from gevent.pywsgi import WSGIServer
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ from handle_text import prepare_tts_input_with_context
9
+ from tts_handler import generate_speech, get_models, get_voices
10
+ from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES
11
+
12
+ app = Flask(__name__)
13
+ load_dotenv()
14
+
15
+ API_KEY = os.getenv('API_KEY', 'your_api_key_here')
16
+ PORT = int(os.getenv('PORT', 5050))
17
+
18
+ DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural')
19
+ DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3')
20
+ DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0))
21
+
22
+ REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False)
23
+ EXPAND_API = getenv_bool('EXPAND_API', True)
24
+
25
+ # DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1')
26
+
27
+ @app.route('/v1/audio/speech', methods=['POST'])
28
+ @app.route('/audio/speech', methods=['POST']) # Add this line for the alias
29
+ @require_api_key
30
+ def text_to_speech():
31
+ data = request.json
32
+ if not data or 'input' not in data:
33
+ return jsonify({"error": "Missing 'input' in request body"}), 400
34
+
35
+ text = data.get('input')
36
+
37
+ if not REMOVE_FILTER:
38
+ text = prepare_tts_input_with_context(text)
39
+
40
+ # model = data.get('model', DEFAULT_MODEL)
41
+ voice = data.get('voice', DEFAULT_VOICE)
42
+
43
+ response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT)
44
+ speed = float(data.get('speed', DEFAULT_SPEED))
45
+
46
+ mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg")
47
+
48
+ # Generate the audio file in the specified format with speed adjustment
49
+ output_file_path = generate_speech(text, voice, response_format, speed)
50
+
51
+ # Return the file with the correct MIME type
52
+ return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}")
53
+
54
+ @app.route('/v1/models', methods=['GET', 'POST'])
55
+ @app.route('/models', methods=['GET', 'POST'])
56
+ @require_api_key
57
+ def list_models():
58
+ return jsonify({"data": get_models()})
59
+
60
+ @app.route('/v1/voices', methods=['GET', 'POST'])
61
+ @app.route('/voices', methods=['GET', 'POST'])
62
+ @require_api_key
63
+ def list_voices():
64
+ specific_language = None
65
+
66
+ data = request.args if request.method == 'GET' else request.json
67
+ if data and ('language' in data or 'locale' in data):
68
+ specific_language = data.get('language') if 'language' in data else data.get('locale')
69
+
70
+ return jsonify({"voices": get_voices(specific_language)})
71
+
72
+ @app.route('/v1/voices/all', methods=['GET', 'POST'])
73
+ @app.route('/voices/all', methods=['GET', 'POST'])
74
+ @require_api_key
75
+ def list_all_voices():
76
+ return jsonify({"voices": get_voices('all')})
77
+
78
+ """
79
+ Support for ElevenLabs and Azure AI Speech
80
+ (currently in beta)
81
+ """
82
+
83
+ # http://localhost:5050/elevenlabs/v1/text-to-speech
84
+ # http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural
85
+ @app.route('/elevenlabs/v1/text-to-speech/<voice_id>', methods=['POST'])
86
+ @require_api_key
87
+ def elevenlabs_tts(voice_id):
88
+ if not EXPAND_API:
89
+ return jsonify({"error": f"Endpoint not allowed"}), 500
90
+
91
+ # Parse the incoming JSON payload
92
+ try:
93
+ payload = request.json
94
+ if not payload or 'text' not in payload:
95
+ return jsonify({"error": "Missing 'text' in request body"}), 400
96
+ except Exception as e:
97
+ return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400
98
+
99
+ text = payload['text']
100
+
101
+ if not REMOVE_FILTER:
102
+ text = prepare_tts_input_with_context(text)
103
+
104
+ voice = voice_id # ElevenLabs uses the voice_id in the URL
105
+
106
+ # Use default settings for edge-tts
107
+ response_format = 'mp3'
108
+ speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED)
109
+
110
+ # Generate speech using edge-tts
111
+ try:
112
+ output_file_path = generate_speech(text, voice, response_format, speed)
113
+ except Exception as e:
114
+ return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500
115
+
116
+ # Return the generated audio file
117
+ return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3")
118
+
119
+ # tts.speech.microsoft.com/cognitiveservices/v1
120
+ # https://{region}.tts.speech.microsoft.com/cognitiveservices/v1
121
+ # http://localhost:5050/azure/cognitiveservices/v1
122
+ @app.route('/azure/cognitiveservices/v1', methods=['POST'])
123
+ @require_api_key
124
+ def azure_tts():
125
+ if not EXPAND_API:
126
+ return jsonify({"error": f"Endpoint not allowed"}), 500
127
+
128
+ # Parse the SSML payload
129
+ try:
130
+ ssml_data = request.data.decode('utf-8')
131
+ if not ssml_data:
132
+ return jsonify({"error": "Missing SSML payload"}), 400
133
+
134
+ # Extract the text and voice from SSML
135
+ from xml.etree import ElementTree as ET
136
+ root = ET.fromstring(ssml_data)
137
+ text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text
138
+ voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name')
139
+ except Exception as e:
140
+ return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400
141
+
142
+ # Use default settings for edge-tts
143
+ response_format = 'mp3'
144
+ speed = DEFAULT_SPEED
145
+
146
+ if not REMOVE_FILTER:
147
+ text = prepare_tts_input_with_context(text)
148
+
149
+ # Generate speech using edge-tts
150
+ try:
151
+ output_file_path = generate_speech(text, voice, response_format, speed)
152
+ except Exception as e:
153
+ return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500
154
+
155
+ # Return the generated audio file
156
+ return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3")
157
+
158
+ print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API")
159
+ print(f" ")
160
+ print(f" * Serving OpenAI Edge TTS")
161
+ print(f" * Server running on http://localhost:{PORT}")
162
+ print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech")
163
+ print(f" ")
164
+
165
+ if __name__ == '__main__':
166
+ http_server = WSGIServer(('0.0.0.0', PORT), app)
167
+ http_server.serve_forever()
weclone-audio/src/server未完工/tts_handler.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import edge_tts
2
+ import asyncio
3
+ import tempfile
4
+ import subprocess
5
+ import os
6
+ from pathlib import Path
7
+
8
+ # Language default (environment variable)
9
+ DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US')
10
+
11
+ # OpenAI voice names mapped to edge-tts equivalents
12
+ voice_mapping = {
13
+ 'alloy': 'en-US-AvaNeural',
14
+ 'echo': 'en-US-AndrewNeural',
15
+ 'fable': 'en-GB-SoniaNeural',
16
+ 'onyx': 'en-US-EricNeural',
17
+ 'nova': 'en-US-SteffanNeural',
18
+ 'shimmer': 'en-US-EmmaNeural'
19
+ }
20
+
21
+ def is_ffmpeg_installed():
22
+ """Check if FFmpeg is installed and accessible."""
23
+ try:
24
+ subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
25
+ return True
26
+ except (subprocess.CalledProcessError, FileNotFoundError):
27
+ return False
28
+
29
+ async def _generate_audio(text, voice, response_format, speed):
30
+ """Generate TTS audio and optionally convert to a different format."""
31
+ # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice
32
+ edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is
33
+
34
+ # Generate the TTS output in mp3 format first
35
+ temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
36
+
37
+ # Convert speed to SSML rate format
38
+ try:
39
+ speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%"
40
+ except Exception as e:
41
+ print(f"Error converting speed: {e}. Defaulting to +0%.")
42
+ speed_rate = "+0%"
43
+
44
+ # Generate the MP3 file
45
+ communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate)
46
+ await communicator.save(temp_output_file.name)
47
+
48
+ # If the requested format is mp3, return the generated file directly
49
+ if response_format == "mp3":
50
+ return temp_output_file.name
51
+
52
+ # Check if FFmpeg is installed
53
+ if not is_ffmpeg_installed():
54
+ print("FFmpeg is not available. Returning unmodified mp3 file.")
55
+ return temp_output_file.name
56
+
57
+ # Create a new temporary file for the converted output
58
+ converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}")
59
+
60
+ # Build the FFmpeg command
61
+ ffmpeg_command = [
62
+ "ffmpeg",
63
+ "-i", temp_output_file.name, # Input file
64
+ "-c:a", {
65
+ "aac": "aac",
66
+ "mp3": "libmp3lame",
67
+ "wav": "pcm_s16le",
68
+ "opus": "libopus",
69
+ "flac": "flac"
70
+ }.get(response_format, "aac"), # Default to AAC if unknown
71
+ "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV
72
+ "-f", {
73
+ "aac": "mp4", # AAC in MP4 container
74
+ "mp3": "mp3",
75
+ "wav": "wav",
76
+ "opus": "ogg",
77
+ "flac": "flac"
78
+ }.get(response_format, response_format), # Default to matching format
79
+ "-y", # Overwrite without prompt
80
+ converted_output_file.name # Output file
81
+ ]
82
+
83
+ try:
84
+ # Run FFmpeg command and ensure no errors occur
85
+ subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
86
+ except subprocess.CalledProcessError as e:
87
+ raise RuntimeError(f"FFmpeg error during audio conversion: {e}")
88
+
89
+ # Clean up the original temporary file
90
+ Path(temp_output_file.name).unlink(missing_ok=True)
91
+
92
+ return converted_output_file.name
93
+
94
+ def generate_speech(text, voice, response_format, speed=1.0):
95
+ return asyncio.run(_generate_audio(text, voice, response_format, speed))
96
+
97
+ def get_models():
98
+ return [
99
+ {"id": "tts-1", "name": "Text-to-speech v1"},
100
+ {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"}
101
+ ]
102
+
103
+ async def _get_voices(language=None):
104
+ # List all voices, filter by language if specified
105
+ all_voices = await edge_tts.list_voices()
106
+ language = language or DEFAULT_LANGUAGE # Use default if no language specified
107
+ filtered_voices = [
108
+ {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']}
109
+ for v in all_voices if language == 'all' or language is None or v['Locale'] == language
110
+ ]
111
+ return filtered_voices
112
+
113
+ def get_voices(language=None):
114
+ return asyncio.run(_get_voices(language))
115
+
116
+ def speed_to_rate(speed: float) -> str:
117
+ """
118
+ Converts a multiplicative speed value to the edge-tts "rate" format.
119
+
120
+ Args:
121
+ speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%).
122
+
123
+ Returns:
124
+ str: The formatted "rate" string (e.g., "+50%" or "-50%").
125
+ """
126
+ if speed < 0 or speed > 2:
127
+ raise ValueError("Speed must be between 0 and 2 (inclusive).")
128
+
129
+ # Convert speed to percentage change
130
+ percentage_change = (speed - 1) * 100
131
+
132
+ # Format with a leading "+" or "-" as required
133
+ return f"{percentage_change:+.0f}%"
weclone-audio/src/server未完工/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ from flask import request, jsonify
4
+ from functools import wraps
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ def getenv_bool(name: str, default: bool = False) -> bool:
11
+ return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t")
12
+
13
+ API_KEY = os.getenv('API_KEY', 'your_api_key_here')
14
+ REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True)
15
+
16
+ def require_api_key(f):
17
+ @wraps(f)
18
+ def decorated_function(*args, **kwargs):
19
+ if not REQUIRE_API_KEY:
20
+ return f(*args, **kwargs)
21
+ auth_header = request.headers.get('Authorization')
22
+ if not auth_header or not auth_header.startswith('Bearer '):
23
+ return jsonify({"error": "Missing or invalid API key"}), 401
24
+ token = auth_header.split('Bearer ')[1]
25
+ if token != API_KEY:
26
+ return jsonify({"error": "Invalid API key"}), 401
27
+ return f(*args, **kwargs)
28
+ return decorated_function
29
+
30
+ # Mapping of audio format to MIME type
31
+ AUDIO_FORMAT_MIME_TYPES = {
32
+ "mp3": "audio/mpeg",
33
+ "opus": "audio/ogg",
34
+ "aac": "audio/aac",
35
+ "flac": "audio/flac",
36
+ "wav": "audio/wav",
37
+ "pcm": "audio/L16"
38
+ }
weclone/__init__.py ADDED
File without changes
weclone/cli.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import commentjson
3
+ from pathlib import Path
4
+ import os
5
+ import sys
6
+ import functools
7
+
8
+ from weclone.utils.log import logger, capture_output
9
+ from weclone.utils.config import load_config
10
+
11
+ cli_config: dict | None = None
12
+
13
+ try:
14
+ import tomllib # type: ignore Python 3.11+
15
+ except ImportError:
16
+ import tomli as tomllib
17
+
18
+
19
+ def clear_argv(func):
20
+ """
21
+ 装饰器:在调用被装饰函数前,清理 sys.argv,只保留脚本名。调用后恢复原始 sys.argv。
22
+ 用于防止参数被 Hugging Face HfArgumentParser 解析造成 ValueError。
23
+ """
24
+
25
+ @functools.wraps(func)
26
+ def wrapper(*args, **kwargs):
27
+ original_argv = sys.argv.copy()
28
+ sys.argv = [original_argv[0]] # 只保留脚本名
29
+ try:
30
+ return func(*args, **kwargs)
31
+ finally:
32
+ sys.argv = original_argv # 恢复原始 sys.argv
33
+
34
+ return wrapper
35
+
36
+
37
+ def apply_common_decorators(capture_output_enabled=False):
38
+ """
39
+ A unified decorator for applications
40
+ """
41
+
42
+ def decorator(original_cmd_func):
43
+ @functools.wraps(original_cmd_func)
44
+ def new_runtime_wrapper(*args, **kwargs):
45
+ if cli_config and cli_config.get("full_log", False):
46
+ return capture_output(original_cmd_func)(*args, **kwargs)
47
+ else:
48
+ return original_cmd_func(*args, **kwargs)
49
+
50
+ func_with_clear_argv = clear_argv(new_runtime_wrapper)
51
+
52
+ return functools.wraps(original_cmd_func)(func_with_clear_argv)
53
+
54
+ return decorator
55
+
56
+
57
+ @click.group()
58
+ def cli():
59
+ """WeClone: 从聊天记录创造数字分身的一站式解决方案"""
60
+ _check_project_root()
61
+ _check_versions()
62
+ global cli_config
63
+ cli_config = load_config(arg_type="cli_args")
64
+
65
+
66
+ @cli.command("make-dataset", help="处理聊天记录CSV文件,生成问答对数据集。")
67
+ @apply_common_decorators()
68
+ def qa_generator():
69
+ """处理聊天记录CSV文件,生成问答对数据集。"""
70
+ from weclone.data.qa_generator import DataProcessor
71
+
72
+ processor = DataProcessor()
73
+ processor.main()
74
+
75
+
76
+ @cli.command("train-sft", help="使用准备好的数据集对模型进行微调。")
77
+ @apply_common_decorators()
78
+ def train_sft():
79
+ """使用准备好的数据集对模型进行微调。"""
80
+ from weclone.train.train_sft import main as train_sft_main
81
+
82
+ train_sft_main()
83
+
84
+
85
+ @cli.command("webchat-demo", help="启动 Web UI 与微调后的模型进行交互测试。") # 命令名修改为 web-demo
86
+ @apply_common_decorators()
87
+ def web_demo():
88
+ """启动 Web UI 与微调后的模型进行交互测试。"""
89
+ from weclone.eval.web_demo import main as web_demo_main
90
+
91
+ web_demo_main()
92
+
93
+
94
+ # TODO 添加评估功能 @cli.command("eval-model", help="使用从训练数据中划分出来的验证集评估。")
95
+ @apply_common_decorators()
96
+ def eval_model():
97
+ """使用从训练数据中划分出来的验证集评估。"""
98
+ from weclone.eval.eval_model import main as evaluate_main
99
+
100
+ evaluate_main()
101
+
102
+
103
+ @cli.command("test-model", help="使用常见聊天问题测试模型。")
104
+ @apply_common_decorators()
105
+ def test_model():
106
+ """测试"""
107
+ from weclone.eval.test_model import main as test_main
108
+
109
+ test_main()
110
+
111
+
112
+ @cli.command("server", help="启动API服务,提供模型推理接口。")
113
+ @apply_common_decorators()
114
+ def server():
115
+ """启动API服务,提供模型推理接口。"""
116
+ from weclone.server.api_service import main as server_main
117
+
118
+ server_main()
119
+
120
+
121
+ def _check_project_root():
122
+ """检查当前目录是否为项目根目录,并验证项目名称。"""
123
+ project_root_marker = "pyproject.toml"
124
+ current_dir = Path(os.getcwd())
125
+ pyproject_path = current_dir / project_root_marker
126
+
127
+ if not pyproject_path.is_file():
128
+ logger.error(f"未在当前目录找到 {project_root_marker} 文件。")
129
+ logger.error("请确保在WeClone项目根目录下运行此命令。")
130
+ sys.exit(1)
131
+
132
+ try:
133
+ with open(pyproject_path, "rb") as f:
134
+ pyproject_data = tomllib.load(f)
135
+ project_name = pyproject_data.get("project", {}).get("name")
136
+ if project_name != "WeClone":
137
+ logger.error("请确保在正确的 WeClone 项目根目录下运行。")
138
+ sys.exit(1)
139
+ except tomllib.TOMLDecodeError as e:
140
+ logger.error(f"错误:无法解析 {pyproject_path} 文件: {e}")
141
+ sys.exit(1)
142
+ except Exception as e:
143
+ logger.error(f"读取或处理 {pyproject_path} 时发生意外错误: {e}")
144
+ sys.exit(1)
145
+
146
+
147
+ def _check_versions():
148
+ """比较本地 settings.jsonc 版本和 pyproject.toml 中的配置文件指南版本"""
149
+ if tomllib is None: # Skip check if toml parser failed to import
150
+ return
151
+
152
+ ROOT_DIR = Path(__file__).parent.parent
153
+ SETTINGS_PATH = ROOT_DIR / "settings.jsonc"
154
+ PYPROJECT_PATH = ROOT_DIR / "pyproject.toml"
155
+
156
+ settings_version = None
157
+ config_guide_version = None
158
+ config_changelog = None
159
+
160
+ if SETTINGS_PATH.exists():
161
+ try:
162
+ with open(SETTINGS_PATH, "r", encoding="utf-8") as f:
163
+ settings_data = commentjson.load(f)
164
+ settings_version = settings_data.get("version")
165
+ except Exception as e:
166
+ logger.error(f"错误:无法读取或解析 {SETTINGS_PATH}: {e}")
167
+ logger.error("请确保 settings.jsonc 文件存在且格式正确。")
168
+ sys.exit(1)
169
+ else:
170
+ logger.error(f"错误:未找到配置文件 {SETTINGS_PATH}。")
171
+ logger.error("请确保 settings.jsonc 文件位于项目根目录。")
172
+ sys.exit(1)
173
+
174
+ if PYPROJECT_PATH.exists():
175
+ try:
176
+ with open(PYPROJECT_PATH, "rb") as f: # tomllib 需要二进制模式
177
+ pyproject_data = tomllib.load(f)
178
+ weclone_tool_data = pyproject_data.get("tool", {}).get("weclone", {})
179
+ config_guide_version = weclone_tool_data.get("config_version")
180
+ config_changelog = weclone_tool_data.get("config_changelog", "N/A")
181
+ except Exception as e:
182
+ logger.warning(f"警告:无法读取或解析 {PYPROJECT_PATH}: {e}。无法检查配置文件是否为最新。")
183
+ else:
184
+ logger.warning(f"警告:未找到文件 {PYPROJECT_PATH}。无法检查配置文件是否为最新。")
185
+
186
+ if not settings_version:
187
+ logger.error(f"错误:在 {SETTINGS_PATH} 中未找到 'version' 字段。")
188
+ logger.error("请从 settings.template.json 复制或更新您的 settings.jsonc 文件。")
189
+ sys.exit(1)
190
+
191
+ if config_guide_version:
192
+ if settings_version != config_guide_version:
193
+ logger.warning(
194
+ f"警告:您的 settings.jsonc 文件版本 ({settings_version}) 与项目建议的配置版本 ({config_guide_version}) 不一致。"
195
+ )
196
+ logger.warning("这可能导致意外行为或错误。请从 settings.template.json 复制或更新您的 settings.jsonc 文件。")
197
+ # TODO 根据版本号打印更新日志
198
+ logger.warning(f"配置文件更新日志:\n{config_changelog}")
199
+ elif PYPROJECT_PATH.exists(): # 如果文件存在但未读到版本
200
+ logger.warning(
201
+ f"警告:在 {PYPROJECT_PATH} 的 [tool.weclone] 下未找到 'config_version' 字段。"
202
+ "无法确认您的 settings.jsonc 是否为最新配置版本。"
203
+ )
204
+
205
+
206
+ if __name__ == "__main__":
207
+ cli()
weclone/core/inference/offline_infer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Optional, Union
3
+
4
+
5
+ from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
6
+ from llamafactory.extras.constants import IGNORE_INDEX
7
+ from llamafactory.extras.misc import get_device_count
8
+ from llamafactory.extras.packages import is_vllm_available
9
+ from llamafactory.hparams import get_infer_args
10
+ from llamafactory.model import load_tokenizer
11
+ from pydantic import BaseModel
12
+ from vllm.sampling_params import GuidedDecodingParams
13
+
14
+
15
+ from vllm import LLM, SamplingParams
16
+ from vllm.lora.request import LoRARequest
17
+
18
+
19
+ # 这里不需要写太好,transforms库后续更新自带vllm
20
+
21
+
22
+ def vllm_infer(
23
+ inputs: Union[str, List[str]],
24
+ model_name_or_path: str,
25
+ adapter_name_or_path: Optional[str] = None,
26
+ dataset: str = "alpaca_en_demo",
27
+ dataset_dir: str = "data",
28
+ template: str = "default",
29
+ cutoff_len: int = 2048,
30
+ max_samples: Optional[int] = None,
31
+ vllm_config: str = "{}",
32
+ save_name: str = "generated_predictions.jsonl",
33
+ temperature: float = 0.95,
34
+ top_p: float = 0.7,
35
+ top_k: int = 50,
36
+ guided_decoding_class: Optional[type[BaseModel]] = None,
37
+ bad_words: Optional[List[str]] = None,
38
+ logprobs: Optional[int] = None,
39
+ max_new_tokens: int = 1024,
40
+ repetition_penalty: float = 1.0,
41
+ skip_special_tokens: bool = True,
42
+ seed: Optional[int] = None,
43
+ pipeline_parallel_size: int = 1,
44
+ image_max_pixels: int = 768 * 768,
45
+ image_min_pixels: int = 32 * 32,
46
+ ):
47
+ r"""Perform batch generation using vLLM engine, which supports tensor parallelism."""
48
+ if pipeline_parallel_size > get_device_count():
49
+ raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
50
+
51
+ model_args, data_args, _, generating_args = get_infer_args(
52
+ dict(
53
+ model_name_or_path=model_name_or_path,
54
+ adapter_name_or_path=adapter_name_or_path,
55
+ dataset=dataset,
56
+ dataset_dir=dataset_dir,
57
+ template=template,
58
+ cutoff_len=cutoff_len,
59
+ max_samples=max_samples,
60
+ preprocessing_num_workers=16,
61
+ vllm_config=vllm_config,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ max_new_tokens=max_new_tokens,
66
+ repetition_penalty=repetition_penalty,
67
+ )
68
+ )
69
+
70
+ tokenizer_module = load_tokenizer(model_args)
71
+ tokenizer = tokenizer_module["tokenizer"]
72
+ template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
73
+ template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
74
+
75
+ if guided_decoding_class:
76
+ json_schema = guided_decoding_class.model_json_schema()
77
+ guided_decoding_params = GuidedDecodingParams(json=json_schema)
78
+ else:
79
+ guided_decoding_params = None
80
+
81
+ sampling_params = SamplingParams(
82
+ repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
83
+ temperature=generating_args.temperature,
84
+ top_p=generating_args.top_p or 1.0, # top_p must > 0
85
+ top_k=generating_args.top_k or -1, # top_k must > 0
86
+ stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
87
+ max_tokens=generating_args.max_new_tokens,
88
+ skip_special_tokens=skip_special_tokens,
89
+ seed=seed,
90
+ guided_decoding=guided_decoding_params,
91
+ bad_words=bad_words,
92
+ )
93
+ if model_args.adapter_name_or_path is not None:
94
+ lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
95
+ else:
96
+ lora_request = None
97
+
98
+ engine_args = {
99
+ "model": model_args.model_name_or_path,
100
+ "trust_remote_code": True,
101
+ "dtype": model_args.infer_dtype,
102
+ "max_model_len": cutoff_len + max_new_tokens,
103
+ # "tensor_parallel_size": 1,
104
+ # "pipeline_parallel_size": pipeline_parallel_size,
105
+ # "data_parallel_size": get_device_count(), // vllm0.8.5版本支持DP
106
+ "disable_log_stats": True,
107
+ "enable_lora": model_args.adapter_name_or_path is not None,
108
+ "enable_prefix_caching": True, # 是否启用前缀缓存
109
+ "gpu_memory_utilization": 0.95,
110
+ # "quantization": "bitsandbytes", # 是否启用vllm的 bitsandbytes 的量化加载
111
+ # "load_format": "bitsandbytes",
112
+ }
113
+ if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
114
+ engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
115
+
116
+ if isinstance(model_args.vllm_config, dict):
117
+ engine_args.update(model_args.vllm_config)
118
+
119
+ results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
120
+ return results
weclone/core/inference/online_infer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import requests
4
+ from openai import OpenAI
5
+
6
+ class OnlineLLM:
7
+ def __init__(self, api_key: str, base_url: str,model_name: str,default_system: str):
8
+ self.api_key = api_key
9
+ self.base_url = base_url
10
+ self.model_name = model_name
11
+ self.default_system = default_system
12
+ self.client = OpenAI(
13
+ api_key=self.api_key,
14
+ base_url=self.base_url
15
+ )
16
+
17
+
18
+ def chat(self,prompt_text,
19
+ temperature: float = 0.7,
20
+ max_tokens: int = 1024,
21
+ top_p: float = 0.95,
22
+ stream: bool = False,
23
+ enable_thinking: bool = False):
24
+ messages = [
25
+ {"role": "system", "content": self.default_system},
26
+ {"role": "user", "content": prompt_text},
27
+ ]
28
+ response = self.client.chat.completions.create(
29
+ model=self.model_name,
30
+ messages=messages,
31
+ stream=stream,
32
+ temperature = temperature,
33
+ max_tokens=max_tokens,
34
+ top_p=top_p,
35
+ # enable_thinking=enable_thinking 适配Qwen3动态开启推理
36
+
37
+ )
38
+
39
+ return response
40
+
weclone/data/__init__.py ADDED
File without changes
weclone/data/chat_parsers/wechat_parser.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class WeChatParser:
2
+ def decrypt_wechat_image(self, encrypted_path, output_path):
3
+ """解密微信加密的图片文件"""
4
+ pass
5
+
6
+ def parse_chat_records(self, db_path):
7
+ """解析聊天记录数据库"""
8
+ pass
weclone/data/clean/__init__.py ADDED
File without changes
weclone/data/clean/get_score.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ # TODO 未使用
5
+ def adjust_score_tiered(
6
+ initial_score: int, probabilities: list[float], thresholds: list[float], downgrade_levels: list[int]
7
+ ) -> int:
8
+ """
9
+ 根据大模型给出评分时的概率,对原始评分进行分级置信度调整。
10
+
11
+ Args:
12
+ initial_score: 大模型给出的原始评分 (整数 1 到 5)。
13
+ probabilities: 包含 5 个评分 (1 到 5) 概率的列表。
14
+ 例如 [P(1), P(2), P(3), P(4), P(5)]。
15
+ thresholds: 一个降序排列的概率阈值列表,定义置信度区间边界。
16
+ 例如 [0.6, 0.3]。
17
+ downgrade_levels: 与 thresholds 对应的降级幅度列表,长度比 thresholds 多 1。
18
+ 定义了每个置信度区间的降级数。例如 [0, 1, 2]。
19
+
20
+ Returns:
21
+ 经过置信度调整后的最终评分 (整数 1 到 5)。
22
+
23
+ Raises:
24
+ ValueError: 如果输入参数不合法(例如概率列表长度不对,阈值未降序等)。
25
+ """
26
+ # --- 输入校验 ---
27
+ if not (1 <= initial_score <= 5):
28
+ raise ValueError("initial_score 必须在 1 到 5 之间。")
29
+ if len(probabilities) != 5:
30
+ raise ValueError("probabilities 列表必须包含 5 个元素。")
31
+ # 检查概率和是否接近 1 (允许小的浮点误差)
32
+ if not math.isclose(sum(probabilities), 1.0, abs_tol=1e-6):
33
+ print(f"警告: 概率之和 {sum(probabilities)} 不接近 1.0。请检查概率来源。") # 打印警告而非直接报错
34
+ # raise ValueError("probabilities 中元素的和必须接近 1.0。")
35
+ if len(downgrade_levels) != len(thresholds) + 1:
36
+ raise ValueError("downgrade_levels 的长度必须比 thresholds 的长度多 1。")
37
+ if any(thresholds[i] < thresholds[i + 1] for i in range(len(thresholds) - 1)):
38
+ raise ValueError("thresholds 列表必须是降序排列的。")
39
+ if any(level < 0 for level in downgrade_levels):
40
+ raise ValueError("downgrade_levels 中的降级幅度不能为负数。")
41
+
42
+ # --- 算法核心 ---
43
+ # 1. 获取选中分数的概率
44
+ # 列表索引从0开始,所以评分 s 对应的索引是 s-1
45
+ try:
46
+ p_chosen = probabilities[initial_score - 1]
47
+ except IndexError:
48
+ # 这个错误理论上不应发生,因为 initial_score 已校验在 1-5 之间
49
+ raise ValueError(f"无法从 probabilities 列表获取索引 {initial_score - 1} 的值。")
50
+
51
+ # 2. 确定降级幅度
52
+ downgrade = downgrade_levels[-1] # 默认为最低置信度区间的降级幅度
53
+ # 遍历阈值列表 (从高到低)
54
+ for i in range(len(thresholds)):
55
+ if p_chosen >= thresholds[i]:
56
+ downgrade = downgrade_levels[i] # 找到对应的置信度区间
57
+ break # 停止遍历
58
+
59
+ # 3. 计算调整后的评分
60
+ preliminary_score = initial_score - downgrade
61
+ adjusted_score = max(1, preliminary_score) # 确保分数不低于 1
62
+
63
+ # 4. 返回结果
64
+ return adjusted_score
weclone/data/clean/strategies.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Union
6
+ from langchain_core.prompts import PromptTemplate
7
+ from weclone.data.models import QaPair, CutMessage, QaPairScore
8
+ from weclone.prompts.clean_data import CLEAN_PROMPT
9
+ import os
10
+ from weclone.utils.log import logger
11
+
12
+
13
+ @dataclass
14
+ class CleaningStrategy(ABC):
15
+ """数据清洗策略的抽象基类"""
16
+
17
+ make_dataset_config: Dict
18
+
19
+ @abstractmethod
20
+ def clean(self, data: Any) -> Any:
21
+ """
22
+ 执行数据清洗操作。
23
+
24
+ Args:
25
+ data: 需要清洗的数据。
26
+
27
+ Returns:
28
+ 清洗后的数据。
29
+ """
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class LLMCleaningStrategy(CleaningStrategy):
35
+ """使用大模型进行数据清洗的策略"""
36
+
37
+
38
+ def judge(self, data: List[QaPair]) -> None:
39
+ """
40
+ 调用llm打分,并将分数直接赋值给传入的QaPair。
41
+ """
42
+ from weclone.core.inference.offline_infer import vllm_infer
43
+ logger.info("开始使用llm对数据打分")
44
+ inputs = []
45
+ prompt_template = PromptTemplate.from_template(CLEAN_PROMPT)
46
+ for qa in data:
47
+ inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text) # type: ignore
48
+ outputs = vllm_infer(
49
+ inputs,
50
+ self.make_dataset_config["model_name_or_path"],
51
+ template=self.make_dataset_config["template"],
52
+ temperature=0,
53
+ guided_decoding_class=QaPairScore,
54
+ repetition_penalty=1.2,
55
+ bad_words=[r"\n"],
56
+ )
57
+
58
+ parsed_scores: List[QaPairScore] = []
59
+ for result in outputs:
60
+ try:
61
+ score_data = json.loads(result.outputs[0].text)
62
+ qa_score = QaPairScore(**score_data)
63
+ parsed_scores.append(qa_score)
64
+ except json.JSONDecodeError:
65
+ logger.error(f"Error decoding JSON: {result.outputs[0].text}")
66
+
67
+ score_map = {score.id: score.score for score in parsed_scores}
68
+ for qa in data:
69
+ if qa.id in score_map:
70
+ qa.score = score_map[qa.id]
71
+ else:
72
+ logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.")
73
+
74
+ scores = [qa.score for qa in data if qa.score is not None]
75
+ score_series = pd.Series(scores)
76
+ score_counts = score_series.value_counts().sort_index()
77
+ score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
78
+ pd.set_option("display.unicode.east_asian_width", True) # 尝试修正对齐问题
79
+ distribution_df = pd.DataFrame( # 合并数量和百分比到一个 DataFrame 中以便打印
80
+ {
81
+ "数量": score_counts,
82
+ "占比(%)": score_percentages.round(2),
83
+ }
84
+ )
85
+ distribution_df.index.name = "分数" # 给第一列加上列名:分数
86
+ printable_df_str = distribution_df.reset_index().to_string(index=False)
87
+ logger.success(f"llm打分分数分布情况:\n{printable_df_str}")
88
+
89
+ def clean(self) -> str:
90
+ """
91
+ 清洗 SFT 数据并返回清洗后的文件路径。
92
+ 如果未启用清洗,则返回原始路径。
93
+ """
94
+ config = self.make_dataset_config
95
+ dataset_dir = config["dataset_dir"]
96
+ dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
97
+
98
+ sft_json_path = os.path.join(dataset_dir, "sft-my.json")
99
+ output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
100
+ accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
101
+
102
+ if not config.get("clean_dataset", {}).get("enable_clean"):
103
+ logger.info("未启用清洗功能")
104
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
105
+ return sft_json_path
106
+
107
+ try:
108
+ with open(sft_json_path, 'r', encoding='utf-8') as f:
109
+ data = json.load(f)
110
+ filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
111
+
112
+ with open(output_json_path, 'w', encoding='utf-8') as f:
113
+ json.dump(filtered_data, f, ensure_ascii=False, indent=4)
114
+
115
+ logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
116
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
117
+ return output_json_path
118
+
119
+ except Exception as e:
120
+ logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
121
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
122
+ return sft_json_path
123
+
124
+ def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
125
+ """
126
+ 修改 dataset_info.json 文件中的 file_name 字段
127
+ """
128
+ try:
129
+ with open(dataset_info_path, "r", encoding="utf-8") as f:
130
+ dataset_info = json.load(f)
131
+
132
+ # 更新所有支持的数据集的 file_name
133
+ for key in ["wechat-sft", "wechat-sft-with-history"]:
134
+ if key in dataset_info:
135
+ dataset_info[key]["file_name"] = new_file_name
136
+
137
+ # 写回文件
138
+ with open(dataset_info_path, "w", encoding="utf-8") as f:
139
+ json.dump(dataset_info, f, indent=4, ensure_ascii=False)
140
+
141
+ logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
142
+
143
+ except Exception as e:
144
+ logger.warning(f"无法更新 dataset_info.json: {e}")
weclone/data/clean/strategies_online.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List
8
+ from langchain_core.prompts import PromptTemplate
9
+ from weclone.data.models import QaPair, QaPairScore
10
+ from weclone.prompts.clean_data import CLEAN_PROMPT,ONLINE_LLM_CLEAN_PROMPT
11
+ from weclone.core.inference.online_infer import OnlineLLM
12
+ from weclone.utils.log import logger
13
+ import os
14
+
15
+ @dataclass
16
+ class CleaningStrategy(ABC):
17
+ """数据清洗策略的抽象基类"""
18
+
19
+ make_dataset_config: Dict
20
+
21
+ @abstractmethod
22
+ def clean(self, data: Any) -> Any:
23
+ pass
24
+
25
+ @dataclass
26
+ class OlineLLMCleaningStrategy(CleaningStrategy):
27
+ """使用大模型进行数据清洗的策略"""
28
+
29
+ def judge(self, data: List[QaPair]) -> None:
30
+ logger.info("开始使用在线模型对数据打分")
31
+
32
+ logger.info(f"使用模型 {self.make_dataset_config.get('model_name', '')}")
33
+
34
+ client = OnlineLLM(
35
+ api_key = self.make_dataset_config.get("llm_api_key"),
36
+ base_url = self.make_dataset_config.get("base_url"),
37
+ model_name = self.make_dataset_config.get("model_name"),
38
+ default_system = self.make_dataset_config.get("default_system")
39
+ )
40
+ prompt_template = PromptTemplate.from_template(ONLINE_LLM_CLEAN_PROMPT)
41
+
42
+ parsed_scores = []
43
+ clean_batch_size = int(self.make_dataset_config.get("clean_batch_size", 10))
44
+ for i in tqdm(range(0, len(data), clean_batch_size), desc="在线模型评分进度"):
45
+ batch = data[i : i + clean_batch_size]
46
+ # 构造当前批次的 qa_list
47
+ qa_list = [
48
+ {"id": qa.id, "Q": qa.instruction, "A": qa.output}
49
+ for qa in batch
50
+ ]
51
+ qa_list_json = json.dumps(qa_list, ensure_ascii=False)
52
+ # 填充模板
53
+ prompt_text = prompt_template.invoke({
54
+ "qa_list": qa_list_json
55
+ }).text
56
+ try:
57
+ response = client.chat(prompt_text)
58
+ result_text = response.choices[0].message.content
59
+ # print("大模型返回:",result_text)
60
+ # 如果有 <think> … </think>,只保留 </think> 之后的内容
61
+ if "</think>" in result_text:
62
+ result_text = result_text.split("</think>", 1)[1]
63
+ # 去掉开头和结尾的 ```json 或 ``` 等代码块标记
64
+ result_text = re.sub(r"^```json\s*|```$", "", result_text.strip(), flags=re.MULTILINE)
65
+ # 如果偶尔的几次解析失败就跳过
66
+ try:
67
+ score_list = json.loads(result_text)
68
+ except json.JSONDecodeError as e:
69
+ logger.error(f"JSON 解析失败,跳过本批次: {e}\n内容:{result_text}")
70
+ continue
71
+
72
+ for item in score_list:
73
+ parsed_scores.append(QaPairScore(**item))
74
+ except Exception as e:
75
+ ids_in_batch = [qa["id"] for qa in qa_list]
76
+ logger.error(f"调用在线模型或解析结果失败,当前 batch QA ID 列表: {ids_in_batch},错误信息: {str(e)}")
77
+
78
+ score_map = {score.id: score.score for score in parsed_scores}
79
+ for qa in data:
80
+ if qa.id in score_map:
81
+ qa.score = score_map[qa.id]
82
+ else:
83
+ logger.warning(f"未获取到QA ID {qa.id}的分数,默认赋值0")
84
+ qa.score = 0
85
+
86
+ # 统计分数分布,打印日志(和本地版本保持一致)
87
+ scores = [qa.score for qa in data if qa.score is not None]
88
+ score_series = pd.Series(scores)
89
+ score_counts = score_series.value_counts().sort_index()
90
+ score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
91
+ pd.set_option("display.unicode.east_asian_width", True)
92
+ distribution_df = pd.DataFrame({
93
+ "数量": score_counts,
94
+ "占比(%)": score_percentages.round(2),
95
+ })
96
+ distribution_df.index.name = "分数"
97
+ printable_df_str = distribution_df.reset_index().to_string(index=False)
98
+ logger.success(f"在线模型打分分数分布情况:\n{printable_df_str}")
99
+
100
+ def clean(self) -> str:
101
+ """
102
+ 清洗 SFT 数据并返回清洗后的文件路径。
103
+ 如果未启用清洗,则返回原始路径。
104
+ """
105
+ config = self.make_dataset_config
106
+ dataset_dir = config["dataset_dir"]
107
+ dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
108
+
109
+ sft_json_path = os.path.join(dataset_dir, "sft-my.json")
110
+ output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
111
+ accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
112
+
113
+ if not config.get("clean_dataset", {}).get("enable_clean"):
114
+ logger.info("未启用清洗功能")
115
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
116
+ return sft_json_path
117
+
118
+ try:
119
+ with open(sft_json_path, 'r', encoding='utf-8') as f:
120
+ data = json.load(f)
121
+ filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
122
+
123
+ with open(output_json_path, 'w', encoding='utf-8') as f:
124
+ json.dump(filtered_data, f, ensure_ascii=False, indent=4)
125
+
126
+ logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
127
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
128
+ return output_json_path
129
+
130
+ except Exception as e:
131
+ logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
132
+ self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
133
+ return sft_json_path
134
+
135
+ def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
136
+ """
137
+ 修改 dataset_info.json 文件中的 file_name 字段
138
+ """
139
+ try:
140
+ with open(dataset_info_path, "r", encoding="utf-8") as f:
141
+ dataset_info = json.load(f)
142
+
143
+ # 更新所有支持的数据集的 file_name
144
+ for key in ["wechat-sft", "wechat-sft-with-history"]:
145
+ if key in dataset_info:
146
+ dataset_info[key]["file_name"] = new_file_name
147
+
148
+ # 写回文件
149
+ with open(dataset_info_path, "w", encoding="utf-8") as f:
150
+ json.dump(dataset_info, f, indent=4, ensure_ascii=False)
151
+
152
+ logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
153
+
154
+ except Exception as e:
155
+ logger.warning(f"无法更新 dataset_info.json: {e}")
weclone/data/models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pandas import Timestamp
3
+ from pydantic import BaseModel
4
+
5
+
6
+ @dataclass
7
+ class ChatMessage:
8
+ id: int
9
+ MsgSvrID: int
10
+ type_name: str
11
+ is_sender: int
12
+ talker: str
13
+ room_name: str
14
+ msg: str
15
+ src: str
16
+ CreateTime: Timestamp
17
+
18
+
19
+ @dataclass
20
+ class CutMessage:
21
+ is_sender: int
22
+ cut_type: str
23
+ CreateTime: Timestamp
24
+
25
+
26
+ @dataclass
27
+ class QaPair:
28
+ id: int
29
+ system: str
30
+ instruction: str
31
+ output: str
32
+ history: list[list[str]]
33
+ time: Timestamp
34
+ score: int
35
+
36
+
37
+ class QaPairScore(BaseModel):
38
+ id: int
39
+ score: int
40
+
41
+
42
+ skip_type_list = [
43
+ "添加好友",
44
+ "推荐公众号",
45
+ "动画表情",
46
+ "位置",
47
+ "文件",
48
+ "位置共享",
49
+ "接龙",
50
+ "引用回复",
51
+ "视频号直播或直播回放",
52
+ "用户上传的GIF表情",
53
+ "文件(猜)",
54
+ "群公告",
55
+ "视频号直播或直播回放等",
56
+ "游戏相关",
57
+ "转账",
58
+ "赠送红包封面",
59
+ "语音通话",
60
+ "企业微信打招呼(猜)",
61
+ "企业微信添加好友(猜)",
62
+ "系统通知",
63
+ "消息撤回1",
64
+ "拍一拍",
65
+ "消息撤回5",
66
+ "消息撤回6",
67
+ "消息撤回33",
68
+ "消息撤回36",
69
+ "消息撤回57",
70
+ "邀请加群",
71
+ "未知-11000,0",
72
+ ]
73
+ # 没处理的类型
74
+ unprocessed_type_list = []
weclone/data/qa_generator.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ from typing import Dict, List, Union
5
+ import re
6
+
7
+ import pandas as pd
8
+ import json
9
+ from pandas import Timestamp
10
+ from llamafactory.extras.packages import is_vllm_available
11
+
12
+ from weclone.data.clean.strategies import LLMCleaningStrategy
13
+ from weclone.data.clean.strategies_online import OlineLLMCleaningStrategy
14
+ from weclone.utils.config import load_config
15
+ from weclone.utils.log import logger
16
+ from weclone.data.models import ChatMessage, CutMessage, skip_type_list, QaPair
17
+ from weclone.data.strategies import TimeWindowStrategy, LLMStrategy
18
+
19
+
20
+ class DataProcessor:
21
+ def __init__(self):
22
+ self.config = load_config(arg_type="make_dataset")
23
+ self.csv_folder = "./dataset/csv"
24
+ self.system_prompt = self.config["default_system"]
25
+ self.cut_type_list = [
26
+ "图片",
27
+ "视频",
28
+ "合并转发的聊天记录",
29
+ "语音",
30
+ "(分享)音乐",
31
+ "(分享)卡片式链接",
32
+ "(分享)笔记",
33
+ "(分享)小程序",
34
+ "(分享)收藏夹",
35
+ "(分享)小说(猜)",
36
+ "(分享)视频号名片",
37
+ "(分享)视频号视频",
38
+ "粘贴的文本", # 无法解析的分享链接
39
+ ]
40
+
41
+ # blocked_words
42
+ config_blocked_words = self.config.get("blocked_words", [])
43
+ file_blocked_words = []
44
+ try:
45
+ with open("./dataset/blocked_words.json", encoding="utf-8") as f:
46
+ file_blocked_words = json.load(f).get("blocked_words", [])
47
+ except (FileNotFoundError, json.JSONDecodeError):
48
+ pass
49
+
50
+ self.blocked_words = list(set(config_blocked_words + file_blocked_words))
51
+ # logger.info(f"聊天记录禁用词: {self.blocked_words}")
52
+
53
+ if self.config["single_combine_strategy"] == "time_window":
54
+ self.single_combine_strategy = TimeWindowStrategy(
55
+ time_window=self.config["single_combine_time_window"] * 60,
56
+ is_single_chat=True,
57
+ )
58
+ elif self.config["single_combine_strategy"] == "llm":
59
+ self.single_combine_strategy = LLMStrategy(
60
+ is_single_chat=True,
61
+ )
62
+
63
+ if self.config["qa_match_strategy"] == "time_window":
64
+ self.qa_match_strategy = TimeWindowStrategy(
65
+ time_window=self.config["qa_match_time_window"] * 60,
66
+ is_single_chat=False,
67
+ )
68
+ elif self.config["qa_match_strategy"] == "llm":
69
+ self.qa_match_strategy = LLMStrategy(is_single_chat=False)
70
+
71
+ clean_dataset_config = self.config.get("clean_dataset", {})
72
+ enable_clean = clean_dataset_config.get("enable_clean", False)
73
+
74
+ if enable_clean:
75
+ if self.config.get("prompt_with_history", False):
76
+ logger.warning("开启 prompt_with_history 不支持 clean_dataset 功能")
77
+ exit()
78
+
79
+ if not is_vllm_available() and not self.config.get("online_llm_clear"):
80
+ logger.warning("vLLM 不可用,暂不清洗数据集。")
81
+ clean_dataset_config["enable_clean"] = False
82
+
83
+ if self.config.get("clean_dataset", {}).get("enable_clean", False):
84
+ if self.config.get("clean_dataset", {}).get("clean_strategy", "llm") == "llm":
85
+ if self.config.get("online_llm_clear"):
86
+ self.clean_strategy = OlineLLMCleaningStrategy(make_dataset_config=self.config)
87
+ else:
88
+ self.clean_strategy = LLMCleaningStrategy(make_dataset_config=self.config)
89
+ self.c = self.config
90
+
91
+ def main(self):
92
+ if not os.path.exists(self.csv_folder) or not os.listdir(self.csv_folder):
93
+ logger.error(f"错误:目录 '{self.csv_folder}' 不存在或为空,请检查路径并确保其中包含 CSV 聊天数据文件。")
94
+ return
95
+
96
+ csv_files = self.get_csv_files()
97
+ logger.info(f"共发现 {len(csv_files)} 个 CSV 文件,开始处理")
98
+ message_list: List[ChatMessage] = []
99
+ for csv_file in csv_files:
100
+ logger.debug(f"开始处理 CSV 文件: {csv_file}")
101
+ chat_messages = self.load_csv(csv_file)
102
+ message_list.extend(self.group_consecutive_messages(messages=chat_messages))
103
+ # self.process_by_msgtype(chat_message)
104
+ logger.debug(f"处理完成: {csv_file},共加载 {len(chat_messages)} 条消息")
105
+ qa_res = self.match_qa(message_list)
106
+ if self.c["prompt_with_history"]:
107
+ qa_res = self.add_history_to_qa(qa_res)
108
+ else:
109
+ qa_res = [item for item in qa_res if isinstance(item, QaPair)]
110
+
111
+ if self.c.get("clean_dataset", {}).get("enable_clean", False):
112
+ self.clean_strategy.judge(qa_res)
113
+ # qa_res = self.clean_strategy.clean(qa_res)
114
+ self.save_result(qa_res)
115
+ self._execute_length_cdf_script()
116
+
117
+ logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 ./dataset/res_csv/sft/sft-my.json")
118
+
119
+ def _execute_length_cdf_script(self):
120
+ """执行 length_cdf.py 脚本来计算cutoff_len。"""
121
+ try:
122
+ python_executable = sys.executable
123
+ # 脚本路径是相对于项目根目录的
124
+ script_path = os.path.join("weclone", "utils", "length_cdf.py")
125
+
126
+ command_parts = [
127
+ python_executable,
128
+ script_path,
129
+ f'--model_name_or_path="{self.c["model_name_or_path"]}"',
130
+ f'--dataset="{self.c["dataset"]}"',
131
+ f'--dataset_dir="{self.c["dataset_dir"]}"',
132
+ f'--template="{self.c["template"]}"',
133
+ f"--interval={self.c['cutoff_len']}",
134
+ ]
135
+
136
+ child_env = os.environ.copy()
137
+ child_env["CUDA_VISIBLE_DEVICES"] = "0"
138
+ child_env["LLAMAFACTORY_VERBOSITY"] = "ERROR"
139
+
140
+ process = subprocess.Popen(
141
+ command_parts,
142
+ env=child_env,
143
+ stdout=None, # 使用 None 表示使用父进程的标准输出(即终端)
144
+ stderr=None, # 使用 None 表示使用父进程的标准错误(即终端)
145
+ text=True,
146
+ bufsize=1, # 行缓冲
147
+ )
148
+ return_code = process.wait()
149
+ if return_code != 0:
150
+ logger.error(f"命令 '{' '.join(command_parts)}' 执行失败,返回码 {return_code}")
151
+ except FileNotFoundError:
152
+ # command_parts[0] 是 python_executable, command_parts[1] 是 script_path
153
+ logger.error(f"命令执行失败: 找不到可执行文件 '{command_parts[0]}' 或脚本 '{command_parts[1]}'")
154
+ except KeyError as e:
155
+ logger.error(f"执行 length_cdf.py 脚本失败:配置项缺失 {str(e)}")
156
+ except Exception as e:
157
+ logger.error(f"执行 length_cdf.py 脚本时发生未知错误: {str(e)}")
158
+
159
+ def get_csv_files(self):
160
+ """遍历文件夹获取所有CSV文件路径,并按文件名中的起始序号排序"""
161
+
162
+ csv_files = []
163
+ for chat_obj_folder in os.listdir(self.csv_folder):
164
+ chat_obj_folder_path = os.path.join(self.csv_folder, chat_obj_folder)
165
+ for csvfile in os.listdir(chat_obj_folder_path):
166
+ if not csvfile.endswith(".csv"):
167
+ continue
168
+ csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
169
+ csv_files.append(csvfile_path)
170
+ # 提取文件名中的起始数字,比如 wxid_..._0_5000.csv → 0
171
+ pattern = re.compile(r"_(\d+)_\d+\.csv$")
172
+
173
+ def extract_start(fp: str) -> int:
174
+ name = os.path.basename(fp)
175
+ m = pattern.search(name)
176
+ return int(m.group(1)) if m else 0
177
+
178
+ # 按起始数字升序排序
179
+ csv_files.sort(key=extract_start)
180
+ return csv_files
181
+
182
+ def match_qa(self, messages: List[ChatMessage]) -> List[Union[QaPair, CutMessage]]:
183
+ """
184
+ 匹配问答对
185
+
186
+ Args:
187
+ messages: 消息列表
188
+
189
+ Returns:
190
+ List[Union[QaPair, CutMessage]]: 包含指令和输出的问答对列表
191
+ """
192
+ # 状态定义
193
+ WAITING_INSTRUCTION = "waiting_instruction" # 等待指令
194
+ WAITING_RESPONSE = "waiting_response" # 等待回复
195
+
196
+ current_state = WAITING_INSTRUCTION
197
+ qa_res: List[Union[QaPair, CutMessage]] = []
198
+ last_message = None
199
+ current_instruction = None
200
+ qa_id_counter = 0
201
+
202
+ for msg in messages:
203
+ if isinstance(msg, CutMessage):
204
+ current_state = WAITING_INSTRUCTION
205
+ current_instruction = None
206
+ last_message = None
207
+ if self.c["prompt_with_history"]:
208
+ qa_res.append(msg)
209
+ continue
210
+
211
+ if current_state == WAITING_INSTRUCTION:
212
+ if msg.is_sender == 0: # 收到对方消息
213
+ current_instruction = msg.msg
214
+ last_message = msg
215
+ current_state = WAITING_RESPONSE
216
+
217
+ elif current_state == WAITING_RESPONSE:
218
+ if msg.is_sender == 0: # 收到对方消息
219
+ current_instruction = msg.msg
220
+ last_message = msg
221
+ # 状态保持不变
222
+ else: # 自己的回复 使用策略判断是否属于同一对话
223
+ if last_message and self.qa_match_strategy.is_same_conversation([last_message], msg):
224
+ assert current_instruction is not None, (
225
+ "current_instruction should not be None when creating a QA pair"
226
+ )
227
+ qa_pair = QaPair(
228
+ id=qa_id_counter,
229
+ system=self.system_prompt,
230
+ instruction=current_instruction,
231
+ output=msg.msg,
232
+ history=[], # No history in this context yet
233
+ time=msg.CreateTime, # Use the response message time
234
+ score=0, # Default score
235
+ )
236
+ qa_res.append(qa_pair)
237
+ qa_id_counter += 1 # 增加计数器
238
+ else:
239
+ if self.c["prompt_with_history"]:
240
+ qa_res.append(
241
+ CutMessage(
242
+ is_sender=msg.is_sender,
243
+ cut_type=msg.type_name,
244
+ CreateTime=msg.CreateTime,
245
+ )
246
+ )
247
+ # 无论是否匹配,都重置状态
248
+ current_state = WAITING_INSTRUCTION
249
+ current_instruction = None
250
+ last_message = None
251
+
252
+ return qa_res
253
+
254
+ # TODO: need review
255
+ def add_history_to_qa(self, qa_res: List[Union[QaPair, CutMessage]]) -> List[QaPair]:
256
+ """
257
+ Adds conversation history to QaPair objects.
258
+
259
+ Args:
260
+ qa_res: A list containing QaPair and CutMessage objects.
261
+
262
+ Returns:
263
+ A list of QaPair objects with history populated.
264
+ """
265
+ qa_res_with_history: List[QaPair] = []
266
+ current_history: List[List[str]] = []
267
+ last_timestamp: Timestamp = None # type: ignore
268
+
269
+ for item in qa_res:
270
+ if isinstance(item, CutMessage):
271
+ if current_history:
272
+ instruction = current_history[-1][0]
273
+ output = current_history[-1][1]
274
+ history = current_history[:-1]
275
+ qa_pair_with_history = QaPair(
276
+ id=-1,
277
+ system=self.system_prompt,
278
+ instruction=instruction,
279
+ output=output,
280
+ history=history,
281
+ time=last_timestamp,
282
+ score=0,
283
+ )
284
+ qa_res_with_history.append(qa_pair_with_history)
285
+ current_history = []
286
+ last_timestamp = None # type: ignore
287
+ elif isinstance(item, QaPair):
288
+ current_history.append([item.instruction, item.output])
289
+ last_timestamp = item.time
290
+
291
+ if current_history:
292
+ instruction = current_history[-1][0]
293
+ output = current_history[-1][1]
294
+ history = current_history[:-1]
295
+ # Ensure last_timestamp is not None before assignment
296
+ final_timestamp_end = last_timestamp
297
+ assert final_timestamp_end is not None, "Timestamp cannot be None for the final QaPair"
298
+ qa_pair_with_history = QaPair(
299
+ id=-1,
300
+ system=self.system_prompt,
301
+ instruction=instruction,
302
+ output=output,
303
+ history=history,
304
+ time=final_timestamp_end,
305
+ score=0,
306
+ )
307
+ qa_res_with_history.append(qa_pair_with_history)
308
+
309
+ return qa_res_with_history
310
+
311
+ def group_consecutive_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
312
+ """
313
+ 将同一个人连续发送的多条消息组合成一条消息,遇到cut_type添加cut
314
+
315
+ Args:
316
+ messages: 消息列表
317
+
318
+ Returns:
319
+ List[ChatMessage]: 组合后的消息列表
320
+ """
321
+ if not messages:
322
+ return []
323
+
324
+ def _combine_text(messages: List[ChatMessage]) -> ChatMessage:
325
+ """
326
+ 合并多条消息为一条
327
+
328
+ Args:
329
+ messages: 要合并的消息列表
330
+
331
+ Returns:
332
+ ChatMessage: 合并后的消息
333
+ """
334
+ base_msg = messages[0]
335
+ combined_content = messages[0].msg
336
+
337
+ for i in messages[1:]:
338
+ content = i.msg
339
+ if not content:
340
+ continue
341
+
342
+ if combined_content and combined_content[-1] not in ["。", "!", "?", "…", ",", "."]:
343
+ combined_content += ","
344
+
345
+ combined_content += content
346
+ if len(combined_content) > self.c["combine_msg_max_length"]:
347
+ logger.warning(
348
+ f"组合后消息长度超过{self.c['combine_msg_max_length']}将截断:\n {combined_content[:50]}"
349
+ )
350
+ combined_content = combined_content[: self.c["combine_msg_max_length"]]
351
+
352
+ combined_message = ChatMessage(
353
+ id=base_msg.id,
354
+ MsgSvrID=base_msg.MsgSvrID,
355
+ type_name=base_msg.type_name,
356
+ is_sender=base_msg.is_sender,
357
+ talker=base_msg.talker,
358
+ room_name=base_msg.room_name,
359
+ msg=combined_content,
360
+ src=base_msg.src,
361
+ CreateTime=messages[-1].CreateTime, # 使用最后一条消息的时间
362
+ )
363
+
364
+ return combined_message
365
+
366
+ def _create_cut_message(message: ChatMessage) -> CutMessage:
367
+ return CutMessage(
368
+ is_sender=message.is_sender,
369
+ cut_type=message.type_name,
370
+ CreateTime=message.CreateTime,
371
+ )
372
+
373
+ def _combine_current_group(group):
374
+ """
375
+ 处理当前消息组并添加到grouped_messages
376
+
377
+ Args:
378
+ group: 当前消息组
379
+ """
380
+ if len(group) > 1:
381
+ combined_msg = _combine_text(group)
382
+ grouped_messages.append(combined_msg)
383
+ else:
384
+ grouped_messages.append(group[0])
385
+
386
+ grouped_messages = []
387
+ current_group = []
388
+
389
+ for _, current_msg in enumerate(messages):
390
+ if current_msg.type_name in self.cut_type_list:
391
+ if current_group:
392
+ # 当前组有消息,合并当前组,并添加一条cut
393
+ _combine_current_group(current_group)
394
+ current_group = []
395
+
396
+ cut_msg = _create_cut_message(current_msg)
397
+ grouped_messages.append(cut_msg)
398
+ else:
399
+ # 当前组没消息,检查上一个组
400
+ if grouped_messages:
401
+ if not isinstance(grouped_messages[-1], CutMessage):
402
+ cut_msg = _create_cut_message(current_msg)
403
+ grouped_messages.append(cut_msg)
404
+ # 如果上一个组没消息或最后一条是CutMessage,直接continue
405
+ continue
406
+
407
+ if not current_group:
408
+ current_group = [current_msg]
409
+ continue
410
+
411
+ last_msg = current_group[-1]
412
+
413
+ # 判断是否是同一个人的连续消息
414
+ if (
415
+ current_msg.is_sender == last_msg.is_sender
416
+ and current_msg.talker == last_msg.talker
417
+ and self.single_combine_strategy.is_same_conversation([last_msg], current_msg)
418
+ ):
419
+ current_group.append(current_msg)
420
+ else:
421
+ # 不是同一个人的消息,处理当前组并开始新组
422
+ _combine_current_group(current_group)
423
+ # 开始新组
424
+ current_group = [current_msg]
425
+
426
+ # 处理最后一组消息
427
+ if current_group:
428
+ _combine_current_group(current_group)
429
+
430
+ return grouped_messages
431
+
432
+ def process_by_msgtype(self, chat_message: ChatMessage):
433
+ if chat_message.type_name == "文本":
434
+ self.process_text(chat_message)
435
+ # elif chat_message.type_name == "图片":
436
+ # self.process_image(chat_message)
437
+
438
+ def load_csv(self, file_path) -> List[ChatMessage]:
439
+ """
440
+ 做整体第一次预处理,过滤不符合条件的行
441
+ """
442
+ df = pd.read_csv(file_path, encoding="utf-8", dtype={"msg": str})
443
+
444
+ df = df[~df["type_name"].isin(values=skip_type_list)]
445
+
446
+ # 如果type_name为文本 并且msg 包含 手机号、身份证号、邮箱、网址则删除这行
447
+ for i in df.index:
448
+ if df.loc[i, "type_name"] == "文本":
449
+ msg_str = str(df.loc[i, "msg"])
450
+ if (
451
+ re.search(r"1\d{10}", msg_str)
452
+ or re.search(r"\d{18}", msg_str)
453
+ or re.search(r"\w+@\w+", msg_str)
454
+ or "http" in msg_str
455
+ or r"\\xa0" in msg_str
456
+ or r"\\u" in msg_str
457
+ ):
458
+ df = df.drop(index=i)
459
+ continue
460
+ for blocked_word in self.blocked_words:
461
+ if blocked_word in msg_str:
462
+ df = df.drop(index=i)
463
+ break
464
+ else:
465
+ df.loc[i, "msg"] = ""
466
+
467
+ df = df.dropna(how="all")
468
+ # 时间格式 2021-07-07 10:27:23
469
+ # 遍历行 相同is_sender的行合并msg()遇到不同is_sender就重新开始
470
+ df["CreateTime"] = pd.to_datetime(df["CreateTime"])
471
+
472
+ return [ChatMessage(*row) for row in df.values]
473
+
474
+ def process_text(self, chat_message: ChatMessage):
475
+ pass
476
+
477
+ def save_result(self, qa_res: List[QaPair]):
478
+ """
479
+ Saves the list of QaPair objects to a JSON file after converting them to dictionaries.
480
+
481
+ Args:
482
+ qa_res: A list of QaPair objects.
483
+ """
484
+ processed_qa_res = []
485
+ for idx, item in enumerate(qa_res):
486
+ item_dict = {
487
+ "id": idx,
488
+ "system": item.system,
489
+ "instruction": item.instruction,
490
+ "output": item.output,
491
+ "history": item.history,
492
+ "time": item.time.isoformat() if item.time else None,
493
+ "score": item.score,
494
+ }
495
+ processed_qa_res.append(item_dict)
496
+
497
+ output_path = "./dataset/res_csv/sft/sft-my.json"
498
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
499
+ with open(output_path, "w", encoding="utf-8") as f:
500
+ json.dump(processed_qa_res, f, ensure_ascii=False, indent=4)
501
+ logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 {output_path}")
502
+
503
+
504
+ if __name__ == "__main__":
505
+ processor = DataProcessor()
506
+ processor.main()
weclone/data/strategies.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ from .models import ChatMessage
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ @dataclass
8
+ class ConversationStrategy(ABC):
9
+ """对话策略的抽象基类"""
10
+
11
+ is_single_chat: bool
12
+
13
+ @abstractmethod
14
+ def is_same_conversation(
15
+ self, history_msg: List[ChatMessage], current_msg: ChatMessage
16
+ ) -> bool:
17
+ """判断两条消息是否属于同一个对话"""
18
+ pass
19
+
20
+
21
+ @dataclass
22
+ class TimeWindowStrategy(ConversationStrategy):
23
+ """基于时间窗口的判断策略"""
24
+
25
+ time_window: int # 时间窗口(分钟)
26
+
27
+ def is_same_conversation(
28
+ self, history_msg: List[ChatMessage], current_msg: ChatMessage
29
+ ) -> bool:
30
+ time_diff = abs(
31
+ (current_msg.CreateTime - history_msg[-1].CreateTime)
32
+ ).total_seconds()
33
+ return time_diff <= self.time_window
34
+
35
+
36
+ @dataclass
37
+ class LLMStrategy(ConversationStrategy):
38
+ """基于大模型判断策略"""
39
+
40
+ def is_same_conversation(
41
+ self, history_msg: List[ChatMessage], current_msg: ChatMessage
42
+ ) -> bool:
43
+ # 修复user_id错误,使用talker字段代替user_id
44
+ return current_msg.talker == history_msg[-1].talker if history_msg else False
45
+
46
+
47
+ @dataclass
48
+ class CompositeStrategy(ConversationStrategy):
49
+ """组合多个策略的复合策略"""
50
+
51
+ strategies: List[ConversationStrategy]
52
+ require_all: bool = True # True表示所有策略都满足,False表示任一策略满足即可
53
+
54
+ def is_same_conversation(
55
+ self, history_msg: List[ChatMessage], current_msg: ChatMessage
56
+ ) -> bool:
57
+ results = [
58
+ s.is_same_conversation(history_msg, current_msg) for s in self.strategies
59
+ ]
60
+ return all(results) if self.require_all else any(results)
weclone/eval/__init__.py ADDED
File without changes
weclone/eval/cli_demo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llamafactory.chat import ChatModel
2
+ from llamafactory.extras.misc import torch_gc
3
+
4
+
5
+ def main():
6
+ try:
7
+ import platform
8
+
9
+ if platform.system() != "Windows":
10
+ import readline # noqa: F401
11
+ except ImportError:
12
+ print("Install `readline` for a better experience.")
13
+
14
+ chat_model = ChatModel()
15
+ messages = []
16
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
17
+
18
+ while True:
19
+ try:
20
+ query = input("\nUser: ")
21
+ except UnicodeDecodeError:
22
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
23
+ continue
24
+ except Exception:
25
+ raise
26
+
27
+ if query.strip() == "exit":
28
+ break
29
+
30
+ if query.strip() == "clear":
31
+ messages = []
32
+ torch_gc()
33
+ print("History has been removed.")
34
+ continue
35
+
36
+ messages.append({"role": "user", "content": query})
37
+ print("Assistant: ", end="", flush=True)
38
+
39
+ response = ""
40
+ for new_text in chat_model.stream_chat(messages):
41
+ print(new_text, end="", flush=True)
42
+ response += new_text
43
+ print()
44
+ messages.append({"role": "assistant", "content": response})
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()