YiFzhao commited on
Commit
efec384
·
verified ·
1 Parent(s): adf6da0

upload code0304

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ final_Graph.json filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/eval-checkpoint.ipynb ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
10
+ "import torch\n",
11
+ "\n",
12
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
13
+ "\n",
14
+ "MODEL_NAME = \"/workspace/model\"\n",
15
+ "model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 2,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "import json\n",
25
+ "import torch\n",
26
+ "from transformers import AutoTokenizer\n",
27
+ "\n",
28
+ "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
29
+ "tokenizer.pad_token = tokenizer.eos_token "
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "json_path = \"final_Graph.json\"\n",
39
+ "with open(json_path, \"r\") as f:\n",
40
+ " data = json.load(f)\n",
41
+ "\n",
42
+ "test_data = data[0]\n"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "ROLE_TOKENS = {\n",
52
+ " \"human\": \"<|User|>\", \n",
53
+ " \"gpt\": \"<|Assistant|>\", \n",
54
+ "}\n",
55
+ "GRAPH_LENGTH = 512\n",
56
+ "max_seq_length = 1100 + GRAPH_LENGTH"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "conversations = test_data.get(\"conversations\")\n",
66
+ "embeddings = test_data.get(\"embedding\") \n",
67
+ "\n",
68
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 6,
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "data": {
78
+ "text/plain": [
79
+ "'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'"
80
+ ]
81
+ },
82
+ "execution_count": 6,
83
+ "metadata": {},
84
+ "output_type": "execute_result"
85
+ }
86
+ ],
87
+ "source": [
88
+ "question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
89
+ "question1"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 7,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "import json\n",
99
+ "import torch\n",
100
+ "import os\n",
101
+ "from transformers import AutoTokenizer\n",
102
+ "# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
103
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
104
+ "from torch.utils.data import Dataset\n",
105
+ "from transformers import AutoModelForCausalLM\n",
106
+ "import torch\n",
107
+ "import torch.nn as nn\n",
108
+ "\n",
109
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
110
+ " def __init__(self, config):\n",
111
+ " super().__init__(config)\n",
112
+ " self.model = AutoModelForCausalLM.from_config(config)\n",
113
+ " \n",
114
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
115
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
116
+ "\n",
117
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
118
+ " \"\"\"\n",
119
+ " `graph_embedding` 形状: (batch_size, 512)\n",
120
+ " `input_ids` 形状: (batch_size, seq_len)\n",
121
+ " \"\"\"\n",
122
+ " # ✅ 获取 token embedding\n",
123
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
124
+ "\n",
125
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
126
+ " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
127
+ "\n",
128
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
129
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
130
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
131
+ "\n",
132
+ " # ✅ 调整 attention mask\n",
133
+ " if attention_mask is not None:\n",
134
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
135
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
136
+ "\n",
137
+ " # ✅ 传入模型\n",
138
+ " outputs = self.model(\n",
139
+ " inputs_embeds=inputs_embeds,\n",
140
+ " attention_mask=attention_mask,\n",
141
+ " labels=labels,\n",
142
+ " )\n",
143
+ "\n",
144
+ " return outputs\n",
145
+ "\n",
146
+ " @classmethod\n",
147
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
148
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
149
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
150
+ " return model\n",
151
+ "\n"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 8,
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "name": "stderr",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
164
+ ]
165
+ }
166
+ ],
167
+ "source": [
168
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
169
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 9,
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "text/plain": [
180
+ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
181
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
182
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
183
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
184
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
185
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
186
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
187
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
188
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
189
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
190
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
191
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
192
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
193
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
194
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
195
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
196
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
197
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
198
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
199
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
200
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
201
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
202
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
203
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
204
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
205
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
206
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
207
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
208
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
209
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
210
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
211
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
212
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
213
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
214
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
215
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
216
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
217
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
218
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
219
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
220
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
221
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
222
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
223
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
224
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
225
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
226
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
227
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
228
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
229
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
230
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
231
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
232
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
233
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
234
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
235
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
236
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
237
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
238
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
239
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
240
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
241
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
242
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
243
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
244
+ " device='cuda:0')"
245
+ ]
246
+ },
247
+ "execution_count": 9,
248
+ "metadata": {},
249
+ "output_type": "execute_result"
250
+ }
251
+ ],
252
+ "source": [
253
+ "from transformers import AutoTokenizer\n",
254
+ "\n",
255
+ "# ✅ 加载分词器\n",
256
+ "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
257
+ "\n",
258
+ "# ✅ 输入文本\n",
259
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
260
+ "\n",
261
+ "graph_embedding.to(device)\n",
262
+ "\n"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 10,
268
+ "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "ename": "RuntimeError",
272
+ "evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3",
273
+ "output_type": "error",
274
+ "traceback": [
275
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
276
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
277
+ "Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :] \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m 14\u001b[0m next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n",
278
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
279
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
280
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
281
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 866\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
282
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
283
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
284
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 567\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 568\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 569\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 576\u001b[0m position_embeddings,\n\u001b[1;32m 577\u001b[0m )\n\u001b[1;32m 578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 585\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
285
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
286
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
287
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
288
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
289
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
290
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 190\u001b[0m attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n",
291
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m 125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n",
292
+ "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3"
293
+ ]
294
+ }
295
+ ],
296
+ "source": [
297
+ "\n",
298
+ "generated_ids = inputs[\"input_ids\"]\n",
299
+ "max_new_tokens = 1024\n",
300
+ "for _ in range(max_new_tokens):\n",
301
+ " # ✅ 计算 logits 并进行生成\n",
302
+ " with torch.no_grad():\n",
303
+ " outputs = model(\n",
304
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
305
+ " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
306
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
307
+ " )\n",
308
+ "\n",
309
+ "\n",
310
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
311
+ " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
312
+ "\n",
313
+ "\n",
314
+ " # ✅ **拼接到已生成序列**\n",
315
+ " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
316
+ "\n",
317
+ " if next_token[:, 0] == tokenizer.eos_token_id:\n",
318
+ " break\n",
319
+ "\n",
320
+ "# ✅ 解码最终输出\n",
321
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
322
+ "print(\"Generated Response:\", generated_text)"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "metadata": {},
329
+ "outputs": [
330
+ {
331
+ "name": "stdout",
332
+ "output_type": "stream",
333
+ "text": [
334
+ "Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n",
335
+ "The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n",
336
+ "How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n",
337
+ "The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n",
338
+ "What is the function of the circuit described in the Verilog code?\n",
339
+ "The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n",
340
+ "How can the circuit be implemented in hardware?\n",
341
+ "The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n"
342
+ ]
343
+ }
344
+ ],
345
+ "source": [
346
+ "generated_ids = inputs[\"input_ids\"]\n",
347
+ "max_new_tokens = 1024\n",
348
+ "for _ in range(max_new_tokens):\n",
349
+ " # ✅ 计算 logits 并进行生成\n",
350
+ " with torch.no_grad():\n",
351
+ " outputs = model(\n",
352
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
353
+ " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
354
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
355
+ " )\n",
356
+ "\n",
357
+ "\n",
358
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
359
+ " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
360
+ "\n",
361
+ "\n",
362
+ " # ✅ **拼接到已生成序列**\n",
363
+ " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
364
+ "\n",
365
+ " if next_token[:, 0] == tokenizer.eos_token_id:\n",
366
+ " break\n",
367
+ "\n",
368
+ "# ✅ 解码最终输出\n",
369
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
370
+ "print(\"Generated Response:\", generated_text)"
371
+ ]
372
+ }
373
+ ],
374
+ "metadata": {
375
+ "kernelspec": {
376
+ "display_name": "Python 3 (ipykernel)",
377
+ "language": "python",
378
+ "name": "python3"
379
+ },
380
+ "language_info": {
381
+ "codemirror_mode": {
382
+ "name": "ipython",
383
+ "version": 3
384
+ },
385
+ "file_extension": ".py",
386
+ "mimetype": "text/x-python",
387
+ "name": "python",
388
+ "nbconvert_exporter": "python",
389
+ "pygments_lexer": "ipython3",
390
+ "version": "3.10.12"
391
+ }
392
+ },
393
+ "nbformat": 4,
394
+ "nbformat_minor": 4
395
+ }
.ipynb_checkpoints/graph_train-checkpoint.ipynb ADDED
@@ -0,0 +1,1591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_081255-v0v96nik</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:23:11, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.349900</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>5.305900</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>4.849500</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>3.910800</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>3.325600</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>3.144900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>2.904200</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>2.082100</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>1.214300</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>1.011600</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.889300</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.907300</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>1.190400</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.889100</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>4.505600</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>6.402800</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>6.479300</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>7.337900</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>8.937600</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>8.938700</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>8.860100</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>8.693600</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>9.234000</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>9.347500</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>8.010300</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>5.952900</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>5.205900</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>4.969600</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>4.884600</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>4.934200</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>5.156900</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>5.115500</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>5.373600</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>4.481800</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>3.957000</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>3.092500</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.791000</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.934400</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>2.176800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>2.112400</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>2.127900</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>2.390200</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>3.091400</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>3.959500</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>3.905000</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>3.777500</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>3.282900</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>2.630300</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>3.705000</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>4.266300</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>4.285300</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>4.634000</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>4.474700</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>3.591300</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>2.486800</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>1.911800</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>2.088100</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>2.015400</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.988500</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.976900</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>2.097700</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.987400</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>2.065000</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.112100</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>2.075300</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>1.983300</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.181900</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.446500</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.434200</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.357000</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.157400</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>1.992900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>2.018400</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>2.010200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>2.009500</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>2.034900</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>2.038800</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.007600</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>1.983200</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.005300</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.014900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.018100</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.033900</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.024600</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>1.995300</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.018000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>1.998300</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.032800</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>1.985900</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>1.967700</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>1.989400</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.004700</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.005800</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.014400</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.009200</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.002200</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>1.914300</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.016900</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>1.972900</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.010300</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.046600</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>1.993900</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.084500</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.011900</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>1.996500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>1.997900</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new/commit/231f89403dca9aa67966e4f321e62bdb41076960', commit_message='End of training', commit_description='', oid='231f89403dca9aa67966e4f321e62bdb41076960', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, config):\n",
662
+ " super().__init__(config)\n",
663
+ "\n",
664
+ " # self.model = AutoModelForCausalLM.from_config(config)\n",
665
+ " \n",
666
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
667
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
668
+ "\n",
669
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
670
+ " \"\"\"\n",
671
+ " `graph_embedding` 形状: (batch_size, 512)\n",
672
+ " `input_ids` 形状: (batch_size, seq_len)\n",
673
+ " \"\"\"\n",
674
+ " # ✅ 获取 token embedding\n",
675
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
676
+ "\n",
677
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
678
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
679
+ "\n",
680
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
681
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
682
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
683
+ "\n",
684
+ " # ✅ 调整 attention mask\n",
685
+ " if attention_mask is not None:\n",
686
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
687
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
688
+ "\n",
689
+ " # ✅ 传入模型\n",
690
+ " outputs = self.model(\n",
691
+ " inputs_embeds=inputs_embeds,\n",
692
+ " attention_mask=attention_mask,\n",
693
+ " labels=labels,\n",
694
+ " )\n",
695
+ "\n",
696
+ " return outputs\n",
697
+ "\n",
698
+ "from transformers import Trainer\n",
699
+ "\n",
700
+ "class GraphTrainer(Trainer):\n",
701
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
702
+ " input_ids = inputs[\"input_ids\"]\n",
703
+ " attention_mask = inputs[\"attention_mask\"]\n",
704
+ " labels = inputs[\"labels\"]\n",
705
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
706
+ "\n",
707
+ " if graph_embedding is not None:\n",
708
+ " outputs = model(\n",
709
+ " input_ids=input_ids,\n",
710
+ " attention_mask=attention_mask,\n",
711
+ " labels=labels,\n",
712
+ " graph_embedding=graph_embedding, \n",
713
+ " )\n",
714
+ " else:\n",
715
+ " outputs = model(\n",
716
+ " input_ids=input_ids,\n",
717
+ " attention_mask=attention_mask,\n",
718
+ " labels=labels,\n",
719
+ " )\n",
720
+ "\n",
721
+ " loss = outputs.loss\n",
722
+ " return (loss, outputs) if return_outputs else loss\n",
723
+ "\n",
724
+ "\n",
725
+ "from transformers import AutoConfig\n",
726
+ "\n",
727
+ "# 1. 加载模型的配置\n",
728
+ "config = AutoConfig.from_pretrained(MODEL_NAME)\n",
729
+ "\n",
730
+ "# 2. 使用配置创建 GraphAwareLM 实例\n",
731
+ "model = GraphAwareLM.from_config(config) \n",
732
+ "\n",
733
+ "pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
734
+ "model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
735
+ "\n",
736
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
737
+ "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
738
+ "# model.config.use_sliding_window_attention = False\n",
739
+ "\n",
740
+ "# ✅ 训练参数\n",
741
+ "training_args = TrainingArguments(\n",
742
+ " output_dir=\"./results\",\n",
743
+ " per_device_train_batch_size=7,\n",
744
+ " eval_strategy=\"no\",\n",
745
+ " save_strategy=\"steps\",\n",
746
+ " save_steps=3000,\n",
747
+ " logging_steps=50,\n",
748
+ " bf16=True,\n",
749
+ " optim=\"galore_adamw\",\n",
750
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
751
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
752
+ " warmup_steps=1000,\n",
753
+ " num_train_epochs=3,\n",
754
+ " push_to_hub=True,\n",
755
+ " hub_model_id=HF_NAME,\n",
756
+ " hub_strategy=\"every_save\",\n",
757
+ " run_name = \"experi0304\"\n",
758
+ ")\n",
759
+ "\n",
760
+ "\n",
761
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
762
+ "train_dataset = GraphDataset(train_data)\n",
763
+ "\n",
764
+ "# ✅ 训练\n",
765
+ "trainer = GraphTrainer(\n",
766
+ " model=model,\n",
767
+ " args=training_args,\n",
768
+ " train_dataset=train_dataset,\n",
769
+ ")\n",
770
+ "\n",
771
+ "trainer.train()\n",
772
+ "trainer.save_model(\"/workspace/model\")\n",
773
+ "trainer.push_to_hub()\n",
774
+ "\n",
775
+ "\n"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 5,
781
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
782
+ "metadata": {},
783
+ "outputs": [
784
+ {
785
+ "name": "stdout",
786
+ "output_type": "stream",
787
+ "text": [
788
+ "🚀 处理后数据条数: 12384\n",
789
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
790
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
791
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
792
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
793
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
794
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
795
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
796
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
797
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
798
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
799
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
800
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
801
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
802
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
803
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
804
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
805
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
806
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
807
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
808
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
809
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
810
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
811
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
812
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
813
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
814
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
815
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
816
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
817
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
818
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
819
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
820
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
821
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
822
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
823
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
824
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
825
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
826
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
827
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
828
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
829
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
830
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
831
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
832
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
833
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
834
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
835
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
836
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
837
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
838
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
839
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
840
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
841
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
842
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
843
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
844
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
845
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
846
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
847
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
848
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
849
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
850
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
851
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
852
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
853
+ "✅ train_data 已保存到 train_data.pt\n"
854
+ ]
855
+ }
856
+ ],
857
+ "source": [
858
+ "import json\n",
859
+ "import torch\n",
860
+ "from transformers import AutoTokenizer\n",
861
+ "\n",
862
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
863
+ "tokenizer.pad_token = tokenizer.eos_token \n",
864
+ "\n",
865
+ "json_path = \"final_Graph.json\"\n",
866
+ "with open(json_path, \"r\") as f:\n",
867
+ " data = json.load(f)\n",
868
+ "\n",
869
+ "train_data = []\n",
870
+ "\n",
871
+ "\n",
872
+ "for sample in data:\n",
873
+ " conversations = sample.get(\"conversations\", [])\n",
874
+ " embeddings = sample.get(\"embedding\", []) \n",
875
+ "\n",
876
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
877
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
878
+ " continue\n",
879
+ "\n",
880
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
881
+ "\n",
882
+ " #拼接所有对话\n",
883
+ " dialogue_text = \"\"\n",
884
+ " for conv in conversations:\n",
885
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
886
+ " content = conv[\"value\"]\n",
887
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
888
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
889
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
890
+ "\n",
891
+ " tokenized = tokenizer(\n",
892
+ " dialogue_text,\n",
893
+ " padding=\"max_length\",\n",
894
+ " truncation=True,\n",
895
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
896
+ " return_tensors=\"pt\",\n",
897
+ " )\n",
898
+ "\n",
899
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
900
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
901
+ "\n",
902
+ " train_data.append({\n",
903
+ " \"input_ids\": input_ids,\n",
904
+ " \"attention_mask\": attention_mask,\n",
905
+ " \"labels\": input_ids.clone(),\n",
906
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
907
+ " })\n",
908
+ "\n",
909
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
910
+ "print(\"✅ 示例数据:\", train_data[0])\n",
911
+ "torch.save(train_data, \"train_data.pt\")\n",
912
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": 6,
918
+ "id": "a33bffb9-2ff9-4a4d-af2c-b89b30a69f7d",
919
+ "metadata": {
920
+ "scrolled": true
921
+ },
922
+ "outputs": [
923
+ {
924
+ "name": "stdout",
925
+ "output_type": "stream",
926
+ "text": [
927
+ "train_data 重新加载成功,数据量: 12384\n"
928
+ ]
929
+ },
930
+ {
931
+ "name": "stderr",
932
+ "output_type": "stream",
933
+ "text": [
934
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
935
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:49: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
936
+ " warnings.warn(\n",
937
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
938
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
939
+ ]
940
+ },
941
+ {
942
+ "data": {
943
+ "text/html": [
944
+ "Tracking run with wandb version 0.19.7"
945
+ ],
946
+ "text/plain": [
947
+ "<IPython.core.display.HTML object>"
948
+ ]
949
+ },
950
+ "metadata": {},
951
+ "output_type": "display_data"
952
+ },
953
+ {
954
+ "data": {
955
+ "text/html": [
956
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_074031-ofm5zhvd</code>"
957
+ ],
958
+ "text/plain": [
959
+ "<IPython.core.display.HTML object>"
960
+ ]
961
+ },
962
+ "metadata": {},
963
+ "output_type": "display_data"
964
+ },
965
+ {
966
+ "data": {
967
+ "text/html": [
968
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
969
+ ],
970
+ "text/plain": [
971
+ "<IPython.core.display.HTML object>"
972
+ ]
973
+ },
974
+ "metadata": {},
975
+ "output_type": "display_data"
976
+ },
977
+ {
978
+ "data": {
979
+ "text/html": [
980
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
981
+ ],
982
+ "text/plain": [
983
+ "<IPython.core.display.HTML object>"
984
+ ]
985
+ },
986
+ "metadata": {},
987
+ "output_type": "display_data"
988
+ },
989
+ {
990
+ "data": {
991
+ "text/html": [
992
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd</a>"
993
+ ],
994
+ "text/plain": [
995
+ "<IPython.core.display.HTML object>"
996
+ ]
997
+ },
998
+ "metadata": {},
999
+ "output_type": "display_data"
1000
+ },
1001
+ {
1002
+ "data": {
1003
+ "text/html": [
1004
+ "\n",
1005
+ " <div>\n",
1006
+ " \n",
1007
+ " <progress value='89' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1008
+ " [ 89/5310 01:06 < 1:06:24, 1.31 it/s, Epoch 0.05/3]\n",
1009
+ " </div>\n",
1010
+ " <table border=\"1\" class=\"dataframe\">\n",
1011
+ " <thead>\n",
1012
+ " <tr style=\"text-align: left;\">\n",
1013
+ " <th>Step</th>\n",
1014
+ " <th>Training Loss</th>\n",
1015
+ " </tr>\n",
1016
+ " </thead>\n",
1017
+ " <tbody>\n",
1018
+ " <tr>\n",
1019
+ " <td>50</td>\n",
1020
+ " <td>0.000000</td>\n",
1021
+ " </tr>\n",
1022
+ " </tbody>\n",
1023
+ "</table><p>"
1024
+ ],
1025
+ "text/plain": [
1026
+ "<IPython.core.display.HTML object>"
1027
+ ]
1028
+ },
1029
+ "metadata": {},
1030
+ "output_type": "display_data"
1031
+ },
1032
+ {
1033
+ "ename": "KeyboardInterrupt",
1034
+ "evalue": "",
1035
+ "output_type": "error",
1036
+ "traceback": [
1037
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1038
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1039
+ "Cell \u001b[0;32mIn[6], line 150\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# ✅ 训练\u001b[39;00m\n\u001b[1;32m 144\u001b[0m trainer \u001b[38;5;241m=\u001b[39m GraphTrainer(\n\u001b[1;32m 145\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 146\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m 147\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 148\u001b[0m )\n\u001b[0;32m--> 150\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub()\n\u001b[1;32m 152\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/workspace/model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
1040
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2232\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2233\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2234\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2235\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2236\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2238\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
1041
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2548\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2541\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2542\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2544\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m 2545\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2546\u001b[0m )\n\u001b[1;32m 2547\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2548\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2551\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2554\u001b[0m ):\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2556\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
1042
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3740\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED:\n\u001b[1;32m 3738\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscale_wrt_gas\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 3740\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3742\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n",
1043
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:2325\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 2324\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 2325\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m learning_rate \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_lomo_optimizer:\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n",
1044
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1045
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1046
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1047
+ ]
1048
+ }
1049
+ ],
1050
+ "source": [
1051
+ "import json\n",
1052
+ "import torch\n",
1053
+ "import os\n",
1054
+ "from transformers import AutoTokenizer\n",
1055
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
1056
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
1057
+ "if 'train_data' not in globals():\n",
1058
+ " train_data_path = \"train_data.pt\"\n",
1059
+ " \n",
1060
+ " if os.path.exists(train_data_path): #确保文件存在\n",
1061
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
1062
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
1063
+ " else:\n",
1064
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
1065
+ " exit()\n",
1066
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
1067
+ "if \"MODEL_NAME\" not in globals():\n",
1068
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
1069
+ "\n",
1070
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1071
+ "\n",
1072
+ "\n",
1073
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
1074
+ "\n",
1075
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
1076
+ "\n",
1077
+ "\n",
1078
+ "from torch.utils.data import Dataset\n",
1079
+ "\n",
1080
+ "class GraphDataset(Dataset):\n",
1081
+ " def __init__(self, data):\n",
1082
+ " self.data = data\n",
1083
+ "\n",
1084
+ " def __len__(self):\n",
1085
+ " return len(self.data)\n",
1086
+ "\n",
1087
+ " def __getitem__(self, idx):\n",
1088
+ " sample = self.data[idx]\n",
1089
+ " return {\n",
1090
+ " \"input_ids\": sample[\"input_ids\"],\n",
1091
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
1092
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
1093
+ " \"labels\": sample[\"labels\"],\n",
1094
+ " }\n",
1095
+ "\n",
1096
+ "from transformers import AutoModelForCausalLM\n",
1097
+ "import torch\n",
1098
+ "import torch.nn as nn\n",
1099
+ "\n",
1100
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1101
+ " def __init__(self, config):\n",
1102
+ " super().__init__(config)\n",
1103
+ " self.model = AutoModelForCausalLM.from_pretrained(config)\n",
1104
+ " \n",
1105
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
1106
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
1107
+ "\n",
1108
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1109
+ " \"\"\"\n",
1110
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1111
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1112
+ " \"\"\"\n",
1113
+ " # ✅ 获取 token embedding\n",
1114
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1115
+ "\n",
1116
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1117
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
1118
+ "\n",
1119
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1120
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1121
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1122
+ "\n",
1123
+ " # ✅ 调整 attention mask\n",
1124
+ " if attention_mask is not None:\n",
1125
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1126
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1127
+ "\n",
1128
+ " # ✅ 传入模型\n",
1129
+ " outputs = self.model(\n",
1130
+ " inputs_embeds=inputs_embeds,\n",
1131
+ " attention_mask=attention_mask,\n",
1132
+ " labels=labels,\n",
1133
+ " )\n",
1134
+ "\n",
1135
+ " return outputs\n",
1136
+ "\n",
1137
+ "from transformers import Trainer\n",
1138
+ "\n",
1139
+ "class GraphTrainer(Trainer):\n",
1140
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
1141
+ " input_ids = inputs[\"input_ids\"]\n",
1142
+ " attention_mask = inputs[\"attention_mask\"]\n",
1143
+ " labels = inputs[\"labels\"]\n",
1144
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
1145
+ "\n",
1146
+ " if graph_embedding is not None:\n",
1147
+ " outputs = model(\n",
1148
+ " input_ids=input_ids,\n",
1149
+ " attention_mask=attention_mask,\n",
1150
+ " labels=labels,\n",
1151
+ " graph_embedding=graph_embedding, \n",
1152
+ " )\n",
1153
+ " else:\n",
1154
+ " outputs = model(\n",
1155
+ " input_ids=input_ids,\n",
1156
+ " attention_mask=attention_mask,\n",
1157
+ " labels=labels,\n",
1158
+ " )\n",
1159
+ "\n",
1160
+ " loss = outputs.loss\n",
1161
+ " return (loss, outputs) if return_outputs else loss\n",
1162
+ "\n",
1163
+ "\n",
1164
+ "\n",
1165
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
1166
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
1167
+ "# model.config.use_sliding_window_attention = False\n",
1168
+ "\n",
1169
+ "# ✅ 训练参数\n",
1170
+ "training_args = TrainingArguments(\n",
1171
+ " output_dir=\"./results\",\n",
1172
+ " per_device_train_batch_size=7,\n",
1173
+ " eval_strategy=\"no\",\n",
1174
+ " save_strategy=\"steps\",\n",
1175
+ " save_steps=3000,\n",
1176
+ " logging_steps=50,\n",
1177
+ " fp16=True,\n",
1178
+ " optim=\"galore_adamw\",\n",
1179
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
1180
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
1181
+ " warmup_steps=1000,\n",
1182
+ " num_train_epochs=3,\n",
1183
+ " push_to_hub=True,\n",
1184
+ " hub_model_id=HF_NAME,\n",
1185
+ " hub_strategy=\"every_save\",\n",
1186
+ " run_name = \"experi0304\"\n",
1187
+ ")\n",
1188
+ "\n",
1189
+ "\n",
1190
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
1191
+ "train_dataset = GraphDataset(train_data)\n",
1192
+ "\n",
1193
+ "# ✅ 训练\n",
1194
+ "trainer = GraphTrainer(\n",
1195
+ " model=model,\n",
1196
+ " args=training_args,\n",
1197
+ " train_dataset=train_dataset,\n",
1198
+ ")\n",
1199
+ "\n",
1200
+ "trainer.train()\n",
1201
+ "trainer.push_to_hub()\n",
1202
+ "trainer.save_model(\"/workspace/model\")\n",
1203
+ "\n"
1204
+ ]
1205
+ },
1206
+ {
1207
+ "cell_type": "code",
1208
+ "execution_count": 1,
1209
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1210
+ "metadata": {},
1211
+ "outputs": [],
1212
+ "source": [
1213
+ "from transformers import AutoModelForCausalLM\n",
1214
+ "import torch\n",
1215
+ "import torch.nn as nn\n",
1216
+ "\n",
1217
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1218
+ "\n",
1219
+ " \n",
1220
+ " def __init__(self, config):\n",
1221
+ " super().__init__(config)\n",
1222
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
1223
+ "\n",
1224
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1225
+ " \"\"\"\n",
1226
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1227
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1228
+ " \"\"\"\n",
1229
+ " # ✅ 获取 token embedding\n",
1230
+ " inputs_embeds = self.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1231
+ "\n",
1232
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1233
+ " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
1234
+ "\n",
1235
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1236
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1237
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1238
+ "\n",
1239
+ " # ✅ 调整 attention mask\n",
1240
+ " if attention_mask is not None:\n",
1241
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1242
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1243
+ "\n",
1244
+ " # ✅ 传入模型\n",
1245
+ " outputs = self.model(\n",
1246
+ " inputs_embeds=inputs_embeds,\n",
1247
+ " attention_mask=attention_mask,\n",
1248
+ " labels=labels,\n",
1249
+ " )\n",
1250
+ "\n",
1251
+ " return outputs\n",
1252
+ "\n",
1253
+ " @classmethod\n",
1254
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1255
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1256
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
1257
+ " return model\n"
1258
+ ]
1259
+ },
1260
+ {
1261
+ "cell_type": "code",
1262
+ "execution_count": 2,
1263
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1264
+ "metadata": {},
1265
+ "outputs": [],
1266
+ "source": [
1267
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1268
+ ]
1269
+ },
1270
+ {
1271
+ "cell_type": "code",
1272
+ "execution_count": 3,
1273
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1274
+ "metadata": {},
1275
+ "outputs": [
1276
+ {
1277
+ "name": "stderr",
1278
+ "output_type": "stream",
1279
+ "text": [
1280
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
1281
+ ]
1282
+ },
1283
+ {
1284
+ "data": {
1285
+ "text/plain": [
1286
+ "Qwen2ForCausalLM(\n",
1287
+ " (model): Qwen2Model(\n",
1288
+ " (embed_tokens): Embedding(151936, 1536)\n",
1289
+ " (layers): ModuleList(\n",
1290
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1291
+ " (self_attn): Qwen2Attention(\n",
1292
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1293
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1294
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1295
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1296
+ " )\n",
1297
+ " (mlp): Qwen2MLP(\n",
1298
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1299
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1300
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1301
+ " (act_fn): SiLU()\n",
1302
+ " )\n",
1303
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1304
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1305
+ " )\n",
1306
+ " )\n",
1307
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1308
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1309
+ " )\n",
1310
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1311
+ " (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
1312
+ ")"
1313
+ ]
1314
+ },
1315
+ "execution_count": 3,
1316
+ "metadata": {},
1317
+ "output_type": "execute_result"
1318
+ }
1319
+ ],
1320
+ "source": [
1321
+ "import torch\n",
1322
+ "from transformers import AutoTokenizer\n",
1323
+ "\n",
1324
+ "# 加载 tokenizer\n",
1325
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1326
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1327
+ "\n",
1328
+ "# 加载训练好的模型\n",
1329
+ "model_path = \"/workspace/model\"\n",
1330
+ "model = GraphAwareLM.from_pretrained(model_path).to(device)\n",
1331
+ "model.eval() # 设置为推理模式\n"
1332
+ ]
1333
+ },
1334
+ {
1335
+ "cell_type": "code",
1336
+ "execution_count": 8,
1337
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1338
+ "metadata": {},
1339
+ "outputs": [],
1340
+ "source": [
1341
+ "import json\n",
1342
+ "json_path = \"final_Graph.json\"\n",
1343
+ "with open(json_path, \"r\") as f:\n",
1344
+ " data = json.load(f)\n",
1345
+ "\n",
1346
+ "test_data = data[0]\n",
1347
+ "\n",
1348
+ "conversations = test_data.get(\"conversations\")\n",
1349
+ "embeddings = test_data.get(\"embedding\") \n",
1350
+ "\n",
1351
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).to(device)\n",
1352
+ "\n",
1353
+ "question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
1354
+ "\n",
1355
+ "from transformers import AutoTokenizer\n",
1356
+ "\n",
1357
+ "# ✅ 输入文本\n",
1358
+ "ROLE_TOKENS = {\n",
1359
+ " \"human\": \"<|User|>\", \n",
1360
+ " \"gpt\": \"<|Assistant|>\", \n",
1361
+ "}\n",
1362
+ "GRAPH_LENGTH = 512\n",
1363
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1364
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1365
+ "\n",
1366
+ "input_ids = inputs[\"input_ids\"]\n",
1367
+ "attention_mask = inputs[\"attention_mask\"]\n"
1368
+ ]
1369
+ },
1370
+ {
1371
+ "cell_type": "code",
1372
+ "execution_count": 5,
1373
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1374
+ "metadata": {},
1375
+ "outputs": [
1376
+ {
1377
+ "data": {
1378
+ "text/plain": [
1379
+ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1380
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1381
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1382
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1383
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1384
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1385
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1386
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1387
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1388
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1389
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1390
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1391
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1392
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1393
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1394
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1395
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1396
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1397
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1398
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1399
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1400
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1401
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1402
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1403
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1404
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1405
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1406
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1407
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1408
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1409
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1410
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1411
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1412
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1413
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1414
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1415
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1416
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1417
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1418
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1419
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1420
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1421
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1422
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1423
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1424
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1425
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1426
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1427
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1428
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1429
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1430
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1431
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1432
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1433
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1434
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1435
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1436
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1437
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1438
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1439
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1440
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1441
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1442
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
1443
+ " device='cuda:0')"
1444
+ ]
1445
+ },
1446
+ "execution_count": 5,
1447
+ "metadata": {},
1448
+ "output_type": "execute_result"
1449
+ }
1450
+ ],
1451
+ "source": [
1452
+ "graph_embedding"
1453
+ ]
1454
+ },
1455
+ {
1456
+ "cell_type": "code",
1457
+ "execution_count": 15,
1458
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1459
+ "metadata": {},
1460
+ "outputs": [
1461
+ {
1462
+ "name": "stdout",
1463
+ "output_type": "stream",
1464
+ "text": [
1465
+ "模型回复: How\n"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "source": [
1470
+ "# ✅ 进行前向传播\n",
1471
+ "with torch.no_grad():\n",
1472
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1473
+ "\n",
1474
+ "# ✅ 提取 logits 并进行贪心解码\n",
1475
+ "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1476
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1477
+ "\n",
1478
+ "# ✅ 反向编码为文本\n",
1479
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1480
+ "\n",
1481
+ "print(\"模型回复:\", response_text)"
1482
+ ]
1483
+ },
1484
+ {
1485
+ "cell_type": "code",
1486
+ "execution_count": 9,
1487
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1488
+ "metadata": {},
1489
+ "outputs": [
1490
+ {
1491
+ "name": "stdout",
1492
+ "output_type": "stream",
1493
+ "text": [
1494
+ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? What are the module's inputs and outputs?\n",
1495
+ "What are the module's inputs and outputs?\n",
1496
+ "What are the module's inputs and outputs?\n",
1497
+ "What are the module's inputs and outputs?\n",
1498
+ "What is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and module's output, and module's input, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module\n"
1499
+ ]
1500
+ }
1501
+ ],
1502
+ "source": [
1503
+ "max_new_tokens = 1024\n",
1504
+ "generated_ids = input_ids.clone()\n",
1505
+ "generated_attention_mask = attention_mask.clone()\n",
1506
+ "for _ in range(max_new_tokens):\n",
1507
+ " # ✅ 计算 logits 并进行生成\n",
1508
+ " with torch.no_grad():\n",
1509
+ " outputs = model(\n",
1510
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1511
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1512
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1513
+ " )\n",
1514
+ "\n",
1515
+ "\n",
1516
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1517
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1518
+ " # print(next_token)\n",
1519
+ "\n",
1520
+ "\n",
1521
+ " # ✅ **拼接到已生成序列**\n",
1522
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1523
+ "\n",
1524
+ " # print(generated_ids)\n",
1525
+ "\n",
1526
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1527
+ " break\n",
1528
+ "\n",
1529
+ " generated_attention_mask = torch.cat(\n",
1530
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1531
+ " ) \n",
1532
+ "\n",
1533
+ "# ✅ 解码最终输出\n",
1534
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1535
+ "print(\"Generated Response:\", generated_text)"
1536
+ ]
1537
+ },
1538
+ {
1539
+ "cell_type": "code",
1540
+ "execution_count": 10,
1541
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1542
+ "metadata": {},
1543
+ "outputs": [
1544
+ {
1545
+ "data": {
1546
+ "text/plain": [
1547
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1548
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1549
+ " 525, 862, 9895, 30]], device='cuda:0')"
1550
+ ]
1551
+ },
1552
+ "execution_count": 10,
1553
+ "metadata": {},
1554
+ "output_type": "execute_result"
1555
+ }
1556
+ ],
1557
+ "source": [
1558
+ "generated_ids"
1559
+ ]
1560
+ },
1561
+ {
1562
+ "cell_type": "code",
1563
+ "execution_count": null,
1564
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1565
+ "metadata": {},
1566
+ "outputs": [],
1567
+ "source": []
1568
+ }
1569
+ ],
1570
+ "metadata": {
1571
+ "kernelspec": {
1572
+ "display_name": "Python 3 (ipykernel)",
1573
+ "language": "python",
1574
+ "name": "python3"
1575
+ },
1576
+ "language_info": {
1577
+ "codemirror_mode": {
1578
+ "name": "ipython",
1579
+ "version": 3
1580
+ },
1581
+ "file_extension": ".py",
1582
+ "mimetype": "text/x-python",
1583
+ "name": "python",
1584
+ "nbconvert_exporter": "python",
1585
+ "pygments_lexer": "ipython3",
1586
+ "version": "3.10.12"
1587
+ }
1588
+ },
1589
+ "nbformat": 4,
1590
+ "nbformat_minor": 5
1591
+ }
.ipynb_checkpoints/graph_train2-checkpoint.ipynb ADDED
@@ -0,0 +1,1674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new2\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 3,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_111730-i9v1vlu1</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">experi030402</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:34:08, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.319300</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>3.641300</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>1.521800</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>1.027500</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>0.922400</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>0.866900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>0.800500</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>0.721600</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>0.740400</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>0.737000</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.713500</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.747000</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>0.869500</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.473300</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>0.753000</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>0.741300</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>0.751400</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>0.787600</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>0.783200</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>0.780200</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>1.012900</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>1.411700</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>1.536400</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>0.853800</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>0.756500</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>0.750800</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>0.747400</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>0.844400</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>0.858400</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>1.053400</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>1.591600</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>1.498900</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>1.471700</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>1.221100</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>1.802300</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>1.826000</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.857300</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.561800</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>1.398800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>1.398900</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>1.381600</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>0.890300</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>0.763700</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>0.753100</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>0.745500</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>1.186100</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>0.862000</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>1.024600</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>1.028400</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>1.008500</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>0.942800</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>0.849700</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>0.771400</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>0.794100</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>0.819200</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>0.937500</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>1.064500</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>1.189300</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.071100</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.003300</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>1.073900</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.043100</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>1.282600</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.145400</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>1.925800</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>2.005600</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.122600</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.163000</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.046600</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.152200</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.151700</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>5.394900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>4.677800</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>4.122200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>3.710200</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>3.350800</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>3.126300</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.988700</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>2.872000</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.848200</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.823900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.781200</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.735000</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.725900</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>2.644400</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.700000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>2.650100</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.704500</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>2.596700</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>2.510500</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>2.515800</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.498100</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.458900</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.449700</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.425000</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.362300</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>2.232000</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.361500</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>2.302300</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.333900</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.367200</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>2.288300</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.426100</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.344100</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>2.283500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>2.296500</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2/commit/291285a5f2155c79a0da893645d8df9bbca98f63', commit_message='End of training', commit_description='', oid='291285a5f2155c79a0da893645d8df9bbca98f63', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new2'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, pretrained_model_name_or_path):\n",
662
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
663
+ " \n",
664
+ " # ✅ 载入 `MODEL_NAME` 预训练模型\n",
665
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
666
+ "\n",
667
+ " \n",
668
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
669
+ " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
670
+ "\n",
671
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
672
+ " \"\"\"\n",
673
+ " `graph_embedding` 形状: (batch_size, 512)\n",
674
+ " `input_ids` 形状: (batch_size, seq_len)\n",
675
+ " \"\"\"\n",
676
+ " # ✅ 获取 token embedding\n",
677
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
678
+ "\n",
679
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
680
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
681
+ "\n",
682
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
683
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
684
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
685
+ "\n",
686
+ " # ✅ 调整 attention mask\n",
687
+ " if attention_mask is not None:\n",
688
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
689
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
690
+ "\n",
691
+ " # ✅ 传入模型\n",
692
+ " outputs = self.model(\n",
693
+ " inputs_embeds=inputs_embeds,\n",
694
+ " attention_mask=attention_mask,\n",
695
+ " labels=labels,\n",
696
+ " )\n",
697
+ "\n",
698
+ " return outputs\n",
699
+ "\n",
700
+ "from transformers import Trainer\n",
701
+ "\n",
702
+ "class GraphTrainer(Trainer):\n",
703
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
704
+ " input_ids = inputs[\"input_ids\"]\n",
705
+ " attention_mask = inputs[\"attention_mask\"]\n",
706
+ " labels = inputs[\"labels\"]\n",
707
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
708
+ "\n",
709
+ " if graph_embedding is not None:\n",
710
+ " outputs = model(\n",
711
+ " input_ids=input_ids,\n",
712
+ " attention_mask=attention_mask,\n",
713
+ " labels=labels,\n",
714
+ " graph_embedding=graph_embedding, \n",
715
+ " )\n",
716
+ " else:\n",
717
+ " outputs = model(\n",
718
+ " input_ids=input_ids,\n",
719
+ " attention_mask=attention_mask,\n",
720
+ " labels=labels,\n",
721
+ " )\n",
722
+ "\n",
723
+ " loss = outputs.loss\n",
724
+ " return (loss, outputs) if return_outputs else loss\n",
725
+ "\n",
726
+ "\n",
727
+ "from transformers import AutoConfig\n",
728
+ "\n",
729
+ "# ✅ 载入微调模型\n",
730
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
731
+ "\n",
732
+ "# # 1. 加载模型的配置\n",
733
+ "# config = AutoConfig.from_pretrained(MODEL_NAME)\n",
734
+ "\n",
735
+ "# # 2. 使用配置创建 GraphAwareLM 实例\n",
736
+ "# model = GraphAwareLM.from_config(config) \n",
737
+ "\n",
738
+ "# pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
739
+ "# model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
740
+ "\n",
741
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
742
+ "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
743
+ "# model.config.use_sliding_window_attention = False\n",
744
+ "\n",
745
+ "# ✅ 训练参数\n",
746
+ "training_args = TrainingArguments(\n",
747
+ " output_dir=\"./results2\",\n",
748
+ " per_device_train_batch_size=7,\n",
749
+ " eval_strategy=\"no\",\n",
750
+ " save_strategy=\"steps\",\n",
751
+ " save_steps=3000,\n",
752
+ " logging_steps=50,\n",
753
+ " bf16=True,\n",
754
+ " optim=\"galore_adamw\",\n",
755
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
756
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
757
+ " warmup_steps=1000,\n",
758
+ " num_train_epochs=3,\n",
759
+ " push_to_hub=True,\n",
760
+ " hub_model_id=HF_NAME,\n",
761
+ " hub_strategy=\"every_save\",\n",
762
+ " run_name = \"experi030402\"\n",
763
+ ")\n",
764
+ "\n",
765
+ "\n",
766
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
767
+ "train_dataset = GraphDataset(train_data)\n",
768
+ "\n",
769
+ "# ✅ 训练\n",
770
+ "trainer = GraphTrainer(\n",
771
+ " model=model,\n",
772
+ " args=training_args,\n",
773
+ " train_dataset=train_dataset,\n",
774
+ ")\n",
775
+ "\n",
776
+ "trainer.train()\n",
777
+ "trainer.save_model(\"/workspace/model2\")\n",
778
+ "trainer.push_to_hub()\n",
779
+ "\n",
780
+ "\n"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "execution_count": 7,
786
+ "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
787
+ "metadata": {},
788
+ "outputs": [
789
+ {
790
+ "data": {
791
+ "text/plain": [
792
+ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb')"
793
+ ]
794
+ },
795
+ "execution_count": 7,
796
+ "metadata": {},
797
+ "output_type": "execute_result"
798
+ }
799
+ ],
800
+ "source": [
801
+ "from huggingface_hub import HfApi\n",
802
+ "\n",
803
+ "api = HfApi()\n",
804
+ "repo_name = \"r1q1.5_graph_lora-wandb\" # 你的模型名称\n",
805
+ "api.create_repo(repo_name, exist_ok=True)"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 6,
811
+ "id": "73c434b9-5d58-4819-8526-24aa18ca1010",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "data": {
816
+ "application/vnd.jupyter.widget-view+json": {
817
+ "model_id": "727ca342a20348d38a4a1c6d286963e0",
818
+ "version_major": 2,
819
+ "version_minor": 0
820
+ },
821
+ "text/plain": [
822
+ "optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
823
+ ]
824
+ },
825
+ "metadata": {},
826
+ "output_type": "display_data"
827
+ },
828
+ {
829
+ "data": {
830
+ "application/vnd.jupyter.widget-view+json": {
831
+ "model_id": "9d4967e00603441091d4527bccf89e43",
832
+ "version_major": 2,
833
+ "version_minor": 0
834
+ },
835
+ "text/plain": [
836
+ "scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
837
+ ]
838
+ },
839
+ "metadata": {},
840
+ "output_type": "display_data"
841
+ },
842
+ {
843
+ "data": {
844
+ "application/vnd.jupyter.widget-view+json": {
845
+ "model_id": "71abe1a9a91a4851bdd941995145dc8e",
846
+ "version_major": 2,
847
+ "version_minor": 0
848
+ },
849
+ "text/plain": [
850
+ "Upload 15 LFS files: 0%| | 0/15 [00:00<?, ?it/s]"
851
+ ]
852
+ },
853
+ "metadata": {},
854
+ "output_type": "display_data"
855
+ },
856
+ {
857
+ "data": {
858
+ "application/vnd.jupyter.widget-view+json": {
859
+ "model_id": "59323569e2fb415bbdbb006b3ad3bceb",
860
+ "version_major": 2,
861
+ "version_minor": 0
862
+ },
863
+ "text/plain": [
864
+ "rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
865
+ ]
866
+ },
867
+ "metadata": {},
868
+ "output_type": "display_data"
869
+ },
870
+ {
871
+ "data": {
872
+ "application/vnd.jupyter.widget-view+json": {
873
+ "model_id": "8855625556094031ba6b14bbd9a062b1",
874
+ "version_major": 2,
875
+ "version_minor": 0
876
+ },
877
+ "text/plain": [
878
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
879
+ ]
880
+ },
881
+ "metadata": {},
882
+ "output_type": "display_data"
883
+ },
884
+ {
885
+ "data": {
886
+ "application/vnd.jupyter.widget-view+json": {
887
+ "model_id": "c1f85411d4ac43ffa27ed9d24c18f468",
888
+ "version_major": 2,
889
+ "version_minor": 0
890
+ },
891
+ "text/plain": [
892
+ "model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
893
+ ]
894
+ },
895
+ "metadata": {},
896
+ "output_type": "display_data"
897
+ },
898
+ {
899
+ "data": {
900
+ "application/vnd.jupyter.widget-view+json": {
901
+ "model_id": "e52bc8f9a5f34887a3844215e4fdfde2",
902
+ "version_major": 2,
903
+ "version_minor": 0
904
+ },
905
+ "text/plain": [
906
+ "training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
907
+ ]
908
+ },
909
+ "metadata": {},
910
+ "output_type": "display_data"
911
+ },
912
+ {
913
+ "data": {
914
+ "application/vnd.jupyter.widget-view+json": {
915
+ "model_id": "10d59eeefafe474488b7972ccf3ca70a",
916
+ "version_major": 2,
917
+ "version_minor": 0
918
+ },
919
+ "text/plain": [
920
+ "model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
921
+ ]
922
+ },
923
+ "metadata": {},
924
+ "output_type": "display_data"
925
+ },
926
+ {
927
+ "data": {
928
+ "application/vnd.jupyter.widget-view+json": {
929
+ "model_id": "dfc1c91010114f09b760450f1b998aa4",
930
+ "version_major": 2,
931
+ "version_minor": 0
932
+ },
933
+ "text/plain": [
934
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
935
+ ]
936
+ },
937
+ "metadata": {},
938
+ "output_type": "display_data"
939
+ },
940
+ {
941
+ "data": {
942
+ "application/vnd.jupyter.widget-view+json": {
943
+ "model_id": "9ebc84de361240ad9317c70342168308",
944
+ "version_major": 2,
945
+ "version_minor": 0
946
+ },
947
+ "text/plain": [
948
+ "optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
949
+ ]
950
+ },
951
+ "metadata": {},
952
+ "output_type": "display_data"
953
+ },
954
+ {
955
+ "data": {
956
+ "application/vnd.jupyter.widget-view+json": {
957
+ "model_id": "4f6e8853b6ff41e88b967ae9c13a81ab",
958
+ "version_major": 2,
959
+ "version_minor": 0
960
+ },
961
+ "text/plain": [
962
+ "rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
963
+ ]
964
+ },
965
+ "metadata": {},
966
+ "output_type": "display_data"
967
+ },
968
+ {
969
+ "data": {
970
+ "application/vnd.jupyter.widget-view+json": {
971
+ "model_id": "75eb5a5e71c749b5945b66e2d41c441c",
972
+ "version_major": 2,
973
+ "version_minor": 0
974
+ },
975
+ "text/plain": [
976
+ "scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
977
+ ]
978
+ },
979
+ "metadata": {},
980
+ "output_type": "display_data"
981
+ },
982
+ {
983
+ "data": {
984
+ "application/vnd.jupyter.widget-view+json": {
985
+ "model_id": "d748aea8d80c4f81a36633a71ce0a0f7",
986
+ "version_major": 2,
987
+ "version_minor": 0
988
+ },
989
+ "text/plain": [
990
+ "training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
991
+ ]
992
+ },
993
+ "metadata": {},
994
+ "output_type": "display_data"
995
+ },
996
+ {
997
+ "data": {
998
+ "application/vnd.jupyter.widget-view+json": {
999
+ "model_id": "d1a84a5ec38f459bab4791065dc8efc3",
1000
+ "version_major": 2,
1001
+ "version_minor": 0
1002
+ },
1003
+ "text/plain": [
1004
+ "model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
1005
+ ]
1006
+ },
1007
+ "metadata": {},
1008
+ "output_type": "display_data"
1009
+ },
1010
+ {
1011
+ "data": {
1012
+ "application/vnd.jupyter.widget-view+json": {
1013
+ "model_id": "ae50659ac6024e5c8fc43ce898982f0c",
1014
+ "version_major": 2,
1015
+ "version_minor": 0
1016
+ },
1017
+ "text/plain": [
1018
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
1019
+ ]
1020
+ },
1021
+ "metadata": {},
1022
+ "output_type": "display_data"
1023
+ },
1024
+ {
1025
+ "data": {
1026
+ "application/vnd.jupyter.widget-view+json": {
1027
+ "model_id": "c109c00c979d41d59fa48d4314bac1e4",
1028
+ "version_major": 2,
1029
+ "version_minor": 0
1030
+ },
1031
+ "text/plain": [
1032
+ "training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
1033
+ ]
1034
+ },
1035
+ "metadata": {},
1036
+ "output_type": "display_data"
1037
+ },
1038
+ {
1039
+ "data": {
1040
+ "text/plain": [
1041
+ "CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora-results3/commit/5a14bfa05e9bd78cd030104d0e9ff02638731668', commit_message='upload results3', commit_description='', oid='5a14bfa05e9bd78cd030104d0e9ff02638731668', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-results3', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-results3'), pr_revision=None, pr_num=None)"
1042
+ ]
1043
+ },
1044
+ "execution_count": 6,
1045
+ "metadata": {},
1046
+ "output_type": "execute_result"
1047
+ }
1048
+ ],
1049
+ "source": [
1050
+ "from huggingface_hub import upload_folder\n",
1051
+ "\n",
1052
+ "upload_folder(\n",
1053
+ " folder_path = \"/workspace/wandb\",\n",
1054
+ " repo_id = \"YiFzhao/r1q1.5_graph_lora-wandb\",\n",
1055
+ " commit_message = \"upload wandb\",\n",
1056
+ ")"
1057
+ ]
1058
+ },
1059
+ {
1060
+ "cell_type": "code",
1061
+ "execution_count": 5,
1062
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
1063
+ "metadata": {},
1064
+ "outputs": [
1065
+ {
1066
+ "name": "stdout",
1067
+ "output_type": "stream",
1068
+ "text": [
1069
+ "🚀 处理后数据条数: 12384\n",
1070
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1071
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1072
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1073
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1074
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1075
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1076
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1077
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1078
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1079
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1080
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1081
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1082
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1083
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1084
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1085
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1086
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1087
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1088
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1089
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1090
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1091
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1092
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1093
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1094
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1095
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1096
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1097
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1098
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1099
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1100
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1101
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1102
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1103
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1104
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1105
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1106
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1107
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1108
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1109
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1110
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1111
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1112
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1113
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1114
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1115
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1116
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1117
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1118
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1119
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1120
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1121
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1122
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1123
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1124
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1125
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1126
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1127
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1128
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1129
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1130
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1131
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1132
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1133
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
1134
+ "✅ train_data 已保存到 train_data.pt\n"
1135
+ ]
1136
+ }
1137
+ ],
1138
+ "source": [
1139
+ "import json\n",
1140
+ "import torch\n",
1141
+ "from transformers import AutoTokenizer\n",
1142
+ "\n",
1143
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1144
+ "tokenizer.pad_token = tokenizer.eos_token \n",
1145
+ "\n",
1146
+ "json_path = \"final_Graph.json\"\n",
1147
+ "with open(json_path, \"r\") as f:\n",
1148
+ " data = json.load(f)\n",
1149
+ "\n",
1150
+ "train_data = []\n",
1151
+ "\n",
1152
+ "\n",
1153
+ "for sample in data:\n",
1154
+ " conversations = sample.get(\"conversations\", [])\n",
1155
+ " embeddings = sample.get(\"embedding\", []) \n",
1156
+ "\n",
1157
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
1158
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
1159
+ " continue\n",
1160
+ "\n",
1161
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
1162
+ "\n",
1163
+ " #拼接���有对话\n",
1164
+ " dialogue_text = \"\"\n",
1165
+ " for conv in conversations:\n",
1166
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
1167
+ " content = conv[\"value\"]\n",
1168
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
1169
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
1170
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
1171
+ "\n",
1172
+ " tokenized = tokenizer(\n",
1173
+ " dialogue_text,\n",
1174
+ " padding=\"max_length\",\n",
1175
+ " truncation=True,\n",
1176
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
1177
+ " return_tensors=\"pt\",\n",
1178
+ " )\n",
1179
+ "\n",
1180
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
1181
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
1182
+ "\n",
1183
+ " train_data.append({\n",
1184
+ " \"input_ids\": input_ids,\n",
1185
+ " \"attention_mask\": attention_mask,\n",
1186
+ " \"labels\": input_ids.clone(),\n",
1187
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
1188
+ " })\n",
1189
+ "\n",
1190
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
1191
+ "print(\"✅ 示例数据:\", train_data[0])\n",
1192
+ "torch.save(train_data, \"train_data.pt\")\n",
1193
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
1194
+ ]
1195
+ },
1196
+ {
1197
+ "cell_type": "code",
1198
+ "execution_count": 10,
1199
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1200
+ "metadata": {},
1201
+ "outputs": [],
1202
+ "source": [
1203
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
1204
+ "import torch\n",
1205
+ "import torch.nn as nn\n",
1206
+ "\n",
1207
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1208
+ " def __init__(self, pretrained_model_name_or_path):\n",
1209
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
1210
+ " \n",
1211
+ " # ✅ 载入 `MODEL_NAME` 预训练模型\n",
1212
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
1213
+ "\n",
1214
+ " \n",
1215
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
1216
+ " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
1217
+ "\n",
1218
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1219
+ " \"\"\"\n",
1220
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1221
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1222
+ " \"\"\"\n",
1223
+ " # ✅ 获取 token embedding\n",
1224
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1225
+ "\n",
1226
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1227
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
1228
+ "\n",
1229
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1230
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1231
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1232
+ "\n",
1233
+ " # ✅ 调整 attention mask\n",
1234
+ " if attention_mask is not None:\n",
1235
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1236
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1237
+ "\n",
1238
+ " # ✅ 传入模型\n",
1239
+ " outputs = self.model(\n",
1240
+ " inputs_embeds=inputs_embeds,\n",
1241
+ " attention_mask=attention_mask,\n",
1242
+ " labels=labels,\n",
1243
+ " )\n",
1244
+ "\n",
1245
+ " return outputs\n",
1246
+ "\n",
1247
+ " def generate_with_graph(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
1248
+ " \"\"\"\n",
1249
+ " ✅ 自定义 `generate()`,支持 `graph_embedding`\n",
1250
+ " `input_text`: 需要生成文本的输入\n",
1251
+ " `graph_embedding`: 形状为 (1, 512) 的张量\n",
1252
+ " \"\"\"\n",
1253
+ " # ✅ 2. 处理 `graph_embedding`\n",
1254
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (1, hidden_size)\n",
1255
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (1, 1, hidden_size)\n",
1256
+ "\n",
1257
+ " # ✅ 3. 获取 Token Embeddings 并拼接\n",
1258
+ " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
1259
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
1260
+ "\n",
1261
+ " # ✅ 4. 调整 `attention_mask`\n",
1262
+ " if \"attention_mask\" in inputs:\n",
1263
+ " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
1264
+ " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
1265
+ " else:\n",
1266
+ " attention_mask = None\n",
1267
+ "\n",
1268
+ " # ✅ 5. 进行文本生成\n",
1269
+ " with torch.no_grad():\n",
1270
+ " output_ids = self.model.generate(\n",
1271
+ " inputs_embeds=inputs_embeds,\n",
1272
+ " attention_mask=attention_mask,\n",
1273
+ " max_length=max_length,\n",
1274
+ " temperature=temperature,\n",
1275
+ " top_k=top_k,\n",
1276
+ " top_p=top_p,\n",
1277
+ " num_return_sequences=1\n",
1278
+ " )\n",
1279
+ "\n",
1280
+ " # ✅ 6. 解码生成的文本\n",
1281
+ " generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
1282
+ " return generated_text\n",
1283
+ "\n",
1284
+ " @classmethod\n",
1285
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1286
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1287
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
1288
+ " return model"
1289
+ ]
1290
+ },
1291
+ {
1292
+ "cell_type": "code",
1293
+ "execution_count": 11,
1294
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1295
+ "metadata": {},
1296
+ "outputs": [],
1297
+ "source": [
1298
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1299
+ ]
1300
+ },
1301
+ {
1302
+ "cell_type": "code",
1303
+ "execution_count": 12,
1304
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1305
+ "metadata": {},
1306
+ "outputs": [
1307
+ {
1308
+ "data": {
1309
+ "application/vnd.jupyter.widget-view+json": {
1310
+ "model_id": "7ad289c5523340f39799ad11e3bc1bb5",
1311
+ "version_major": 2,
1312
+ "version_minor": 0
1313
+ },
1314
+ "text/plain": [
1315
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
1316
+ ]
1317
+ },
1318
+ "metadata": {},
1319
+ "output_type": "display_data"
1320
+ },
1321
+ {
1322
+ "data": {
1323
+ "text/plain": [
1324
+ "Qwen2ForCausalLM(\n",
1325
+ " (model): Qwen2Model(\n",
1326
+ " (embed_tokens): Embedding(151936, 1536)\n",
1327
+ " (layers): ModuleList(\n",
1328
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1329
+ " (self_attn): Qwen2Attention(\n",
1330
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1331
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1332
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1333
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1334
+ " )\n",
1335
+ " (mlp): Qwen2MLP(\n",
1336
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1337
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1338
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1339
+ " (act_fn): SiLU()\n",
1340
+ " )\n",
1341
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1342
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1343
+ " )\n",
1344
+ " )\n",
1345
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1346
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1347
+ " )\n",
1348
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1349
+ " (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
1350
+ ")"
1351
+ ]
1352
+ },
1353
+ "execution_count": 12,
1354
+ "metadata": {},
1355
+ "output_type": "execute_result"
1356
+ }
1357
+ ],
1358
+ "source": [
1359
+ "import torch\n",
1360
+ "from transformers import AutoTokenizer\n",
1361
+ "\n",
1362
+ "# 加载 tokenizer\n",
1363
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1364
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1365
+ "\n",
1366
+ "# 加载训练好的模型\n",
1367
+ "model_path = \"/workspace/model2\"\n",
1368
+ "model = GraphAwareLM.from_pretrained(\"/workspace/results2/checkpoint-5310\").to(device)\n",
1369
+ "model.eval() # 设置为推理模式\n"
1370
+ ]
1371
+ },
1372
+ {
1373
+ "cell_type": "code",
1374
+ "execution_count": 13,
1375
+ "id": "51995891-8906-4049-9401-2d22e06a84e8",
1376
+ "metadata": {},
1377
+ "outputs": [
1378
+ {
1379
+ "name": "stdout",
1380
+ "output_type": "stream",
1381
+ "text": [
1382
+ "Parameter containing:\n",
1383
+ "tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
1384
+ " [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
1385
+ " [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
1386
+ " ...,\n",
1387
+ " [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
1388
+ " [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
1389
+ " [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
1390
+ " device='cuda:0', requires_grad=True)\n"
1391
+ ]
1392
+ }
1393
+ ],
1394
+ "source": [
1395
+ "print(model.graph_proj.weight)\n"
1396
+ ]
1397
+ },
1398
+ {
1399
+ "cell_type": "code",
1400
+ "execution_count": 14,
1401
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1402
+ "metadata": {},
1403
+ "outputs": [],
1404
+ "source": [
1405
+ "import json\n",
1406
+ "json_path = \"final_Graph.json\"\n",
1407
+ "with open(json_path, \"r\") as f:\n",
1408
+ " data = json.load(f)\n",
1409
+ "\n",
1410
+ "test_data = data[0]\n",
1411
+ "\n",
1412
+ "conversations = test_data.get(\"conversations\")\n",
1413
+ "embeddings = test_data.get(\"embedding\") \n",
1414
+ "\n",
1415
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
1416
+ "\n",
1417
+ "question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
1418
+ "\n",
1419
+ "from transformers import AutoTokenizer\n",
1420
+ "\n",
1421
+ "# ✅ 输入文本\n",
1422
+ "ROLE_TOKENS = {\n",
1423
+ " \"human\": \"<|User|>\", \n",
1424
+ " \"gpt\": \"<|Assistant|>\", \n",
1425
+ "}\n",
1426
+ "GRAPH_LENGTH = 512\n",
1427
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1428
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1429
+ "\n",
1430
+ "input_ids = inputs[\"input_ids\"]\n",
1431
+ "attention_mask = inputs[\"attention_mask\"]\n"
1432
+ ]
1433
+ },
1434
+ {
1435
+ "cell_type": "code",
1436
+ "execution_count": 15,
1437
+ "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
1438
+ "metadata": {},
1439
+ "outputs": [
1440
+ {
1441
+ "ename": "AttributeError",
1442
+ "evalue": "'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'",
1443
+ "output_type": "error",
1444
+ "traceback": [
1445
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1446
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1447
+ "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_with_graph\u001b[49m(inputs, graph_embedding)\n",
1448
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
1449
+ "\u001b[0;31mAttributeError\u001b[0m: 'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'"
1450
+ ]
1451
+ }
1452
+ ],
1453
+ "source": [
1454
+ "generated_text = model.generate_with_graph(inputs, graph_embedding)"
1455
+ ]
1456
+ },
1457
+ {
1458
+ "cell_type": "code",
1459
+ "execution_count": 5,
1460
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1461
+ "metadata": {},
1462
+ "outputs": [
1463
+ {
1464
+ "data": {
1465
+ "text/plain": [
1466
+ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1467
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1468
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1469
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1470
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1471
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1472
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1473
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1474
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1475
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1476
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1477
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1478
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1479
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1480
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1481
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1482
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1483
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1484
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1485
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1486
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1487
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1488
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1489
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1490
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1491
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1492
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1493
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1494
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1495
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1496
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1497
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1498
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1499
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1500
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1501
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1502
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1503
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1504
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1505
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1506
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1507
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1508
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1509
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1510
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1511
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1512
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1513
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1514
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1515
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1516
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1517
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1518
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1519
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1520
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1521
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1522
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1523
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1524
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1525
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1526
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1527
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1528
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1529
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
1530
+ " device='cuda:0')"
1531
+ ]
1532
+ },
1533
+ "execution_count": 5,
1534
+ "metadata": {},
1535
+ "output_type": "execute_result"
1536
+ }
1537
+ ],
1538
+ "source": [
1539
+ "graph_embedding"
1540
+ ]
1541
+ },
1542
+ {
1543
+ "cell_type": "code",
1544
+ "execution_count": 15,
1545
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1546
+ "metadata": {},
1547
+ "outputs": [
1548
+ {
1549
+ "name": "stdout",
1550
+ "output_type": "stream",
1551
+ "text": [
1552
+ "模型回复: How\n"
1553
+ ]
1554
+ }
1555
+ ],
1556
+ "source": [
1557
+ "# ✅ 进行前向传播\n",
1558
+ "with torch.no_grad():\n",
1559
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1560
+ "\n",
1561
+ "# ✅ 提取 logits 并进行贪心解码\n",
1562
+ "logits = outputs.logits[:, -1, :] # ��最后一个 token 的 logits\n",
1563
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1564
+ "\n",
1565
+ "# ✅ 反向编码为文本\n",
1566
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1567
+ "\n",
1568
+ "print(\"模型回复:\", response_text)"
1569
+ ]
1570
+ },
1571
+ {
1572
+ "cell_type": "code",
1573
+ "execution_count": 17,
1574
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1575
+ "metadata": {},
1576
+ "outputs": [
1577
+ {
1578
+ "name": "stdout",
1579
+ "output_type": "stream",
1580
+ "text": [
1581
+ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? `data` is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit data, and the output is the output of the `data` is a 1-bit\n"
1582
+ ]
1583
+ }
1584
+ ],
1585
+ "source": [
1586
+ "max_new_tokens = 1024\n",
1587
+ "generated_ids = input_ids.clone()\n",
1588
+ "generated_attention_mask = attention_mask.clone()\n",
1589
+ "for _ in range(max_new_tokens):\n",
1590
+ " # ✅ 计算 logits 并进行生成\n",
1591
+ " with torch.no_grad():\n",
1592
+ " outputs = model(\n",
1593
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1594
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1595
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1596
+ " )\n",
1597
+ "\n",
1598
+ "\n",
1599
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1600
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1601
+ " # print(next_token)\n",
1602
+ "\n",
1603
+ "\n",
1604
+ " # ✅ **拼接到已生成序列**\n",
1605
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1606
+ "\n",
1607
+ " # print(generated_ids)\n",
1608
+ "\n",
1609
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1610
+ " break\n",
1611
+ "\n",
1612
+ " generated_attention_mask = torch.cat(\n",
1613
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1614
+ " ) \n",
1615
+ "\n",
1616
+ "# ✅ 解码最终输出\n",
1617
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1618
+ "print(\"Generated Response:\", generated_text)"
1619
+ ]
1620
+ },
1621
+ {
1622
+ "cell_type": "code",
1623
+ "execution_count": 10,
1624
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1625
+ "metadata": {},
1626
+ "outputs": [
1627
+ {
1628
+ "data": {
1629
+ "text/plain": [
1630
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1631
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1632
+ " 525, 862, 9895, 30]], device='cuda:0')"
1633
+ ]
1634
+ },
1635
+ "execution_count": 10,
1636
+ "metadata": {},
1637
+ "output_type": "execute_result"
1638
+ }
1639
+ ],
1640
+ "source": [
1641
+ "generated_ids"
1642
+ ]
1643
+ },
1644
+ {
1645
+ "cell_type": "code",
1646
+ "execution_count": null,
1647
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1648
+ "metadata": {},
1649
+ "outputs": [],
1650
+ "source": []
1651
+ }
1652
+ ],
1653
+ "metadata": {
1654
+ "kernelspec": {
1655
+ "display_name": "Python 3 (ipykernel)",
1656
+ "language": "python",
1657
+ "name": "python3"
1658
+ },
1659
+ "language_info": {
1660
+ "codemirror_mode": {
1661
+ "name": "ipython",
1662
+ "version": 3
1663
+ },
1664
+ "file_extension": ".py",
1665
+ "mimetype": "text/x-python",
1666
+ "name": "python",
1667
+ "nbconvert_exporter": "python",
1668
+ "pygments_lexer": "ipython3",
1669
+ "version": "3.10.12"
1670
+ }
1671
+ },
1672
+ "nbformat": 4,
1673
+ "nbformat_minor": 5
1674
+ }
.ipynb_checkpoints/graph_train3-checkpoint.ipynb ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new3\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_134403-e0v0giuw</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">experi030403</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:33:59, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.319300</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>3.641300</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>1.521800</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>1.027500</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>0.922400</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>0.866900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>0.800500</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>0.721600</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>0.740400</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>0.737000</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.713500</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.747000</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>0.869500</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.473300</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>0.753000</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>0.741300</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>0.751400</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>0.787600</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>0.783200</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>0.780200</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>1.012900</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>1.411700</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>1.536400</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>0.853800</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>0.756500</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>0.750800</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>0.747400</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>0.844400</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>0.858400</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>1.053400</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>1.591600</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>1.498900</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>1.471700</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>1.221100</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>1.802300</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>1.826000</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.857300</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.561800</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>1.398800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>1.398900</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>1.381600</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>0.890300</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>0.763700</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>0.753100</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>0.745500</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>1.186100</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>0.862000</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>1.024600</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>1.028400</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>1.008500</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>0.942800</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>0.849700</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>0.771400</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>0.794100</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>0.819200</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>0.937500</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>1.064500</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>1.189300</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.071100</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.003300</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>1.073900</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.043100</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>1.282600</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.145400</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>1.925800</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>2.005600</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.122600</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.163000</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.046600</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.152200</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.151700</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>5.394900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>4.677800</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>4.122200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>3.710200</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>3.350800</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>3.126300</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.988700</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>2.872000</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.848200</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.823900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.781200</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.735000</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.725900</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>2.644400</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.700000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>2.650100</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.704500</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>2.596700</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>2.510500</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>2.515800</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.498100</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.458900</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.449700</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.425000</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.362300</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>2.232000</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.361500</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>2.302300</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.333900</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.367200</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>2.288300</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.426100</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.344100</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>2.283500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>2.296500</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3/commit/b9472b66316be8654c6f7c173fa4561889bd3446', commit_message='End of training', commit_description='', oid='b9472b66316be8654c6f7c173fa4561889bd3446', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new3'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
662
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
663
+ " \n",
664
+ " # ✅ 载入 LLM 预训练模型\n",
665
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
666
+ "\n",
667
+ " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
668
+ " self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
669
+ "\n",
670
+ " # ✅ 2. 多头注意力层\n",
671
+ " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
672
+ "\n",
673
+ " # ✅ 3. 线性变换\n",
674
+ " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
675
+ "\n",
676
+ " # ✅ 4. 残差连接 + LayerNorm\n",
677
+ " self.norm = nn.LayerNorm(self.config.hidden_size)\n",
678
+ " \n",
679
+ "\n",
680
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
681
+ " \"\"\"\n",
682
+ " `graph_embedding` 形状: (batch_size, 512)\n",
683
+ " `input_ids` 形状: (batch_size, seq_len)\n",
684
+ " \"\"\"\n",
685
+ " # ✅ 获取 token embedding\n",
686
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
687
+ "\n",
688
+ " # ✅ 1. 线性变换 `graph_embedding`\n",
689
+ " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
690
+ "\n",
691
+ " # ✅ 2. 多头注意力计算(自注意力机制)\n",
692
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
693
+ " \n",
694
+ " # ✅ 3. 线性层 + 残差连接\n",
695
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
696
+ "\n",
697
+ " # ✅ 4. 归一化\n",
698
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
699
+ "\n",
700
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
701
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
702
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
703
+ "\n",
704
+ " # ✅ 调整 attention mask\n",
705
+ " if attention_mask is not None:\n",
706
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
707
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
708
+ "\n",
709
+ " # ✅ 传入模型\n",
710
+ " outputs = self.model(\n",
711
+ " inputs_embeds=inputs_embeds,\n",
712
+ " attention_mask=attention_mask,\n",
713
+ " labels=labels,\n",
714
+ " )\n",
715
+ "\n",
716
+ " return outputs\n",
717
+ "\n",
718
+ "from transformers import Trainer\n",
719
+ "\n",
720
+ "class GraphTrainer(Trainer):\n",
721
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
722
+ " input_ids = inputs[\"input_ids\"]\n",
723
+ " attention_mask = inputs[\"attention_mask\"]\n",
724
+ " labels = inputs[\"labels\"]\n",
725
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
726
+ "\n",
727
+ " if graph_embedding is not None:\n",
728
+ " outputs = model(\n",
729
+ " input_ids=input_ids,\n",
730
+ " attention_mask=attention_mask,\n",
731
+ " labels=labels,\n",
732
+ " graph_embedding=graph_embedding, \n",
733
+ " )\n",
734
+ " else:\n",
735
+ " outputs = model(\n",
736
+ " input_ids=input_ids,\n",
737
+ " attention_mask=attention_mask,\n",
738
+ " labels=labels,\n",
739
+ " )\n",
740
+ "\n",
741
+ " loss = outputs.loss\n",
742
+ " return (loss, outputs) if return_outputs else loss\n",
743
+ "\n",
744
+ "\n",
745
+ "from transformers import AutoConfig\n",
746
+ "\n",
747
+ "# ✅ 载入微调模型\n",
748
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
749
+ "\n",
750
+ "# ✅ 训练参数\n",
751
+ "training_args = TrainingArguments(\n",
752
+ " output_dir=\"./results3\",\n",
753
+ " per_device_train_batch_size=7,\n",
754
+ " eval_strategy=\"no\",\n",
755
+ " save_strategy=\"steps\",\n",
756
+ " save_steps=3000,\n",
757
+ " logging_steps=50,\n",
758
+ " bf16=True,\n",
759
+ " optim=\"galore_adamw\",\n",
760
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
761
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
762
+ " warmup_steps=1000,\n",
763
+ " num_train_epochs=3,\n",
764
+ " push_to_hub=True,\n",
765
+ " hub_model_id=HF_NAME,\n",
766
+ " hub_strategy=\"every_save\",\n",
767
+ " run_name = \"experi030403\"\n",
768
+ ")\n",
769
+ "\n",
770
+ "\n",
771
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
772
+ "train_dataset = GraphDataset(train_data)\n",
773
+ "\n",
774
+ "# ✅ 训练\n",
775
+ "trainer = GraphTrainer(\n",
776
+ " model=model,\n",
777
+ " args=training_args,\n",
778
+ " train_dataset=train_dataset,\n",
779
+ ")\n",
780
+ "\n",
781
+ "trainer.train()\n",
782
+ "trainer.save_model(\"/workspace/model3\")\n",
783
+ "trainer.push_to_hub()\n",
784
+ "\n",
785
+ "\n"
786
+ ]
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": 2,
791
+ "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
792
+ "metadata": {},
793
+ "outputs": [
794
+ {
795
+ "data": {
796
+ "text/plain": [
797
+ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000')"
798
+ ]
799
+ },
800
+ "execution_count": 2,
801
+ "metadata": {},
802
+ "output_type": "execute_result"
803
+ }
804
+ ],
805
+ "source": [
806
+ "from huggingface_hub import HfApi\n",
807
+ "\n",
808
+ "api = HfApi()\n",
809
+ "repo_name = \"r1q1.5_graph_lora-results3\" # 你的模型名称\n",
810
+ "api.create_repo(repo_name, exist_ok=True)"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 3,
816
+ "id": "73c434b9-5d58-4819-8526-24aa18ca1010",
817
+ "metadata": {},
818
+ "outputs": [
819
+ {
820
+ "data": {
821
+ "application/vnd.jupyter.widget-view+json": {
822
+ "model_id": "8b896f21685e4086b0b59404b2b1a866",
823
+ "version_major": 2,
824
+ "version_minor": 0
825
+ },
826
+ "text/plain": [
827
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
828
+ ]
829
+ },
830
+ "metadata": {},
831
+ "output_type": "display_data"
832
+ },
833
+ {
834
+ "data": {
835
+ "application/vnd.jupyter.widget-view+json": {
836
+ "model_id": "d20bff067ca44c4583378181da817897",
837
+ "version_major": 2,
838
+ "version_minor": 0
839
+ },
840
+ "text/plain": [
841
+ "scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
842
+ ]
843
+ },
844
+ "metadata": {},
845
+ "output_type": "display_data"
846
+ },
847
+ {
848
+ "data": {
849
+ "application/vnd.jupyter.widget-view+json": {
850
+ "model_id": "c4b7114a53b341539a3244f2eea8aacf",
851
+ "version_major": 2,
852
+ "version_minor": 0
853
+ },
854
+ "text/plain": [
855
+ "Upload 6 LFS files: 0%| | 0/6 [00:00<?, ?it/s]"
856
+ ]
857
+ },
858
+ "metadata": {},
859
+ "output_type": "display_data"
860
+ },
861
+ {
862
+ "data": {
863
+ "application/vnd.jupyter.widget-view+json": {
864
+ "model_id": "74c6045017b640bdba86fe3ed1bb9c92",
865
+ "version_major": 2,
866
+ "version_minor": 0
867
+ },
868
+ "text/plain": [
869
+ "model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
870
+ ]
871
+ },
872
+ "metadata": {},
873
+ "output_type": "display_data"
874
+ },
875
+ {
876
+ "data": {
877
+ "application/vnd.jupyter.widget-view+json": {
878
+ "model_id": "97436b084bc4420f8b273ec462c50e61",
879
+ "version_major": 2,
880
+ "version_minor": 0
881
+ },
882
+ "text/plain": [
883
+ "optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
884
+ ]
885
+ },
886
+ "metadata": {},
887
+ "output_type": "display_data"
888
+ },
889
+ {
890
+ "data": {
891
+ "application/vnd.jupyter.widget-view+json": {
892
+ "model_id": "d7f10ccff3674e6fa8bcb42553c12b19",
893
+ "version_major": 2,
894
+ "version_minor": 0
895
+ },
896
+ "text/plain": [
897
+ "rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
898
+ ]
899
+ },
900
+ "metadata": {},
901
+ "output_type": "display_data"
902
+ },
903
+ {
904
+ "data": {
905
+ "application/vnd.jupyter.widget-view+json": {
906
+ "model_id": "c5b1a010fd0845f9ba9112291afa8f17",
907
+ "version_major": 2,
908
+ "version_minor": 0
909
+ },
910
+ "text/plain": [
911
+ "training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
912
+ ]
913
+ },
914
+ "metadata": {},
915
+ "output_type": "display_data"
916
+ },
917
+ {
918
+ "data": {
919
+ "text/plain": [
920
+ "CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000/commit/4088de651a0ce2cc39fcb0c950898e54ce91bdea', commit_message='upload checkpoint-3000', commit_description='', oid='4088de651a0ce2cc39fcb0c950898e54ce91bdea', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000'), pr_revision=None, pr_num=None)"
921
+ ]
922
+ },
923
+ "execution_count": 3,
924
+ "metadata": {},
925
+ "output_type": "execute_result"
926
+ }
927
+ ],
928
+ "source": [
929
+ "from huggingface_hub import upload_folder\n",
930
+ "\n",
931
+ "upload_folder(\n",
932
+ " folder_path = \"/workspace/results3\",\n",
933
+ " repo_id = \"YiFzhao/r1q1.5_graph_lora-results3\",\n",
934
+ " commit_message = \"upload results2\",\n",
935
+ ")"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": 5,
941
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
942
+ "metadata": {},
943
+ "outputs": [
944
+ {
945
+ "name": "stdout",
946
+ "output_type": "stream",
947
+ "text": [
948
+ "🚀 处理后数据条数: 12384\n",
949
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
950
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
951
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
952
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
953
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
954
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
955
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
956
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
957
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
958
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
959
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
960
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
961
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
962
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
963
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
964
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
965
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
966
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
967
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
968
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
969
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
970
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
971
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
972
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
973
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
974
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
975
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
976
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
977
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
978
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
979
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
980
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
981
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
982
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
983
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
984
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
985
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
986
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
987
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
988
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
989
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
990
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
991
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
992
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
993
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
994
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
995
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
996
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
997
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
998
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
999
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1000
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1001
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1002
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1003
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1004
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1005
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1006
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1007
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1008
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1009
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1010
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1011
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1012
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
1013
+ "✅ train_data 已保存到 train_data.pt\n"
1014
+ ]
1015
+ }
1016
+ ],
1017
+ "source": [
1018
+ "import json\n",
1019
+ "import torch\n",
1020
+ "from transformers import AutoTokenizer\n",
1021
+ "\n",
1022
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1023
+ "tokenizer.pad_token = tokenizer.eos_token \n",
1024
+ "\n",
1025
+ "json_path = \"final_Graph.json\"\n",
1026
+ "with open(json_path, \"r\") as f:\n",
1027
+ " data = json.load(f)\n",
1028
+ "\n",
1029
+ "train_data = []\n",
1030
+ "\n",
1031
+ "\n",
1032
+ "for sample in data:\n",
1033
+ " conversations = sample.get(\"conversations\", [])\n",
1034
+ " embeddings = sample.get(\"embedding\", []) \n",
1035
+ "\n",
1036
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
1037
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
1038
+ " continue\n",
1039
+ "\n",
1040
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
1041
+ "\n",
1042
+ " #拼接所有对话\n",
1043
+ " dialogue_text = \"\"\n",
1044
+ " for conv in conversations:\n",
1045
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
1046
+ " content = conv[\"value\"]\n",
1047
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
1048
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
1049
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
1050
+ "\n",
1051
+ " tokenized = tokenizer(\n",
1052
+ " dialogue_text,\n",
1053
+ " padding=\"max_length\",\n",
1054
+ " truncation=True,\n",
1055
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
1056
+ " return_tensors=\"pt\",\n",
1057
+ " )\n",
1058
+ "\n",
1059
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
1060
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
1061
+ "\n",
1062
+ " train_data.append({\n",
1063
+ " \"input_ids\": input_ids,\n",
1064
+ " \"attention_mask\": attention_mask,\n",
1065
+ " \"labels\": input_ids.clone(),\n",
1066
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
1067
+ " })\n",
1068
+ "\n",
1069
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
1070
+ "print(\"✅ 示例数据:\", train_data[0])\n",
1071
+ "torch.save(train_data, \"train_data.pt\")\n",
1072
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "code",
1077
+ "execution_count": 6,
1078
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1079
+ "metadata": {},
1080
+ "outputs": [],
1081
+ "source": [
1082
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
1083
+ "import torch\n",
1084
+ "import torch.nn as nn\n",
1085
+ "\n",
1086
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1087
+ " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
1088
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
1089
+ " \n",
1090
+ " # ✅ 载入 LLM 预训练模型\n",
1091
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
1092
+ "\n",
1093
+ " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
1094
+ " self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
1095
+ "\n",
1096
+ " # ✅ 2. 多头注意力层\n",
1097
+ " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
1098
+ "\n",
1099
+ " # ✅ 3. 线性变换\n",
1100
+ " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
1101
+ "\n",
1102
+ " # ✅ 4. 残差连接 + LayerNorm\n",
1103
+ " self.norm = nn.LayerNorm(self.config.hidden_size)\n",
1104
+ " \n",
1105
+ "\n",
1106
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1107
+ " \"\"\"\n",
1108
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1109
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1110
+ " \"\"\"\n",
1111
+ " # ✅ 获取 token embedding\n",
1112
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1113
+ "\n",
1114
+ " # ✅ 1. 线性变换 `graph_embedding`\n",
1115
+ " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
1116
+ "\n",
1117
+ " # ✅ 2. 多头注意力计算(自注意力机制)\n",
1118
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
1119
+ " \n",
1120
+ " # ✅ 3. 线性层 + 残差连接\n",
1121
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
1122
+ "\n",
1123
+ " # ✅ 4. 归一化\n",
1124
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
1125
+ "\n",
1126
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1127
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1128
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1129
+ "\n",
1130
+ " # ✅ 调整 attention mask\n",
1131
+ " if attention_mask is not None:\n",
1132
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1133
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1134
+ "\n",
1135
+ " # ✅ 传入模型\n",
1136
+ " outputs = self.model(\n",
1137
+ " inputs_embeds=inputs_embeds,\n",
1138
+ " attention_mask=attention_mask,\n",
1139
+ " labels=labels,\n",
1140
+ " )\n",
1141
+ "\n",
1142
+ " return outputs\n",
1143
+ "\n",
1144
+ " def generate(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
1145
+ " \"\"\"\n",
1146
+ " ✅ 自定义 `generate()` 方法,支持 `graph_embedding`\n",
1147
+ " `input_text`: 需要生成文本的输入\n",
1148
+ " `graph_embedding`: 形状为 (1, 512) 的张量\n",
1149
+ " \"\"\"\n",
1150
+ "\n",
1151
+ " # ✅ 2. 处理 `graph_embedding`\n",
1152
+ " graph_embedding_token = self.linear1(graph_embedding) # (1, 1, hidden_size)\n",
1153
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
1154
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (1, 1, hidden_size)\n",
1155
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
1156
+ "\n",
1157
+ " # ✅ 3. 获取 Token Embeddings 并拼接\n",
1158
+ " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
1159
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
1160
+ "\n",
1161
+ " # ✅ 4. 调整 `attention_mask`\n",
1162
+ " if \"attention_mask\" in inputs:\n",
1163
+ " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
1164
+ " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
1165
+ " else:\n",
1166
+ " attention_mask = None\n",
1167
+ "\n",
1168
+ " # ✅ 5. 进行文本生成\n",
1169
+ " with torch.no_grad():\n",
1170
+ " output_ids = self.model.generate(\n",
1171
+ " inputs_embeds=inputs_embeds,\n",
1172
+ " attention_mask=attention_mask,\n",
1173
+ " max_length=max_length,\n",
1174
+ " temperature=temperature,\n",
1175
+ " top_k=top_k,\n",
1176
+ " top_p=top_p,\n",
1177
+ " num_return_sequences=1\n",
1178
+ " )\n",
1179
+ "\n",
1180
+ " # ✅ 6. 解码输出\n",
1181
+ " generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
1182
+ " return generated_text\n",
1183
+ "\n",
1184
+ " @classmethod\n",
1185
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1186
+ " # ✅ 1. 调用 `super().from_pretrained()` 加载 LLM\n",
1187
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1188
+ "\n",
1189
+ " # ✅ 2. 初始化 `MLP + MultiheadAttention` 结构\n",
1190
+ " model.linear1 = nn.Linear(512, model.config.hidden_size)\n",
1191
+ " model.multihead_attn = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=8, batch_first=True)\n",
1192
+ " model.linear2 = nn.Linear(model.config.hidden_size, model.config.hidden_size)\n",
1193
+ " model.norm = nn.LayerNorm(model.config.hidden_size)\n",
1194
+ "\n",
1195
+ " return model"
1196
+ ]
1197
+ },
1198
+ {
1199
+ "cell_type": "code",
1200
+ "execution_count": 2,
1201
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1202
+ "metadata": {},
1203
+ "outputs": [],
1204
+ "source": [
1205
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1206
+ ]
1207
+ },
1208
+ {
1209
+ "cell_type": "code",
1210
+ "execution_count": 7,
1211
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1212
+ "metadata": {},
1213
+ "outputs": [
1214
+ {
1215
+ "data": {
1216
+ "application/vnd.jupyter.widget-view+json": {
1217
+ "model_id": "0b50f0cd6d784f598cc64a40cff40f38",
1218
+ "version_major": 2,
1219
+ "version_minor": 0
1220
+ },
1221
+ "text/plain": [
1222
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
1223
+ ]
1224
+ },
1225
+ "metadata": {},
1226
+ "output_type": "display_data"
1227
+ },
1228
+ {
1229
+ "data": {
1230
+ "text/plain": [
1231
+ "Qwen2ForCausalLM(\n",
1232
+ " (model): Qwen2Model(\n",
1233
+ " (embed_tokens): Embedding(151936, 1536)\n",
1234
+ " (layers): ModuleList(\n",
1235
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1236
+ " (self_attn): Qwen2Attention(\n",
1237
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1238
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1239
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1240
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1241
+ " )\n",
1242
+ " (mlp): Qwen2MLP(\n",
1243
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1244
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1245
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1246
+ " (act_fn): SiLU()\n",
1247
+ " )\n",
1248
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1249
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1250
+ " )\n",
1251
+ " )\n",
1252
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1253
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1254
+ " )\n",
1255
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1256
+ " (linear1): Linear(in_features=512, out_features=1536, bias=True)\n",
1257
+ " (multihead_attn): MultiheadAttention(\n",
1258
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)\n",
1259
+ " )\n",
1260
+ " (linear2): Linear(in_features=1536, out_features=1536, bias=True)\n",
1261
+ " (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
1262
+ ")"
1263
+ ]
1264
+ },
1265
+ "execution_count": 7,
1266
+ "metadata": {},
1267
+ "output_type": "execute_result"
1268
+ }
1269
+ ],
1270
+ "source": [
1271
+ "import torch\n",
1272
+ "from transformers import AutoTokenizer\n",
1273
+ "\n",
1274
+ "# 加载 tokenizer\n",
1275
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1276
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1277
+ "\n",
1278
+ "# 加载训练好的模型\n",
1279
+ "model_path = \"/workspace/model2\"\n",
1280
+ "model = GraphAwareLM.from_pretrained(\"/workspace/results3/checkpoint-3000\").to(device)\n",
1281
+ "model.eval() # 设置为推理模式\n"
1282
+ ]
1283
+ },
1284
+ {
1285
+ "cell_type": "code",
1286
+ "execution_count": 13,
1287
+ "id": "51995891-8906-4049-9401-2d22e06a84e8",
1288
+ "metadata": {},
1289
+ "outputs": [
1290
+ {
1291
+ "name": "stdout",
1292
+ "output_type": "stream",
1293
+ "text": [
1294
+ "Parameter containing:\n",
1295
+ "tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
1296
+ " [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
1297
+ " [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
1298
+ " ...,\n",
1299
+ " [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
1300
+ " [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
1301
+ " [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
1302
+ " device='cuda:0', requires_grad=True)\n"
1303
+ ]
1304
+ }
1305
+ ],
1306
+ "source": [
1307
+ "print(model.graph_proj.weight)\n"
1308
+ ]
1309
+ },
1310
+ {
1311
+ "cell_type": "code",
1312
+ "execution_count": 4,
1313
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1314
+ "metadata": {},
1315
+ "outputs": [],
1316
+ "source": [
1317
+ "import json\n",
1318
+ "json_path = \"final_Graph.json\"\n",
1319
+ "with open(json_path, \"r\") as f:\n",
1320
+ " data = json.load(f)\n",
1321
+ "\n",
1322
+ "test_data = data[0]\n",
1323
+ "\n",
1324
+ "conversations = test_data.get(\"conversations\")\n",
1325
+ "embeddings = test_data.get(\"embedding\") \n",
1326
+ "\n",
1327
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
1328
+ "\n",
1329
+ "question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
1330
+ "\n",
1331
+ "from transformers import AutoTokenizer\n",
1332
+ "\n",
1333
+ "# ✅ 输入文本\n",
1334
+ "ROLE_TOKENS = {\n",
1335
+ " \"human\": \"<|User|>\", \n",
1336
+ " \"gpt\": \"<|Assistant|>\", \n",
1337
+ "}\n",
1338
+ "GRAPH_LENGTH = 512\n",
1339
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1340
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1341
+ "\n",
1342
+ "input_ids = inputs[\"input_ids\"]\n",
1343
+ "attention_mask = inputs[\"attention_mask\"]\n"
1344
+ ]
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": 8,
1349
+ "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
1350
+ "metadata": {},
1351
+ "outputs": [
1352
+ {
1353
+ "ename": "AttributeError",
1354
+ "evalue": "'Tensor' object has no attribute 'update'",
1355
+ "output_type": "error",
1356
+ "traceback": [
1357
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1358
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1359
+ "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m)\u001b[49m\n",
1360
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1361
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1982\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1979\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1980\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n\u001b[0;32m-> 1982\u001b[0m generation_config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_generation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1983\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_kwargs(model_kwargs\u001b[38;5;241m.\u001b[39mcopy())\n\u001b[1;32m 1984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_assistant(assistant_model, tokenizer, assistant_tokenizer)\n",
1362
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1549\u001b[0m, in \u001b[0;36mGenerationMixin._prepare_generation_config\u001b[0;34m(self, generation_config, **kwargs)\u001b[0m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 1548\u001b[0m generation_config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(generation_config)\n\u001b[0;32m-> 1549\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1550\u001b[0m \u001b[38;5;66;03m# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model\u001b[39;00m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_model_generation_config:\n",
1363
+ "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'update'"
1364
+ ]
1365
+ }
1366
+ ],
1367
+ "source": [
1368
+ "generated_text = model.generate(inputs, graph_embedding)"
1369
+ ]
1370
+ },
1371
+ {
1372
+ "cell_type": "code",
1373
+ "execution_count": 5,
1374
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1375
+ "metadata": {},
1376
+ "outputs": [
1377
+ {
1378
+ "data": {
1379
+ "text/plain": [
1380
+ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1381
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1382
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1383
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1384
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1385
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1386
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1387
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1388
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1389
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1390
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1391
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1392
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1393
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1394
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1395
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1396
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1397
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1398
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1399
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1400
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1401
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1402
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1403
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1404
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1405
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1406
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1407
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1408
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1409
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1410
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1411
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1412
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1413
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1414
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1415
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1416
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1417
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1418
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1419
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1420
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1421
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1422
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1423
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1424
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1425
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1426
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1427
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1428
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1429
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1430
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1431
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1432
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1433
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1434
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1435
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1436
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1437
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1438
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1439
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1440
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1441
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1442
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1443
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
1444
+ " device='cuda:0')"
1445
+ ]
1446
+ },
1447
+ "execution_count": 5,
1448
+ "metadata": {},
1449
+ "output_type": "execute_result"
1450
+ }
1451
+ ],
1452
+ "source": [
1453
+ "graph_embedding"
1454
+ ]
1455
+ },
1456
+ {
1457
+ "cell_type": "code",
1458
+ "execution_count": 15,
1459
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1460
+ "metadata": {},
1461
+ "outputs": [
1462
+ {
1463
+ "name": "stdout",
1464
+ "output_type": "stream",
1465
+ "text": [
1466
+ "模型回复: How\n"
1467
+ ]
1468
+ }
1469
+ ],
1470
+ "source": [
1471
+ "# ✅ 进行前向传播\n",
1472
+ "with torch.no_grad():\n",
1473
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1474
+ "\n",
1475
+ "# ✅ 提取 logits 并进行贪心解码\n",
1476
+ "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1477
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1478
+ "\n",
1479
+ "# ✅ 反向编码为文本\n",
1480
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1481
+ "\n",
1482
+ "print(\"模型回复:\", response_text)"
1483
+ ]
1484
+ },
1485
+ {
1486
+ "cell_type": "code",
1487
+ "execution_count": 9,
1488
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1489
+ "metadata": {},
1490
+ "outputs": [
1491
+ {
1492
+ "name": "stdout",
1493
+ "output_type": "stream",
1494
+ "text": [
1495
+ "Generated Response: What are the signal definitions in the Verilog code for the calculator module, and what are their purposes? The Verilog code defines the inputs A, B, and C, and the output Y. A and B are the operands, C is the carry-in, and Y is the result. The purpose of the module is to perform a 2-bit adder, which adds two 2-bit numbers, and the output is the sum. The inputs A and B are the operands, C is the carry-in, and Y is the result. The module is designed to handle the addition operation of two 2-bit numbers, with a carry-in, and a 3-bit output. The implementation involves using logic gates to perform the addition operation, with the sum output connected to the gates. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, involving basic gates and an adder circuit. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with\n"
1496
+ ]
1497
+ }
1498
+ ],
1499
+ "source": [
1500
+ "max_new_tokens = 500\n",
1501
+ "generated_ids = input_ids.clone()\n",
1502
+ "generated_attention_mask = attention_mask.clone()\n",
1503
+ "for _ in range(max_new_tokens):\n",
1504
+ " # ✅ 计算 logits 并进行生成\n",
1505
+ " with torch.no_grad():\n",
1506
+ " outputs = model(\n",
1507
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1508
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1509
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1510
+ " )\n",
1511
+ "\n",
1512
+ "\n",
1513
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1514
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1515
+ " # print(next_token)\n",
1516
+ "\n",
1517
+ "\n",
1518
+ " # ✅ **拼接到已生成序列**\n",
1519
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1520
+ "\n",
1521
+ " # print(generated_ids)\n",
1522
+ "\n",
1523
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1524
+ " break\n",
1525
+ "\n",
1526
+ " generated_attention_mask = torch.cat(\n",
1527
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1528
+ " ) \n",
1529
+ "\n",
1530
+ "# ✅ 解码最终输出\n",
1531
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1532
+ "print(\"Generated Response:\", generated_text)"
1533
+ ]
1534
+ },
1535
+ {
1536
+ "cell_type": "code",
1537
+ "execution_count": 10,
1538
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1539
+ "metadata": {},
1540
+ "outputs": [
1541
+ {
1542
+ "data": {
1543
+ "text/plain": [
1544
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1545
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1546
+ " 525, 862, 9895, 30]], device='cuda:0')"
1547
+ ]
1548
+ },
1549
+ "execution_count": 10,
1550
+ "metadata": {},
1551
+ "output_type": "execute_result"
1552
+ }
1553
+ ],
1554
+ "source": [
1555
+ "generated_ids"
1556
+ ]
1557
+ },
1558
+ {
1559
+ "cell_type": "code",
1560
+ "execution_count": null,
1561
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1562
+ "metadata": {},
1563
+ "outputs": [],
1564
+ "source": []
1565
+ }
1566
+ ],
1567
+ "metadata": {
1568
+ "kernelspec": {
1569
+ "display_name": "Python 3 (ipykernel)",
1570
+ "language": "python",
1571
+ "name": "python3"
1572
+ },
1573
+ "language_info": {
1574
+ "codemirror_mode": {
1575
+ "name": "ipython",
1576
+ "version": 3
1577
+ },
1578
+ "file_extension": ".py",
1579
+ "mimetype": "text/x-python",
1580
+ "name": "python",
1581
+ "nbconvert_exporter": "python",
1582
+ "pygments_lexer": "ipython3",
1583
+ "version": "3.10.12"
1584
+ }
1585
+ },
1586
+ "nbformat": 4,
1587
+ "nbformat_minor": 5
1588
+ }
eval.ipynb ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
10
+ "import torch\n",
11
+ "\n",
12
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
13
+ "\n",
14
+ "MODEL_NAME = \"/workspace/model\"\n",
15
+ "model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 2,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "import json\n",
25
+ "import torch\n",
26
+ "from transformers import AutoTokenizer\n",
27
+ "\n",
28
+ "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
29
+ "tokenizer.pad_token = tokenizer.eos_token "
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "json_path = \"final_Graph.json\"\n",
39
+ "with open(json_path, \"r\") as f:\n",
40
+ " data = json.load(f)\n",
41
+ "\n",
42
+ "test_data = data[0]\n"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "ROLE_TOKENS = {\n",
52
+ " \"human\": \"<|User|>\", \n",
53
+ " \"gpt\": \"<|Assistant|>\", \n",
54
+ "}\n",
55
+ "GRAPH_LENGTH = 512\n",
56
+ "max_seq_length = 1100 + GRAPH_LENGTH"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "conversations = test_data.get(\"conversations\")\n",
66
+ "embeddings = test_data.get(\"embedding\") \n",
67
+ "\n",
68
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 6,
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "data": {
78
+ "text/plain": [
79
+ "'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'"
80
+ ]
81
+ },
82
+ "execution_count": 6,
83
+ "metadata": {},
84
+ "output_type": "execute_result"
85
+ }
86
+ ],
87
+ "source": [
88
+ "question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
89
+ "question1"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 11,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "import json\n",
99
+ "import torch\n",
100
+ "import os\n",
101
+ "from transformers import AutoTokenizer\n",
102
+ "# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
103
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
104
+ "from torch.utils.data import Dataset\n",
105
+ "from transformers import AutoModelForCausalLM\n",
106
+ "import torch\n",
107
+ "import torch.nn as nn\n",
108
+ "\n",
109
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
110
+ " def __init__(self, config):\n",
111
+ " super().__init__(config)\n",
112
+ " self.model = AutoModelForCausalLM.from_config(config)\n",
113
+ " \n",
114
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
115
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
116
+ "\n",
117
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
118
+ " \"\"\"\n",
119
+ " `graph_embedding` 形状: (batch_size, 512)\n",
120
+ " `input_ids` 形状: (batch_size, seq_len)\n",
121
+ " \"\"\"\n",
122
+ " # ✅ 获取 token embedding\n",
123
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
124
+ "\n",
125
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
126
+ " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
127
+ "\n",
128
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
129
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
130
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
131
+ "\n",
132
+ " # ✅ 调整 attention mask\n",
133
+ " if attention_mask is not None:\n",
134
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
135
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
136
+ "\n",
137
+ " # ✅ 传入模型\n",
138
+ " outputs = self.model(\n",
139
+ " inputs_embeds=inputs_embeds,\n",
140
+ " attention_mask=attention_mask,\n",
141
+ " labels=labels,\n",
142
+ " )\n",
143
+ "\n",
144
+ " return outputs\n",
145
+ "\n",
146
+ " @classmethod\n",
147
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
148
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
149
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
150
+ " return model\n",
151
+ "\n"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 12,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
161
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 13,
167
+ "metadata": {},
168
+ "outputs": [
169
+ {
170
+ "data": {
171
+ "text/plain": [
172
+ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
173
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
174
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
175
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
176
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
177
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
178
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
179
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
180
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
181
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
182
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
183
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
184
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
185
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
186
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
187
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
188
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
189
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
190
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
191
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
192
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
193
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
194
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
195
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
196
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
197
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
198
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
199
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
200
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
201
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
202
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
203
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
204
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
205
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
206
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
207
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
208
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
209
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
210
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
211
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
212
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
213
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
214
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
215
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
216
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
217
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
218
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
219
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
220
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
221
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
222
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
223
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
224
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
225
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
226
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
227
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
228
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
229
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
230
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
231
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
232
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
233
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
234
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
235
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
236
+ " device='cuda:0')"
237
+ ]
238
+ },
239
+ "execution_count": 13,
240
+ "metadata": {},
241
+ "output_type": "execute_result"
242
+ }
243
+ ],
244
+ "source": [
245
+ "from transformers import AutoTokenizer\n",
246
+ "\n",
247
+ "# ✅ 加载分词器\n",
248
+ "tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
249
+ "\n",
250
+ "# ✅ 输入文本\n",
251
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
252
+ "\n",
253
+ "graph_embedding.to(device)\n",
254
+ "\n"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 14,
260
+ "metadata": {},
261
+ "outputs": [
262
+ {
263
+ "ename": "RuntimeError",
264
+ "evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3",
265
+ "output_type": "error",
266
+ "traceback": [
267
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
268
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
269
+ "Cell \u001b[0;32mIn[14], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :] \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m 14\u001b[0m next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n",
270
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
271
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
272
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
273
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 866\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
274
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
275
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
276
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 567\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 568\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 569\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 576\u001b[0m position_embeddings,\n\u001b[1;32m 577\u001b[0m )\n\u001b[1;32m 578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 585\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
277
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
278
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
279
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
280
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
281
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
282
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 190\u001b[0m attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n",
283
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m 125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n",
284
+ "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3"
285
+ ]
286
+ }
287
+ ],
288
+ "source": [
289
+ "\n",
290
+ "generated_ids = inputs[\"input_ids\"]\n",
291
+ "max_new_tokens = 1024\n",
292
+ "for _ in range(max_new_tokens):\n",
293
+ " # ✅ 计算 logits 并进行生成\n",
294
+ " with torch.no_grad():\n",
295
+ " outputs = model(\n",
296
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
297
+ " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
298
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
299
+ " )\n",
300
+ "\n",
301
+ "\n",
302
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
303
+ " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
304
+ "\n",
305
+ "\n",
306
+ " # ✅ **拼接到已生成序列**\n",
307
+ " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
308
+ "\n",
309
+ " if next_token[:, 0] == tokenizer.eos_token_id:\n",
310
+ " break\n",
311
+ "\n",
312
+ "# ✅ 解码最终输出\n",
313
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
314
+ "print(\"Generated Response:\", generated_text)"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n",
327
+ "The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n",
328
+ "How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n",
329
+ "The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n",
330
+ "What is the function of the circuit described in the Verilog code?\n",
331
+ "The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n",
332
+ "How can the circuit be implemented in hardware?\n",
333
+ "The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n"
334
+ ]
335
+ }
336
+ ],
337
+ "source": [
338
+ "generated_ids = inputs[\"input_ids\"]\n",
339
+ "max_new_tokens = 1024\n",
340
+ "for _ in range(max_new_tokens):\n",
341
+ " # ✅ 计算 logits 并进行生成\n",
342
+ " with torch.no_grad():\n",
343
+ " outputs = model(\n",
344
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
345
+ " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
346
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
347
+ " )\n",
348
+ "\n",
349
+ "\n",
350
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
351
+ " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
352
+ "\n",
353
+ "\n",
354
+ " # ✅ **拼接到已生成序列**\n",
355
+ " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
356
+ "\n",
357
+ " if next_token[:, 0] == tokenizer.eos_token_id:\n",
358
+ " break\n",
359
+ "\n",
360
+ "# ✅ 解码最终输出\n",
361
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
362
+ "print(\"Generated Response:\", generated_text)"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "import torch\n",
372
+ "from transformers import AutoTokenizer\n",
373
+ "\n",
374
+ "# 加载 tokenizer\n",
375
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
376
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
377
+ "\n",
378
+ "# 加载训练好的模型\n",
379
+ "model_path = \"/workspace/model\"\n",
380
+ "model = GraphAwareLM.from_pretrained(model_path)\n",
381
+ "model.eval() # 设置为推理模式\n"
382
+ ]
383
+ }
384
+ ],
385
+ "metadata": {
386
+ "kernelspec": {
387
+ "display_name": "Python 3 (ipykernel)",
388
+ "language": "python",
389
+ "name": "python3"
390
+ },
391
+ "language_info": {
392
+ "codemirror_mode": {
393
+ "name": "ipython",
394
+ "version": 3
395
+ },
396
+ "file_extension": ".py",
397
+ "mimetype": "text/x-python",
398
+ "name": "python",
399
+ "nbconvert_exporter": "python",
400
+ "pygments_lexer": "ipython3",
401
+ "version": "3.10.12"
402
+ }
403
+ },
404
+ "nbformat": 4,
405
+ "nbformat_minor": 4
406
+ }
final_Graph.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5344e1faf783838cb4db7cd8bbbfdd3e4f01189277442d84682bcdaa1e4b9ac3
3
+ size 261383982
graph_train.ipynb ADDED
@@ -0,0 +1,1591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_081255-v0v96nik</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:23:11, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.349900</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>5.305900</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>4.849500</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>3.910800</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>3.325600</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>3.144900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>2.904200</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>2.082100</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>1.214300</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>1.011600</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.889300</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.907300</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>1.190400</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.889100</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>4.505600</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>6.402800</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>6.479300</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>7.337900</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>8.937600</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>8.938700</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>8.860100</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>8.693600</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>9.234000</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>9.347500</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>8.010300</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>5.952900</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>5.205900</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>4.969600</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>4.884600</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>4.934200</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>5.156900</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>5.115500</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>5.373600</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>4.481800</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>3.957000</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>3.092500</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.791000</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.934400</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>2.176800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>2.112400</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>2.127900</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>2.390200</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>3.091400</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>3.959500</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>3.905000</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>3.777500</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>3.282900</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>2.630300</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>3.705000</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>4.266300</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>4.285300</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>4.634000</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>4.474700</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>3.591300</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>2.486800</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>1.911800</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>2.088100</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>2.015400</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.988500</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.976900</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>2.097700</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.987400</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>2.065000</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.112100</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>2.075300</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>1.983300</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.181900</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.446500</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.434200</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.357000</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.157400</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>1.992900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>2.018400</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>2.010200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>2.009500</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>2.034900</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>2.038800</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.007600</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>1.983200</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.005300</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.014900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.018100</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.033900</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.024600</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>1.995300</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.018000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>1.998300</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.032800</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>1.985900</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>1.967700</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>1.989400</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.004700</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.005800</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.014400</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.009200</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.002200</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>1.914300</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.016900</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>1.972900</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.010300</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.046600</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>1.993900</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.084500</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.011900</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>1.996500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>1.997900</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new/commit/231f89403dca9aa67966e4f321e62bdb41076960', commit_message='End of training', commit_description='', oid='231f89403dca9aa67966e4f321e62bdb41076960', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, config):\n",
662
+ " super().__init__(config)\n",
663
+ "\n",
664
+ " # self.model = AutoModelForCausalLM.from_config(config)\n",
665
+ " \n",
666
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
667
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
668
+ "\n",
669
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
670
+ " \"\"\"\n",
671
+ " `graph_embedding` 形状: (batch_size, 512)\n",
672
+ " `input_ids` 形状: (batch_size, seq_len)\n",
673
+ " \"\"\"\n",
674
+ " # ✅ 获取 token embedding\n",
675
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
676
+ "\n",
677
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
678
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
679
+ "\n",
680
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
681
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
682
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
683
+ "\n",
684
+ " # ✅ 调整 attention mask\n",
685
+ " if attention_mask is not None:\n",
686
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
687
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
688
+ "\n",
689
+ " # ✅ 传入模型\n",
690
+ " outputs = self.model(\n",
691
+ " inputs_embeds=inputs_embeds,\n",
692
+ " attention_mask=attention_mask,\n",
693
+ " labels=labels,\n",
694
+ " )\n",
695
+ "\n",
696
+ " return outputs\n",
697
+ "\n",
698
+ "from transformers import Trainer\n",
699
+ "\n",
700
+ "class GraphTrainer(Trainer):\n",
701
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
702
+ " input_ids = inputs[\"input_ids\"]\n",
703
+ " attention_mask = inputs[\"attention_mask\"]\n",
704
+ " labels = inputs[\"labels\"]\n",
705
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
706
+ "\n",
707
+ " if graph_embedding is not None:\n",
708
+ " outputs = model(\n",
709
+ " input_ids=input_ids,\n",
710
+ " attention_mask=attention_mask,\n",
711
+ " labels=labels,\n",
712
+ " graph_embedding=graph_embedding, \n",
713
+ " )\n",
714
+ " else:\n",
715
+ " outputs = model(\n",
716
+ " input_ids=input_ids,\n",
717
+ " attention_mask=attention_mask,\n",
718
+ " labels=labels,\n",
719
+ " )\n",
720
+ "\n",
721
+ " loss = outputs.loss\n",
722
+ " return (loss, outputs) if return_outputs else loss\n",
723
+ "\n",
724
+ "\n",
725
+ "from transformers import AutoConfig\n",
726
+ "\n",
727
+ "# 1. 加载模型的配置\n",
728
+ "config = AutoConfig.from_pretrained(MODEL_NAME)\n",
729
+ "\n",
730
+ "# 2. 使用配置创建 GraphAwareLM 实例\n",
731
+ "model = GraphAwareLM.from_config(config) \n",
732
+ "\n",
733
+ "pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
734
+ "model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
735
+ "\n",
736
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
737
+ "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
738
+ "# model.config.use_sliding_window_attention = False\n",
739
+ "\n",
740
+ "# ✅ 训练参数\n",
741
+ "training_args = TrainingArguments(\n",
742
+ " output_dir=\"./results\",\n",
743
+ " per_device_train_batch_size=7,\n",
744
+ " eval_strategy=\"no\",\n",
745
+ " save_strategy=\"steps\",\n",
746
+ " save_steps=3000,\n",
747
+ " logging_steps=50,\n",
748
+ " bf16=True,\n",
749
+ " optim=\"galore_adamw\",\n",
750
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
751
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
752
+ " warmup_steps=1000,\n",
753
+ " num_train_epochs=3,\n",
754
+ " push_to_hub=True,\n",
755
+ " hub_model_id=HF_NAME,\n",
756
+ " hub_strategy=\"every_save\",\n",
757
+ " run_name = \"experi0304\"\n",
758
+ ")\n",
759
+ "\n",
760
+ "\n",
761
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
762
+ "train_dataset = GraphDataset(train_data)\n",
763
+ "\n",
764
+ "# ✅ 训练\n",
765
+ "trainer = GraphTrainer(\n",
766
+ " model=model,\n",
767
+ " args=training_args,\n",
768
+ " train_dataset=train_dataset,\n",
769
+ ")\n",
770
+ "\n",
771
+ "trainer.train()\n",
772
+ "trainer.save_model(\"/workspace/model\")\n",
773
+ "trainer.push_to_hub()\n",
774
+ "\n",
775
+ "\n"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 5,
781
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
782
+ "metadata": {},
783
+ "outputs": [
784
+ {
785
+ "name": "stdout",
786
+ "output_type": "stream",
787
+ "text": [
788
+ "🚀 处理后数据条数: 12384\n",
789
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
790
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
791
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
792
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
793
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
794
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
795
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
796
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
797
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
798
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
799
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
800
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
801
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
802
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
803
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
804
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
805
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
806
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
807
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
808
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
809
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
810
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
811
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
812
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
813
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
814
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
815
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
816
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
817
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
818
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
819
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
820
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
821
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
822
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
823
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
824
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
825
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
826
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
827
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
828
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
829
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
830
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
831
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
832
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
833
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
834
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
835
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
836
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
837
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
838
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
839
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
840
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
841
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
842
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
843
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
844
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
845
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
846
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
847
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
848
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
849
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
850
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
851
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
852
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
853
+ "✅ train_data 已保存到 train_data.pt\n"
854
+ ]
855
+ }
856
+ ],
857
+ "source": [
858
+ "import json\n",
859
+ "import torch\n",
860
+ "from transformers import AutoTokenizer\n",
861
+ "\n",
862
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
863
+ "tokenizer.pad_token = tokenizer.eos_token \n",
864
+ "\n",
865
+ "json_path = \"final_Graph.json\"\n",
866
+ "with open(json_path, \"r\") as f:\n",
867
+ " data = json.load(f)\n",
868
+ "\n",
869
+ "train_data = []\n",
870
+ "\n",
871
+ "\n",
872
+ "for sample in data:\n",
873
+ " conversations = sample.get(\"conversations\", [])\n",
874
+ " embeddings = sample.get(\"embedding\", []) \n",
875
+ "\n",
876
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
877
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
878
+ " continue\n",
879
+ "\n",
880
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
881
+ "\n",
882
+ " #拼接所有对话\n",
883
+ " dialogue_text = \"\"\n",
884
+ " for conv in conversations:\n",
885
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
886
+ " content = conv[\"value\"]\n",
887
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
888
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
889
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
890
+ "\n",
891
+ " tokenized = tokenizer(\n",
892
+ " dialogue_text,\n",
893
+ " padding=\"max_length\",\n",
894
+ " truncation=True,\n",
895
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
896
+ " return_tensors=\"pt\",\n",
897
+ " )\n",
898
+ "\n",
899
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
900
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
901
+ "\n",
902
+ " train_data.append({\n",
903
+ " \"input_ids\": input_ids,\n",
904
+ " \"attention_mask\": attention_mask,\n",
905
+ " \"labels\": input_ids.clone(),\n",
906
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
907
+ " })\n",
908
+ "\n",
909
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
910
+ "print(\"✅ 示例数据:\", train_data[0])\n",
911
+ "torch.save(train_data, \"train_data.pt\")\n",
912
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": 6,
918
+ "id": "a33bffb9-2ff9-4a4d-af2c-b89b30a69f7d",
919
+ "metadata": {
920
+ "scrolled": true
921
+ },
922
+ "outputs": [
923
+ {
924
+ "name": "stdout",
925
+ "output_type": "stream",
926
+ "text": [
927
+ "train_data 重新加载成功,数据量: 12384\n"
928
+ ]
929
+ },
930
+ {
931
+ "name": "stderr",
932
+ "output_type": "stream",
933
+ "text": [
934
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
935
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:49: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
936
+ " warnings.warn(\n",
937
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
938
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
939
+ ]
940
+ },
941
+ {
942
+ "data": {
943
+ "text/html": [
944
+ "Tracking run with wandb version 0.19.7"
945
+ ],
946
+ "text/plain": [
947
+ "<IPython.core.display.HTML object>"
948
+ ]
949
+ },
950
+ "metadata": {},
951
+ "output_type": "display_data"
952
+ },
953
+ {
954
+ "data": {
955
+ "text/html": [
956
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_074031-ofm5zhvd</code>"
957
+ ],
958
+ "text/plain": [
959
+ "<IPython.core.display.HTML object>"
960
+ ]
961
+ },
962
+ "metadata": {},
963
+ "output_type": "display_data"
964
+ },
965
+ {
966
+ "data": {
967
+ "text/html": [
968
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
969
+ ],
970
+ "text/plain": [
971
+ "<IPython.core.display.HTML object>"
972
+ ]
973
+ },
974
+ "metadata": {},
975
+ "output_type": "display_data"
976
+ },
977
+ {
978
+ "data": {
979
+ "text/html": [
980
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
981
+ ],
982
+ "text/plain": [
983
+ "<IPython.core.display.HTML object>"
984
+ ]
985
+ },
986
+ "metadata": {},
987
+ "output_type": "display_data"
988
+ },
989
+ {
990
+ "data": {
991
+ "text/html": [
992
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd</a>"
993
+ ],
994
+ "text/plain": [
995
+ "<IPython.core.display.HTML object>"
996
+ ]
997
+ },
998
+ "metadata": {},
999
+ "output_type": "display_data"
1000
+ },
1001
+ {
1002
+ "data": {
1003
+ "text/html": [
1004
+ "\n",
1005
+ " <div>\n",
1006
+ " \n",
1007
+ " <progress value='89' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1008
+ " [ 89/5310 01:06 < 1:06:24, 1.31 it/s, Epoch 0.05/3]\n",
1009
+ " </div>\n",
1010
+ " <table border=\"1\" class=\"dataframe\">\n",
1011
+ " <thead>\n",
1012
+ " <tr style=\"text-align: left;\">\n",
1013
+ " <th>Step</th>\n",
1014
+ " <th>Training Loss</th>\n",
1015
+ " </tr>\n",
1016
+ " </thead>\n",
1017
+ " <tbody>\n",
1018
+ " <tr>\n",
1019
+ " <td>50</td>\n",
1020
+ " <td>0.000000</td>\n",
1021
+ " </tr>\n",
1022
+ " </tbody>\n",
1023
+ "</table><p>"
1024
+ ],
1025
+ "text/plain": [
1026
+ "<IPython.core.display.HTML object>"
1027
+ ]
1028
+ },
1029
+ "metadata": {},
1030
+ "output_type": "display_data"
1031
+ },
1032
+ {
1033
+ "ename": "KeyboardInterrupt",
1034
+ "evalue": "",
1035
+ "output_type": "error",
1036
+ "traceback": [
1037
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1038
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1039
+ "Cell \u001b[0;32mIn[6], line 150\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# ✅ 训练\u001b[39;00m\n\u001b[1;32m 144\u001b[0m trainer \u001b[38;5;241m=\u001b[39m GraphTrainer(\n\u001b[1;32m 145\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 146\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m 147\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 148\u001b[0m )\n\u001b[0;32m--> 150\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub()\n\u001b[1;32m 152\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/workspace/model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
1040
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2232\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2233\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2234\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2235\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2236\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2238\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
1041
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2548\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2541\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2542\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2544\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m 2545\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2546\u001b[0m )\n\u001b[1;32m 2547\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2548\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2551\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2554\u001b[0m ):\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2556\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
1042
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3740\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED:\n\u001b[1;32m 3738\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscale_wrt_gas\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 3740\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3742\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n",
1043
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:2325\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 2324\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 2325\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m learning_rate \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_lomo_optimizer:\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n",
1044
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1045
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1046
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1047
+ ]
1048
+ }
1049
+ ],
1050
+ "source": [
1051
+ "import json\n",
1052
+ "import torch\n",
1053
+ "import os\n",
1054
+ "from transformers import AutoTokenizer\n",
1055
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
1056
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
1057
+ "if 'train_data' not in globals():\n",
1058
+ " train_data_path = \"train_data.pt\"\n",
1059
+ " \n",
1060
+ " if os.path.exists(train_data_path): #确保文件存在\n",
1061
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
1062
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
1063
+ " else:\n",
1064
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
1065
+ " exit()\n",
1066
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
1067
+ "if \"MODEL_NAME\" not in globals():\n",
1068
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
1069
+ "\n",
1070
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1071
+ "\n",
1072
+ "\n",
1073
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
1074
+ "\n",
1075
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
1076
+ "\n",
1077
+ "\n",
1078
+ "from torch.utils.data import Dataset\n",
1079
+ "\n",
1080
+ "class GraphDataset(Dataset):\n",
1081
+ " def __init__(self, data):\n",
1082
+ " self.data = data\n",
1083
+ "\n",
1084
+ " def __len__(self):\n",
1085
+ " return len(self.data)\n",
1086
+ "\n",
1087
+ " def __getitem__(self, idx):\n",
1088
+ " sample = self.data[idx]\n",
1089
+ " return {\n",
1090
+ " \"input_ids\": sample[\"input_ids\"],\n",
1091
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
1092
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
1093
+ " \"labels\": sample[\"labels\"],\n",
1094
+ " }\n",
1095
+ "\n",
1096
+ "from transformers import AutoModelForCausalLM\n",
1097
+ "import torch\n",
1098
+ "import torch.nn as nn\n",
1099
+ "\n",
1100
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1101
+ " def __init__(self, config):\n",
1102
+ " super().__init__(config)\n",
1103
+ " self.model = AutoModelForCausalLM.from_pretrained(config)\n",
1104
+ " \n",
1105
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
1106
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
1107
+ "\n",
1108
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1109
+ " \"\"\"\n",
1110
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1111
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1112
+ " \"\"\"\n",
1113
+ " # ✅ 获取 token embedding\n",
1114
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1115
+ "\n",
1116
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1117
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
1118
+ "\n",
1119
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1120
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1121
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1122
+ "\n",
1123
+ " # ✅ 调整 attention mask\n",
1124
+ " if attention_mask is not None:\n",
1125
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1126
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1127
+ "\n",
1128
+ " # ✅ 传入模型\n",
1129
+ " outputs = self.model(\n",
1130
+ " inputs_embeds=inputs_embeds,\n",
1131
+ " attention_mask=attention_mask,\n",
1132
+ " labels=labels,\n",
1133
+ " )\n",
1134
+ "\n",
1135
+ " return outputs\n",
1136
+ "\n",
1137
+ "from transformers import Trainer\n",
1138
+ "\n",
1139
+ "class GraphTrainer(Trainer):\n",
1140
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
1141
+ " input_ids = inputs[\"input_ids\"]\n",
1142
+ " attention_mask = inputs[\"attention_mask\"]\n",
1143
+ " labels = inputs[\"labels\"]\n",
1144
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
1145
+ "\n",
1146
+ " if graph_embedding is not None:\n",
1147
+ " outputs = model(\n",
1148
+ " input_ids=input_ids,\n",
1149
+ " attention_mask=attention_mask,\n",
1150
+ " labels=labels,\n",
1151
+ " graph_embedding=graph_embedding, \n",
1152
+ " )\n",
1153
+ " else:\n",
1154
+ " outputs = model(\n",
1155
+ " input_ids=input_ids,\n",
1156
+ " attention_mask=attention_mask,\n",
1157
+ " labels=labels,\n",
1158
+ " )\n",
1159
+ "\n",
1160
+ " loss = outputs.loss\n",
1161
+ " return (loss, outputs) if return_outputs else loss\n",
1162
+ "\n",
1163
+ "\n",
1164
+ "\n",
1165
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
1166
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
1167
+ "# model.config.use_sliding_window_attention = False\n",
1168
+ "\n",
1169
+ "# ✅ 训练参数\n",
1170
+ "training_args = TrainingArguments(\n",
1171
+ " output_dir=\"./results\",\n",
1172
+ " per_device_train_batch_size=7,\n",
1173
+ " eval_strategy=\"no\",\n",
1174
+ " save_strategy=\"steps\",\n",
1175
+ " save_steps=3000,\n",
1176
+ " logging_steps=50,\n",
1177
+ " fp16=True,\n",
1178
+ " optim=\"galore_adamw\",\n",
1179
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
1180
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
1181
+ " warmup_steps=1000,\n",
1182
+ " num_train_epochs=3,\n",
1183
+ " push_to_hub=True,\n",
1184
+ " hub_model_id=HF_NAME,\n",
1185
+ " hub_strategy=\"every_save\",\n",
1186
+ " run_name = \"experi0304\"\n",
1187
+ ")\n",
1188
+ "\n",
1189
+ "\n",
1190
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
1191
+ "train_dataset = GraphDataset(train_data)\n",
1192
+ "\n",
1193
+ "# ✅ 训练\n",
1194
+ "trainer = GraphTrainer(\n",
1195
+ " model=model,\n",
1196
+ " args=training_args,\n",
1197
+ " train_dataset=train_dataset,\n",
1198
+ ")\n",
1199
+ "\n",
1200
+ "trainer.train()\n",
1201
+ "trainer.push_to_hub()\n",
1202
+ "trainer.save_model(\"/workspace/model\")\n",
1203
+ "\n"
1204
+ ]
1205
+ },
1206
+ {
1207
+ "cell_type": "code",
1208
+ "execution_count": 1,
1209
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1210
+ "metadata": {},
1211
+ "outputs": [],
1212
+ "source": [
1213
+ "from transformers import AutoModelForCausalLM\n",
1214
+ "import torch\n",
1215
+ "import torch.nn as nn\n",
1216
+ "\n",
1217
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1218
+ "\n",
1219
+ " \n",
1220
+ " def __init__(self, config):\n",
1221
+ " super().__init__(config)\n",
1222
+ " self.graph_proj = nn.Linear(512, config.hidden_size)\n",
1223
+ "\n",
1224
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1225
+ " \"\"\"\n",
1226
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1227
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1228
+ " \"\"\"\n",
1229
+ " # ✅ 获取 token embedding\n",
1230
+ " inputs_embeds = self.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1231
+ "\n",
1232
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1233
+ " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
1234
+ "\n",
1235
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1236
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1237
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1238
+ "\n",
1239
+ " # ✅ 调整 attention mask\n",
1240
+ " if attention_mask is not None:\n",
1241
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1242
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1243
+ "\n",
1244
+ " # ✅ 传入模型\n",
1245
+ " outputs = self.model(\n",
1246
+ " inputs_embeds=inputs_embeds,\n",
1247
+ " attention_mask=attention_mask,\n",
1248
+ " labels=labels,\n",
1249
+ " )\n",
1250
+ "\n",
1251
+ " return outputs\n",
1252
+ "\n",
1253
+ " @classmethod\n",
1254
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1255
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1256
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
1257
+ " return model\n"
1258
+ ]
1259
+ },
1260
+ {
1261
+ "cell_type": "code",
1262
+ "execution_count": 2,
1263
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1264
+ "metadata": {},
1265
+ "outputs": [],
1266
+ "source": [
1267
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1268
+ ]
1269
+ },
1270
+ {
1271
+ "cell_type": "code",
1272
+ "execution_count": 3,
1273
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1274
+ "metadata": {},
1275
+ "outputs": [
1276
+ {
1277
+ "name": "stderr",
1278
+ "output_type": "stream",
1279
+ "text": [
1280
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
1281
+ ]
1282
+ },
1283
+ {
1284
+ "data": {
1285
+ "text/plain": [
1286
+ "Qwen2ForCausalLM(\n",
1287
+ " (model): Qwen2Model(\n",
1288
+ " (embed_tokens): Embedding(151936, 1536)\n",
1289
+ " (layers): ModuleList(\n",
1290
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1291
+ " (self_attn): Qwen2Attention(\n",
1292
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1293
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1294
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1295
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1296
+ " )\n",
1297
+ " (mlp): Qwen2MLP(\n",
1298
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1299
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1300
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1301
+ " (act_fn): SiLU()\n",
1302
+ " )\n",
1303
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1304
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1305
+ " )\n",
1306
+ " )\n",
1307
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1308
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1309
+ " )\n",
1310
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1311
+ " (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
1312
+ ")"
1313
+ ]
1314
+ },
1315
+ "execution_count": 3,
1316
+ "metadata": {},
1317
+ "output_type": "execute_result"
1318
+ }
1319
+ ],
1320
+ "source": [
1321
+ "import torch\n",
1322
+ "from transformers import AutoTokenizer\n",
1323
+ "\n",
1324
+ "# 加载 tokenizer\n",
1325
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1326
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1327
+ "\n",
1328
+ "# 加载训练好的模型\n",
1329
+ "model_path = \"/workspace/model\"\n",
1330
+ "model = GraphAwareLM.from_pretrained(model_path).to(device)\n",
1331
+ "model.eval() # 设置为推理模式\n"
1332
+ ]
1333
+ },
1334
+ {
1335
+ "cell_type": "code",
1336
+ "execution_count": 8,
1337
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1338
+ "metadata": {},
1339
+ "outputs": [],
1340
+ "source": [
1341
+ "import json\n",
1342
+ "json_path = \"final_Graph.json\"\n",
1343
+ "with open(json_path, \"r\") as f:\n",
1344
+ " data = json.load(f)\n",
1345
+ "\n",
1346
+ "test_data = data[0]\n",
1347
+ "\n",
1348
+ "conversations = test_data.get(\"conversations\")\n",
1349
+ "embeddings = test_data.get(\"embedding\") \n",
1350
+ "\n",
1351
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).to(device)\n",
1352
+ "\n",
1353
+ "question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
1354
+ "\n",
1355
+ "from transformers import AutoTokenizer\n",
1356
+ "\n",
1357
+ "# ✅ 输入文本\n",
1358
+ "ROLE_TOKENS = {\n",
1359
+ " \"human\": \"<|User|>\", \n",
1360
+ " \"gpt\": \"<|Assistant|>\", \n",
1361
+ "}\n",
1362
+ "GRAPH_LENGTH = 512\n",
1363
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1364
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1365
+ "\n",
1366
+ "input_ids = inputs[\"input_ids\"]\n",
1367
+ "attention_mask = inputs[\"attention_mask\"]\n"
1368
+ ]
1369
+ },
1370
+ {
1371
+ "cell_type": "code",
1372
+ "execution_count": 5,
1373
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1374
+ "metadata": {},
1375
+ "outputs": [
1376
+ {
1377
+ "data": {
1378
+ "text/plain": [
1379
+ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1380
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1381
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1382
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1383
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1384
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1385
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1386
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1387
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1388
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1389
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1390
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1391
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1392
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1393
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1394
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1395
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1396
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1397
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1398
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1399
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1400
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1401
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1402
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1403
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1404
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1405
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1406
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1407
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1408
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1409
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1410
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1411
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1412
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1413
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1414
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1415
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1416
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1417
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1418
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1419
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1420
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1421
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1422
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1423
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1424
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1425
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1426
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1427
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1428
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1429
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1430
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1431
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1432
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1433
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1434
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1435
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1436
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1437
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1438
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1439
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1440
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1441
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1442
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
1443
+ " device='cuda:0')"
1444
+ ]
1445
+ },
1446
+ "execution_count": 5,
1447
+ "metadata": {},
1448
+ "output_type": "execute_result"
1449
+ }
1450
+ ],
1451
+ "source": [
1452
+ "graph_embedding"
1453
+ ]
1454
+ },
1455
+ {
1456
+ "cell_type": "code",
1457
+ "execution_count": 15,
1458
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1459
+ "metadata": {},
1460
+ "outputs": [
1461
+ {
1462
+ "name": "stdout",
1463
+ "output_type": "stream",
1464
+ "text": [
1465
+ "模型回复: How\n"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "source": [
1470
+ "# ✅ 进行前向传播\n",
1471
+ "with torch.no_grad():\n",
1472
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1473
+ "\n",
1474
+ "# ✅ 提取 logits 并进行贪心解码\n",
1475
+ "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1476
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1477
+ "\n",
1478
+ "# ✅ 反向编码为文本\n",
1479
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1480
+ "\n",
1481
+ "print(\"模型回复:\", response_text)"
1482
+ ]
1483
+ },
1484
+ {
1485
+ "cell_type": "code",
1486
+ "execution_count": 9,
1487
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1488
+ "metadata": {},
1489
+ "outputs": [
1490
+ {
1491
+ "name": "stdout",
1492
+ "output_type": "stream",
1493
+ "text": [
1494
+ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? What are the module's inputs and outputs?\n",
1495
+ "What are the module's inputs and outputs?\n",
1496
+ "What are the module's inputs and outputs?\n",
1497
+ "What are the module's inputs and outputs?\n",
1498
+ "What is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and module's output, and module's input, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module\n"
1499
+ ]
1500
+ }
1501
+ ],
1502
+ "source": [
1503
+ "max_new_tokens = 1024\n",
1504
+ "generated_ids = input_ids.clone()\n",
1505
+ "generated_attention_mask = attention_mask.clone()\n",
1506
+ "for _ in range(max_new_tokens):\n",
1507
+ " # ✅ 计算 logits 并进行生成\n",
1508
+ " with torch.no_grad():\n",
1509
+ " outputs = model(\n",
1510
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1511
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1512
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1513
+ " )\n",
1514
+ "\n",
1515
+ "\n",
1516
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1517
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1518
+ " # print(next_token)\n",
1519
+ "\n",
1520
+ "\n",
1521
+ " # ✅ **拼接到已生成序列**\n",
1522
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1523
+ "\n",
1524
+ " # print(generated_ids)\n",
1525
+ "\n",
1526
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1527
+ " break\n",
1528
+ "\n",
1529
+ " generated_attention_mask = torch.cat(\n",
1530
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1531
+ " ) \n",
1532
+ "\n",
1533
+ "# ✅ 解码最终输出\n",
1534
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1535
+ "print(\"Generated Response:\", generated_text)"
1536
+ ]
1537
+ },
1538
+ {
1539
+ "cell_type": "code",
1540
+ "execution_count": 10,
1541
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1542
+ "metadata": {},
1543
+ "outputs": [
1544
+ {
1545
+ "data": {
1546
+ "text/plain": [
1547
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1548
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1549
+ " 525, 862, 9895, 30]], device='cuda:0')"
1550
+ ]
1551
+ },
1552
+ "execution_count": 10,
1553
+ "metadata": {},
1554
+ "output_type": "execute_result"
1555
+ }
1556
+ ],
1557
+ "source": [
1558
+ "generated_ids"
1559
+ ]
1560
+ },
1561
+ {
1562
+ "cell_type": "code",
1563
+ "execution_count": null,
1564
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1565
+ "metadata": {},
1566
+ "outputs": [],
1567
+ "source": []
1568
+ }
1569
+ ],
1570
+ "metadata": {
1571
+ "kernelspec": {
1572
+ "display_name": "Python 3 (ipykernel)",
1573
+ "language": "python",
1574
+ "name": "python3"
1575
+ },
1576
+ "language_info": {
1577
+ "codemirror_mode": {
1578
+ "name": "ipython",
1579
+ "version": 3
1580
+ },
1581
+ "file_extension": ".py",
1582
+ "mimetype": "text/x-python",
1583
+ "name": "python",
1584
+ "nbconvert_exporter": "python",
1585
+ "pygments_lexer": "ipython3",
1586
+ "version": "3.10.12"
1587
+ }
1588
+ },
1589
+ "nbformat": 4,
1590
+ "nbformat_minor": 5
1591
+ }
graph_train2.ipynb ADDED
@@ -0,0 +1,1506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new2\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 3,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_111730-i9v1vlu1</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">experi030402</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:34:08, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.319300</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>3.641300</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>1.521800</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>1.027500</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>0.922400</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>0.866900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>0.800500</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>0.721600</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>0.740400</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>0.737000</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.713500</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.747000</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>0.869500</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.473300</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>0.753000</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>0.741300</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>0.751400</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>0.787600</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>0.783200</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>0.780200</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>1.012900</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>1.411700</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>1.536400</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>0.853800</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>0.756500</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>0.750800</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>0.747400</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>0.844400</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>0.858400</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>1.053400</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>1.591600</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>1.498900</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>1.471700</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>1.221100</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>1.802300</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>1.826000</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.857300</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.561800</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>1.398800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>1.398900</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>1.381600</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>0.890300</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>0.763700</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>0.753100</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>0.745500</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>1.186100</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>0.862000</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>1.024600</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>1.028400</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>1.008500</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>0.942800</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>0.849700</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>0.771400</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>0.794100</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>0.819200</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>0.937500</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>1.064500</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>1.189300</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.071100</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.003300</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>1.073900</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.043100</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>1.282600</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.145400</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>1.925800</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>2.005600</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.122600</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.163000</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.046600</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.152200</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.151700</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>5.394900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>4.677800</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>4.122200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>3.710200</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>3.350800</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>3.126300</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.988700</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>2.872000</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.848200</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.823900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.781200</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.735000</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.725900</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>2.644400</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.700000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>2.650100</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.704500</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>2.596700</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>2.510500</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>2.515800</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.498100</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.458900</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.449700</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.425000</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.362300</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>2.232000</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.361500</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>2.302300</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.333900</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.367200</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>2.288300</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.426100</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.344100</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>2.283500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>2.296500</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2/commit/291285a5f2155c79a0da893645d8df9bbca98f63', commit_message='End of training', commit_description='', oid='291285a5f2155c79a0da893645d8df9bbca98f63', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new2'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, pretrained_model_name_or_path):\n",
662
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
663
+ " \n",
664
+ " # ✅ 载入 `MODEL_NAME` 预训练模型\n",
665
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
666
+ "\n",
667
+ " \n",
668
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
669
+ " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
670
+ "\n",
671
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
672
+ " \"\"\"\n",
673
+ " `graph_embedding` 形状: (batch_size, 512)\n",
674
+ " `input_ids` 形状: (batch_size, seq_len)\n",
675
+ " \"\"\"\n",
676
+ " # ✅ 获取 token embedding\n",
677
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
678
+ "\n",
679
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
680
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
681
+ "\n",
682
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
683
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
684
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
685
+ "\n",
686
+ " # ✅ 调整 attention mask\n",
687
+ " if attention_mask is not None:\n",
688
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
689
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
690
+ "\n",
691
+ " # ✅ 传入模型\n",
692
+ " outputs = self.model(\n",
693
+ " inputs_embeds=inputs_embeds,\n",
694
+ " attention_mask=attention_mask,\n",
695
+ " labels=labels,\n",
696
+ " )\n",
697
+ "\n",
698
+ " return outputs\n",
699
+ "\n",
700
+ "from transformers import Trainer\n",
701
+ "\n",
702
+ "class GraphTrainer(Trainer):\n",
703
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
704
+ " input_ids = inputs[\"input_ids\"]\n",
705
+ " attention_mask = inputs[\"attention_mask\"]\n",
706
+ " labels = inputs[\"labels\"]\n",
707
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
708
+ "\n",
709
+ " if graph_embedding is not None:\n",
710
+ " outputs = model(\n",
711
+ " input_ids=input_ids,\n",
712
+ " attention_mask=attention_mask,\n",
713
+ " labels=labels,\n",
714
+ " graph_embedding=graph_embedding, \n",
715
+ " )\n",
716
+ " else:\n",
717
+ " outputs = model(\n",
718
+ " input_ids=input_ids,\n",
719
+ " attention_mask=attention_mask,\n",
720
+ " labels=labels,\n",
721
+ " )\n",
722
+ "\n",
723
+ " loss = outputs.loss\n",
724
+ " return (loss, outputs) if return_outputs else loss\n",
725
+ "\n",
726
+ "\n",
727
+ "from transformers import AutoConfig\n",
728
+ "\n",
729
+ "# ✅ 载入微调模型\n",
730
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
731
+ "\n",
732
+ "# # 1. 加载模型的配置\n",
733
+ "# config = AutoConfig.from_pretrained(MODEL_NAME)\n",
734
+ "\n",
735
+ "# # 2. 使用配置创建 GraphAwareLM 实例\n",
736
+ "# model = GraphAwareLM.from_config(config) \n",
737
+ "\n",
738
+ "# pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
739
+ "# model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
740
+ "\n",
741
+ "# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
742
+ "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
743
+ "# model.config.use_sliding_window_attention = False\n",
744
+ "\n",
745
+ "# ✅ 训练参数\n",
746
+ "training_args = TrainingArguments(\n",
747
+ " output_dir=\"./results2\",\n",
748
+ " per_device_train_batch_size=7,\n",
749
+ " eval_strategy=\"no\",\n",
750
+ " save_strategy=\"steps\",\n",
751
+ " save_steps=3000,\n",
752
+ " logging_steps=50,\n",
753
+ " bf16=True,\n",
754
+ " optim=\"galore_adamw\",\n",
755
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
756
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
757
+ " warmup_steps=1000,\n",
758
+ " num_train_epochs=3,\n",
759
+ " push_to_hub=True,\n",
760
+ " hub_model_id=HF_NAME,\n",
761
+ " hub_strategy=\"every_save\",\n",
762
+ " run_name = \"experi030402\"\n",
763
+ ")\n",
764
+ "\n",
765
+ "\n",
766
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
767
+ "train_dataset = GraphDataset(train_data)\n",
768
+ "\n",
769
+ "# ✅ 训练\n",
770
+ "trainer = GraphTrainer(\n",
771
+ " model=model,\n",
772
+ " args=training_args,\n",
773
+ " train_dataset=train_dataset,\n",
774
+ ")\n",
775
+ "\n",
776
+ "trainer.train()\n",
777
+ "trainer.save_model(\"/workspace/model2\")\n",
778
+ "trainer.push_to_hub()\n",
779
+ "\n",
780
+ "\n"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "execution_count": 7,
786
+ "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
787
+ "metadata": {},
788
+ "outputs": [
789
+ {
790
+ "data": {
791
+ "text/plain": [
792
+ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb')"
793
+ ]
794
+ },
795
+ "execution_count": 7,
796
+ "metadata": {},
797
+ "output_type": "execute_result"
798
+ }
799
+ ],
800
+ "source": [
801
+ "from huggingface_hub import HfApi\n",
802
+ "\n",
803
+ "api = HfApi()\n",
804
+ "repo_name = \"r1q1.5_graph_lora-wandb\" # 你的模型名称\n",
805
+ "api.create_repo(repo_name, exist_ok=True)"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 8,
811
+ "id": "73c434b9-5d58-4819-8526-24aa18ca1010",
812
+ "metadata": {},
813
+ "outputs": [
814
+ {
815
+ "data": {
816
+ "application/vnd.jupyter.widget-view+json": {
817
+ "model_id": "2bf786e437064435b543c4b364404933",
818
+ "version_major": 2,
819
+ "version_minor": 0
820
+ },
821
+ "text/plain": [
822
+ "run-v0v96nik.wandb: 0%| | 0.00/582k [00:00<?, ?B/s]"
823
+ ]
824
+ },
825
+ "metadata": {},
826
+ "output_type": "display_data"
827
+ },
828
+ {
829
+ "data": {
830
+ "application/vnd.jupyter.widget-view+json": {
831
+ "model_id": "d8d8867a8978418cbba012ae48c6a461",
832
+ "version_major": 2,
833
+ "version_minor": 0
834
+ },
835
+ "text/plain": [
836
+ "run-i9v1vlu1.wandb: 0%| | 0.00/617k [00:00<?, ?B/s]"
837
+ ]
838
+ },
839
+ "metadata": {},
840
+ "output_type": "display_data"
841
+ },
842
+ {
843
+ "data": {
844
+ "application/vnd.jupyter.widget-view+json": {
845
+ "model_id": "aa41d13f7f204554a401f018f535d83a",
846
+ "version_major": 2,
847
+ "version_minor": 0
848
+ },
849
+ "text/plain": [
850
+ "Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]"
851
+ ]
852
+ },
853
+ "metadata": {},
854
+ "output_type": "display_data"
855
+ },
856
+ {
857
+ "data": {
858
+ "application/vnd.jupyter.widget-view+json": {
859
+ "model_id": "4f6a00ed3d4e43c9806cb5050b812bf8",
860
+ "version_major": 2,
861
+ "version_minor": 0
862
+ },
863
+ "text/plain": [
864
+ "run-e0v0giuw.wandb: 0%| | 0.00/616k [00:00<?, ?B/s]"
865
+ ]
866
+ },
867
+ "metadata": {},
868
+ "output_type": "display_data"
869
+ },
870
+ {
871
+ "data": {
872
+ "text/plain": [
873
+ "CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb/commit/81d72bb1534aa8769166ca2e2dd6f4a657ab3742', commit_message='upload wandb', commit_description='', oid='81d72bb1534aa8769166ca2e2dd6f4a657ab3742', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb'), pr_revision=None, pr_num=None)"
874
+ ]
875
+ },
876
+ "execution_count": 8,
877
+ "metadata": {},
878
+ "output_type": "execute_result"
879
+ }
880
+ ],
881
+ "source": [
882
+ "from huggingface_hub import upload_folder\n",
883
+ "\n",
884
+ "upload_folder(\n",
885
+ " folder_path = \"/workspace/wandb\",\n",
886
+ " repo_id = \"YiFzhao/r1q1.5_graph_lora-wandb\",\n",
887
+ " commit_message = \"upload wandb\",\n",
888
+ ")"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": 5,
894
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
895
+ "metadata": {},
896
+ "outputs": [
897
+ {
898
+ "name": "stdout",
899
+ "output_type": "stream",
900
+ "text": [
901
+ "🚀 处理后数据条数: 12384\n",
902
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
903
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
904
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
905
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
906
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
907
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
908
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
909
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
910
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
911
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
912
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
913
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
914
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
915
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
916
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
917
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
918
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
919
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
920
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
921
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
922
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
923
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
924
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
925
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
926
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
927
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
928
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
929
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
930
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
931
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
932
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
933
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
934
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
935
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
936
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
937
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
938
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
939
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
940
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
941
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
942
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
943
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
944
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
945
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
946
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
947
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
948
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
949
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
950
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
951
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
952
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
953
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
954
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
955
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
956
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
957
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
958
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
959
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
960
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
961
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
962
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
963
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
964
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
965
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
966
+ "✅ train_data 已保存到 train_data.pt\n"
967
+ ]
968
+ }
969
+ ],
970
+ "source": [
971
+ "import json\n",
972
+ "import torch\n",
973
+ "from transformers import AutoTokenizer\n",
974
+ "\n",
975
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
976
+ "tokenizer.pad_token = tokenizer.eos_token \n",
977
+ "\n",
978
+ "json_path = \"final_Graph.json\"\n",
979
+ "with open(json_path, \"r\") as f:\n",
980
+ " data = json.load(f)\n",
981
+ "\n",
982
+ "train_data = []\n",
983
+ "\n",
984
+ "\n",
985
+ "for sample in data:\n",
986
+ " conversations = sample.get(\"conversations\", [])\n",
987
+ " embeddings = sample.get(\"embedding\", []) \n",
988
+ "\n",
989
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
990
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
991
+ " continue\n",
992
+ "\n",
993
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
994
+ "\n",
995
+ " #拼接所有对话\n",
996
+ " dialogue_text = \"\"\n",
997
+ " for conv in conversations:\n",
998
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
999
+ " content = conv[\"value\"]\n",
1000
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
1001
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
1002
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
1003
+ "\n",
1004
+ " tokenized = tokenizer(\n",
1005
+ " dialogue_text,\n",
1006
+ " padding=\"max_length\",\n",
1007
+ " truncation=True,\n",
1008
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
1009
+ " return_tensors=\"pt\",\n",
1010
+ " )\n",
1011
+ "\n",
1012
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
1013
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
1014
+ "\n",
1015
+ " train_data.append({\n",
1016
+ " \"input_ids\": input_ids,\n",
1017
+ " \"attention_mask\": attention_mask,\n",
1018
+ " \"labels\": input_ids.clone(),\n",
1019
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
1020
+ " })\n",
1021
+ "\n",
1022
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
1023
+ "print(\"✅ 示例数据:\", train_data[0])\n",
1024
+ "torch.save(train_data, \"train_data.pt\")\n",
1025
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
1026
+ ]
1027
+ },
1028
+ {
1029
+ "cell_type": "code",
1030
+ "execution_count": 10,
1031
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1032
+ "metadata": {},
1033
+ "outputs": [],
1034
+ "source": [
1035
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
1036
+ "import torch\n",
1037
+ "import torch.nn as nn\n",
1038
+ "\n",
1039
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1040
+ " def __init__(self, pretrained_model_name_or_path):\n",
1041
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
1042
+ " \n",
1043
+ " # ✅ 载入 `MODEL_NAME` 预训练模型\n",
1044
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
1045
+ "\n",
1046
+ " \n",
1047
+ " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
1048
+ " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
1049
+ "\n",
1050
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1051
+ " \"\"\"\n",
1052
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1053
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1054
+ " \"\"\"\n",
1055
+ " # ✅ 获取 token embedding\n",
1056
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1057
+ "\n",
1058
+ " # ✅ 变换 graph embedding 到 hidden_size\n",
1059
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
1060
+ "\n",
1061
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1062
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1063
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1064
+ "\n",
1065
+ " # ✅ 调整 attention mask\n",
1066
+ " if attention_mask is not None:\n",
1067
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1068
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1069
+ "\n",
1070
+ " # ✅ 传入模型\n",
1071
+ " outputs = self.model(\n",
1072
+ " inputs_embeds=inputs_embeds,\n",
1073
+ " attention_mask=attention_mask,\n",
1074
+ " labels=labels,\n",
1075
+ " )\n",
1076
+ "\n",
1077
+ " return outputs\n",
1078
+ "\n",
1079
+ " def generate_with_graph(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
1080
+ " \"\"\"\n",
1081
+ " ✅ 自定义 `generate()`,支持 `graph_embedding`\n",
1082
+ " `input_text`: 需要生成文本的输入\n",
1083
+ " `graph_embedding`: 形状为 (1, 512) 的张量\n",
1084
+ " \"\"\"\n",
1085
+ " # ✅ 2. 处理 `graph_embedding`\n",
1086
+ " graph_embedding_token = self.graph_proj(graph_embedding) # (1, hidden_size)\n",
1087
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (1, 1, hidden_size)\n",
1088
+ "\n",
1089
+ " # ��� 3. 获取 Token Embeddings 并拼接\n",
1090
+ " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
1091
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
1092
+ "\n",
1093
+ " # ✅ 4. 调整 `attention_mask`\n",
1094
+ " if \"attention_mask\" in inputs:\n",
1095
+ " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
1096
+ " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
1097
+ " else:\n",
1098
+ " attention_mask = None\n",
1099
+ "\n",
1100
+ " # ✅ 5. 进行文本生成\n",
1101
+ " with torch.no_grad():\n",
1102
+ " output_ids = self.model.generate(\n",
1103
+ " inputs_embeds=inputs_embeds,\n",
1104
+ " attention_mask=attention_mask,\n",
1105
+ " max_length=max_length,\n",
1106
+ " temperature=temperature,\n",
1107
+ " top_k=top_k,\n",
1108
+ " top_p=top_p,\n",
1109
+ " num_return_sequences=1\n",
1110
+ " )\n",
1111
+ "\n",
1112
+ " # ✅ 6. 解码生成的文本\n",
1113
+ " generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
1114
+ " return generated_text\n",
1115
+ "\n",
1116
+ " @classmethod\n",
1117
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1118
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1119
+ " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
1120
+ " return model"
1121
+ ]
1122
+ },
1123
+ {
1124
+ "cell_type": "code",
1125
+ "execution_count": 11,
1126
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1127
+ "metadata": {},
1128
+ "outputs": [],
1129
+ "source": [
1130
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1131
+ ]
1132
+ },
1133
+ {
1134
+ "cell_type": "code",
1135
+ "execution_count": 12,
1136
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1137
+ "metadata": {},
1138
+ "outputs": [
1139
+ {
1140
+ "data": {
1141
+ "application/vnd.jupyter.widget-view+json": {
1142
+ "model_id": "7ad289c5523340f39799ad11e3bc1bb5",
1143
+ "version_major": 2,
1144
+ "version_minor": 0
1145
+ },
1146
+ "text/plain": [
1147
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
1148
+ ]
1149
+ },
1150
+ "metadata": {},
1151
+ "output_type": "display_data"
1152
+ },
1153
+ {
1154
+ "data": {
1155
+ "text/plain": [
1156
+ "Qwen2ForCausalLM(\n",
1157
+ " (model): Qwen2Model(\n",
1158
+ " (embed_tokens): Embedding(151936, 1536)\n",
1159
+ " (layers): ModuleList(\n",
1160
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1161
+ " (self_attn): Qwen2Attention(\n",
1162
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1163
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1164
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1165
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1166
+ " )\n",
1167
+ " (mlp): Qwen2MLP(\n",
1168
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1169
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1170
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1171
+ " (act_fn): SiLU()\n",
1172
+ " )\n",
1173
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1174
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1175
+ " )\n",
1176
+ " )\n",
1177
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1178
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1179
+ " )\n",
1180
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1181
+ " (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
1182
+ ")"
1183
+ ]
1184
+ },
1185
+ "execution_count": 12,
1186
+ "metadata": {},
1187
+ "output_type": "execute_result"
1188
+ }
1189
+ ],
1190
+ "source": [
1191
+ "import torch\n",
1192
+ "from transformers import AutoTokenizer\n",
1193
+ "\n",
1194
+ "# 加载 tokenizer\n",
1195
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1196
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1197
+ "\n",
1198
+ "# 加载训练好的模型\n",
1199
+ "model_path = \"/workspace/model2\"\n",
1200
+ "model = GraphAwareLM.from_pretrained(\"/workspace/results2/checkpoint-5310\").to(device)\n",
1201
+ "model.eval() # 设置为推理模式\n"
1202
+ ]
1203
+ },
1204
+ {
1205
+ "cell_type": "code",
1206
+ "execution_count": 13,
1207
+ "id": "51995891-8906-4049-9401-2d22e06a84e8",
1208
+ "metadata": {},
1209
+ "outputs": [
1210
+ {
1211
+ "name": "stdout",
1212
+ "output_type": "stream",
1213
+ "text": [
1214
+ "Parameter containing:\n",
1215
+ "tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
1216
+ " [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
1217
+ " [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
1218
+ " ...,\n",
1219
+ " [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
1220
+ " [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
1221
+ " [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
1222
+ " device='cuda:0', requires_grad=True)\n"
1223
+ ]
1224
+ }
1225
+ ],
1226
+ "source": [
1227
+ "print(model.graph_proj.weight)\n"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "cell_type": "code",
1232
+ "execution_count": 14,
1233
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1234
+ "metadata": {},
1235
+ "outputs": [],
1236
+ "source": [
1237
+ "import json\n",
1238
+ "json_path = \"final_Graph.json\"\n",
1239
+ "with open(json_path, \"r\") as f:\n",
1240
+ " data = json.load(f)\n",
1241
+ "\n",
1242
+ "test_data = data[0]\n",
1243
+ "\n",
1244
+ "conversations = test_data.get(\"conversations\")\n",
1245
+ "embeddings = test_data.get(\"embedding\") \n",
1246
+ "\n",
1247
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
1248
+ "\n",
1249
+ "question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
1250
+ "\n",
1251
+ "from transformers import AutoTokenizer\n",
1252
+ "\n",
1253
+ "# ✅ 输入文本\n",
1254
+ "ROLE_TOKENS = {\n",
1255
+ " \"human\": \"<|User|>\", \n",
1256
+ " \"gpt\": \"<|Assistant|>\", \n",
1257
+ "}\n",
1258
+ "GRAPH_LENGTH = 512\n",
1259
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1260
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1261
+ "\n",
1262
+ "input_ids = inputs[\"input_ids\"]\n",
1263
+ "attention_mask = inputs[\"attention_mask\"]\n"
1264
+ ]
1265
+ },
1266
+ {
1267
+ "cell_type": "code",
1268
+ "execution_count": 15,
1269
+ "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
1270
+ "metadata": {},
1271
+ "outputs": [
1272
+ {
1273
+ "ename": "AttributeError",
1274
+ "evalue": "'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'",
1275
+ "output_type": "error",
1276
+ "traceback": [
1277
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1278
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1279
+ "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_with_graph\u001b[49m(inputs, graph_embedding)\n",
1280
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
1281
+ "\u001b[0;31mAttributeError\u001b[0m: 'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'"
1282
+ ]
1283
+ }
1284
+ ],
1285
+ "source": [
1286
+ "generated_text = model.generate_with_graph(inputs, graph_embedding)"
1287
+ ]
1288
+ },
1289
+ {
1290
+ "cell_type": "code",
1291
+ "execution_count": 5,
1292
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1293
+ "metadata": {},
1294
+ "outputs": [
1295
+ {
1296
+ "data": {
1297
+ "text/plain": [
1298
+ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1299
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1300
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1301
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1302
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1303
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1304
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1305
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1306
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1307
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1308
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1309
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1310
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1311
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1312
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1313
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1314
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1315
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1316
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1317
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1318
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1319
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1320
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1321
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1322
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1323
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1324
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1325
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1326
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1327
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1328
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1329
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1330
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1331
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1332
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1333
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1334
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1335
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1336
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1337
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1338
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1339
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1340
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1341
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1342
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1343
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1344
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1345
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1346
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1347
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1348
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1349
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1350
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1351
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1352
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1353
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1354
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1355
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1356
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1357
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1358
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1359
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1360
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1361
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
1362
+ " device='cuda:0')"
1363
+ ]
1364
+ },
1365
+ "execution_count": 5,
1366
+ "metadata": {},
1367
+ "output_type": "execute_result"
1368
+ }
1369
+ ],
1370
+ "source": [
1371
+ "graph_embedding"
1372
+ ]
1373
+ },
1374
+ {
1375
+ "cell_type": "code",
1376
+ "execution_count": 15,
1377
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1378
+ "metadata": {},
1379
+ "outputs": [
1380
+ {
1381
+ "name": "stdout",
1382
+ "output_type": "stream",
1383
+ "text": [
1384
+ "模型回复: How\n"
1385
+ ]
1386
+ }
1387
+ ],
1388
+ "source": [
1389
+ "# ✅ 进行前向传播\n",
1390
+ "with torch.no_grad():\n",
1391
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1392
+ "\n",
1393
+ "# ✅ 提取 logits 并进行贪心解码\n",
1394
+ "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1395
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1396
+ "\n",
1397
+ "# ✅ 反向编码为文本\n",
1398
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1399
+ "\n",
1400
+ "print(\"模型回复:\", response_text)"
1401
+ ]
1402
+ },
1403
+ {
1404
+ "cell_type": "code",
1405
+ "execution_count": 17,
1406
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1407
+ "metadata": {},
1408
+ "outputs": [
1409
+ {
1410
+ "name": "stdout",
1411
+ "output_type": "stream",
1412
+ "text": [
1413
+ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? `data` is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit data, and the output is the output of the `data` is a 1-bit\n"
1414
+ ]
1415
+ }
1416
+ ],
1417
+ "source": [
1418
+ "max_new_tokens = 1024\n",
1419
+ "generated_ids = input_ids.clone()\n",
1420
+ "generated_attention_mask = attention_mask.clone()\n",
1421
+ "for _ in range(max_new_tokens):\n",
1422
+ " # ✅ 计算 logits 并进行生成\n",
1423
+ " with torch.no_grad():\n",
1424
+ " outputs = model(\n",
1425
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1426
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1427
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1428
+ " )\n",
1429
+ "\n",
1430
+ "\n",
1431
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1432
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1433
+ " # print(next_token)\n",
1434
+ "\n",
1435
+ "\n",
1436
+ " # ✅ **拼接到已生成序列**\n",
1437
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1438
+ "\n",
1439
+ " # print(generated_ids)\n",
1440
+ "\n",
1441
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1442
+ " break\n",
1443
+ "\n",
1444
+ " generated_attention_mask = torch.cat(\n",
1445
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1446
+ " ) \n",
1447
+ "\n",
1448
+ "# ✅ 解码最终输出\n",
1449
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1450
+ "print(\"Generated Response:\", generated_text)"
1451
+ ]
1452
+ },
1453
+ {
1454
+ "cell_type": "code",
1455
+ "execution_count": 10,
1456
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1457
+ "metadata": {},
1458
+ "outputs": [
1459
+ {
1460
+ "data": {
1461
+ "text/plain": [
1462
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1463
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1464
+ " 525, 862, 9895, 30]], device='cuda:0')"
1465
+ ]
1466
+ },
1467
+ "execution_count": 10,
1468
+ "metadata": {},
1469
+ "output_type": "execute_result"
1470
+ }
1471
+ ],
1472
+ "source": [
1473
+ "generated_ids"
1474
+ ]
1475
+ },
1476
+ {
1477
+ "cell_type": "code",
1478
+ "execution_count": null,
1479
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1480
+ "metadata": {},
1481
+ "outputs": [],
1482
+ "source": []
1483
+ }
1484
+ ],
1485
+ "metadata": {
1486
+ "kernelspec": {
1487
+ "display_name": "Python 3 (ipykernel)",
1488
+ "language": "python",
1489
+ "name": "python3"
1490
+ },
1491
+ "language_info": {
1492
+ "codemirror_mode": {
1493
+ "name": "ipython",
1494
+ "version": 3
1495
+ },
1496
+ "file_extension": ".py",
1497
+ "mimetype": "text/x-python",
1498
+ "name": "python",
1499
+ "nbconvert_exporter": "python",
1500
+ "pygments_lexer": "ipython3",
1501
+ "version": "3.10.12"
1502
+ }
1503
+ },
1504
+ "nbformat": 4,
1505
+ "nbformat_minor": 5
1506
+ }
graph_train3.ipynb ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "ROLE_TOKENS = {\n",
11
+ " \"human\": \"<|User|>\", \n",
12
+ " \"gpt\": \"<|Assistant|>\", \n",
13
+ "}\n",
14
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
15
+ "GRAPH_LENGTH = 512\n",
16
+ "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new3\""
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "bba6e6db-4b79-4461-ba13-75fd41019358",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "CUDA 可用: True\n",
30
+ "GPU 数量: 1\n",
31
+ "当前 GPU: 0\n",
32
+ "GPU 名称: NVIDIA A100 80GB PCIe\n"
33
+ ]
34
+ }
35
+ ],
36
+ "source": [
37
+ "# !pip install transformers accelerate datasets\n",
38
+ "# !pip install galora\n",
39
+ "# !pip install huggingface_hub\n",
40
+ "import torch\n",
41
+ "print(\"CUDA 可用:\", torch.cuda.is_available())\n",
42
+ "print(\"GPU 数量:\", torch.cuda.device_count())\n",
43
+ "print(\"当前 GPU:\", torch.cuda.current_device())\n",
44
+ "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "train_data 重新加载成功,数据量: 12384\n"
68
+ ]
69
+ },
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
75
+ "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
76
+ " warnings.warn(\n",
77
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
78
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "text/html": [
84
+ "Tracking run with wandb version 0.19.7"
85
+ ],
86
+ "text/plain": [
87
+ "<IPython.core.display.HTML object>"
88
+ ]
89
+ },
90
+ "metadata": {},
91
+ "output_type": "display_data"
92
+ },
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "Run data is saved locally in <code>/workspace/wandb/run-20250304_134403-e0v0giuw</code>"
97
+ ],
98
+ "text/plain": [
99
+ "<IPython.core.display.HTML object>"
100
+ ]
101
+ },
102
+ "metadata": {},
103
+ "output_type": "display_data"
104
+ },
105
+ {
106
+ "data": {
107
+ "text/html": [
108
+ "Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">experi030403</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
109
+ ],
110
+ "text/plain": [
111
+ "<IPython.core.display.HTML object>"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
+ {
118
+ "data": {
119
+ "text/html": [
120
+ " View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
121
+ ],
122
+ "text/plain": [
123
+ "<IPython.core.display.HTML object>"
124
+ ]
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/html": [
132
+ " View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw</a>"
133
+ ],
134
+ "text/plain": [
135
+ "<IPython.core.display.HTML object>"
136
+ ]
137
+ },
138
+ "metadata": {},
139
+ "output_type": "display_data"
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [5310/5310 1:33:59, Epoch 3/3]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Step</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <td>50</td>\n",
160
+ " <td>5.319300</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <td>100</td>\n",
164
+ " <td>3.641300</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <td>150</td>\n",
168
+ " <td>1.521800</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <td>200</td>\n",
172
+ " <td>1.027500</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <td>250</td>\n",
176
+ " <td>0.922400</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <td>300</td>\n",
180
+ " <td>0.866900</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <td>350</td>\n",
184
+ " <td>0.800500</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <td>400</td>\n",
188
+ " <td>0.721600</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <td>450</td>\n",
192
+ " <td>0.740400</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <td>500</td>\n",
196
+ " <td>0.737000</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <td>550</td>\n",
200
+ " <td>0.713500</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <td>600</td>\n",
204
+ " <td>0.747000</td>\n",
205
+ " </tr>\n",
206
+ " <tr>\n",
207
+ " <td>650</td>\n",
208
+ " <td>0.869500</td>\n",
209
+ " </tr>\n",
210
+ " <tr>\n",
211
+ " <td>700</td>\n",
212
+ " <td>1.473300</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <td>750</td>\n",
216
+ " <td>0.753000</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <td>800</td>\n",
220
+ " <td>0.741300</td>\n",
221
+ " </tr>\n",
222
+ " <tr>\n",
223
+ " <td>850</td>\n",
224
+ " <td>0.751400</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <td>900</td>\n",
228
+ " <td>0.787600</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <td>950</td>\n",
232
+ " <td>0.783200</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <td>1000</td>\n",
236
+ " <td>0.780200</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <td>1050</td>\n",
240
+ " <td>1.012900</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <td>1100</td>\n",
244
+ " <td>1.411700</td>\n",
245
+ " </tr>\n",
246
+ " <tr>\n",
247
+ " <td>1150</td>\n",
248
+ " <td>1.536400</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <td>1200</td>\n",
252
+ " <td>0.853800</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <td>1250</td>\n",
256
+ " <td>0.756500</td>\n",
257
+ " </tr>\n",
258
+ " <tr>\n",
259
+ " <td>1300</td>\n",
260
+ " <td>0.750800</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <td>1350</td>\n",
264
+ " <td>0.747400</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <td>1400</td>\n",
268
+ " <td>0.844400</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>1450</td>\n",
272
+ " <td>0.858400</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <td>1500</td>\n",
276
+ " <td>1.053400</td>\n",
277
+ " </tr>\n",
278
+ " <tr>\n",
279
+ " <td>1550</td>\n",
280
+ " <td>1.591600</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <td>1600</td>\n",
284
+ " <td>1.498900</td>\n",
285
+ " </tr>\n",
286
+ " <tr>\n",
287
+ " <td>1650</td>\n",
288
+ " <td>1.471700</td>\n",
289
+ " </tr>\n",
290
+ " <tr>\n",
291
+ " <td>1700</td>\n",
292
+ " <td>1.221100</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <td>1750</td>\n",
296
+ " <td>1.802300</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>1800</td>\n",
300
+ " <td>1.826000</td>\n",
301
+ " </tr>\n",
302
+ " <tr>\n",
303
+ " <td>1850</td>\n",
304
+ " <td>1.857300</td>\n",
305
+ " </tr>\n",
306
+ " <tr>\n",
307
+ " <td>1900</td>\n",
308
+ " <td>1.561800</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <td>1950</td>\n",
312
+ " <td>1.398800</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <td>2000</td>\n",
316
+ " <td>1.398900</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>2050</td>\n",
320
+ " <td>1.381600</td>\n",
321
+ " </tr>\n",
322
+ " <tr>\n",
323
+ " <td>2100</td>\n",
324
+ " <td>0.890300</td>\n",
325
+ " </tr>\n",
326
+ " <tr>\n",
327
+ " <td>2150</td>\n",
328
+ " <td>0.763700</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <td>2200</td>\n",
332
+ " <td>0.753100</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <td>2250</td>\n",
336
+ " <td>0.745500</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>2300</td>\n",
340
+ " <td>1.186100</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <td>2350</td>\n",
344
+ " <td>0.862000</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <td>2400</td>\n",
348
+ " <td>1.024600</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <td>2450</td>\n",
352
+ " <td>1.028400</td>\n",
353
+ " </tr>\n",
354
+ " <tr>\n",
355
+ " <td>2500</td>\n",
356
+ " <td>1.008500</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>2550</td>\n",
360
+ " <td>0.942800</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <td>2600</td>\n",
364
+ " <td>0.849700</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <td>2650</td>\n",
368
+ " <td>0.771400</td>\n",
369
+ " </tr>\n",
370
+ " <tr>\n",
371
+ " <td>2700</td>\n",
372
+ " <td>0.794100</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <td>2750</td>\n",
376
+ " <td>0.819200</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>2800</td>\n",
380
+ " <td>0.937500</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>2850</td>\n",
384
+ " <td>1.064500</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>2900</td>\n",
388
+ " <td>1.189300</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>2950</td>\n",
392
+ " <td>1.071100</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>3000</td>\n",
396
+ " <td>1.003300</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>3050</td>\n",
400
+ " <td>1.073900</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>3100</td>\n",
404
+ " <td>1.043100</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>3150</td>\n",
408
+ " <td>1.282600</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>3200</td>\n",
412
+ " <td>2.145400</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>3250</td>\n",
416
+ " <td>1.925800</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>3300</td>\n",
420
+ " <td>2.005600</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>3350</td>\n",
424
+ " <td>2.122600</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>3400</td>\n",
428
+ " <td>2.163000</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>3450</td>\n",
432
+ " <td>2.046600</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>3500</td>\n",
436
+ " <td>2.152200</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>3550</td>\n",
440
+ " <td>2.151700</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>3600</td>\n",
444
+ " <td>5.394900</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>3650</td>\n",
448
+ " <td>4.677800</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>3700</td>\n",
452
+ " <td>4.122200</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>3750</td>\n",
456
+ " <td>3.710200</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>3800</td>\n",
460
+ " <td>3.350800</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>3850</td>\n",
464
+ " <td>3.126300</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>3900</td>\n",
468
+ " <td>2.988700</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>3950</td>\n",
472
+ " <td>2.872000</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>4000</td>\n",
476
+ " <td>2.848200</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>4050</td>\n",
480
+ " <td>2.823900</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>4100</td>\n",
484
+ " <td>2.781200</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>4150</td>\n",
488
+ " <td>2.735000</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>4200</td>\n",
492
+ " <td>2.725900</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>4250</td>\n",
496
+ " <td>2.644400</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>4300</td>\n",
500
+ " <td>2.700000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>4350</td>\n",
504
+ " <td>2.650100</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>4400</td>\n",
508
+ " <td>2.704500</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>4450</td>\n",
512
+ " <td>2.596700</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>4500</td>\n",
516
+ " <td>2.510500</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>4550</td>\n",
520
+ " <td>2.515800</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>4600</td>\n",
524
+ " <td>2.498100</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>4650</td>\n",
528
+ " <td>2.458900</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>4700</td>\n",
532
+ " <td>2.449700</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>4750</td>\n",
536
+ " <td>2.425000</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>4800</td>\n",
540
+ " <td>2.362300</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>4850</td>\n",
544
+ " <td>2.232000</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>4900</td>\n",
548
+ " <td>2.361500</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>4950</td>\n",
552
+ " <td>2.302300</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>5000</td>\n",
556
+ " <td>2.333900</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>5050</td>\n",
560
+ " <td>2.367200</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>5100</td>\n",
564
+ " <td>2.288300</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>5150</td>\n",
568
+ " <td>2.426100</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>5200</td>\n",
572
+ " <td>2.344100</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>5250</td>\n",
576
+ " <td>2.283500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>5300</td>\n",
580
+ " <td>2.296500</td>\n",
581
+ " </tr>\n",
582
+ " </tbody>\n",
583
+ "</table><p>"
584
+ ],
585
+ "text/plain": [
586
+ "<IPython.core.display.HTML object>"
587
+ ]
588
+ },
589
+ "metadata": {},
590
+ "output_type": "display_data"
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n"
597
+ ]
598
+ },
599
+ {
600
+ "data": {
601
+ "text/plain": [
602
+ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3/commit/b9472b66316be8654c6f7c173fa4561889bd3446', commit_message='End of training', commit_description='', oid='b9472b66316be8654c6f7c173fa4561889bd3446', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new3'), pr_revision=None, pr_num=None)"
603
+ ]
604
+ },
605
+ "execution_count": 4,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "import json\n",
612
+ "import torch\n",
613
+ "import os\n",
614
+ "from transformers import AutoTokenizer\n",
615
+ "train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
616
+ "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
617
+ "if 'train_data' not in globals():\n",
618
+ " train_data_path = \"train_data.pt\"\n",
619
+ " \n",
620
+ " if os.path.exists(train_data_path): #确保文件存在\n",
621
+ " train_data = torch.load(train_data_path, weights_only=False)\n",
622
+ " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
623
+ " else:\n",
624
+ " print(f\"未找到 {train_data_path},请检查路径!\")\n",
625
+ " exit()\n",
626
+ "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
627
+ "if \"MODEL_NAME\" not in globals():\n",
628
+ " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
629
+ "\n",
630
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
631
+ "\n",
632
+ "\n",
633
+ "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
634
+ "\n",
635
+ "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
636
+ "\n",
637
+ "\n",
638
+ "from torch.utils.data import Dataset\n",
639
+ "\n",
640
+ "class GraphDataset(Dataset):\n",
641
+ " def __init__(self, data):\n",
642
+ " self.data = data\n",
643
+ "\n",
644
+ " def __len__(self):\n",
645
+ " return len(self.data)\n",
646
+ "\n",
647
+ " def __getitem__(self, idx):\n",
648
+ " sample = self.data[idx]\n",
649
+ " return {\n",
650
+ " \"input_ids\": sample[\"input_ids\"],\n",
651
+ " \"attention_mask\": sample[\"attention_mask\"],\n",
652
+ " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
653
+ " \"labels\": sample[\"labels\"],\n",
654
+ " }\n",
655
+ "\n",
656
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
657
+ "import torch\n",
658
+ "import torch.nn as nn\n",
659
+ "\n",
660
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
661
+ " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
662
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
663
+ " \n",
664
+ " # ✅ 载入 LLM 预训练模型\n",
665
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
666
+ "\n",
667
+ " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
668
+ " self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
669
+ "\n",
670
+ " # ✅ 2. 多头注意力层\n",
671
+ " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
672
+ "\n",
673
+ " # ✅ 3. 线性变换\n",
674
+ " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
675
+ "\n",
676
+ " # ✅ 4. 残差连接 + LayerNorm\n",
677
+ " self.norm = nn.LayerNorm(self.config.hidden_size)\n",
678
+ " \n",
679
+ "\n",
680
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
681
+ " \"\"\"\n",
682
+ " `graph_embedding` 形状: (batch_size, 512)\n",
683
+ " `input_ids` 形状: (batch_size, seq_len)\n",
684
+ " \"\"\"\n",
685
+ " # ✅ 获取 token embedding\n",
686
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
687
+ "\n",
688
+ " # ✅ 1. 线性变换 `graph_embedding`\n",
689
+ " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
690
+ "\n",
691
+ " # ✅ 2. 多头注意力计算(自注意力机制)\n",
692
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
693
+ " \n",
694
+ " # ✅ 3. 线性层 + 残差连接\n",
695
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
696
+ "\n",
697
+ " # ✅ 4. 归一化\n",
698
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
699
+ "\n",
700
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
701
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
702
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
703
+ "\n",
704
+ " # ✅ 调整 attention mask\n",
705
+ " if attention_mask is not None:\n",
706
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
707
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
708
+ "\n",
709
+ " # ✅ 传入模型\n",
710
+ " outputs = self.model(\n",
711
+ " inputs_embeds=inputs_embeds,\n",
712
+ " attention_mask=attention_mask,\n",
713
+ " labels=labels,\n",
714
+ " )\n",
715
+ "\n",
716
+ " return outputs\n",
717
+ "\n",
718
+ "from transformers import Trainer\n",
719
+ "\n",
720
+ "class GraphTrainer(Trainer):\n",
721
+ " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
722
+ " input_ids = inputs[\"input_ids\"]\n",
723
+ " attention_mask = inputs[\"attention_mask\"]\n",
724
+ " labels = inputs[\"labels\"]\n",
725
+ " graph_embedding = inputs.get(\"graph_embedding\", None) \n",
726
+ "\n",
727
+ " if graph_embedding is not None:\n",
728
+ " outputs = model(\n",
729
+ " input_ids=input_ids,\n",
730
+ " attention_mask=attention_mask,\n",
731
+ " labels=labels,\n",
732
+ " graph_embedding=graph_embedding, \n",
733
+ " )\n",
734
+ " else:\n",
735
+ " outputs = model(\n",
736
+ " input_ids=input_ids,\n",
737
+ " attention_mask=attention_mask,\n",
738
+ " labels=labels,\n",
739
+ " )\n",
740
+ "\n",
741
+ " loss = outputs.loss\n",
742
+ " return (loss, outputs) if return_outputs else loss\n",
743
+ "\n",
744
+ "\n",
745
+ "from transformers import AutoConfig\n",
746
+ "\n",
747
+ "# ✅ 载入微调模型\n",
748
+ "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
749
+ "\n",
750
+ "# ✅ 训练参数\n",
751
+ "training_args = TrainingArguments(\n",
752
+ " output_dir=\"./results3\",\n",
753
+ " per_device_train_batch_size=7,\n",
754
+ " eval_strategy=\"no\",\n",
755
+ " save_strategy=\"steps\",\n",
756
+ " save_steps=3000,\n",
757
+ " logging_steps=50,\n",
758
+ " bf16=True,\n",
759
+ " optim=\"galore_adamw\",\n",
760
+ " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
761
+ " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
762
+ " warmup_steps=1000,\n",
763
+ " num_train_epochs=3,\n",
764
+ " push_to_hub=True,\n",
765
+ " hub_model_id=HF_NAME,\n",
766
+ " hub_strategy=\"every_save\",\n",
767
+ " run_name = \"experi030403\"\n",
768
+ ")\n",
769
+ "\n",
770
+ "\n",
771
+ "# ✅ 转换 `train_data` 为 `Dataset`\n",
772
+ "train_dataset = GraphDataset(train_data)\n",
773
+ "\n",
774
+ "# ✅ 训练\n",
775
+ "trainer = GraphTrainer(\n",
776
+ " model=model,\n",
777
+ " args=training_args,\n",
778
+ " train_dataset=train_dataset,\n",
779
+ ")\n",
780
+ "\n",
781
+ "trainer.train()\n",
782
+ "trainer.save_model(\"/workspace/model3\")\n",
783
+ "trainer.push_to_hub()\n",
784
+ "\n",
785
+ "\n"
786
+ ]
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": 2,
791
+ "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
792
+ "metadata": {},
793
+ "outputs": [
794
+ {
795
+ "data": {
796
+ "text/plain": [
797
+ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000')"
798
+ ]
799
+ },
800
+ "execution_count": 2,
801
+ "metadata": {},
802
+ "output_type": "execute_result"
803
+ }
804
+ ],
805
+ "source": [
806
+ "from huggingface_hub import HfApi\n",
807
+ "\n",
808
+ "api = HfApi()\n",
809
+ "repo_name = \"r1q1.5_graph_lora-results3\" # 你的模型名称\n",
810
+ "api.create_repo(repo_name, exist_ok=True)"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "code",
815
+ "execution_count": 3,
816
+ "id": "73c434b9-5d58-4819-8526-24aa18ca1010",
817
+ "metadata": {},
818
+ "outputs": [
819
+ {
820
+ "data": {
821
+ "application/vnd.jupyter.widget-view+json": {
822
+ "model_id": "8b896f21685e4086b0b59404b2b1a866",
823
+ "version_major": 2,
824
+ "version_minor": 0
825
+ },
826
+ "text/plain": [
827
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
828
+ ]
829
+ },
830
+ "metadata": {},
831
+ "output_type": "display_data"
832
+ },
833
+ {
834
+ "data": {
835
+ "application/vnd.jupyter.widget-view+json": {
836
+ "model_id": "d20bff067ca44c4583378181da817897",
837
+ "version_major": 2,
838
+ "version_minor": 0
839
+ },
840
+ "text/plain": [
841
+ "scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
842
+ ]
843
+ },
844
+ "metadata": {},
845
+ "output_type": "display_data"
846
+ },
847
+ {
848
+ "data": {
849
+ "application/vnd.jupyter.widget-view+json": {
850
+ "model_id": "c4b7114a53b341539a3244f2eea8aacf",
851
+ "version_major": 2,
852
+ "version_minor": 0
853
+ },
854
+ "text/plain": [
855
+ "Upload 6 LFS files: 0%| | 0/6 [00:00<?, ?it/s]"
856
+ ]
857
+ },
858
+ "metadata": {},
859
+ "output_type": "display_data"
860
+ },
861
+ {
862
+ "data": {
863
+ "application/vnd.jupyter.widget-view+json": {
864
+ "model_id": "74c6045017b640bdba86fe3ed1bb9c92",
865
+ "version_major": 2,
866
+ "version_minor": 0
867
+ },
868
+ "text/plain": [
869
+ "model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
870
+ ]
871
+ },
872
+ "metadata": {},
873
+ "output_type": "display_data"
874
+ },
875
+ {
876
+ "data": {
877
+ "application/vnd.jupyter.widget-view+json": {
878
+ "model_id": "97436b084bc4420f8b273ec462c50e61",
879
+ "version_major": 2,
880
+ "version_minor": 0
881
+ },
882
+ "text/plain": [
883
+ "optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
884
+ ]
885
+ },
886
+ "metadata": {},
887
+ "output_type": "display_data"
888
+ },
889
+ {
890
+ "data": {
891
+ "application/vnd.jupyter.widget-view+json": {
892
+ "model_id": "d7f10ccff3674e6fa8bcb42553c12b19",
893
+ "version_major": 2,
894
+ "version_minor": 0
895
+ },
896
+ "text/plain": [
897
+ "rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
898
+ ]
899
+ },
900
+ "metadata": {},
901
+ "output_type": "display_data"
902
+ },
903
+ {
904
+ "data": {
905
+ "application/vnd.jupyter.widget-view+json": {
906
+ "model_id": "c5b1a010fd0845f9ba9112291afa8f17",
907
+ "version_major": 2,
908
+ "version_minor": 0
909
+ },
910
+ "text/plain": [
911
+ "training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
912
+ ]
913
+ },
914
+ "metadata": {},
915
+ "output_type": "display_data"
916
+ },
917
+ {
918
+ "data": {
919
+ "text/plain": [
920
+ "CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000/commit/4088de651a0ce2cc39fcb0c950898e54ce91bdea', commit_message='upload checkpoint-3000', commit_description='', oid='4088de651a0ce2cc39fcb0c950898e54ce91bdea', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000'), pr_revision=None, pr_num=None)"
921
+ ]
922
+ },
923
+ "execution_count": 3,
924
+ "metadata": {},
925
+ "output_type": "execute_result"
926
+ }
927
+ ],
928
+ "source": [
929
+ "from huggingface_hub import upload_folder\n",
930
+ "\n",
931
+ "upload_folder(\n",
932
+ " folder_path = \"/workspace/results3\",\n",
933
+ " repo_id = \"YiFzhao/r1q1.5_graph_lora-results3\",\n",
934
+ " commit_message = \"upload results2\",\n",
935
+ ")"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": 5,
941
+ "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
942
+ "metadata": {},
943
+ "outputs": [
944
+ {
945
+ "name": "stdout",
946
+ "output_type": "stream",
947
+ "text": [
948
+ "🚀 处理后数据条数: 12384\n",
949
+ "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
950
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
951
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
952
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
953
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
954
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
955
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
956
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
957
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
958
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
959
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
960
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
961
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
962
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
963
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
964
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
965
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
966
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
967
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
968
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
969
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
970
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
971
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
972
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
973
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
974
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
975
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
976
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
977
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
978
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
979
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
980
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
981
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
982
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
983
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
984
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
985
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
986
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
987
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
988
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
989
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
990
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
991
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
992
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
993
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
994
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
995
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
996
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
997
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
998
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
999
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1000
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1001
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1002
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1003
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1004
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1005
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1006
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1007
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1008
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1009
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1010
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1011
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1012
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
1013
+ "✅ train_data 已保存到 train_data.pt\n"
1014
+ ]
1015
+ }
1016
+ ],
1017
+ "source": [
1018
+ "import json\n",
1019
+ "import torch\n",
1020
+ "from transformers import AutoTokenizer\n",
1021
+ "\n",
1022
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1023
+ "tokenizer.pad_token = tokenizer.eos_token \n",
1024
+ "\n",
1025
+ "json_path = \"final_Graph.json\"\n",
1026
+ "with open(json_path, \"r\") as f:\n",
1027
+ " data = json.load(f)\n",
1028
+ "\n",
1029
+ "train_data = []\n",
1030
+ "\n",
1031
+ "\n",
1032
+ "for sample in data:\n",
1033
+ " conversations = sample.get(\"conversations\", [])\n",
1034
+ " embeddings = sample.get(\"embedding\", []) \n",
1035
+ "\n",
1036
+ " if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
1037
+ " print(f\"无效的 embedding,跳过样本:{sample}\")\n",
1038
+ " continue\n",
1039
+ "\n",
1040
+ " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
1041
+ "\n",
1042
+ " #拼接所有对话\n",
1043
+ " dialogue_text = \"\"\n",
1044
+ " for conv in conversations:\n",
1045
+ " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
1046
+ " content = conv[\"value\"]\n",
1047
+ " content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
1048
+ " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
1049
+ " dialogue_text += f\"{role_token} {content}\\n\"\n",
1050
+ "\n",
1051
+ " tokenized = tokenizer(\n",
1052
+ " dialogue_text,\n",
1053
+ " padding=\"max_length\",\n",
1054
+ " truncation=True,\n",
1055
+ " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
1056
+ " return_tensors=\"pt\",\n",
1057
+ " )\n",
1058
+ "\n",
1059
+ " input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
1060
+ " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
1061
+ "\n",
1062
+ " train_data.append({\n",
1063
+ " \"input_ids\": input_ids,\n",
1064
+ " \"attention_mask\": attention_mask,\n",
1065
+ " \"labels\": input_ids.clone(),\n",
1066
+ " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
1067
+ " })\n",
1068
+ "\n",
1069
+ "print(\"🚀 处理后数据条数:\", len(train_data))\n",
1070
+ "print(\"✅ 示例数据:\", train_data[0])\n",
1071
+ "torch.save(train_data, \"train_data.pt\")\n",
1072
+ "print(\"✅ train_data 已保存到 train_data.pt\")\n"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "code",
1077
+ "execution_count": 6,
1078
+ "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
1079
+ "metadata": {},
1080
+ "outputs": [],
1081
+ "source": [
1082
+ "from transformers import AutoModelForCausalLM, AutoConfig\n",
1083
+ "import torch\n",
1084
+ "import torch.nn as nn\n",
1085
+ "\n",
1086
+ "class GraphAwareLM(AutoModelForCausalLM):\n",
1087
+ " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
1088
+ " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
1089
+ " \n",
1090
+ " # ✅ 载入 LLM 预训练模型\n",
1091
+ " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
1092
+ "\n",
1093
+ " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
1094
+ " self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
1095
+ "\n",
1096
+ " # ✅ 2. 多头注意力层\n",
1097
+ " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
1098
+ "\n",
1099
+ " # ✅ 3. 线性变换\n",
1100
+ " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
1101
+ "\n",
1102
+ " # ✅ 4. 残差连接 + LayerNorm\n",
1103
+ " self.norm = nn.LayerNorm(self.config.hidden_size)\n",
1104
+ " \n",
1105
+ "\n",
1106
+ " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
1107
+ " \"\"\"\n",
1108
+ " `graph_embedding` 形状: (batch_size, 512)\n",
1109
+ " `input_ids` 形状: (batch_size, seq_len)\n",
1110
+ " \"\"\"\n",
1111
+ " # ✅ 获取 token embedding\n",
1112
+ " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
1113
+ "\n",
1114
+ " # ✅ 1. 线性变换 `graph_embedding`\n",
1115
+ " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
1116
+ "\n",
1117
+ " # ✅ 2. 多头注意力计算(自注意力机制)\n",
1118
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
1119
+ " \n",
1120
+ " # ✅ 3. 线性层 + 残差连接\n",
1121
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
1122
+ "\n",
1123
+ " # ✅ 4. 归一化\n",
1124
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
1125
+ "\n",
1126
+ " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
1127
+ " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
1128
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
1129
+ "\n",
1130
+ " # ✅ 调整 attention mask\n",
1131
+ " if attention_mask is not None:\n",
1132
+ " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
1133
+ " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
1134
+ "\n",
1135
+ " # ✅ 传入模型\n",
1136
+ " outputs = self.model(\n",
1137
+ " inputs_embeds=inputs_embeds,\n",
1138
+ " attention_mask=attention_mask,\n",
1139
+ " labels=labels,\n",
1140
+ " )\n",
1141
+ "\n",
1142
+ " return outputs\n",
1143
+ "\n",
1144
+ " def generate(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
1145
+ " \"\"\"\n",
1146
+ " ✅ 自定义 `generate()` 方法,支持 `graph_embedding`\n",
1147
+ " `input_text`: 需要生成文本的输入\n",
1148
+ " `graph_embedding`: 形状为 (1, 512) 的张量\n",
1149
+ " \"\"\"\n",
1150
+ "\n",
1151
+ " # ✅ 2. 处理 `graph_embedding`\n",
1152
+ " graph_embedding_token = self.linear1(graph_embedding) # (1, 1, hidden_size)\n",
1153
+ " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
1154
+ " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (1, 1, hidden_size)\n",
1155
+ " graph_embedding_token = self.norm(graph_embedding_token)\n",
1156
+ "\n",
1157
+ " # ✅ 3. 获取 Token Embeddings 并拼接\n",
1158
+ " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
1159
+ " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
1160
+ "\n",
1161
+ " # ✅ 4. 调整 `attention_mask`\n",
1162
+ " if \"attention_mask\" in inputs:\n",
1163
+ " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
1164
+ " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
1165
+ " else:\n",
1166
+ " attention_mask = None\n",
1167
+ "\n",
1168
+ " # ✅ 5. 进行文本生成\n",
1169
+ " with torch.no_grad():\n",
1170
+ " output_ids = self.model.generate(\n",
1171
+ " inputs_embeds=inputs_embeds,\n",
1172
+ " attention_mask=attention_mask,\n",
1173
+ " max_length=max_length,\n",
1174
+ " temperature=temperature,\n",
1175
+ " top_k=top_k,\n",
1176
+ " top_p=top_p,\n",
1177
+ " num_return_sequences=1\n",
1178
+ " )\n",
1179
+ "\n",
1180
+ " # ✅ 6. 解码输出\n",
1181
+ " generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
1182
+ " return generated_text\n",
1183
+ "\n",
1184
+ " @classmethod\n",
1185
+ " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
1186
+ " # ✅ 1. 调用 `super().from_pretrained()` 加载 LLM\n",
1187
+ " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
1188
+ "\n",
1189
+ " # ✅ 2. 初始化 `MLP + MultiheadAttention` 结构\n",
1190
+ " model.linear1 = nn.Linear(512, model.config.hidden_size)\n",
1191
+ " model.multihead_attn = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=8, batch_first=True)\n",
1192
+ " model.linear2 = nn.Linear(model.config.hidden_size, model.config.hidden_size)\n",
1193
+ " model.norm = nn.LayerNorm(model.config.hidden_size)\n",
1194
+ "\n",
1195
+ " return model"
1196
+ ]
1197
+ },
1198
+ {
1199
+ "cell_type": "code",
1200
+ "execution_count": 2,
1201
+ "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
1202
+ "metadata": {},
1203
+ "outputs": [],
1204
+ "source": [
1205
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
1206
+ ]
1207
+ },
1208
+ {
1209
+ "cell_type": "code",
1210
+ "execution_count": 7,
1211
+ "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
1212
+ "metadata": {},
1213
+ "outputs": [
1214
+ {
1215
+ "data": {
1216
+ "application/vnd.jupyter.widget-view+json": {
1217
+ "model_id": "0b50f0cd6d784f598cc64a40cff40f38",
1218
+ "version_major": 2,
1219
+ "version_minor": 0
1220
+ },
1221
+ "text/plain": [
1222
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
1223
+ ]
1224
+ },
1225
+ "metadata": {},
1226
+ "output_type": "display_data"
1227
+ },
1228
+ {
1229
+ "data": {
1230
+ "text/plain": [
1231
+ "Qwen2ForCausalLM(\n",
1232
+ " (model): Qwen2Model(\n",
1233
+ " (embed_tokens): Embedding(151936, 1536)\n",
1234
+ " (layers): ModuleList(\n",
1235
+ " (0-27): 28 x Qwen2DecoderLayer(\n",
1236
+ " (self_attn): Qwen2Attention(\n",
1237
+ " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
1238
+ " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1239
+ " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
1240
+ " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
1241
+ " )\n",
1242
+ " (mlp): Qwen2MLP(\n",
1243
+ " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1244
+ " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
1245
+ " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
1246
+ " (act_fn): SiLU()\n",
1247
+ " )\n",
1248
+ " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1249
+ " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1250
+ " )\n",
1251
+ " )\n",
1252
+ " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
1253
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
1254
+ " )\n",
1255
+ " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
1256
+ " (linear1): Linear(in_features=512, out_features=1536, bias=True)\n",
1257
+ " (multihead_attn): MultiheadAttention(\n",
1258
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)\n",
1259
+ " )\n",
1260
+ " (linear2): Linear(in_features=1536, out_features=1536, bias=True)\n",
1261
+ " (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
1262
+ ")"
1263
+ ]
1264
+ },
1265
+ "execution_count": 7,
1266
+ "metadata": {},
1267
+ "output_type": "execute_result"
1268
+ }
1269
+ ],
1270
+ "source": [
1271
+ "import torch\n",
1272
+ "from transformers import AutoTokenizer\n",
1273
+ "\n",
1274
+ "# 加载 tokenizer\n",
1275
+ "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
1276
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
1277
+ "\n",
1278
+ "# 加载训练好的模型\n",
1279
+ "model_path = \"/workspace/model2\"\n",
1280
+ "model = GraphAwareLM.from_pretrained(\"/workspace/results3/checkpoint-3000\").to(device)\n",
1281
+ "model.eval() # 设置为推理模式\n"
1282
+ ]
1283
+ },
1284
+ {
1285
+ "cell_type": "code",
1286
+ "execution_count": 13,
1287
+ "id": "51995891-8906-4049-9401-2d22e06a84e8",
1288
+ "metadata": {},
1289
+ "outputs": [
1290
+ {
1291
+ "name": "stdout",
1292
+ "output_type": "stream",
1293
+ "text": [
1294
+ "Parameter containing:\n",
1295
+ "tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
1296
+ " [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
1297
+ " [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
1298
+ " ...,\n",
1299
+ " [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
1300
+ " [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
1301
+ " [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
1302
+ " device='cuda:0', requires_grad=True)\n"
1303
+ ]
1304
+ }
1305
+ ],
1306
+ "source": [
1307
+ "print(model.graph_proj.weight)\n"
1308
+ ]
1309
+ },
1310
+ {
1311
+ "cell_type": "code",
1312
+ "execution_count": 4,
1313
+ "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
1314
+ "metadata": {},
1315
+ "outputs": [],
1316
+ "source": [
1317
+ "import json\n",
1318
+ "json_path = \"final_Graph.json\"\n",
1319
+ "with open(json_path, \"r\") as f:\n",
1320
+ " data = json.load(f)\n",
1321
+ "\n",
1322
+ "test_data = data[0]\n",
1323
+ "\n",
1324
+ "conversations = test_data.get(\"conversations\")\n",
1325
+ "embeddings = test_data.get(\"embedding\") \n",
1326
+ "\n",
1327
+ "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
1328
+ "\n",
1329
+ "question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
1330
+ "\n",
1331
+ "from transformers import AutoTokenizer\n",
1332
+ "\n",
1333
+ "# ✅ 输入文本\n",
1334
+ "ROLE_TOKENS = {\n",
1335
+ " \"human\": \"<|User|>\", \n",
1336
+ " \"gpt\": \"<|Assistant|>\", \n",
1337
+ "}\n",
1338
+ "GRAPH_LENGTH = 512\n",
1339
+ "max_seq_length = 1100 + GRAPH_LENGTH\n",
1340
+ "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
1341
+ "\n",
1342
+ "input_ids = inputs[\"input_ids\"]\n",
1343
+ "attention_mask = inputs[\"attention_mask\"]\n"
1344
+ ]
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": 8,
1349
+ "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
1350
+ "metadata": {},
1351
+ "outputs": [
1352
+ {
1353
+ "ename": "AttributeError",
1354
+ "evalue": "'Tensor' object has no attribute 'update'",
1355
+ "output_type": "error",
1356
+ "traceback": [
1357
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1358
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1359
+ "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m)\u001b[49m\n",
1360
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1361
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1982\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1979\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1980\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n\u001b[0;32m-> 1982\u001b[0m generation_config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_generation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1983\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_kwargs(model_kwargs\u001b[38;5;241m.\u001b[39mcopy())\n\u001b[1;32m 1984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_assistant(assistant_model, tokenizer, assistant_tokenizer)\n",
1362
+ "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1549\u001b[0m, in \u001b[0;36mGenerationMixin._prepare_generation_config\u001b[0;34m(self, generation_config, **kwargs)\u001b[0m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 1548\u001b[0m generation_config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(generation_config)\n\u001b[0;32m-> 1549\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1550\u001b[0m \u001b[38;5;66;03m# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model\u001b[39;00m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_model_generation_config:\n",
1363
+ "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'update'"
1364
+ ]
1365
+ }
1366
+ ],
1367
+ "source": [
1368
+ "generated_text = model.generate(inputs, graph_embedding)"
1369
+ ]
1370
+ },
1371
+ {
1372
+ "cell_type": "code",
1373
+ "execution_count": 5,
1374
+ "id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
1375
+ "metadata": {},
1376
+ "outputs": [
1377
+ {
1378
+ "data": {
1379
+ "text/plain": [
1380
+ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
1381
+ " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
1382
+ " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
1383
+ " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
1384
+ " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
1385
+ " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
1386
+ " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
1387
+ " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
1388
+ " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
1389
+ " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
1390
+ " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
1391
+ " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
1392
+ " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
1393
+ " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
1394
+ " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
1395
+ " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
1396
+ " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
1397
+ " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
1398
+ " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
1399
+ " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
1400
+ " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
1401
+ " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
1402
+ " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
1403
+ " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
1404
+ " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
1405
+ " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
1406
+ " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
1407
+ " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
1408
+ " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
1409
+ " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
1410
+ " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
1411
+ " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
1412
+ " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
1413
+ " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
1414
+ " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
1415
+ " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
1416
+ " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
1417
+ " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
1418
+ " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
1419
+ " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
1420
+ " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
1421
+ " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
1422
+ " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
1423
+ " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
1424
+ " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
1425
+ " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
1426
+ " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
1427
+ " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
1428
+ " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
1429
+ " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
1430
+ " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
1431
+ " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
1432
+ " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
1433
+ " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
1434
+ " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
1435
+ " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
1436
+ " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
1437
+ " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
1438
+ " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
1439
+ " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
1440
+ " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
1441
+ " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
1442
+ " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
1443
+ " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
1444
+ " device='cuda:0')"
1445
+ ]
1446
+ },
1447
+ "execution_count": 5,
1448
+ "metadata": {},
1449
+ "output_type": "execute_result"
1450
+ }
1451
+ ],
1452
+ "source": [
1453
+ "graph_embedding"
1454
+ ]
1455
+ },
1456
+ {
1457
+ "cell_type": "code",
1458
+ "execution_count": 15,
1459
+ "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
1460
+ "metadata": {},
1461
+ "outputs": [
1462
+ {
1463
+ "name": "stdout",
1464
+ "output_type": "stream",
1465
+ "text": [
1466
+ "模型回复: How\n"
1467
+ ]
1468
+ }
1469
+ ],
1470
+ "source": [
1471
+ "# ✅ 进行前向传播\n",
1472
+ "with torch.no_grad():\n",
1473
+ " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
1474
+ "\n",
1475
+ "# ✅ 提取 logits 并进行贪心解码\n",
1476
+ "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1477
+ "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
1478
+ "\n",
1479
+ "# ✅ 反向编码为文本\n",
1480
+ "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
1481
+ "\n",
1482
+ "print(\"模型回复:\", response_text)"
1483
+ ]
1484
+ },
1485
+ {
1486
+ "cell_type": "code",
1487
+ "execution_count": 9,
1488
+ "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
1489
+ "metadata": {},
1490
+ "outputs": [
1491
+ {
1492
+ "name": "stdout",
1493
+ "output_type": "stream",
1494
+ "text": [
1495
+ "Generated Response: What are the signal definitions in the Verilog code for the calculator module, and what are their purposes? The Verilog code defines the inputs A, B, and C, and the output Y. A and B are the operands, C is the carry-in, and Y is the result. The purpose of the module is to perform a 2-bit adder, which adds two 2-bit numbers, and the output is the sum. The inputs A and B are the operands, C is the carry-in, and Y is the result. The module is designed to handle the addition operation of two 2-bit numbers, with a carry-in, and a 3-bit output. The implementation involves using logic gates to perform the addition operation, with the sum output connected to the gates. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, involving basic gates and an adder circuit. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with\n"
1496
+ ]
1497
+ }
1498
+ ],
1499
+ "source": [
1500
+ "max_new_tokens = 500\n",
1501
+ "generated_ids = input_ids.clone()\n",
1502
+ "generated_attention_mask = attention_mask.clone()\n",
1503
+ "for _ in range(max_new_tokens):\n",
1504
+ " # ✅ 计算 logits 并进行生成\n",
1505
+ " with torch.no_grad():\n",
1506
+ " outputs = model(\n",
1507
+ " input_ids=generated_ids, # (batch_size, seq_len)\n",
1508
+ " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
1509
+ " graph_embedding=graph_embedding, # (batch_size, 512)\n",
1510
+ " )\n",
1511
+ "\n",
1512
+ "\n",
1513
+ " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
1514
+ " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
1515
+ " # print(next_token)\n",
1516
+ "\n",
1517
+ "\n",
1518
+ " # ✅ **拼接到已生成序列**\n",
1519
+ " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
1520
+ "\n",
1521
+ " # print(generated_ids)\n",
1522
+ "\n",
1523
+ " if next_token.item() == tokenizer.eos_token_id:\n",
1524
+ " break\n",
1525
+ "\n",
1526
+ " generated_attention_mask = torch.cat(\n",
1527
+ " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
1528
+ " ) \n",
1529
+ "\n",
1530
+ "# ✅ 解码最终输出\n",
1531
+ "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1532
+ "print(\"Generated Response:\", generated_text)"
1533
+ ]
1534
+ },
1535
+ {
1536
+ "cell_type": "code",
1537
+ "execution_count": 10,
1538
+ "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
1539
+ "metadata": {},
1540
+ "outputs": [
1541
+ {
1542
+ "data": {
1543
+ "text/plain": [
1544
+ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
1545
+ " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
1546
+ " 525, 862, 9895, 30]], device='cuda:0')"
1547
+ ]
1548
+ },
1549
+ "execution_count": 10,
1550
+ "metadata": {},
1551
+ "output_type": "execute_result"
1552
+ }
1553
+ ],
1554
+ "source": [
1555
+ "generated_ids"
1556
+ ]
1557
+ },
1558
+ {
1559
+ "cell_type": "code",
1560
+ "execution_count": null,
1561
+ "id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
1562
+ "metadata": {},
1563
+ "outputs": [],
1564
+ "source": []
1565
+ }
1566
+ ],
1567
+ "metadata": {
1568
+ "kernelspec": {
1569
+ "display_name": "Python 3 (ipykernel)",
1570
+ "language": "python",
1571
+ "name": "python3"
1572
+ },
1573
+ "language_info": {
1574
+ "codemirror_mode": {
1575
+ "name": "ipython",
1576
+ "version": 3
1577
+ },
1578
+ "file_extension": ".py",
1579
+ "mimetype": "text/x-python",
1580
+ "name": "python",
1581
+ "nbconvert_exporter": "python",
1582
+ "pygments_lexer": "ipython3",
1583
+ "version": "3.10.12"
1584
+ }
1585
+ },
1586
+ "nbformat": 4,
1587
+ "nbformat_minor": 5
1588
+ }
train_data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c09c4c4be57cf268061c0f20f6e6d877359dd683fbff17f4d3a6b7cffee3dae
3
+ size 364711686