upload code0304
Browse files- .gitattributes +1 -0
- .ipynb_checkpoints/eval-checkpoint.ipynb +395 -0
- .ipynb_checkpoints/graph_train-checkpoint.ipynb +1591 -0
- .ipynb_checkpoints/graph_train2-checkpoint.ipynb +1674 -0
- .ipynb_checkpoints/graph_train3-checkpoint.ipynb +1588 -0
- eval.ipynb +406 -0
- final_Graph.json +3 -0
- graph_train.ipynb +1591 -0
- graph_train2.ipynb +1506 -0
- graph_train3.ipynb +1588 -0
- train_data.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
final_Graph.json filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/eval-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 10 |
+
"import torch\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"MODEL_NAME = \"/workspace/model\"\n",
|
| 15 |
+
"model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 2,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"import json\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"from transformers import AutoTokenizer\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
|
| 29 |
+
"tokenizer.pad_token = tokenizer.eos_token "
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 3,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"json_path = \"final_Graph.json\"\n",
|
| 39 |
+
"with open(json_path, \"r\") as f:\n",
|
| 40 |
+
" data = json.load(f)\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"test_data = data[0]\n"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 4,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"ROLE_TOKENS = {\n",
|
| 52 |
+
" \"human\": \"<|User|>\", \n",
|
| 53 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 54 |
+
"}\n",
|
| 55 |
+
"GRAPH_LENGTH = 512\n",
|
| 56 |
+
"max_seq_length = 1100 + GRAPH_LENGTH"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 5,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 66 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 67 |
+
"\n",
|
| 68 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32)"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"execution_count": 6,
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [
|
| 76 |
+
{
|
| 77 |
+
"data": {
|
| 78 |
+
"text/plain": [
|
| 79 |
+
"'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
"execution_count": 6,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"output_type": "execute_result"
|
| 85 |
+
}
|
| 86 |
+
],
|
| 87 |
+
"source": [
|
| 88 |
+
"question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 89 |
+
"question1"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 7,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"import json\n",
|
| 99 |
+
"import torch\n",
|
| 100 |
+
"import os\n",
|
| 101 |
+
"from transformers import AutoTokenizer\n",
|
| 102 |
+
"# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 103 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 104 |
+
"from torch.utils.data import Dataset\n",
|
| 105 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 106 |
+
"import torch\n",
|
| 107 |
+
"import torch.nn as nn\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 110 |
+
" def __init__(self, config):\n",
|
| 111 |
+
" super().__init__(config)\n",
|
| 112 |
+
" self.model = AutoModelForCausalLM.from_config(config)\n",
|
| 113 |
+
" \n",
|
| 114 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 115 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 118 |
+
" \"\"\"\n",
|
| 119 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 120 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 121 |
+
" \"\"\"\n",
|
| 122 |
+
" # ✅ 获取 token embedding\n",
|
| 123 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 126 |
+
" graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 129 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 130 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 131 |
+
"\n",
|
| 132 |
+
" # ✅ 调整 attention mask\n",
|
| 133 |
+
" if attention_mask is not None:\n",
|
| 134 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 135 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" # ✅ 传入模型\n",
|
| 138 |
+
" outputs = self.model(\n",
|
| 139 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 140 |
+
" attention_mask=attention_mask,\n",
|
| 141 |
+
" labels=labels,\n",
|
| 142 |
+
" )\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" return outputs\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" @classmethod\n",
|
| 147 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 148 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 149 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 150 |
+
" return model\n",
|
| 151 |
+
"\n"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": 8,
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [
|
| 159 |
+
{
|
| 160 |
+
"name": "stderr",
|
| 161 |
+
"output_type": "stream",
|
| 162 |
+
"text": [
|
| 163 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
|
| 164 |
+
]
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"source": [
|
| 168 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 169 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": 9,
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [
|
| 177 |
+
{
|
| 178 |
+
"data": {
|
| 179 |
+
"text/plain": [
|
| 180 |
+
"tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 181 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 182 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 183 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 184 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 185 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 186 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 187 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 188 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 189 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 190 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 191 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 192 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 193 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 194 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 195 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 196 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 197 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 198 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 199 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 200 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 201 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 202 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 203 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 204 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 205 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 206 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 207 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 208 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 209 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 210 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 211 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 212 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 213 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 214 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 215 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 216 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 217 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 218 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 219 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 220 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 221 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 222 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 223 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 224 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 225 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 226 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 227 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 228 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 229 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 230 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 231 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 232 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 233 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 234 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 235 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 236 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 237 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 238 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 239 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 240 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 241 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 242 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 243 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
|
| 244 |
+
" device='cuda:0')"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
"execution_count": 9,
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"output_type": "execute_result"
|
| 250 |
+
}
|
| 251 |
+
],
|
| 252 |
+
"source": [
|
| 253 |
+
"from transformers import AutoTokenizer\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"# ✅ 加载分词器\n",
|
| 256 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"# ✅ 输入文本\n",
|
| 259 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"graph_embedding.to(device)\n",
|
| 262 |
+
"\n"
|
| 263 |
+
]
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"cell_type": "code",
|
| 267 |
+
"execution_count": 10,
|
| 268 |
+
"metadata": {},
|
| 269 |
+
"outputs": [
|
| 270 |
+
{
|
| 271 |
+
"ename": "RuntimeError",
|
| 272 |
+
"evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3",
|
| 273 |
+
"output_type": "error",
|
| 274 |
+
"traceback": [
|
| 275 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 276 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 277 |
+
"Cell \u001b[0;32mIn[10], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :] \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m 14\u001b[0m next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n",
|
| 278 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 279 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 280 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 281 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 866\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
|
| 282 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 283 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 284 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 567\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 568\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 569\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 576\u001b[0m position_embeddings,\n\u001b[1;32m 577\u001b[0m )\n\u001b[1;32m 578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 585\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
|
| 285 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 286 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 287 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
|
| 288 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 289 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 290 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 190\u001b[0m attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n",
|
| 291 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m 125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n",
|
| 292 |
+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3"
|
| 293 |
+
]
|
| 294 |
+
}
|
| 295 |
+
],
|
| 296 |
+
"source": [
|
| 297 |
+
"\n",
|
| 298 |
+
"generated_ids = inputs[\"input_ids\"]\n",
|
| 299 |
+
"max_new_tokens = 1024\n",
|
| 300 |
+
"for _ in range(max_new_tokens):\n",
|
| 301 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 302 |
+
" with torch.no_grad():\n",
|
| 303 |
+
" outputs = model(\n",
|
| 304 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 305 |
+
" attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
|
| 306 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 307 |
+
" )\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 311 |
+
" next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 315 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 316 |
+
"\n",
|
| 317 |
+
" if next_token[:, 0] == tokenizer.eos_token_id:\n",
|
| 318 |
+
" break\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"# ✅ 解码最终输出\n",
|
| 321 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 322 |
+
"print(\"Generated Response:\", generated_text)"
|
| 323 |
+
]
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"cell_type": "code",
|
| 327 |
+
"execution_count": null,
|
| 328 |
+
"metadata": {},
|
| 329 |
+
"outputs": [
|
| 330 |
+
{
|
| 331 |
+
"name": "stdout",
|
| 332 |
+
"output_type": "stream",
|
| 333 |
+
"text": [
|
| 334 |
+
"Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n",
|
| 335 |
+
"The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n",
|
| 336 |
+
"How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n",
|
| 337 |
+
"The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n",
|
| 338 |
+
"What is the function of the circuit described in the Verilog code?\n",
|
| 339 |
+
"The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n",
|
| 340 |
+
"How can the circuit be implemented in hardware?\n",
|
| 341 |
+
"The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n"
|
| 342 |
+
]
|
| 343 |
+
}
|
| 344 |
+
],
|
| 345 |
+
"source": [
|
| 346 |
+
"generated_ids = inputs[\"input_ids\"]\n",
|
| 347 |
+
"max_new_tokens = 1024\n",
|
| 348 |
+
"for _ in range(max_new_tokens):\n",
|
| 349 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 350 |
+
" with torch.no_grad():\n",
|
| 351 |
+
" outputs = model(\n",
|
| 352 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 353 |
+
" attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
|
| 354 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 355 |
+
" )\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 359 |
+
" next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"\n",
|
| 362 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 363 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" if next_token[:, 0] == tokenizer.eos_token_id:\n",
|
| 366 |
+
" break\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"# ✅ 解码最终输出\n",
|
| 369 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 370 |
+
"print(\"Generated Response:\", generated_text)"
|
| 371 |
+
]
|
| 372 |
+
}
|
| 373 |
+
],
|
| 374 |
+
"metadata": {
|
| 375 |
+
"kernelspec": {
|
| 376 |
+
"display_name": "Python 3 (ipykernel)",
|
| 377 |
+
"language": "python",
|
| 378 |
+
"name": "python3"
|
| 379 |
+
},
|
| 380 |
+
"language_info": {
|
| 381 |
+
"codemirror_mode": {
|
| 382 |
+
"name": "ipython",
|
| 383 |
+
"version": 3
|
| 384 |
+
},
|
| 385 |
+
"file_extension": ".py",
|
| 386 |
+
"mimetype": "text/x-python",
|
| 387 |
+
"name": "python",
|
| 388 |
+
"nbconvert_exporter": "python",
|
| 389 |
+
"pygments_lexer": "ipython3",
|
| 390 |
+
"version": "3.10.12"
|
| 391 |
+
}
|
| 392 |
+
},
|
| 393 |
+
"nbformat": 4,
|
| 394 |
+
"nbformat_minor": 4
|
| 395 |
+
}
|
.ipynb_checkpoints/graph_train-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,1591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 2,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 3,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_081255-v0v96nik</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:23:11, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.349900</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>5.305900</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>4.849500</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>3.910800</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>3.325600</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>3.144900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>2.904200</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>2.082100</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>1.214300</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>1.011600</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.889300</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.907300</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>1.190400</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.889100</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>4.505600</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>6.402800</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>6.479300</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>7.337900</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>8.937600</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>8.938700</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>8.860100</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>8.693600</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>9.234000</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>9.347500</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>8.010300</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>5.952900</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>5.205900</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>4.969600</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>4.884600</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>4.934200</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>5.156900</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>5.115500</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>5.373600</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>4.481800</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>3.957000</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>3.092500</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.791000</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.934400</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>2.176800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>2.112400</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>2.127900</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>2.390200</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>3.091400</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>3.959500</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>3.905000</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>3.777500</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>3.282900</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>2.630300</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>3.705000</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>4.266300</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>4.285300</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>4.634000</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>4.474700</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>3.591300</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>2.486800</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>1.911800</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>2.088100</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>2.015400</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.988500</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.976900</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>2.097700</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.987400</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>2.065000</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.112100</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>2.075300</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>1.983300</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.181900</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.446500</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.434200</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.357000</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.157400</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>1.992900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>2.018400</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>2.010200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>2.009500</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>2.034900</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>2.038800</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.007600</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>1.983200</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.005300</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.014900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.018100</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.033900</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.024600</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>1.995300</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.018000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>1.998300</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.032800</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>1.985900</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>1.967700</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>1.989400</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.004700</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.005800</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.014400</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.009200</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.002200</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>1.914300</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.016900</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>1.972900</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.010300</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.046600</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>1.993900</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.084500</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.011900</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>1.996500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>1.997900</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new/commit/231f89403dca9aa67966e4f321e62bdb41076960', commit_message='End of training', commit_description='', oid='231f89403dca9aa67966e4f321e62bdb41076960', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, config):\n",
|
| 662 |
+
" super().__init__(config)\n",
|
| 663 |
+
"\n",
|
| 664 |
+
" # self.model = AutoModelForCausalLM.from_config(config)\n",
|
| 665 |
+
" \n",
|
| 666 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 667 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 668 |
+
"\n",
|
| 669 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 670 |
+
" \"\"\"\n",
|
| 671 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 672 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 673 |
+
" \"\"\"\n",
|
| 674 |
+
" # ✅ 获取 token embedding\n",
|
| 675 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 676 |
+
"\n",
|
| 677 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 678 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 679 |
+
"\n",
|
| 680 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 681 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 682 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 683 |
+
"\n",
|
| 684 |
+
" # ✅ 调整 attention mask\n",
|
| 685 |
+
" if attention_mask is not None:\n",
|
| 686 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 687 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 688 |
+
"\n",
|
| 689 |
+
" # ✅ 传入模型\n",
|
| 690 |
+
" outputs = self.model(\n",
|
| 691 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 692 |
+
" attention_mask=attention_mask,\n",
|
| 693 |
+
" labels=labels,\n",
|
| 694 |
+
" )\n",
|
| 695 |
+
"\n",
|
| 696 |
+
" return outputs\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"from transformers import Trainer\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"class GraphTrainer(Trainer):\n",
|
| 701 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 702 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 703 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 704 |
+
" labels = inputs[\"labels\"]\n",
|
| 705 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 706 |
+
"\n",
|
| 707 |
+
" if graph_embedding is not None:\n",
|
| 708 |
+
" outputs = model(\n",
|
| 709 |
+
" input_ids=input_ids,\n",
|
| 710 |
+
" attention_mask=attention_mask,\n",
|
| 711 |
+
" labels=labels,\n",
|
| 712 |
+
" graph_embedding=graph_embedding, \n",
|
| 713 |
+
" )\n",
|
| 714 |
+
" else:\n",
|
| 715 |
+
" outputs = model(\n",
|
| 716 |
+
" input_ids=input_ids,\n",
|
| 717 |
+
" attention_mask=attention_mask,\n",
|
| 718 |
+
" labels=labels,\n",
|
| 719 |
+
" )\n",
|
| 720 |
+
"\n",
|
| 721 |
+
" loss = outputs.loss\n",
|
| 722 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 723 |
+
"\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"from transformers import AutoConfig\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"# 1. 加载模型的配置\n",
|
| 728 |
+
"config = AutoConfig.from_pretrained(MODEL_NAME)\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"# 2. 使用配置创建 GraphAwareLM 实例\n",
|
| 731 |
+
"model = GraphAwareLM.from_config(config) \n",
|
| 732 |
+
"\n",
|
| 733 |
+
"pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 734 |
+
"model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 737 |
+
"# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 738 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"# ✅ 训练参数\n",
|
| 741 |
+
"training_args = TrainingArguments(\n",
|
| 742 |
+
" output_dir=\"./results\",\n",
|
| 743 |
+
" per_device_train_batch_size=7,\n",
|
| 744 |
+
" eval_strategy=\"no\",\n",
|
| 745 |
+
" save_strategy=\"steps\",\n",
|
| 746 |
+
" save_steps=3000,\n",
|
| 747 |
+
" logging_steps=50,\n",
|
| 748 |
+
" bf16=True,\n",
|
| 749 |
+
" optim=\"galore_adamw\",\n",
|
| 750 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 751 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 752 |
+
" warmup_steps=1000,\n",
|
| 753 |
+
" num_train_epochs=3,\n",
|
| 754 |
+
" push_to_hub=True,\n",
|
| 755 |
+
" hub_model_id=HF_NAME,\n",
|
| 756 |
+
" hub_strategy=\"every_save\",\n",
|
| 757 |
+
" run_name = \"experi0304\"\n",
|
| 758 |
+
")\n",
|
| 759 |
+
"\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 762 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 763 |
+
"\n",
|
| 764 |
+
"# ✅ 训练\n",
|
| 765 |
+
"trainer = GraphTrainer(\n",
|
| 766 |
+
" model=model,\n",
|
| 767 |
+
" args=training_args,\n",
|
| 768 |
+
" train_dataset=train_dataset,\n",
|
| 769 |
+
")\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"trainer.train()\n",
|
| 772 |
+
"trainer.save_model(\"/workspace/model\")\n",
|
| 773 |
+
"trainer.push_to_hub()\n",
|
| 774 |
+
"\n",
|
| 775 |
+
"\n"
|
| 776 |
+
]
|
| 777 |
+
},
|
| 778 |
+
{
|
| 779 |
+
"cell_type": "code",
|
| 780 |
+
"execution_count": 5,
|
| 781 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 782 |
+
"metadata": {},
|
| 783 |
+
"outputs": [
|
| 784 |
+
{
|
| 785 |
+
"name": "stdout",
|
| 786 |
+
"output_type": "stream",
|
| 787 |
+
"text": [
|
| 788 |
+
"🚀 处理后数据条数: 12384\n",
|
| 789 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 790 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 791 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 792 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 793 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 794 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 795 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 796 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 797 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 798 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 799 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 800 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 801 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 802 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 803 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 804 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 805 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 806 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 807 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 808 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 809 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 810 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 811 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 812 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 813 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 814 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 815 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 816 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 817 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 818 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 819 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 820 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 821 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 822 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 823 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 824 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 825 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 826 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 827 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 828 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 829 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 830 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 831 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 832 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 833 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 834 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 835 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 836 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 837 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 838 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 839 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 840 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 841 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 842 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 843 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 844 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 845 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 846 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 847 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 848 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 849 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 850 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 851 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 852 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 853 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 854 |
+
]
|
| 855 |
+
}
|
| 856 |
+
],
|
| 857 |
+
"source": [
|
| 858 |
+
"import json\n",
|
| 859 |
+
"import torch\n",
|
| 860 |
+
"from transformers import AutoTokenizer\n",
|
| 861 |
+
"\n",
|
| 862 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 863 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 864 |
+
"\n",
|
| 865 |
+
"json_path = \"final_Graph.json\"\n",
|
| 866 |
+
"with open(json_path, \"r\") as f:\n",
|
| 867 |
+
" data = json.load(f)\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"train_data = []\n",
|
| 870 |
+
"\n",
|
| 871 |
+
"\n",
|
| 872 |
+
"for sample in data:\n",
|
| 873 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 874 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 875 |
+
"\n",
|
| 876 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 877 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 878 |
+
" continue\n",
|
| 879 |
+
"\n",
|
| 880 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 881 |
+
"\n",
|
| 882 |
+
" #拼接所有对话\n",
|
| 883 |
+
" dialogue_text = \"\"\n",
|
| 884 |
+
" for conv in conversations:\n",
|
| 885 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 886 |
+
" content = conv[\"value\"]\n",
|
| 887 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 888 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 889 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 890 |
+
"\n",
|
| 891 |
+
" tokenized = tokenizer(\n",
|
| 892 |
+
" dialogue_text,\n",
|
| 893 |
+
" padding=\"max_length\",\n",
|
| 894 |
+
" truncation=True,\n",
|
| 895 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 896 |
+
" return_tensors=\"pt\",\n",
|
| 897 |
+
" )\n",
|
| 898 |
+
"\n",
|
| 899 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 900 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 901 |
+
"\n",
|
| 902 |
+
" train_data.append({\n",
|
| 903 |
+
" \"input_ids\": input_ids,\n",
|
| 904 |
+
" \"attention_mask\": attention_mask,\n",
|
| 905 |
+
" \"labels\": input_ids.clone(),\n",
|
| 906 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 907 |
+
" })\n",
|
| 908 |
+
"\n",
|
| 909 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 910 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 911 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 912 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 913 |
+
]
|
| 914 |
+
},
|
| 915 |
+
{
|
| 916 |
+
"cell_type": "code",
|
| 917 |
+
"execution_count": 6,
|
| 918 |
+
"id": "a33bffb9-2ff9-4a4d-af2c-b89b30a69f7d",
|
| 919 |
+
"metadata": {
|
| 920 |
+
"scrolled": true
|
| 921 |
+
},
|
| 922 |
+
"outputs": [
|
| 923 |
+
{
|
| 924 |
+
"name": "stdout",
|
| 925 |
+
"output_type": "stream",
|
| 926 |
+
"text": [
|
| 927 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"name": "stderr",
|
| 932 |
+
"output_type": "stream",
|
| 933 |
+
"text": [
|
| 934 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 935 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:49: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 936 |
+
" warnings.warn(\n",
|
| 937 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 938 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 939 |
+
]
|
| 940 |
+
},
|
| 941 |
+
{
|
| 942 |
+
"data": {
|
| 943 |
+
"text/html": [
|
| 944 |
+
"Tracking run with wandb version 0.19.7"
|
| 945 |
+
],
|
| 946 |
+
"text/plain": [
|
| 947 |
+
"<IPython.core.display.HTML object>"
|
| 948 |
+
]
|
| 949 |
+
},
|
| 950 |
+
"metadata": {},
|
| 951 |
+
"output_type": "display_data"
|
| 952 |
+
},
|
| 953 |
+
{
|
| 954 |
+
"data": {
|
| 955 |
+
"text/html": [
|
| 956 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_074031-ofm5zhvd</code>"
|
| 957 |
+
],
|
| 958 |
+
"text/plain": [
|
| 959 |
+
"<IPython.core.display.HTML object>"
|
| 960 |
+
]
|
| 961 |
+
},
|
| 962 |
+
"metadata": {},
|
| 963 |
+
"output_type": "display_data"
|
| 964 |
+
},
|
| 965 |
+
{
|
| 966 |
+
"data": {
|
| 967 |
+
"text/html": [
|
| 968 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 969 |
+
],
|
| 970 |
+
"text/plain": [
|
| 971 |
+
"<IPython.core.display.HTML object>"
|
| 972 |
+
]
|
| 973 |
+
},
|
| 974 |
+
"metadata": {},
|
| 975 |
+
"output_type": "display_data"
|
| 976 |
+
},
|
| 977 |
+
{
|
| 978 |
+
"data": {
|
| 979 |
+
"text/html": [
|
| 980 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 981 |
+
],
|
| 982 |
+
"text/plain": [
|
| 983 |
+
"<IPython.core.display.HTML object>"
|
| 984 |
+
]
|
| 985 |
+
},
|
| 986 |
+
"metadata": {},
|
| 987 |
+
"output_type": "display_data"
|
| 988 |
+
},
|
| 989 |
+
{
|
| 990 |
+
"data": {
|
| 991 |
+
"text/html": [
|
| 992 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd</a>"
|
| 993 |
+
],
|
| 994 |
+
"text/plain": [
|
| 995 |
+
"<IPython.core.display.HTML object>"
|
| 996 |
+
]
|
| 997 |
+
},
|
| 998 |
+
"metadata": {},
|
| 999 |
+
"output_type": "display_data"
|
| 1000 |
+
},
|
| 1001 |
+
{
|
| 1002 |
+
"data": {
|
| 1003 |
+
"text/html": [
|
| 1004 |
+
"\n",
|
| 1005 |
+
" <div>\n",
|
| 1006 |
+
" \n",
|
| 1007 |
+
" <progress value='89' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 1008 |
+
" [ 89/5310 01:06 < 1:06:24, 1.31 it/s, Epoch 0.05/3]\n",
|
| 1009 |
+
" </div>\n",
|
| 1010 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 1011 |
+
" <thead>\n",
|
| 1012 |
+
" <tr style=\"text-align: left;\">\n",
|
| 1013 |
+
" <th>Step</th>\n",
|
| 1014 |
+
" <th>Training Loss</th>\n",
|
| 1015 |
+
" </tr>\n",
|
| 1016 |
+
" </thead>\n",
|
| 1017 |
+
" <tbody>\n",
|
| 1018 |
+
" <tr>\n",
|
| 1019 |
+
" <td>50</td>\n",
|
| 1020 |
+
" <td>0.000000</td>\n",
|
| 1021 |
+
" </tr>\n",
|
| 1022 |
+
" </tbody>\n",
|
| 1023 |
+
"</table><p>"
|
| 1024 |
+
],
|
| 1025 |
+
"text/plain": [
|
| 1026 |
+
"<IPython.core.display.HTML object>"
|
| 1027 |
+
]
|
| 1028 |
+
},
|
| 1029 |
+
"metadata": {},
|
| 1030 |
+
"output_type": "display_data"
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
"ename": "KeyboardInterrupt",
|
| 1034 |
+
"evalue": "",
|
| 1035 |
+
"output_type": "error",
|
| 1036 |
+
"traceback": [
|
| 1037 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1038 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 1039 |
+
"Cell \u001b[0;32mIn[6], line 150\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# ✅ 训练\u001b[39;00m\n\u001b[1;32m 144\u001b[0m trainer \u001b[38;5;241m=\u001b[39m GraphTrainer(\n\u001b[1;32m 145\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 146\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m 147\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 148\u001b[0m )\n\u001b[0;32m--> 150\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub()\n\u001b[1;32m 152\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/workspace/model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 1040 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2232\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2233\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2234\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2235\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2236\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2238\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
|
| 1041 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2548\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2541\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2542\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2544\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m 2545\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2546\u001b[0m )\n\u001b[1;32m 2547\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2548\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2551\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2554\u001b[0m ):\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2556\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
| 1042 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3740\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED:\n\u001b[1;32m 3738\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscale_wrt_gas\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 3740\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3742\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n",
|
| 1043 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:2325\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 2324\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 2325\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m learning_rate \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_lomo_optimizer:\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n",
|
| 1044 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1045 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1046 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 1047 |
+
]
|
| 1048 |
+
}
|
| 1049 |
+
],
|
| 1050 |
+
"source": [
|
| 1051 |
+
"import json\n",
|
| 1052 |
+
"import torch\n",
|
| 1053 |
+
"import os\n",
|
| 1054 |
+
"from transformers import AutoTokenizer\n",
|
| 1055 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 1056 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 1057 |
+
"if 'train_data' not in globals():\n",
|
| 1058 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 1059 |
+
" \n",
|
| 1060 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 1061 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 1062 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 1063 |
+
" else:\n",
|
| 1064 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 1065 |
+
" exit()\n",
|
| 1066 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 1067 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 1068 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 1069 |
+
"\n",
|
| 1070 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1071 |
+
"\n",
|
| 1072 |
+
"\n",
|
| 1073 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 1074 |
+
"\n",
|
| 1075 |
+
"model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
"\n",
|
| 1078 |
+
"from torch.utils.data import Dataset\n",
|
| 1079 |
+
"\n",
|
| 1080 |
+
"class GraphDataset(Dataset):\n",
|
| 1081 |
+
" def __init__(self, data):\n",
|
| 1082 |
+
" self.data = data\n",
|
| 1083 |
+
"\n",
|
| 1084 |
+
" def __len__(self):\n",
|
| 1085 |
+
" return len(self.data)\n",
|
| 1086 |
+
"\n",
|
| 1087 |
+
" def __getitem__(self, idx):\n",
|
| 1088 |
+
" sample = self.data[idx]\n",
|
| 1089 |
+
" return {\n",
|
| 1090 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 1091 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 1092 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 1093 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 1094 |
+
" }\n",
|
| 1095 |
+
"\n",
|
| 1096 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 1097 |
+
"import torch\n",
|
| 1098 |
+
"import torch.nn as nn\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1101 |
+
" def __init__(self, config):\n",
|
| 1102 |
+
" super().__init__(config)\n",
|
| 1103 |
+
" self.model = AutoModelForCausalLM.from_pretrained(config)\n",
|
| 1104 |
+
" \n",
|
| 1105 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 1106 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 1107 |
+
"\n",
|
| 1108 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1109 |
+
" \"\"\"\n",
|
| 1110 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1111 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1112 |
+
" \"\"\"\n",
|
| 1113 |
+
" # ✅ 获取 token embedding\n",
|
| 1114 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1117 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 1118 |
+
"\n",
|
| 1119 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1120 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1121 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1122 |
+
"\n",
|
| 1123 |
+
" # ✅ 调整 attention mask\n",
|
| 1124 |
+
" if attention_mask is not None:\n",
|
| 1125 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1126 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1127 |
+
"\n",
|
| 1128 |
+
" # ✅ 传入模型\n",
|
| 1129 |
+
" outputs = self.model(\n",
|
| 1130 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1131 |
+
" attention_mask=attention_mask,\n",
|
| 1132 |
+
" labels=labels,\n",
|
| 1133 |
+
" )\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
" return outputs\n",
|
| 1136 |
+
"\n",
|
| 1137 |
+
"from transformers import Trainer\n",
|
| 1138 |
+
"\n",
|
| 1139 |
+
"class GraphTrainer(Trainer):\n",
|
| 1140 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 1141 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 1142 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 1143 |
+
" labels = inputs[\"labels\"]\n",
|
| 1144 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
" if graph_embedding is not None:\n",
|
| 1147 |
+
" outputs = model(\n",
|
| 1148 |
+
" input_ids=input_ids,\n",
|
| 1149 |
+
" attention_mask=attention_mask,\n",
|
| 1150 |
+
" labels=labels,\n",
|
| 1151 |
+
" graph_embedding=graph_embedding, \n",
|
| 1152 |
+
" )\n",
|
| 1153 |
+
" else:\n",
|
| 1154 |
+
" outputs = model(\n",
|
| 1155 |
+
" input_ids=input_ids,\n",
|
| 1156 |
+
" attention_mask=attention_mask,\n",
|
| 1157 |
+
" labels=labels,\n",
|
| 1158 |
+
" )\n",
|
| 1159 |
+
"\n",
|
| 1160 |
+
" loss = outputs.loss\n",
|
| 1161 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 1162 |
+
"\n",
|
| 1163 |
+
"\n",
|
| 1164 |
+
"\n",
|
| 1165 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 1166 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 1167 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 1168 |
+
"\n",
|
| 1169 |
+
"# ✅ 训练参数\n",
|
| 1170 |
+
"training_args = TrainingArguments(\n",
|
| 1171 |
+
" output_dir=\"./results\",\n",
|
| 1172 |
+
" per_device_train_batch_size=7,\n",
|
| 1173 |
+
" eval_strategy=\"no\",\n",
|
| 1174 |
+
" save_strategy=\"steps\",\n",
|
| 1175 |
+
" save_steps=3000,\n",
|
| 1176 |
+
" logging_steps=50,\n",
|
| 1177 |
+
" fp16=True,\n",
|
| 1178 |
+
" optim=\"galore_adamw\",\n",
|
| 1179 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 1180 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 1181 |
+
" warmup_steps=1000,\n",
|
| 1182 |
+
" num_train_epochs=3,\n",
|
| 1183 |
+
" push_to_hub=True,\n",
|
| 1184 |
+
" hub_model_id=HF_NAME,\n",
|
| 1185 |
+
" hub_strategy=\"every_save\",\n",
|
| 1186 |
+
" run_name = \"experi0304\"\n",
|
| 1187 |
+
")\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
"\n",
|
| 1190 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 1191 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 1192 |
+
"\n",
|
| 1193 |
+
"# ✅ 训练\n",
|
| 1194 |
+
"trainer = GraphTrainer(\n",
|
| 1195 |
+
" model=model,\n",
|
| 1196 |
+
" args=training_args,\n",
|
| 1197 |
+
" train_dataset=train_dataset,\n",
|
| 1198 |
+
")\n",
|
| 1199 |
+
"\n",
|
| 1200 |
+
"trainer.train()\n",
|
| 1201 |
+
"trainer.push_to_hub()\n",
|
| 1202 |
+
"trainer.save_model(\"/workspace/model\")\n",
|
| 1203 |
+
"\n"
|
| 1204 |
+
]
|
| 1205 |
+
},
|
| 1206 |
+
{
|
| 1207 |
+
"cell_type": "code",
|
| 1208 |
+
"execution_count": 1,
|
| 1209 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1210 |
+
"metadata": {},
|
| 1211 |
+
"outputs": [],
|
| 1212 |
+
"source": [
|
| 1213 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 1214 |
+
"import torch\n",
|
| 1215 |
+
"import torch.nn as nn\n",
|
| 1216 |
+
"\n",
|
| 1217 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1218 |
+
"\n",
|
| 1219 |
+
" \n",
|
| 1220 |
+
" def __init__(self, config):\n",
|
| 1221 |
+
" super().__init__(config)\n",
|
| 1222 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 1223 |
+
"\n",
|
| 1224 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1225 |
+
" \"\"\"\n",
|
| 1226 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1227 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1228 |
+
" \"\"\"\n",
|
| 1229 |
+
" # ✅ 获取 token embedding\n",
|
| 1230 |
+
" inputs_embeds = self.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1231 |
+
"\n",
|
| 1232 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1233 |
+
" graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
|
| 1234 |
+
"\n",
|
| 1235 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1236 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1237 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1238 |
+
"\n",
|
| 1239 |
+
" # ✅ 调整 attention mask\n",
|
| 1240 |
+
" if attention_mask is not None:\n",
|
| 1241 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1242 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1243 |
+
"\n",
|
| 1244 |
+
" # ✅ 传入模型\n",
|
| 1245 |
+
" outputs = self.model(\n",
|
| 1246 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1247 |
+
" attention_mask=attention_mask,\n",
|
| 1248 |
+
" labels=labels,\n",
|
| 1249 |
+
" )\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
" return outputs\n",
|
| 1252 |
+
"\n",
|
| 1253 |
+
" @classmethod\n",
|
| 1254 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1255 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1256 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 1257 |
+
" return model\n"
|
| 1258 |
+
]
|
| 1259 |
+
},
|
| 1260 |
+
{
|
| 1261 |
+
"cell_type": "code",
|
| 1262 |
+
"execution_count": 2,
|
| 1263 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1264 |
+
"metadata": {},
|
| 1265 |
+
"outputs": [],
|
| 1266 |
+
"source": [
|
| 1267 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1268 |
+
]
|
| 1269 |
+
},
|
| 1270 |
+
{
|
| 1271 |
+
"cell_type": "code",
|
| 1272 |
+
"execution_count": 3,
|
| 1273 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1274 |
+
"metadata": {},
|
| 1275 |
+
"outputs": [
|
| 1276 |
+
{
|
| 1277 |
+
"name": "stderr",
|
| 1278 |
+
"output_type": "stream",
|
| 1279 |
+
"text": [
|
| 1280 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
|
| 1281 |
+
]
|
| 1282 |
+
},
|
| 1283 |
+
{
|
| 1284 |
+
"data": {
|
| 1285 |
+
"text/plain": [
|
| 1286 |
+
"Qwen2ForCausalLM(\n",
|
| 1287 |
+
" (model): Qwen2Model(\n",
|
| 1288 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1289 |
+
" (layers): ModuleList(\n",
|
| 1290 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1291 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1292 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1293 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1294 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1295 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1296 |
+
" )\n",
|
| 1297 |
+
" (mlp): Qwen2MLP(\n",
|
| 1298 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1299 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1300 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1301 |
+
" (act_fn): SiLU()\n",
|
| 1302 |
+
" )\n",
|
| 1303 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1304 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1305 |
+
" )\n",
|
| 1306 |
+
" )\n",
|
| 1307 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1308 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1309 |
+
" )\n",
|
| 1310 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1311 |
+
" (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1312 |
+
")"
|
| 1313 |
+
]
|
| 1314 |
+
},
|
| 1315 |
+
"execution_count": 3,
|
| 1316 |
+
"metadata": {},
|
| 1317 |
+
"output_type": "execute_result"
|
| 1318 |
+
}
|
| 1319 |
+
],
|
| 1320 |
+
"source": [
|
| 1321 |
+
"import torch\n",
|
| 1322 |
+
"from transformers import AutoTokenizer\n",
|
| 1323 |
+
"\n",
|
| 1324 |
+
"# 加载 tokenizer\n",
|
| 1325 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1326 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1327 |
+
"\n",
|
| 1328 |
+
"# 加载训练好的模型\n",
|
| 1329 |
+
"model_path = \"/workspace/model\"\n",
|
| 1330 |
+
"model = GraphAwareLM.from_pretrained(model_path).to(device)\n",
|
| 1331 |
+
"model.eval() # 设置为推理模式\n"
|
| 1332 |
+
]
|
| 1333 |
+
},
|
| 1334 |
+
{
|
| 1335 |
+
"cell_type": "code",
|
| 1336 |
+
"execution_count": 8,
|
| 1337 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1338 |
+
"metadata": {},
|
| 1339 |
+
"outputs": [],
|
| 1340 |
+
"source": [
|
| 1341 |
+
"import json\n",
|
| 1342 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1343 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1344 |
+
" data = json.load(f)\n",
|
| 1345 |
+
"\n",
|
| 1346 |
+
"test_data = data[0]\n",
|
| 1347 |
+
"\n",
|
| 1348 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1349 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1350 |
+
"\n",
|
| 1351 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).to(device)\n",
|
| 1352 |
+
"\n",
|
| 1353 |
+
"question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1354 |
+
"\n",
|
| 1355 |
+
"from transformers import AutoTokenizer\n",
|
| 1356 |
+
"\n",
|
| 1357 |
+
"# ✅ 输入文本\n",
|
| 1358 |
+
"ROLE_TOKENS = {\n",
|
| 1359 |
+
" \"human\": \"<|User|>\", \n",
|
| 1360 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1361 |
+
"}\n",
|
| 1362 |
+
"GRAPH_LENGTH = 512\n",
|
| 1363 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1364 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1365 |
+
"\n",
|
| 1366 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1367 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1368 |
+
]
|
| 1369 |
+
},
|
| 1370 |
+
{
|
| 1371 |
+
"cell_type": "code",
|
| 1372 |
+
"execution_count": 5,
|
| 1373 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1374 |
+
"metadata": {},
|
| 1375 |
+
"outputs": [
|
| 1376 |
+
{
|
| 1377 |
+
"data": {
|
| 1378 |
+
"text/plain": [
|
| 1379 |
+
"tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1380 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1381 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1382 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1383 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1384 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1385 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1386 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1387 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1388 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1389 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1390 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1391 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1392 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1393 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1394 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1395 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1396 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1397 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1398 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1399 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1400 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1401 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1402 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1403 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1404 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1405 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1406 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1407 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1408 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1409 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1410 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1411 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1412 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1413 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1414 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1415 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1416 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1417 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1418 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1419 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1420 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1421 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1422 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1423 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1424 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1425 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1426 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1427 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1428 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1429 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1430 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1431 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1432 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1433 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1434 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1435 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1436 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1437 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1438 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1439 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1440 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1441 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1442 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
|
| 1443 |
+
" device='cuda:0')"
|
| 1444 |
+
]
|
| 1445 |
+
},
|
| 1446 |
+
"execution_count": 5,
|
| 1447 |
+
"metadata": {},
|
| 1448 |
+
"output_type": "execute_result"
|
| 1449 |
+
}
|
| 1450 |
+
],
|
| 1451 |
+
"source": [
|
| 1452 |
+
"graph_embedding"
|
| 1453 |
+
]
|
| 1454 |
+
},
|
| 1455 |
+
{
|
| 1456 |
+
"cell_type": "code",
|
| 1457 |
+
"execution_count": 15,
|
| 1458 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1459 |
+
"metadata": {},
|
| 1460 |
+
"outputs": [
|
| 1461 |
+
{
|
| 1462 |
+
"name": "stdout",
|
| 1463 |
+
"output_type": "stream",
|
| 1464 |
+
"text": [
|
| 1465 |
+
"模型回复: How\n"
|
| 1466 |
+
]
|
| 1467 |
+
}
|
| 1468 |
+
],
|
| 1469 |
+
"source": [
|
| 1470 |
+
"# ✅ 进行前向传播\n",
|
| 1471 |
+
"with torch.no_grad():\n",
|
| 1472 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1473 |
+
"\n",
|
| 1474 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1475 |
+
"logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1476 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1477 |
+
"\n",
|
| 1478 |
+
"# ✅ 反向编码为文本\n",
|
| 1479 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1480 |
+
"\n",
|
| 1481 |
+
"print(\"模型回复:\", response_text)"
|
| 1482 |
+
]
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"cell_type": "code",
|
| 1486 |
+
"execution_count": 9,
|
| 1487 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1488 |
+
"metadata": {},
|
| 1489 |
+
"outputs": [
|
| 1490 |
+
{
|
| 1491 |
+
"name": "stdout",
|
| 1492 |
+
"output_type": "stream",
|
| 1493 |
+
"text": [
|
| 1494 |
+
"Generated Response: Is there any sequential logic in the module, and if so, how is it handled? What are the module's inputs and outputs?\n",
|
| 1495 |
+
"What are the module's inputs and outputs?\n",
|
| 1496 |
+
"What are the module's inputs and outputs?\n",
|
| 1497 |
+
"What are the module's inputs and outputs?\n",
|
| 1498 |
+
"What is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and module's output, and module's input, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module\n"
|
| 1499 |
+
]
|
| 1500 |
+
}
|
| 1501 |
+
],
|
| 1502 |
+
"source": [
|
| 1503 |
+
"max_new_tokens = 1024\n",
|
| 1504 |
+
"generated_ids = input_ids.clone()\n",
|
| 1505 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1506 |
+
"for _ in range(max_new_tokens):\n",
|
| 1507 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1508 |
+
" with torch.no_grad():\n",
|
| 1509 |
+
" outputs = model(\n",
|
| 1510 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1511 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1512 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1513 |
+
" )\n",
|
| 1514 |
+
"\n",
|
| 1515 |
+
"\n",
|
| 1516 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1517 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1518 |
+
" # print(next_token)\n",
|
| 1519 |
+
"\n",
|
| 1520 |
+
"\n",
|
| 1521 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1522 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1523 |
+
"\n",
|
| 1524 |
+
" # print(generated_ids)\n",
|
| 1525 |
+
"\n",
|
| 1526 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1527 |
+
" break\n",
|
| 1528 |
+
"\n",
|
| 1529 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1530 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1531 |
+
" ) \n",
|
| 1532 |
+
"\n",
|
| 1533 |
+
"# ✅ 解码最终输出\n",
|
| 1534 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1535 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1536 |
+
]
|
| 1537 |
+
},
|
| 1538 |
+
{
|
| 1539 |
+
"cell_type": "code",
|
| 1540 |
+
"execution_count": 10,
|
| 1541 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1542 |
+
"metadata": {},
|
| 1543 |
+
"outputs": [
|
| 1544 |
+
{
|
| 1545 |
+
"data": {
|
| 1546 |
+
"text/plain": [
|
| 1547 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1548 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1549 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1550 |
+
]
|
| 1551 |
+
},
|
| 1552 |
+
"execution_count": 10,
|
| 1553 |
+
"metadata": {},
|
| 1554 |
+
"output_type": "execute_result"
|
| 1555 |
+
}
|
| 1556 |
+
],
|
| 1557 |
+
"source": [
|
| 1558 |
+
"generated_ids"
|
| 1559 |
+
]
|
| 1560 |
+
},
|
| 1561 |
+
{
|
| 1562 |
+
"cell_type": "code",
|
| 1563 |
+
"execution_count": null,
|
| 1564 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1565 |
+
"metadata": {},
|
| 1566 |
+
"outputs": [],
|
| 1567 |
+
"source": []
|
| 1568 |
+
}
|
| 1569 |
+
],
|
| 1570 |
+
"metadata": {
|
| 1571 |
+
"kernelspec": {
|
| 1572 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1573 |
+
"language": "python",
|
| 1574 |
+
"name": "python3"
|
| 1575 |
+
},
|
| 1576 |
+
"language_info": {
|
| 1577 |
+
"codemirror_mode": {
|
| 1578 |
+
"name": "ipython",
|
| 1579 |
+
"version": 3
|
| 1580 |
+
},
|
| 1581 |
+
"file_extension": ".py",
|
| 1582 |
+
"mimetype": "text/x-python",
|
| 1583 |
+
"name": "python",
|
| 1584 |
+
"nbconvert_exporter": "python",
|
| 1585 |
+
"pygments_lexer": "ipython3",
|
| 1586 |
+
"version": "3.10.12"
|
| 1587 |
+
}
|
| 1588 |
+
},
|
| 1589 |
+
"nbformat": 4,
|
| 1590 |
+
"nbformat_minor": 5
|
| 1591 |
+
}
|
.ipynb_checkpoints/graph_train2-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,1674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new2\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 3,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 4,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_111730-i9v1vlu1</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">experi030402</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:34:08, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.319300</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>3.641300</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>1.521800</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>1.027500</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>0.922400</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>0.866900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>0.800500</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>0.721600</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>0.740400</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>0.737000</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.713500</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.747000</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>0.869500</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.473300</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>0.753000</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>0.741300</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>0.751400</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>0.787600</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>0.783200</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>0.780200</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>1.012900</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>1.411700</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>1.536400</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>0.853800</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>0.756500</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>0.750800</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>0.747400</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>0.844400</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>0.858400</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>1.053400</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>1.591600</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>1.498900</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>1.471700</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>1.221100</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>1.802300</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>1.826000</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.857300</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.561800</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>1.398800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>1.398900</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>1.381600</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>0.890300</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>0.763700</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>0.753100</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>0.745500</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>1.186100</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>0.862000</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>1.024600</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>1.028400</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>1.008500</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>0.942800</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>0.849700</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>0.771400</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>0.794100</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>0.819200</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>0.937500</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>1.064500</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>1.189300</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.071100</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.003300</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>1.073900</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.043100</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>1.282600</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.145400</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>1.925800</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>2.005600</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.122600</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.163000</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.046600</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.152200</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.151700</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>5.394900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>4.677800</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>4.122200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>3.710200</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>3.350800</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>3.126300</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.988700</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>2.872000</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.848200</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.823900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.781200</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.735000</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.725900</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>2.644400</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.700000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>2.650100</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.704500</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>2.596700</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>2.510500</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>2.515800</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.498100</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.458900</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.449700</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.425000</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.362300</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>2.232000</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.361500</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>2.302300</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.333900</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.367200</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>2.288300</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.426100</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.344100</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>2.283500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>2.296500</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2/commit/291285a5f2155c79a0da893645d8df9bbca98f63', commit_message='End of training', commit_description='', oid='291285a5f2155c79a0da893645d8df9bbca98f63', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new2'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, pretrained_model_name_or_path):\n",
|
| 662 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 663 |
+
" \n",
|
| 664 |
+
" # ✅ 载入 `MODEL_NAME` 预训练模型\n",
|
| 665 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" \n",
|
| 668 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 669 |
+
" self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
|
| 670 |
+
"\n",
|
| 671 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 672 |
+
" \"\"\"\n",
|
| 673 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 674 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 675 |
+
" \"\"\"\n",
|
| 676 |
+
" # ✅ 获取 token embedding\n",
|
| 677 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 680 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 681 |
+
"\n",
|
| 682 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 683 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 684 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 685 |
+
"\n",
|
| 686 |
+
" # ✅ 调整 attention mask\n",
|
| 687 |
+
" if attention_mask is not None:\n",
|
| 688 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 689 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" # ✅ 传入模型\n",
|
| 692 |
+
" outputs = self.model(\n",
|
| 693 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 694 |
+
" attention_mask=attention_mask,\n",
|
| 695 |
+
" labels=labels,\n",
|
| 696 |
+
" )\n",
|
| 697 |
+
"\n",
|
| 698 |
+
" return outputs\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"from transformers import Trainer\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"class GraphTrainer(Trainer):\n",
|
| 703 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 704 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 705 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 706 |
+
" labels = inputs[\"labels\"]\n",
|
| 707 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 708 |
+
"\n",
|
| 709 |
+
" if graph_embedding is not None:\n",
|
| 710 |
+
" outputs = model(\n",
|
| 711 |
+
" input_ids=input_ids,\n",
|
| 712 |
+
" attention_mask=attention_mask,\n",
|
| 713 |
+
" labels=labels,\n",
|
| 714 |
+
" graph_embedding=graph_embedding, \n",
|
| 715 |
+
" )\n",
|
| 716 |
+
" else:\n",
|
| 717 |
+
" outputs = model(\n",
|
| 718 |
+
" input_ids=input_ids,\n",
|
| 719 |
+
" attention_mask=attention_mask,\n",
|
| 720 |
+
" labels=labels,\n",
|
| 721 |
+
" )\n",
|
| 722 |
+
"\n",
|
| 723 |
+
" loss = outputs.loss\n",
|
| 724 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 725 |
+
"\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"from transformers import AutoConfig\n",
|
| 728 |
+
"\n",
|
| 729 |
+
"# ✅ 载入微调模型\n",
|
| 730 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"# # 1. 加载模型的配置\n",
|
| 733 |
+
"# config = AutoConfig.from_pretrained(MODEL_NAME)\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"# # 2. 使用配置创建 GraphAwareLM 实例\n",
|
| 736 |
+
"# model = GraphAwareLM.from_config(config) \n",
|
| 737 |
+
"\n",
|
| 738 |
+
"# pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 739 |
+
"# model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
|
| 740 |
+
"\n",
|
| 741 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 742 |
+
"# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 743 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"# ✅ 训练参数\n",
|
| 746 |
+
"training_args = TrainingArguments(\n",
|
| 747 |
+
" output_dir=\"./results2\",\n",
|
| 748 |
+
" per_device_train_batch_size=7,\n",
|
| 749 |
+
" eval_strategy=\"no\",\n",
|
| 750 |
+
" save_strategy=\"steps\",\n",
|
| 751 |
+
" save_steps=3000,\n",
|
| 752 |
+
" logging_steps=50,\n",
|
| 753 |
+
" bf16=True,\n",
|
| 754 |
+
" optim=\"galore_adamw\",\n",
|
| 755 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 756 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 757 |
+
" warmup_steps=1000,\n",
|
| 758 |
+
" num_train_epochs=3,\n",
|
| 759 |
+
" push_to_hub=True,\n",
|
| 760 |
+
" hub_model_id=HF_NAME,\n",
|
| 761 |
+
" hub_strategy=\"every_save\",\n",
|
| 762 |
+
" run_name = \"experi030402\"\n",
|
| 763 |
+
")\n",
|
| 764 |
+
"\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 767 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"# ✅ 训练\n",
|
| 770 |
+
"trainer = GraphTrainer(\n",
|
| 771 |
+
" model=model,\n",
|
| 772 |
+
" args=training_args,\n",
|
| 773 |
+
" train_dataset=train_dataset,\n",
|
| 774 |
+
")\n",
|
| 775 |
+
"\n",
|
| 776 |
+
"trainer.train()\n",
|
| 777 |
+
"trainer.save_model(\"/workspace/model2\")\n",
|
| 778 |
+
"trainer.push_to_hub()\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"\n"
|
| 781 |
+
]
|
| 782 |
+
},
|
| 783 |
+
{
|
| 784 |
+
"cell_type": "code",
|
| 785 |
+
"execution_count": 7,
|
| 786 |
+
"id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
|
| 787 |
+
"metadata": {},
|
| 788 |
+
"outputs": [
|
| 789 |
+
{
|
| 790 |
+
"data": {
|
| 791 |
+
"text/plain": [
|
| 792 |
+
"RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb')"
|
| 793 |
+
]
|
| 794 |
+
},
|
| 795 |
+
"execution_count": 7,
|
| 796 |
+
"metadata": {},
|
| 797 |
+
"output_type": "execute_result"
|
| 798 |
+
}
|
| 799 |
+
],
|
| 800 |
+
"source": [
|
| 801 |
+
"from huggingface_hub import HfApi\n",
|
| 802 |
+
"\n",
|
| 803 |
+
"api = HfApi()\n",
|
| 804 |
+
"repo_name = \"r1q1.5_graph_lora-wandb\" # 你的模型名称\n",
|
| 805 |
+
"api.create_repo(repo_name, exist_ok=True)"
|
| 806 |
+
]
|
| 807 |
+
},
|
| 808 |
+
{
|
| 809 |
+
"cell_type": "code",
|
| 810 |
+
"execution_count": 6,
|
| 811 |
+
"id": "73c434b9-5d58-4819-8526-24aa18ca1010",
|
| 812 |
+
"metadata": {},
|
| 813 |
+
"outputs": [
|
| 814 |
+
{
|
| 815 |
+
"data": {
|
| 816 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 817 |
+
"model_id": "727ca342a20348d38a4a1c6d286963e0",
|
| 818 |
+
"version_major": 2,
|
| 819 |
+
"version_minor": 0
|
| 820 |
+
},
|
| 821 |
+
"text/plain": [
|
| 822 |
+
"optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
|
| 823 |
+
]
|
| 824 |
+
},
|
| 825 |
+
"metadata": {},
|
| 826 |
+
"output_type": "display_data"
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"data": {
|
| 830 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 831 |
+
"model_id": "9d4967e00603441091d4527bccf89e43",
|
| 832 |
+
"version_major": 2,
|
| 833 |
+
"version_minor": 0
|
| 834 |
+
},
|
| 835 |
+
"text/plain": [
|
| 836 |
+
"scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
|
| 837 |
+
]
|
| 838 |
+
},
|
| 839 |
+
"metadata": {},
|
| 840 |
+
"output_type": "display_data"
|
| 841 |
+
},
|
| 842 |
+
{
|
| 843 |
+
"data": {
|
| 844 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 845 |
+
"model_id": "71abe1a9a91a4851bdd941995145dc8e",
|
| 846 |
+
"version_major": 2,
|
| 847 |
+
"version_minor": 0
|
| 848 |
+
},
|
| 849 |
+
"text/plain": [
|
| 850 |
+
"Upload 15 LFS files: 0%| | 0/15 [00:00<?, ?it/s]"
|
| 851 |
+
]
|
| 852 |
+
},
|
| 853 |
+
"metadata": {},
|
| 854 |
+
"output_type": "display_data"
|
| 855 |
+
},
|
| 856 |
+
{
|
| 857 |
+
"data": {
|
| 858 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 859 |
+
"model_id": "59323569e2fb415bbdbb006b3ad3bceb",
|
| 860 |
+
"version_major": 2,
|
| 861 |
+
"version_minor": 0
|
| 862 |
+
},
|
| 863 |
+
"text/plain": [
|
| 864 |
+
"rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
|
| 865 |
+
]
|
| 866 |
+
},
|
| 867 |
+
"metadata": {},
|
| 868 |
+
"output_type": "display_data"
|
| 869 |
+
},
|
| 870 |
+
{
|
| 871 |
+
"data": {
|
| 872 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 873 |
+
"model_id": "8855625556094031ba6b14bbd9a062b1",
|
| 874 |
+
"version_major": 2,
|
| 875 |
+
"version_minor": 0
|
| 876 |
+
},
|
| 877 |
+
"text/plain": [
|
| 878 |
+
"model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
|
| 879 |
+
]
|
| 880 |
+
},
|
| 881 |
+
"metadata": {},
|
| 882 |
+
"output_type": "display_data"
|
| 883 |
+
},
|
| 884 |
+
{
|
| 885 |
+
"data": {
|
| 886 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 887 |
+
"model_id": "c1f85411d4ac43ffa27ed9d24c18f468",
|
| 888 |
+
"version_major": 2,
|
| 889 |
+
"version_minor": 0
|
| 890 |
+
},
|
| 891 |
+
"text/plain": [
|
| 892 |
+
"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
| 893 |
+
]
|
| 894 |
+
},
|
| 895 |
+
"metadata": {},
|
| 896 |
+
"output_type": "display_data"
|
| 897 |
+
},
|
| 898 |
+
{
|
| 899 |
+
"data": {
|
| 900 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 901 |
+
"model_id": "e52bc8f9a5f34887a3844215e4fdfde2",
|
| 902 |
+
"version_major": 2,
|
| 903 |
+
"version_minor": 0
|
| 904 |
+
},
|
| 905 |
+
"text/plain": [
|
| 906 |
+
"training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
|
| 907 |
+
]
|
| 908 |
+
},
|
| 909 |
+
"metadata": {},
|
| 910 |
+
"output_type": "display_data"
|
| 911 |
+
},
|
| 912 |
+
{
|
| 913 |
+
"data": {
|
| 914 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 915 |
+
"model_id": "10d59eeefafe474488b7972ccf3ca70a",
|
| 916 |
+
"version_major": 2,
|
| 917 |
+
"version_minor": 0
|
| 918 |
+
},
|
| 919 |
+
"text/plain": [
|
| 920 |
+
"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
| 921 |
+
]
|
| 922 |
+
},
|
| 923 |
+
"metadata": {},
|
| 924 |
+
"output_type": "display_data"
|
| 925 |
+
},
|
| 926 |
+
{
|
| 927 |
+
"data": {
|
| 928 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 929 |
+
"model_id": "dfc1c91010114f09b760450f1b998aa4",
|
| 930 |
+
"version_major": 2,
|
| 931 |
+
"version_minor": 0
|
| 932 |
+
},
|
| 933 |
+
"text/plain": [
|
| 934 |
+
"model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
|
| 935 |
+
]
|
| 936 |
+
},
|
| 937 |
+
"metadata": {},
|
| 938 |
+
"output_type": "display_data"
|
| 939 |
+
},
|
| 940 |
+
{
|
| 941 |
+
"data": {
|
| 942 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 943 |
+
"model_id": "9ebc84de361240ad9317c70342168308",
|
| 944 |
+
"version_major": 2,
|
| 945 |
+
"version_minor": 0
|
| 946 |
+
},
|
| 947 |
+
"text/plain": [
|
| 948 |
+
"optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
|
| 949 |
+
]
|
| 950 |
+
},
|
| 951 |
+
"metadata": {},
|
| 952 |
+
"output_type": "display_data"
|
| 953 |
+
},
|
| 954 |
+
{
|
| 955 |
+
"data": {
|
| 956 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 957 |
+
"model_id": "4f6e8853b6ff41e88b967ae9c13a81ab",
|
| 958 |
+
"version_major": 2,
|
| 959 |
+
"version_minor": 0
|
| 960 |
+
},
|
| 961 |
+
"text/plain": [
|
| 962 |
+
"rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
|
| 963 |
+
]
|
| 964 |
+
},
|
| 965 |
+
"metadata": {},
|
| 966 |
+
"output_type": "display_data"
|
| 967 |
+
},
|
| 968 |
+
{
|
| 969 |
+
"data": {
|
| 970 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 971 |
+
"model_id": "75eb5a5e71c749b5945b66e2d41c441c",
|
| 972 |
+
"version_major": 2,
|
| 973 |
+
"version_minor": 0
|
| 974 |
+
},
|
| 975 |
+
"text/plain": [
|
| 976 |
+
"scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
|
| 977 |
+
]
|
| 978 |
+
},
|
| 979 |
+
"metadata": {},
|
| 980 |
+
"output_type": "display_data"
|
| 981 |
+
},
|
| 982 |
+
{
|
| 983 |
+
"data": {
|
| 984 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 985 |
+
"model_id": "d748aea8d80c4f81a36633a71ce0a0f7",
|
| 986 |
+
"version_major": 2,
|
| 987 |
+
"version_minor": 0
|
| 988 |
+
},
|
| 989 |
+
"text/plain": [
|
| 990 |
+
"training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
|
| 991 |
+
]
|
| 992 |
+
},
|
| 993 |
+
"metadata": {},
|
| 994 |
+
"output_type": "display_data"
|
| 995 |
+
},
|
| 996 |
+
{
|
| 997 |
+
"data": {
|
| 998 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 999 |
+
"model_id": "d1a84a5ec38f459bab4791065dc8efc3",
|
| 1000 |
+
"version_major": 2,
|
| 1001 |
+
"version_minor": 0
|
| 1002 |
+
},
|
| 1003 |
+
"text/plain": [
|
| 1004 |
+
"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
| 1005 |
+
]
|
| 1006 |
+
},
|
| 1007 |
+
"metadata": {},
|
| 1008 |
+
"output_type": "display_data"
|
| 1009 |
+
},
|
| 1010 |
+
{
|
| 1011 |
+
"data": {
|
| 1012 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1013 |
+
"model_id": "ae50659ac6024e5c8fc43ce898982f0c",
|
| 1014 |
+
"version_major": 2,
|
| 1015 |
+
"version_minor": 0
|
| 1016 |
+
},
|
| 1017 |
+
"text/plain": [
|
| 1018 |
+
"model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
|
| 1019 |
+
]
|
| 1020 |
+
},
|
| 1021 |
+
"metadata": {},
|
| 1022 |
+
"output_type": "display_data"
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"data": {
|
| 1026 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1027 |
+
"model_id": "c109c00c979d41d59fa48d4314bac1e4",
|
| 1028 |
+
"version_major": 2,
|
| 1029 |
+
"version_minor": 0
|
| 1030 |
+
},
|
| 1031 |
+
"text/plain": [
|
| 1032 |
+
"training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
|
| 1033 |
+
]
|
| 1034 |
+
},
|
| 1035 |
+
"metadata": {},
|
| 1036 |
+
"output_type": "display_data"
|
| 1037 |
+
},
|
| 1038 |
+
{
|
| 1039 |
+
"data": {
|
| 1040 |
+
"text/plain": [
|
| 1041 |
+
"CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora-results3/commit/5a14bfa05e9bd78cd030104d0e9ff02638731668', commit_message='upload results3', commit_description='', oid='5a14bfa05e9bd78cd030104d0e9ff02638731668', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-results3', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-results3'), pr_revision=None, pr_num=None)"
|
| 1042 |
+
]
|
| 1043 |
+
},
|
| 1044 |
+
"execution_count": 6,
|
| 1045 |
+
"metadata": {},
|
| 1046 |
+
"output_type": "execute_result"
|
| 1047 |
+
}
|
| 1048 |
+
],
|
| 1049 |
+
"source": [
|
| 1050 |
+
"from huggingface_hub import upload_folder\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
"upload_folder(\n",
|
| 1053 |
+
" folder_path = \"/workspace/wandb\",\n",
|
| 1054 |
+
" repo_id = \"YiFzhao/r1q1.5_graph_lora-wandb\",\n",
|
| 1055 |
+
" commit_message = \"upload wandb\",\n",
|
| 1056 |
+
")"
|
| 1057 |
+
]
|
| 1058 |
+
},
|
| 1059 |
+
{
|
| 1060 |
+
"cell_type": "code",
|
| 1061 |
+
"execution_count": 5,
|
| 1062 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 1063 |
+
"metadata": {},
|
| 1064 |
+
"outputs": [
|
| 1065 |
+
{
|
| 1066 |
+
"name": "stdout",
|
| 1067 |
+
"output_type": "stream",
|
| 1068 |
+
"text": [
|
| 1069 |
+
"🚀 处理后数据条数: 12384\n",
|
| 1070 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1071 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1072 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1073 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1074 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1075 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1076 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1077 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1078 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1079 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1080 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1081 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1082 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1083 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1084 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1085 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1086 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1087 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1088 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1089 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1090 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1091 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1092 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1093 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1094 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1095 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1096 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1097 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1098 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1099 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1100 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1101 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1102 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1103 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1104 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1105 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1106 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1107 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1108 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1109 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1110 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1111 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1112 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1113 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1114 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1115 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1116 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1117 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1118 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1119 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1120 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1121 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1122 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1123 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1124 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1125 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1126 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1127 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1128 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1129 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1130 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1131 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1132 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1133 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 1134 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 1135 |
+
]
|
| 1136 |
+
}
|
| 1137 |
+
],
|
| 1138 |
+
"source": [
|
| 1139 |
+
"import json\n",
|
| 1140 |
+
"import torch\n",
|
| 1141 |
+
"from transformers import AutoTokenizer\n",
|
| 1142 |
+
"\n",
|
| 1143 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1144 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1147 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1148 |
+
" data = json.load(f)\n",
|
| 1149 |
+
"\n",
|
| 1150 |
+
"train_data = []\n",
|
| 1151 |
+
"\n",
|
| 1152 |
+
"\n",
|
| 1153 |
+
"for sample in data:\n",
|
| 1154 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 1155 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 1156 |
+
"\n",
|
| 1157 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 1158 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 1159 |
+
" continue\n",
|
| 1160 |
+
"\n",
|
| 1161 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 1162 |
+
"\n",
|
| 1163 |
+
" #拼接���有对话\n",
|
| 1164 |
+
" dialogue_text = \"\"\n",
|
| 1165 |
+
" for conv in conversations:\n",
|
| 1166 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 1167 |
+
" content = conv[\"value\"]\n",
|
| 1168 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 1169 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 1170 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 1171 |
+
"\n",
|
| 1172 |
+
" tokenized = tokenizer(\n",
|
| 1173 |
+
" dialogue_text,\n",
|
| 1174 |
+
" padding=\"max_length\",\n",
|
| 1175 |
+
" truncation=True,\n",
|
| 1176 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 1177 |
+
" return_tensors=\"pt\",\n",
|
| 1178 |
+
" )\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 1181 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 1182 |
+
"\n",
|
| 1183 |
+
" train_data.append({\n",
|
| 1184 |
+
" \"input_ids\": input_ids,\n",
|
| 1185 |
+
" \"attention_mask\": attention_mask,\n",
|
| 1186 |
+
" \"labels\": input_ids.clone(),\n",
|
| 1187 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 1188 |
+
" })\n",
|
| 1189 |
+
"\n",
|
| 1190 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 1191 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 1192 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 1193 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 1194 |
+
]
|
| 1195 |
+
},
|
| 1196 |
+
{
|
| 1197 |
+
"cell_type": "code",
|
| 1198 |
+
"execution_count": 10,
|
| 1199 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1200 |
+
"metadata": {},
|
| 1201 |
+
"outputs": [],
|
| 1202 |
+
"source": [
|
| 1203 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 1204 |
+
"import torch\n",
|
| 1205 |
+
"import torch.nn as nn\n",
|
| 1206 |
+
"\n",
|
| 1207 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1208 |
+
" def __init__(self, pretrained_model_name_or_path):\n",
|
| 1209 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 1210 |
+
" \n",
|
| 1211 |
+
" # ✅ 载入 `MODEL_NAME` 预训练模型\n",
|
| 1212 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 1213 |
+
"\n",
|
| 1214 |
+
" \n",
|
| 1215 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 1216 |
+
" self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
|
| 1217 |
+
"\n",
|
| 1218 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1219 |
+
" \"\"\"\n",
|
| 1220 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1221 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1222 |
+
" \"\"\"\n",
|
| 1223 |
+
" # ✅ 获取 token embedding\n",
|
| 1224 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1225 |
+
"\n",
|
| 1226 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1227 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 1228 |
+
"\n",
|
| 1229 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1230 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1231 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1232 |
+
"\n",
|
| 1233 |
+
" # ✅ 调整 attention mask\n",
|
| 1234 |
+
" if attention_mask is not None:\n",
|
| 1235 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1236 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1237 |
+
"\n",
|
| 1238 |
+
" # ✅ 传入模型\n",
|
| 1239 |
+
" outputs = self.model(\n",
|
| 1240 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1241 |
+
" attention_mask=attention_mask,\n",
|
| 1242 |
+
" labels=labels,\n",
|
| 1243 |
+
" )\n",
|
| 1244 |
+
"\n",
|
| 1245 |
+
" return outputs\n",
|
| 1246 |
+
"\n",
|
| 1247 |
+
" def generate_with_graph(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
|
| 1248 |
+
" \"\"\"\n",
|
| 1249 |
+
" ✅ 自定义 `generate()`,支持 `graph_embedding`\n",
|
| 1250 |
+
" `input_text`: 需要生成文本的输入\n",
|
| 1251 |
+
" `graph_embedding`: 形状为 (1, 512) 的张量\n",
|
| 1252 |
+
" \"\"\"\n",
|
| 1253 |
+
" # ✅ 2. 处理 `graph_embedding`\n",
|
| 1254 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (1, hidden_size)\n",
|
| 1255 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (1, 1, hidden_size)\n",
|
| 1256 |
+
"\n",
|
| 1257 |
+
" # ✅ 3. 获取 Token Embeddings 并拼接\n",
|
| 1258 |
+
" inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
|
| 1259 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
|
| 1260 |
+
"\n",
|
| 1261 |
+
" # ✅ 4. 调整 `attention_mask`\n",
|
| 1262 |
+
" if \"attention_mask\" in inputs:\n",
|
| 1263 |
+
" graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
|
| 1264 |
+
" attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
|
| 1265 |
+
" else:\n",
|
| 1266 |
+
" attention_mask = None\n",
|
| 1267 |
+
"\n",
|
| 1268 |
+
" # ✅ 5. 进行文本生成\n",
|
| 1269 |
+
" with torch.no_grad():\n",
|
| 1270 |
+
" output_ids = self.model.generate(\n",
|
| 1271 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1272 |
+
" attention_mask=attention_mask,\n",
|
| 1273 |
+
" max_length=max_length,\n",
|
| 1274 |
+
" temperature=temperature,\n",
|
| 1275 |
+
" top_k=top_k,\n",
|
| 1276 |
+
" top_p=top_p,\n",
|
| 1277 |
+
" num_return_sequences=1\n",
|
| 1278 |
+
" )\n",
|
| 1279 |
+
"\n",
|
| 1280 |
+
" # ✅ 6. 解码生成的文本\n",
|
| 1281 |
+
" generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
|
| 1282 |
+
" return generated_text\n",
|
| 1283 |
+
"\n",
|
| 1284 |
+
" @classmethod\n",
|
| 1285 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1286 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1287 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 1288 |
+
" return model"
|
| 1289 |
+
]
|
| 1290 |
+
},
|
| 1291 |
+
{
|
| 1292 |
+
"cell_type": "code",
|
| 1293 |
+
"execution_count": 11,
|
| 1294 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1295 |
+
"metadata": {},
|
| 1296 |
+
"outputs": [],
|
| 1297 |
+
"source": [
|
| 1298 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1299 |
+
]
|
| 1300 |
+
},
|
| 1301 |
+
{
|
| 1302 |
+
"cell_type": "code",
|
| 1303 |
+
"execution_count": 12,
|
| 1304 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1305 |
+
"metadata": {},
|
| 1306 |
+
"outputs": [
|
| 1307 |
+
{
|
| 1308 |
+
"data": {
|
| 1309 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1310 |
+
"model_id": "7ad289c5523340f39799ad11e3bc1bb5",
|
| 1311 |
+
"version_major": 2,
|
| 1312 |
+
"version_minor": 0
|
| 1313 |
+
},
|
| 1314 |
+
"text/plain": [
|
| 1315 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1316 |
+
]
|
| 1317 |
+
},
|
| 1318 |
+
"metadata": {},
|
| 1319 |
+
"output_type": "display_data"
|
| 1320 |
+
},
|
| 1321 |
+
{
|
| 1322 |
+
"data": {
|
| 1323 |
+
"text/plain": [
|
| 1324 |
+
"Qwen2ForCausalLM(\n",
|
| 1325 |
+
" (model): Qwen2Model(\n",
|
| 1326 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1327 |
+
" (layers): ModuleList(\n",
|
| 1328 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1329 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1330 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1331 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1332 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1333 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1334 |
+
" )\n",
|
| 1335 |
+
" (mlp): Qwen2MLP(\n",
|
| 1336 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1337 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1338 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1339 |
+
" (act_fn): SiLU()\n",
|
| 1340 |
+
" )\n",
|
| 1341 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1342 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1343 |
+
" )\n",
|
| 1344 |
+
" )\n",
|
| 1345 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1346 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1347 |
+
" )\n",
|
| 1348 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1349 |
+
" (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1350 |
+
")"
|
| 1351 |
+
]
|
| 1352 |
+
},
|
| 1353 |
+
"execution_count": 12,
|
| 1354 |
+
"metadata": {},
|
| 1355 |
+
"output_type": "execute_result"
|
| 1356 |
+
}
|
| 1357 |
+
],
|
| 1358 |
+
"source": [
|
| 1359 |
+
"import torch\n",
|
| 1360 |
+
"from transformers import AutoTokenizer\n",
|
| 1361 |
+
"\n",
|
| 1362 |
+
"# 加载 tokenizer\n",
|
| 1363 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1364 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1365 |
+
"\n",
|
| 1366 |
+
"# 加载训练好的模型\n",
|
| 1367 |
+
"model_path = \"/workspace/model2\"\n",
|
| 1368 |
+
"model = GraphAwareLM.from_pretrained(\"/workspace/results2/checkpoint-5310\").to(device)\n",
|
| 1369 |
+
"model.eval() # 设置为推理模式\n"
|
| 1370 |
+
]
|
| 1371 |
+
},
|
| 1372 |
+
{
|
| 1373 |
+
"cell_type": "code",
|
| 1374 |
+
"execution_count": 13,
|
| 1375 |
+
"id": "51995891-8906-4049-9401-2d22e06a84e8",
|
| 1376 |
+
"metadata": {},
|
| 1377 |
+
"outputs": [
|
| 1378 |
+
{
|
| 1379 |
+
"name": "stdout",
|
| 1380 |
+
"output_type": "stream",
|
| 1381 |
+
"text": [
|
| 1382 |
+
"Parameter containing:\n",
|
| 1383 |
+
"tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
|
| 1384 |
+
" [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
|
| 1385 |
+
" [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
|
| 1386 |
+
" ...,\n",
|
| 1387 |
+
" [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
|
| 1388 |
+
" [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
|
| 1389 |
+
" [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
|
| 1390 |
+
" device='cuda:0', requires_grad=True)\n"
|
| 1391 |
+
]
|
| 1392 |
+
}
|
| 1393 |
+
],
|
| 1394 |
+
"source": [
|
| 1395 |
+
"print(model.graph_proj.weight)\n"
|
| 1396 |
+
]
|
| 1397 |
+
},
|
| 1398 |
+
{
|
| 1399 |
+
"cell_type": "code",
|
| 1400 |
+
"execution_count": 14,
|
| 1401 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1402 |
+
"metadata": {},
|
| 1403 |
+
"outputs": [],
|
| 1404 |
+
"source": [
|
| 1405 |
+
"import json\n",
|
| 1406 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1407 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1408 |
+
" data = json.load(f)\n",
|
| 1409 |
+
"\n",
|
| 1410 |
+
"test_data = data[0]\n",
|
| 1411 |
+
"\n",
|
| 1412 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1413 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1414 |
+
"\n",
|
| 1415 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
|
| 1416 |
+
"\n",
|
| 1417 |
+
"question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1418 |
+
"\n",
|
| 1419 |
+
"from transformers import AutoTokenizer\n",
|
| 1420 |
+
"\n",
|
| 1421 |
+
"# ✅ 输入文本\n",
|
| 1422 |
+
"ROLE_TOKENS = {\n",
|
| 1423 |
+
" \"human\": \"<|User|>\", \n",
|
| 1424 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1425 |
+
"}\n",
|
| 1426 |
+
"GRAPH_LENGTH = 512\n",
|
| 1427 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1428 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1429 |
+
"\n",
|
| 1430 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1431 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1432 |
+
]
|
| 1433 |
+
},
|
| 1434 |
+
{
|
| 1435 |
+
"cell_type": "code",
|
| 1436 |
+
"execution_count": 15,
|
| 1437 |
+
"id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
|
| 1438 |
+
"metadata": {},
|
| 1439 |
+
"outputs": [
|
| 1440 |
+
{
|
| 1441 |
+
"ename": "AttributeError",
|
| 1442 |
+
"evalue": "'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'",
|
| 1443 |
+
"output_type": "error",
|
| 1444 |
+
"traceback": [
|
| 1445 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1446 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 1447 |
+
"Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_with_graph\u001b[49m(inputs, graph_embedding)\n",
|
| 1448 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 1449 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'"
|
| 1450 |
+
]
|
| 1451 |
+
}
|
| 1452 |
+
],
|
| 1453 |
+
"source": [
|
| 1454 |
+
"generated_text = model.generate_with_graph(inputs, graph_embedding)"
|
| 1455 |
+
]
|
| 1456 |
+
},
|
| 1457 |
+
{
|
| 1458 |
+
"cell_type": "code",
|
| 1459 |
+
"execution_count": 5,
|
| 1460 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1461 |
+
"metadata": {},
|
| 1462 |
+
"outputs": [
|
| 1463 |
+
{
|
| 1464 |
+
"data": {
|
| 1465 |
+
"text/plain": [
|
| 1466 |
+
"tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1467 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1468 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1469 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1470 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1471 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1472 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1473 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1474 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1475 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1476 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1477 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1478 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1479 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1480 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1481 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1482 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1483 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1484 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1485 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1486 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1487 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1488 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1489 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1490 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1491 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1492 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1493 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1494 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1495 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1496 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1497 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1498 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1499 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1500 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1501 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1502 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1503 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1504 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1505 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1506 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1507 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1508 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1509 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1510 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1511 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1512 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1513 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1514 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1515 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1516 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1517 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1518 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1519 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1520 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1521 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1522 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1523 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1524 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1525 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1526 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1527 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1528 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1529 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
|
| 1530 |
+
" device='cuda:0')"
|
| 1531 |
+
]
|
| 1532 |
+
},
|
| 1533 |
+
"execution_count": 5,
|
| 1534 |
+
"metadata": {},
|
| 1535 |
+
"output_type": "execute_result"
|
| 1536 |
+
}
|
| 1537 |
+
],
|
| 1538 |
+
"source": [
|
| 1539 |
+
"graph_embedding"
|
| 1540 |
+
]
|
| 1541 |
+
},
|
| 1542 |
+
{
|
| 1543 |
+
"cell_type": "code",
|
| 1544 |
+
"execution_count": 15,
|
| 1545 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1546 |
+
"metadata": {},
|
| 1547 |
+
"outputs": [
|
| 1548 |
+
{
|
| 1549 |
+
"name": "stdout",
|
| 1550 |
+
"output_type": "stream",
|
| 1551 |
+
"text": [
|
| 1552 |
+
"模型回复: How\n"
|
| 1553 |
+
]
|
| 1554 |
+
}
|
| 1555 |
+
],
|
| 1556 |
+
"source": [
|
| 1557 |
+
"# ✅ 进行前向传播\n",
|
| 1558 |
+
"with torch.no_grad():\n",
|
| 1559 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1560 |
+
"\n",
|
| 1561 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1562 |
+
"logits = outputs.logits[:, -1, :] # ��最后一个 token 的 logits\n",
|
| 1563 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1564 |
+
"\n",
|
| 1565 |
+
"# ✅ 反向编码为文本\n",
|
| 1566 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1567 |
+
"\n",
|
| 1568 |
+
"print(\"模型回复:\", response_text)"
|
| 1569 |
+
]
|
| 1570 |
+
},
|
| 1571 |
+
{
|
| 1572 |
+
"cell_type": "code",
|
| 1573 |
+
"execution_count": 17,
|
| 1574 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1575 |
+
"metadata": {},
|
| 1576 |
+
"outputs": [
|
| 1577 |
+
{
|
| 1578 |
+
"name": "stdout",
|
| 1579 |
+
"output_type": "stream",
|
| 1580 |
+
"text": [
|
| 1581 |
+
"Generated Response: Is there any sequential logic in the module, and if so, how is it handled? `data` is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit data, and the output is the output of the `data` is a 1-bit\n"
|
| 1582 |
+
]
|
| 1583 |
+
}
|
| 1584 |
+
],
|
| 1585 |
+
"source": [
|
| 1586 |
+
"max_new_tokens = 1024\n",
|
| 1587 |
+
"generated_ids = input_ids.clone()\n",
|
| 1588 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1589 |
+
"for _ in range(max_new_tokens):\n",
|
| 1590 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1591 |
+
" with torch.no_grad():\n",
|
| 1592 |
+
" outputs = model(\n",
|
| 1593 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1594 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1595 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1596 |
+
" )\n",
|
| 1597 |
+
"\n",
|
| 1598 |
+
"\n",
|
| 1599 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1600 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1601 |
+
" # print(next_token)\n",
|
| 1602 |
+
"\n",
|
| 1603 |
+
"\n",
|
| 1604 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1605 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1606 |
+
"\n",
|
| 1607 |
+
" # print(generated_ids)\n",
|
| 1608 |
+
"\n",
|
| 1609 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1610 |
+
" break\n",
|
| 1611 |
+
"\n",
|
| 1612 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1613 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1614 |
+
" ) \n",
|
| 1615 |
+
"\n",
|
| 1616 |
+
"# ✅ 解码最终输出\n",
|
| 1617 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1618 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1619 |
+
]
|
| 1620 |
+
},
|
| 1621 |
+
{
|
| 1622 |
+
"cell_type": "code",
|
| 1623 |
+
"execution_count": 10,
|
| 1624 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1625 |
+
"metadata": {},
|
| 1626 |
+
"outputs": [
|
| 1627 |
+
{
|
| 1628 |
+
"data": {
|
| 1629 |
+
"text/plain": [
|
| 1630 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1631 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1632 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1633 |
+
]
|
| 1634 |
+
},
|
| 1635 |
+
"execution_count": 10,
|
| 1636 |
+
"metadata": {},
|
| 1637 |
+
"output_type": "execute_result"
|
| 1638 |
+
}
|
| 1639 |
+
],
|
| 1640 |
+
"source": [
|
| 1641 |
+
"generated_ids"
|
| 1642 |
+
]
|
| 1643 |
+
},
|
| 1644 |
+
{
|
| 1645 |
+
"cell_type": "code",
|
| 1646 |
+
"execution_count": null,
|
| 1647 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1648 |
+
"metadata": {},
|
| 1649 |
+
"outputs": [],
|
| 1650 |
+
"source": []
|
| 1651 |
+
}
|
| 1652 |
+
],
|
| 1653 |
+
"metadata": {
|
| 1654 |
+
"kernelspec": {
|
| 1655 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1656 |
+
"language": "python",
|
| 1657 |
+
"name": "python3"
|
| 1658 |
+
},
|
| 1659 |
+
"language_info": {
|
| 1660 |
+
"codemirror_mode": {
|
| 1661 |
+
"name": "ipython",
|
| 1662 |
+
"version": 3
|
| 1663 |
+
},
|
| 1664 |
+
"file_extension": ".py",
|
| 1665 |
+
"mimetype": "text/x-python",
|
| 1666 |
+
"name": "python",
|
| 1667 |
+
"nbconvert_exporter": "python",
|
| 1668 |
+
"pygments_lexer": "ipython3",
|
| 1669 |
+
"version": "3.10.12"
|
| 1670 |
+
}
|
| 1671 |
+
},
|
| 1672 |
+
"nbformat": 4,
|
| 1673 |
+
"nbformat_minor": 5
|
| 1674 |
+
}
|
.ipynb_checkpoints/graph_train3-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,1588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new3\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 2,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 3,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_134403-e0v0giuw</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">experi030403</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:33:59, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.319300</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>3.641300</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>1.521800</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>1.027500</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>0.922400</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>0.866900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>0.800500</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>0.721600</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>0.740400</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>0.737000</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.713500</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.747000</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>0.869500</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.473300</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>0.753000</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>0.741300</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>0.751400</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>0.787600</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>0.783200</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>0.780200</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>1.012900</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>1.411700</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>1.536400</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>0.853800</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>0.756500</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>0.750800</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>0.747400</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>0.844400</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>0.858400</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>1.053400</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>1.591600</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>1.498900</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>1.471700</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>1.221100</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>1.802300</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>1.826000</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.857300</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.561800</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>1.398800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>1.398900</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>1.381600</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>0.890300</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>0.763700</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>0.753100</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>0.745500</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>1.186100</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>0.862000</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>1.024600</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>1.028400</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>1.008500</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>0.942800</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>0.849700</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>0.771400</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>0.794100</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>0.819200</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>0.937500</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>1.064500</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>1.189300</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.071100</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.003300</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>1.073900</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.043100</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>1.282600</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.145400</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>1.925800</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>2.005600</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.122600</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.163000</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.046600</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.152200</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.151700</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>5.394900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>4.677800</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>4.122200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>3.710200</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>3.350800</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>3.126300</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.988700</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>2.872000</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.848200</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.823900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.781200</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.735000</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.725900</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>2.644400</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.700000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>2.650100</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.704500</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>2.596700</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>2.510500</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>2.515800</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.498100</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.458900</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.449700</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.425000</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.362300</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>2.232000</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.361500</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>2.302300</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.333900</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.367200</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>2.288300</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.426100</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.344100</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>2.283500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>2.296500</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3/commit/b9472b66316be8654c6f7c173fa4561889bd3446', commit_message='End of training', commit_description='', oid='b9472b66316be8654c6f7c173fa4561889bd3446', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new3'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
|
| 662 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 663 |
+
" \n",
|
| 664 |
+
" # ✅ 载入 LLM 预训练模型\n",
|
| 665 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
|
| 668 |
+
" self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
|
| 669 |
+
"\n",
|
| 670 |
+
" # ✅ 2. 多头注意力层\n",
|
| 671 |
+
" self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
|
| 672 |
+
"\n",
|
| 673 |
+
" # ✅ 3. 线性变换\n",
|
| 674 |
+
" self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
|
| 675 |
+
"\n",
|
| 676 |
+
" # ✅ 4. 残差连接 + LayerNorm\n",
|
| 677 |
+
" self.norm = nn.LayerNorm(self.config.hidden_size)\n",
|
| 678 |
+
" \n",
|
| 679 |
+
"\n",
|
| 680 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 681 |
+
" \"\"\"\n",
|
| 682 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 683 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 684 |
+
" \"\"\"\n",
|
| 685 |
+
" # ✅ 获取 token embedding\n",
|
| 686 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 687 |
+
"\n",
|
| 688 |
+
" # ✅ 1. 线性变换 `graph_embedding`\n",
|
| 689 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" # ✅ 2. 多头注意力计算(自注意力机制)\n",
|
| 692 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 693 |
+
" \n",
|
| 694 |
+
" # ✅ 3. 线性层 + 残差连接\n",
|
| 695 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
|
| 696 |
+
"\n",
|
| 697 |
+
" # ✅ 4. 归一化\n",
|
| 698 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 699 |
+
"\n",
|
| 700 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 701 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 702 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" # ✅ 调整 attention mask\n",
|
| 705 |
+
" if attention_mask is not None:\n",
|
| 706 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 707 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 708 |
+
"\n",
|
| 709 |
+
" # ✅ 传入模型\n",
|
| 710 |
+
" outputs = self.model(\n",
|
| 711 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 712 |
+
" attention_mask=attention_mask,\n",
|
| 713 |
+
" labels=labels,\n",
|
| 714 |
+
" )\n",
|
| 715 |
+
"\n",
|
| 716 |
+
" return outputs\n",
|
| 717 |
+
"\n",
|
| 718 |
+
"from transformers import Trainer\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"class GraphTrainer(Trainer):\n",
|
| 721 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 722 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 723 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 724 |
+
" labels = inputs[\"labels\"]\n",
|
| 725 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 726 |
+
"\n",
|
| 727 |
+
" if graph_embedding is not None:\n",
|
| 728 |
+
" outputs = model(\n",
|
| 729 |
+
" input_ids=input_ids,\n",
|
| 730 |
+
" attention_mask=attention_mask,\n",
|
| 731 |
+
" labels=labels,\n",
|
| 732 |
+
" graph_embedding=graph_embedding, \n",
|
| 733 |
+
" )\n",
|
| 734 |
+
" else:\n",
|
| 735 |
+
" outputs = model(\n",
|
| 736 |
+
" input_ids=input_ids,\n",
|
| 737 |
+
" attention_mask=attention_mask,\n",
|
| 738 |
+
" labels=labels,\n",
|
| 739 |
+
" )\n",
|
| 740 |
+
"\n",
|
| 741 |
+
" loss = outputs.loss\n",
|
| 742 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 743 |
+
"\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"from transformers import AutoConfig\n",
|
| 746 |
+
"\n",
|
| 747 |
+
"# ✅ 载入微调模型\n",
|
| 748 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"# ✅ 训练参数\n",
|
| 751 |
+
"training_args = TrainingArguments(\n",
|
| 752 |
+
" output_dir=\"./results3\",\n",
|
| 753 |
+
" per_device_train_batch_size=7,\n",
|
| 754 |
+
" eval_strategy=\"no\",\n",
|
| 755 |
+
" save_strategy=\"steps\",\n",
|
| 756 |
+
" save_steps=3000,\n",
|
| 757 |
+
" logging_steps=50,\n",
|
| 758 |
+
" bf16=True,\n",
|
| 759 |
+
" optim=\"galore_adamw\",\n",
|
| 760 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 761 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 762 |
+
" warmup_steps=1000,\n",
|
| 763 |
+
" num_train_epochs=3,\n",
|
| 764 |
+
" push_to_hub=True,\n",
|
| 765 |
+
" hub_model_id=HF_NAME,\n",
|
| 766 |
+
" hub_strategy=\"every_save\",\n",
|
| 767 |
+
" run_name = \"experi030403\"\n",
|
| 768 |
+
")\n",
|
| 769 |
+
"\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 772 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 773 |
+
"\n",
|
| 774 |
+
"# ✅ 训练\n",
|
| 775 |
+
"trainer = GraphTrainer(\n",
|
| 776 |
+
" model=model,\n",
|
| 777 |
+
" args=training_args,\n",
|
| 778 |
+
" train_dataset=train_dataset,\n",
|
| 779 |
+
")\n",
|
| 780 |
+
"\n",
|
| 781 |
+
"trainer.train()\n",
|
| 782 |
+
"trainer.save_model(\"/workspace/model3\")\n",
|
| 783 |
+
"trainer.push_to_hub()\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"\n"
|
| 786 |
+
]
|
| 787 |
+
},
|
| 788 |
+
{
|
| 789 |
+
"cell_type": "code",
|
| 790 |
+
"execution_count": 2,
|
| 791 |
+
"id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
|
| 792 |
+
"metadata": {},
|
| 793 |
+
"outputs": [
|
| 794 |
+
{
|
| 795 |
+
"data": {
|
| 796 |
+
"text/plain": [
|
| 797 |
+
"RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000')"
|
| 798 |
+
]
|
| 799 |
+
},
|
| 800 |
+
"execution_count": 2,
|
| 801 |
+
"metadata": {},
|
| 802 |
+
"output_type": "execute_result"
|
| 803 |
+
}
|
| 804 |
+
],
|
| 805 |
+
"source": [
|
| 806 |
+
"from huggingface_hub import HfApi\n",
|
| 807 |
+
"\n",
|
| 808 |
+
"api = HfApi()\n",
|
| 809 |
+
"repo_name = \"r1q1.5_graph_lora-results3\" # 你的模型名称\n",
|
| 810 |
+
"api.create_repo(repo_name, exist_ok=True)"
|
| 811 |
+
]
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"cell_type": "code",
|
| 815 |
+
"execution_count": 3,
|
| 816 |
+
"id": "73c434b9-5d58-4819-8526-24aa18ca1010",
|
| 817 |
+
"metadata": {},
|
| 818 |
+
"outputs": [
|
| 819 |
+
{
|
| 820 |
+
"data": {
|
| 821 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 822 |
+
"model_id": "8b896f21685e4086b0b59404b2b1a866",
|
| 823 |
+
"version_major": 2,
|
| 824 |
+
"version_minor": 0
|
| 825 |
+
},
|
| 826 |
+
"text/plain": [
|
| 827 |
+
"model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
|
| 828 |
+
]
|
| 829 |
+
},
|
| 830 |
+
"metadata": {},
|
| 831 |
+
"output_type": "display_data"
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"data": {
|
| 835 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 836 |
+
"model_id": "d20bff067ca44c4583378181da817897",
|
| 837 |
+
"version_major": 2,
|
| 838 |
+
"version_minor": 0
|
| 839 |
+
},
|
| 840 |
+
"text/plain": [
|
| 841 |
+
"scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
|
| 842 |
+
]
|
| 843 |
+
},
|
| 844 |
+
"metadata": {},
|
| 845 |
+
"output_type": "display_data"
|
| 846 |
+
},
|
| 847 |
+
{
|
| 848 |
+
"data": {
|
| 849 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 850 |
+
"model_id": "c4b7114a53b341539a3244f2eea8aacf",
|
| 851 |
+
"version_major": 2,
|
| 852 |
+
"version_minor": 0
|
| 853 |
+
},
|
| 854 |
+
"text/plain": [
|
| 855 |
+
"Upload 6 LFS files: 0%| | 0/6 [00:00<?, ?it/s]"
|
| 856 |
+
]
|
| 857 |
+
},
|
| 858 |
+
"metadata": {},
|
| 859 |
+
"output_type": "display_data"
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"data": {
|
| 863 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 864 |
+
"model_id": "74c6045017b640bdba86fe3ed1bb9c92",
|
| 865 |
+
"version_major": 2,
|
| 866 |
+
"version_minor": 0
|
| 867 |
+
},
|
| 868 |
+
"text/plain": [
|
| 869 |
+
"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
| 870 |
+
]
|
| 871 |
+
},
|
| 872 |
+
"metadata": {},
|
| 873 |
+
"output_type": "display_data"
|
| 874 |
+
},
|
| 875 |
+
{
|
| 876 |
+
"data": {
|
| 877 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 878 |
+
"model_id": "97436b084bc4420f8b273ec462c50e61",
|
| 879 |
+
"version_major": 2,
|
| 880 |
+
"version_minor": 0
|
| 881 |
+
},
|
| 882 |
+
"text/plain": [
|
| 883 |
+
"optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
|
| 884 |
+
]
|
| 885 |
+
},
|
| 886 |
+
"metadata": {},
|
| 887 |
+
"output_type": "display_data"
|
| 888 |
+
},
|
| 889 |
+
{
|
| 890 |
+
"data": {
|
| 891 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 892 |
+
"model_id": "d7f10ccff3674e6fa8bcb42553c12b19",
|
| 893 |
+
"version_major": 2,
|
| 894 |
+
"version_minor": 0
|
| 895 |
+
},
|
| 896 |
+
"text/plain": [
|
| 897 |
+
"rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
|
| 898 |
+
]
|
| 899 |
+
},
|
| 900 |
+
"metadata": {},
|
| 901 |
+
"output_type": "display_data"
|
| 902 |
+
},
|
| 903 |
+
{
|
| 904 |
+
"data": {
|
| 905 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 906 |
+
"model_id": "c5b1a010fd0845f9ba9112291afa8f17",
|
| 907 |
+
"version_major": 2,
|
| 908 |
+
"version_minor": 0
|
| 909 |
+
},
|
| 910 |
+
"text/plain": [
|
| 911 |
+
"training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
|
| 912 |
+
]
|
| 913 |
+
},
|
| 914 |
+
"metadata": {},
|
| 915 |
+
"output_type": "display_data"
|
| 916 |
+
},
|
| 917 |
+
{
|
| 918 |
+
"data": {
|
| 919 |
+
"text/plain": [
|
| 920 |
+
"CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000/commit/4088de651a0ce2cc39fcb0c950898e54ce91bdea', commit_message='upload checkpoint-3000', commit_description='', oid='4088de651a0ce2cc39fcb0c950898e54ce91bdea', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000'), pr_revision=None, pr_num=None)"
|
| 921 |
+
]
|
| 922 |
+
},
|
| 923 |
+
"execution_count": 3,
|
| 924 |
+
"metadata": {},
|
| 925 |
+
"output_type": "execute_result"
|
| 926 |
+
}
|
| 927 |
+
],
|
| 928 |
+
"source": [
|
| 929 |
+
"from huggingface_hub import upload_folder\n",
|
| 930 |
+
"\n",
|
| 931 |
+
"upload_folder(\n",
|
| 932 |
+
" folder_path = \"/workspace/results3\",\n",
|
| 933 |
+
" repo_id = \"YiFzhao/r1q1.5_graph_lora-results3\",\n",
|
| 934 |
+
" commit_message = \"upload results2\",\n",
|
| 935 |
+
")"
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "code",
|
| 940 |
+
"execution_count": 5,
|
| 941 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"outputs": [
|
| 944 |
+
{
|
| 945 |
+
"name": "stdout",
|
| 946 |
+
"output_type": "stream",
|
| 947 |
+
"text": [
|
| 948 |
+
"🚀 处理后数据条数: 12384\n",
|
| 949 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 950 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 951 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 952 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 953 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 954 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 955 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 956 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 957 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 958 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 959 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 960 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 961 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 962 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 963 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 964 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 965 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 966 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 967 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 968 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 969 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 970 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 971 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 972 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 973 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 974 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 975 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 976 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 977 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 978 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 979 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 980 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 981 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 982 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 983 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 984 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 985 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 986 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 987 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 988 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 989 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 990 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 991 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 992 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 993 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 994 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 995 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 996 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 997 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 998 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 999 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1000 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1001 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1002 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1003 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1004 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1005 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1006 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1007 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1008 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1009 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1010 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1011 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1012 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 1013 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 1014 |
+
]
|
| 1015 |
+
}
|
| 1016 |
+
],
|
| 1017 |
+
"source": [
|
| 1018 |
+
"import json\n",
|
| 1019 |
+
"import torch\n",
|
| 1020 |
+
"from transformers import AutoTokenizer\n",
|
| 1021 |
+
"\n",
|
| 1022 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1023 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1026 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1027 |
+
" data = json.load(f)\n",
|
| 1028 |
+
"\n",
|
| 1029 |
+
"train_data = []\n",
|
| 1030 |
+
"\n",
|
| 1031 |
+
"\n",
|
| 1032 |
+
"for sample in data:\n",
|
| 1033 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 1034 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 1037 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 1038 |
+
" continue\n",
|
| 1039 |
+
"\n",
|
| 1040 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" #拼接所有对话\n",
|
| 1043 |
+
" dialogue_text = \"\"\n",
|
| 1044 |
+
" for conv in conversations:\n",
|
| 1045 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 1046 |
+
" content = conv[\"value\"]\n",
|
| 1047 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 1048 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 1049 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 1050 |
+
"\n",
|
| 1051 |
+
" tokenized = tokenizer(\n",
|
| 1052 |
+
" dialogue_text,\n",
|
| 1053 |
+
" padding=\"max_length\",\n",
|
| 1054 |
+
" truncation=True,\n",
|
| 1055 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 1056 |
+
" return_tensors=\"pt\",\n",
|
| 1057 |
+
" )\n",
|
| 1058 |
+
"\n",
|
| 1059 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 1060 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 1061 |
+
"\n",
|
| 1062 |
+
" train_data.append({\n",
|
| 1063 |
+
" \"input_ids\": input_ids,\n",
|
| 1064 |
+
" \"attention_mask\": attention_mask,\n",
|
| 1065 |
+
" \"labels\": input_ids.clone(),\n",
|
| 1066 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 1067 |
+
" })\n",
|
| 1068 |
+
"\n",
|
| 1069 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 1070 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 1071 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 1072 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 1073 |
+
]
|
| 1074 |
+
},
|
| 1075 |
+
{
|
| 1076 |
+
"cell_type": "code",
|
| 1077 |
+
"execution_count": 6,
|
| 1078 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1079 |
+
"metadata": {},
|
| 1080 |
+
"outputs": [],
|
| 1081 |
+
"source": [
|
| 1082 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 1083 |
+
"import torch\n",
|
| 1084 |
+
"import torch.nn as nn\n",
|
| 1085 |
+
"\n",
|
| 1086 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1087 |
+
" def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
|
| 1088 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 1089 |
+
" \n",
|
| 1090 |
+
" # ✅ 载入 LLM 预训练模型\n",
|
| 1091 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
" # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
|
| 1094 |
+
" self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
|
| 1095 |
+
"\n",
|
| 1096 |
+
" # ✅ 2. 多头注意力层\n",
|
| 1097 |
+
" self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
|
| 1098 |
+
"\n",
|
| 1099 |
+
" # ✅ 3. 线性变换\n",
|
| 1100 |
+
" self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
|
| 1101 |
+
"\n",
|
| 1102 |
+
" # ✅ 4. 残差连接 + LayerNorm\n",
|
| 1103 |
+
" self.norm = nn.LayerNorm(self.config.hidden_size)\n",
|
| 1104 |
+
" \n",
|
| 1105 |
+
"\n",
|
| 1106 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1107 |
+
" \"\"\"\n",
|
| 1108 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1109 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1110 |
+
" \"\"\"\n",
|
| 1111 |
+
" # ✅ 获取 token embedding\n",
|
| 1112 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
" # ✅ 1. 线性变换 `graph_embedding`\n",
|
| 1115 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
|
| 1116 |
+
"\n",
|
| 1117 |
+
" # ✅ 2. 多头注意力计算(自注意力机制)\n",
|
| 1118 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 1119 |
+
" \n",
|
| 1120 |
+
" # ✅ 3. 线性层 + 残差连接\n",
|
| 1121 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
|
| 1122 |
+
"\n",
|
| 1123 |
+
" # ✅ 4. 归一化\n",
|
| 1124 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 1125 |
+
"\n",
|
| 1126 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1127 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1128 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1129 |
+
"\n",
|
| 1130 |
+
" # ✅ 调整 attention mask\n",
|
| 1131 |
+
" if attention_mask is not None:\n",
|
| 1132 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1133 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
" # ✅ 传入模型\n",
|
| 1136 |
+
" outputs = self.model(\n",
|
| 1137 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1138 |
+
" attention_mask=attention_mask,\n",
|
| 1139 |
+
" labels=labels,\n",
|
| 1140 |
+
" )\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
" return outputs\n",
|
| 1143 |
+
"\n",
|
| 1144 |
+
" def generate(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
|
| 1145 |
+
" \"\"\"\n",
|
| 1146 |
+
" ✅ 自定义 `generate()` 方法,支持 `graph_embedding`\n",
|
| 1147 |
+
" `input_text`: 需要生成文本的输入\n",
|
| 1148 |
+
" `graph_embedding`: 形状为 (1, 512) 的张量\n",
|
| 1149 |
+
" \"\"\"\n",
|
| 1150 |
+
"\n",
|
| 1151 |
+
" # ✅ 2. 处理 `graph_embedding`\n",
|
| 1152 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (1, 1, hidden_size)\n",
|
| 1153 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 1154 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (1, 1, hidden_size)\n",
|
| 1155 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 1156 |
+
"\n",
|
| 1157 |
+
" # ✅ 3. 获取 Token Embeddings 并拼接\n",
|
| 1158 |
+
" inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
|
| 1159 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
|
| 1160 |
+
"\n",
|
| 1161 |
+
" # ✅ 4. 调整 `attention_mask`\n",
|
| 1162 |
+
" if \"attention_mask\" in inputs:\n",
|
| 1163 |
+
" graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
|
| 1164 |
+
" attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
|
| 1165 |
+
" else:\n",
|
| 1166 |
+
" attention_mask = None\n",
|
| 1167 |
+
"\n",
|
| 1168 |
+
" # ✅ 5. 进行文本生成\n",
|
| 1169 |
+
" with torch.no_grad():\n",
|
| 1170 |
+
" output_ids = self.model.generate(\n",
|
| 1171 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1172 |
+
" attention_mask=attention_mask,\n",
|
| 1173 |
+
" max_length=max_length,\n",
|
| 1174 |
+
" temperature=temperature,\n",
|
| 1175 |
+
" top_k=top_k,\n",
|
| 1176 |
+
" top_p=top_p,\n",
|
| 1177 |
+
" num_return_sequences=1\n",
|
| 1178 |
+
" )\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
" # ✅ 6. 解码输出\n",
|
| 1181 |
+
" generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
|
| 1182 |
+
" return generated_text\n",
|
| 1183 |
+
"\n",
|
| 1184 |
+
" @classmethod\n",
|
| 1185 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1186 |
+
" # ✅ 1. 调用 `super().from_pretrained()` 加载 LLM\n",
|
| 1187 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
" # ✅ 2. 初始化 `MLP + MultiheadAttention` 结构\n",
|
| 1190 |
+
" model.linear1 = nn.Linear(512, model.config.hidden_size)\n",
|
| 1191 |
+
" model.multihead_attn = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=8, batch_first=True)\n",
|
| 1192 |
+
" model.linear2 = nn.Linear(model.config.hidden_size, model.config.hidden_size)\n",
|
| 1193 |
+
" model.norm = nn.LayerNorm(model.config.hidden_size)\n",
|
| 1194 |
+
"\n",
|
| 1195 |
+
" return model"
|
| 1196 |
+
]
|
| 1197 |
+
},
|
| 1198 |
+
{
|
| 1199 |
+
"cell_type": "code",
|
| 1200 |
+
"execution_count": 2,
|
| 1201 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1202 |
+
"metadata": {},
|
| 1203 |
+
"outputs": [],
|
| 1204 |
+
"source": [
|
| 1205 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1206 |
+
]
|
| 1207 |
+
},
|
| 1208 |
+
{
|
| 1209 |
+
"cell_type": "code",
|
| 1210 |
+
"execution_count": 7,
|
| 1211 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1212 |
+
"metadata": {},
|
| 1213 |
+
"outputs": [
|
| 1214 |
+
{
|
| 1215 |
+
"data": {
|
| 1216 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1217 |
+
"model_id": "0b50f0cd6d784f598cc64a40cff40f38",
|
| 1218 |
+
"version_major": 2,
|
| 1219 |
+
"version_minor": 0
|
| 1220 |
+
},
|
| 1221 |
+
"text/plain": [
|
| 1222 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1223 |
+
]
|
| 1224 |
+
},
|
| 1225 |
+
"metadata": {},
|
| 1226 |
+
"output_type": "display_data"
|
| 1227 |
+
},
|
| 1228 |
+
{
|
| 1229 |
+
"data": {
|
| 1230 |
+
"text/plain": [
|
| 1231 |
+
"Qwen2ForCausalLM(\n",
|
| 1232 |
+
" (model): Qwen2Model(\n",
|
| 1233 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1234 |
+
" (layers): ModuleList(\n",
|
| 1235 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1236 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1237 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1238 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1239 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1240 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1241 |
+
" )\n",
|
| 1242 |
+
" (mlp): Qwen2MLP(\n",
|
| 1243 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1244 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1245 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1246 |
+
" (act_fn): SiLU()\n",
|
| 1247 |
+
" )\n",
|
| 1248 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1249 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1250 |
+
" )\n",
|
| 1251 |
+
" )\n",
|
| 1252 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1253 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1254 |
+
" )\n",
|
| 1255 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1256 |
+
" (linear1): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1257 |
+
" (multihead_attn): MultiheadAttention(\n",
|
| 1258 |
+
" (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1259 |
+
" )\n",
|
| 1260 |
+
" (linear2): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1261 |
+
" (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
|
| 1262 |
+
")"
|
| 1263 |
+
]
|
| 1264 |
+
},
|
| 1265 |
+
"execution_count": 7,
|
| 1266 |
+
"metadata": {},
|
| 1267 |
+
"output_type": "execute_result"
|
| 1268 |
+
}
|
| 1269 |
+
],
|
| 1270 |
+
"source": [
|
| 1271 |
+
"import torch\n",
|
| 1272 |
+
"from transformers import AutoTokenizer\n",
|
| 1273 |
+
"\n",
|
| 1274 |
+
"# 加载 tokenizer\n",
|
| 1275 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1276 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1277 |
+
"\n",
|
| 1278 |
+
"# 加载训练好的模型\n",
|
| 1279 |
+
"model_path = \"/workspace/model2\"\n",
|
| 1280 |
+
"model = GraphAwareLM.from_pretrained(\"/workspace/results3/checkpoint-3000\").to(device)\n",
|
| 1281 |
+
"model.eval() # 设置为推理模式\n"
|
| 1282 |
+
]
|
| 1283 |
+
},
|
| 1284 |
+
{
|
| 1285 |
+
"cell_type": "code",
|
| 1286 |
+
"execution_count": 13,
|
| 1287 |
+
"id": "51995891-8906-4049-9401-2d22e06a84e8",
|
| 1288 |
+
"metadata": {},
|
| 1289 |
+
"outputs": [
|
| 1290 |
+
{
|
| 1291 |
+
"name": "stdout",
|
| 1292 |
+
"output_type": "stream",
|
| 1293 |
+
"text": [
|
| 1294 |
+
"Parameter containing:\n",
|
| 1295 |
+
"tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
|
| 1296 |
+
" [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
|
| 1297 |
+
" [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
|
| 1298 |
+
" ...,\n",
|
| 1299 |
+
" [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
|
| 1300 |
+
" [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
|
| 1301 |
+
" [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
|
| 1302 |
+
" device='cuda:0', requires_grad=True)\n"
|
| 1303 |
+
]
|
| 1304 |
+
}
|
| 1305 |
+
],
|
| 1306 |
+
"source": [
|
| 1307 |
+
"print(model.graph_proj.weight)\n"
|
| 1308 |
+
]
|
| 1309 |
+
},
|
| 1310 |
+
{
|
| 1311 |
+
"cell_type": "code",
|
| 1312 |
+
"execution_count": 4,
|
| 1313 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1314 |
+
"metadata": {},
|
| 1315 |
+
"outputs": [],
|
| 1316 |
+
"source": [
|
| 1317 |
+
"import json\n",
|
| 1318 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1319 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1320 |
+
" data = json.load(f)\n",
|
| 1321 |
+
"\n",
|
| 1322 |
+
"test_data = data[0]\n",
|
| 1323 |
+
"\n",
|
| 1324 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1325 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1326 |
+
"\n",
|
| 1327 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
|
| 1328 |
+
"\n",
|
| 1329 |
+
"question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1330 |
+
"\n",
|
| 1331 |
+
"from transformers import AutoTokenizer\n",
|
| 1332 |
+
"\n",
|
| 1333 |
+
"# ✅ 输入文本\n",
|
| 1334 |
+
"ROLE_TOKENS = {\n",
|
| 1335 |
+
" \"human\": \"<|User|>\", \n",
|
| 1336 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1337 |
+
"}\n",
|
| 1338 |
+
"GRAPH_LENGTH = 512\n",
|
| 1339 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1340 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1341 |
+
"\n",
|
| 1342 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1343 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1344 |
+
]
|
| 1345 |
+
},
|
| 1346 |
+
{
|
| 1347 |
+
"cell_type": "code",
|
| 1348 |
+
"execution_count": 8,
|
| 1349 |
+
"id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
|
| 1350 |
+
"metadata": {},
|
| 1351 |
+
"outputs": [
|
| 1352 |
+
{
|
| 1353 |
+
"ename": "AttributeError",
|
| 1354 |
+
"evalue": "'Tensor' object has no attribute 'update'",
|
| 1355 |
+
"output_type": "error",
|
| 1356 |
+
"traceback": [
|
| 1357 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1358 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 1359 |
+
"Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1360 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1361 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1982\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1979\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1980\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n\u001b[0;32m-> 1982\u001b[0m generation_config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_generation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1983\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_kwargs(model_kwargs\u001b[38;5;241m.\u001b[39mcopy())\n\u001b[1;32m 1984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_assistant(assistant_model, tokenizer, assistant_tokenizer)\n",
|
| 1362 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1549\u001b[0m, in \u001b[0;36mGenerationMixin._prepare_generation_config\u001b[0;34m(self, generation_config, **kwargs)\u001b[0m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 1548\u001b[0m generation_config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(generation_config)\n\u001b[0;32m-> 1549\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1550\u001b[0m \u001b[38;5;66;03m# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model\u001b[39;00m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_model_generation_config:\n",
|
| 1363 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'update'"
|
| 1364 |
+
]
|
| 1365 |
+
}
|
| 1366 |
+
],
|
| 1367 |
+
"source": [
|
| 1368 |
+
"generated_text = model.generate(inputs, graph_embedding)"
|
| 1369 |
+
]
|
| 1370 |
+
},
|
| 1371 |
+
{
|
| 1372 |
+
"cell_type": "code",
|
| 1373 |
+
"execution_count": 5,
|
| 1374 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1375 |
+
"metadata": {},
|
| 1376 |
+
"outputs": [
|
| 1377 |
+
{
|
| 1378 |
+
"data": {
|
| 1379 |
+
"text/plain": [
|
| 1380 |
+
"tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1381 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1382 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1383 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1384 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1385 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1386 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1387 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1388 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1389 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1390 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1391 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1392 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1393 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1394 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1395 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1396 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1397 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1398 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1399 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1400 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1401 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1402 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1403 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1404 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1405 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1406 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1407 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1408 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1409 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1410 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1411 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1412 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1413 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1414 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1415 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1416 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1417 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1418 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1419 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1420 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1421 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1422 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1423 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1424 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1425 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1426 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1427 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1428 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1429 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1430 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1431 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1432 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1433 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1434 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1435 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1436 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1437 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1438 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1439 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1440 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1441 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1442 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1443 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
|
| 1444 |
+
" device='cuda:0')"
|
| 1445 |
+
]
|
| 1446 |
+
},
|
| 1447 |
+
"execution_count": 5,
|
| 1448 |
+
"metadata": {},
|
| 1449 |
+
"output_type": "execute_result"
|
| 1450 |
+
}
|
| 1451 |
+
],
|
| 1452 |
+
"source": [
|
| 1453 |
+
"graph_embedding"
|
| 1454 |
+
]
|
| 1455 |
+
},
|
| 1456 |
+
{
|
| 1457 |
+
"cell_type": "code",
|
| 1458 |
+
"execution_count": 15,
|
| 1459 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1460 |
+
"metadata": {},
|
| 1461 |
+
"outputs": [
|
| 1462 |
+
{
|
| 1463 |
+
"name": "stdout",
|
| 1464 |
+
"output_type": "stream",
|
| 1465 |
+
"text": [
|
| 1466 |
+
"模型回复: How\n"
|
| 1467 |
+
]
|
| 1468 |
+
}
|
| 1469 |
+
],
|
| 1470 |
+
"source": [
|
| 1471 |
+
"# ✅ 进行前向传播\n",
|
| 1472 |
+
"with torch.no_grad():\n",
|
| 1473 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1474 |
+
"\n",
|
| 1475 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1476 |
+
"logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1477 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1478 |
+
"\n",
|
| 1479 |
+
"# ✅ 反向编码为文本\n",
|
| 1480 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1481 |
+
"\n",
|
| 1482 |
+
"print(\"模型回复:\", response_text)"
|
| 1483 |
+
]
|
| 1484 |
+
},
|
| 1485 |
+
{
|
| 1486 |
+
"cell_type": "code",
|
| 1487 |
+
"execution_count": 9,
|
| 1488 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1489 |
+
"metadata": {},
|
| 1490 |
+
"outputs": [
|
| 1491 |
+
{
|
| 1492 |
+
"name": "stdout",
|
| 1493 |
+
"output_type": "stream",
|
| 1494 |
+
"text": [
|
| 1495 |
+
"Generated Response: What are the signal definitions in the Verilog code for the calculator module, and what are their purposes? The Verilog code defines the inputs A, B, and C, and the output Y. A and B are the operands, C is the carry-in, and Y is the result. The purpose of the module is to perform a 2-bit adder, which adds two 2-bit numbers, and the output is the sum. The inputs A and B are the operands, C is the carry-in, and Y is the result. The module is designed to handle the addition operation of two 2-bit numbers, with a carry-in, and a 3-bit output. The implementation involves using logic gates to perform the addition operation, with the sum output connected to the gates. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, involving basic gates and an adder circuit. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with\n"
|
| 1496 |
+
]
|
| 1497 |
+
}
|
| 1498 |
+
],
|
| 1499 |
+
"source": [
|
| 1500 |
+
"max_new_tokens = 500\n",
|
| 1501 |
+
"generated_ids = input_ids.clone()\n",
|
| 1502 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1503 |
+
"for _ in range(max_new_tokens):\n",
|
| 1504 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1505 |
+
" with torch.no_grad():\n",
|
| 1506 |
+
" outputs = model(\n",
|
| 1507 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1508 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1509 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1510 |
+
" )\n",
|
| 1511 |
+
"\n",
|
| 1512 |
+
"\n",
|
| 1513 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1514 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1515 |
+
" # print(next_token)\n",
|
| 1516 |
+
"\n",
|
| 1517 |
+
"\n",
|
| 1518 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1519 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1520 |
+
"\n",
|
| 1521 |
+
" # print(generated_ids)\n",
|
| 1522 |
+
"\n",
|
| 1523 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1524 |
+
" break\n",
|
| 1525 |
+
"\n",
|
| 1526 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1527 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1528 |
+
" ) \n",
|
| 1529 |
+
"\n",
|
| 1530 |
+
"# ✅ 解码最终输出\n",
|
| 1531 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1532 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1533 |
+
]
|
| 1534 |
+
},
|
| 1535 |
+
{
|
| 1536 |
+
"cell_type": "code",
|
| 1537 |
+
"execution_count": 10,
|
| 1538 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1539 |
+
"metadata": {},
|
| 1540 |
+
"outputs": [
|
| 1541 |
+
{
|
| 1542 |
+
"data": {
|
| 1543 |
+
"text/plain": [
|
| 1544 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1545 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1546 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1547 |
+
]
|
| 1548 |
+
},
|
| 1549 |
+
"execution_count": 10,
|
| 1550 |
+
"metadata": {},
|
| 1551 |
+
"output_type": "execute_result"
|
| 1552 |
+
}
|
| 1553 |
+
],
|
| 1554 |
+
"source": [
|
| 1555 |
+
"generated_ids"
|
| 1556 |
+
]
|
| 1557 |
+
},
|
| 1558 |
+
{
|
| 1559 |
+
"cell_type": "code",
|
| 1560 |
+
"execution_count": null,
|
| 1561 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1562 |
+
"metadata": {},
|
| 1563 |
+
"outputs": [],
|
| 1564 |
+
"source": []
|
| 1565 |
+
}
|
| 1566 |
+
],
|
| 1567 |
+
"metadata": {
|
| 1568 |
+
"kernelspec": {
|
| 1569 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1570 |
+
"language": "python",
|
| 1571 |
+
"name": "python3"
|
| 1572 |
+
},
|
| 1573 |
+
"language_info": {
|
| 1574 |
+
"codemirror_mode": {
|
| 1575 |
+
"name": "ipython",
|
| 1576 |
+
"version": 3
|
| 1577 |
+
},
|
| 1578 |
+
"file_extension": ".py",
|
| 1579 |
+
"mimetype": "text/x-python",
|
| 1580 |
+
"name": "python",
|
| 1581 |
+
"nbconvert_exporter": "python",
|
| 1582 |
+
"pygments_lexer": "ipython3",
|
| 1583 |
+
"version": "3.10.12"
|
| 1584 |
+
}
|
| 1585 |
+
},
|
| 1586 |
+
"nbformat": 4,
|
| 1587 |
+
"nbformat_minor": 5
|
| 1588 |
+
}
|
eval.ipynb
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 10 |
+
"import torch\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"MODEL_NAME = \"/workspace/model\"\n",
|
| 15 |
+
"model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 2,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"import json\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"from transformers import AutoTokenizer\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
|
| 29 |
+
"tokenizer.pad_token = tokenizer.eos_token "
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 3,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"json_path = \"final_Graph.json\"\n",
|
| 39 |
+
"with open(json_path, \"r\") as f:\n",
|
| 40 |
+
" data = json.load(f)\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"test_data = data[0]\n"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 4,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"ROLE_TOKENS = {\n",
|
| 52 |
+
" \"human\": \"<|User|>\", \n",
|
| 53 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 54 |
+
"}\n",
|
| 55 |
+
"GRAPH_LENGTH = 512\n",
|
| 56 |
+
"max_seq_length = 1100 + GRAPH_LENGTH"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 5,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 66 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 67 |
+
"\n",
|
| 68 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32)"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"execution_count": 6,
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [
|
| 76 |
+
{
|
| 77 |
+
"data": {
|
| 78 |
+
"text/plain": [
|
| 79 |
+
"'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
"execution_count": 6,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"output_type": "execute_result"
|
| 85 |
+
}
|
| 86 |
+
],
|
| 87 |
+
"source": [
|
| 88 |
+
"question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 89 |
+
"question1"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 11,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"import json\n",
|
| 99 |
+
"import torch\n",
|
| 100 |
+
"import os\n",
|
| 101 |
+
"from transformers import AutoTokenizer\n",
|
| 102 |
+
"# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 103 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 104 |
+
"from torch.utils.data import Dataset\n",
|
| 105 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 106 |
+
"import torch\n",
|
| 107 |
+
"import torch.nn as nn\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 110 |
+
" def __init__(self, config):\n",
|
| 111 |
+
" super().__init__(config)\n",
|
| 112 |
+
" self.model = AutoModelForCausalLM.from_config(config)\n",
|
| 113 |
+
" \n",
|
| 114 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 115 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 118 |
+
" \"\"\"\n",
|
| 119 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 120 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 121 |
+
" \"\"\"\n",
|
| 122 |
+
" # ✅ 获取 token embedding\n",
|
| 123 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 126 |
+
" graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 129 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 130 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 131 |
+
"\n",
|
| 132 |
+
" # ✅ 调整 attention mask\n",
|
| 133 |
+
" if attention_mask is not None:\n",
|
| 134 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 135 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" # ✅ 传入模型\n",
|
| 138 |
+
" outputs = self.model(\n",
|
| 139 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 140 |
+
" attention_mask=attention_mask,\n",
|
| 141 |
+
" labels=labels,\n",
|
| 142 |
+
" )\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" return outputs\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" @classmethod\n",
|
| 147 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 148 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 149 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 150 |
+
" return model\n",
|
| 151 |
+
"\n"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": 12,
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 161 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": 13,
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [
|
| 169 |
+
{
|
| 170 |
+
"data": {
|
| 171 |
+
"text/plain": [
|
| 172 |
+
"tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 173 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 174 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 175 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 176 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 177 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 178 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 179 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 180 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 181 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 182 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 183 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 184 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 185 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 186 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 187 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 188 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 189 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 190 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 191 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 192 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 193 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 194 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 195 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 196 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 197 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 198 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 199 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 200 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 201 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 202 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 203 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 204 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 205 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 206 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 207 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 208 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 209 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 210 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 211 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 212 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 213 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 214 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 215 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 216 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 217 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 218 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 219 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 220 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 221 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 222 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 223 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 224 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 225 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 226 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 227 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 228 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 229 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 230 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 231 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 232 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 233 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 234 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 235 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
|
| 236 |
+
" device='cuda:0')"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
"execution_count": 13,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"output_type": "execute_result"
|
| 242 |
+
}
|
| 243 |
+
],
|
| 244 |
+
"source": [
|
| 245 |
+
"from transformers import AutoTokenizer\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"# ✅ 加载分词器\n",
|
| 248 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_token)\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"# ✅ 输入文本\n",
|
| 251 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"graph_embedding.to(device)\n",
|
| 254 |
+
"\n"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": 14,
|
| 260 |
+
"metadata": {},
|
| 261 |
+
"outputs": [
|
| 262 |
+
{
|
| 263 |
+
"ename": "RuntimeError",
|
| 264 |
+
"evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3",
|
| 265 |
+
"output_type": "error",
|
| 266 |
+
"traceback": [
|
| 267 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 268 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 269 |
+
"Cell \u001b[0;32mIn[14], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :] \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m 14\u001b[0m next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n",
|
| 270 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 271 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 272 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 273 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 866\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
|
| 274 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 275 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 276 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 567\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 568\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 569\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 576\u001b[0m position_embeddings,\n\u001b[1;32m 577\u001b[0m )\n\u001b[1;32m 578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 585\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
|
| 277 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 278 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 279 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
|
| 280 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 281 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 282 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 190\u001b[0m attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n",
|
| 283 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m 125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n",
|
| 284 |
+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3"
|
| 285 |
+
]
|
| 286 |
+
}
|
| 287 |
+
],
|
| 288 |
+
"source": [
|
| 289 |
+
"\n",
|
| 290 |
+
"generated_ids = inputs[\"input_ids\"]\n",
|
| 291 |
+
"max_new_tokens = 1024\n",
|
| 292 |
+
"for _ in range(max_new_tokens):\n",
|
| 293 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 294 |
+
" with torch.no_grad():\n",
|
| 295 |
+
" outputs = model(\n",
|
| 296 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 297 |
+
" attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
|
| 298 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 299 |
+
" )\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 303 |
+
" next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"\n",
|
| 306 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 307 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 308 |
+
"\n",
|
| 309 |
+
" if next_token[:, 0] == tokenizer.eos_token_id:\n",
|
| 310 |
+
" break\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"# ✅ 解码最终输出\n",
|
| 313 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 314 |
+
"print(\"Generated Response:\", generated_text)"
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"cell_type": "code",
|
| 319 |
+
"execution_count": null,
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"outputs": [
|
| 322 |
+
{
|
| 323 |
+
"name": "stdout",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
"Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n",
|
| 327 |
+
"The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n",
|
| 328 |
+
"How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n",
|
| 329 |
+
"The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n",
|
| 330 |
+
"What is the function of the circuit described in the Verilog code?\n",
|
| 331 |
+
"The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n",
|
| 332 |
+
"How can the circuit be implemented in hardware?\n",
|
| 333 |
+
"The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n"
|
| 334 |
+
]
|
| 335 |
+
}
|
| 336 |
+
],
|
| 337 |
+
"source": [
|
| 338 |
+
"generated_ids = inputs[\"input_ids\"]\n",
|
| 339 |
+
"max_new_tokens = 1024\n",
|
| 340 |
+
"for _ in range(max_new_tokens):\n",
|
| 341 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 342 |
+
" with torch.no_grad():\n",
|
| 343 |
+
" outputs = model(\n",
|
| 344 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 345 |
+
" attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n",
|
| 346 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 347 |
+
" )\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"\n",
|
| 350 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 351 |
+
" next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"\n",
|
| 354 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 355 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" if next_token[:, 0] == tokenizer.eos_token_id:\n",
|
| 358 |
+
" break\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"# ✅ 解码最终输出\n",
|
| 361 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 362 |
+
"print(\"Generated Response:\", generated_text)"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"execution_count": null,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"import torch\n",
|
| 372 |
+
"from transformers import AutoTokenizer\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"# 加载 tokenizer\n",
|
| 375 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 376 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"# 加载训练好的模型\n",
|
| 379 |
+
"model_path = \"/workspace/model\"\n",
|
| 380 |
+
"model = GraphAwareLM.from_pretrained(model_path)\n",
|
| 381 |
+
"model.eval() # 设置为推理模式\n"
|
| 382 |
+
]
|
| 383 |
+
}
|
| 384 |
+
],
|
| 385 |
+
"metadata": {
|
| 386 |
+
"kernelspec": {
|
| 387 |
+
"display_name": "Python 3 (ipykernel)",
|
| 388 |
+
"language": "python",
|
| 389 |
+
"name": "python3"
|
| 390 |
+
},
|
| 391 |
+
"language_info": {
|
| 392 |
+
"codemirror_mode": {
|
| 393 |
+
"name": "ipython",
|
| 394 |
+
"version": 3
|
| 395 |
+
},
|
| 396 |
+
"file_extension": ".py",
|
| 397 |
+
"mimetype": "text/x-python",
|
| 398 |
+
"name": "python",
|
| 399 |
+
"nbconvert_exporter": "python",
|
| 400 |
+
"pygments_lexer": "ipython3",
|
| 401 |
+
"version": "3.10.12"
|
| 402 |
+
}
|
| 403 |
+
},
|
| 404 |
+
"nbformat": 4,
|
| 405 |
+
"nbformat_minor": 4
|
| 406 |
+
}
|
final_Graph.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5344e1faf783838cb4db7cd8bbbfdd3e4f01189277442d84682bcdaa1e4b9ac3
|
| 3 |
+
size 261383982
|
graph_train.ipynb
ADDED
|
@@ -0,0 +1,1591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 2,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 3,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_081255-v0v96nik</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:23:11, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.349900</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>5.305900</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>4.849500</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>3.910800</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>3.325600</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>3.144900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>2.904200</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>2.082100</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>1.214300</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>1.011600</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.889300</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.907300</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>1.190400</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.889100</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>4.505600</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>6.402800</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>6.479300</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>7.337900</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>8.937600</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>8.938700</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>8.860100</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>8.693600</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>9.234000</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>9.347500</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>8.010300</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>5.952900</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>5.205900</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>4.969600</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>4.884600</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>4.934200</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>5.156900</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>5.115500</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>5.373600</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>4.481800</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>3.957000</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>3.092500</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.791000</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.934400</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>2.176800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>2.112400</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>2.127900</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>2.390200</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>3.091400</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>3.959500</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>3.905000</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>3.777500</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>3.282900</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>2.630300</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>3.705000</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>4.266300</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>4.285300</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>4.634000</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>4.474700</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>3.591300</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>2.486800</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>1.911800</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>2.088100</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>2.015400</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.988500</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.976900</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>2.097700</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.987400</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>2.065000</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.112100</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>2.075300</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>1.983300</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.181900</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.446500</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.434200</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.357000</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.157400</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>1.992900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>2.018400</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>2.010200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>2.009500</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>2.034900</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>2.038800</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.007600</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>1.983200</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.005300</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.014900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.018100</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.033900</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.024600</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>1.995300</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.018000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>1.998300</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.032800</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>1.985900</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>1.967700</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>1.989400</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.004700</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.005800</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.014400</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.009200</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.002200</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>1.914300</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.016900</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>1.972900</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.010300</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.046600</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>1.993900</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.084500</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.011900</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>1.996500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>1.997900</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new/commit/231f89403dca9aa67966e4f321e62bdb41076960', commit_message='End of training', commit_description='', oid='231f89403dca9aa67966e4f321e62bdb41076960', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, config):\n",
|
| 662 |
+
" super().__init__(config)\n",
|
| 663 |
+
"\n",
|
| 664 |
+
" # self.model = AutoModelForCausalLM.from_config(config)\n",
|
| 665 |
+
" \n",
|
| 666 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 667 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 668 |
+
"\n",
|
| 669 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 670 |
+
" \"\"\"\n",
|
| 671 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 672 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 673 |
+
" \"\"\"\n",
|
| 674 |
+
" # ✅ 获取 token embedding\n",
|
| 675 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 676 |
+
"\n",
|
| 677 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 678 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 679 |
+
"\n",
|
| 680 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 681 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 682 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 683 |
+
"\n",
|
| 684 |
+
" # ✅ 调整 attention mask\n",
|
| 685 |
+
" if attention_mask is not None:\n",
|
| 686 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 687 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 688 |
+
"\n",
|
| 689 |
+
" # ✅ 传入模型\n",
|
| 690 |
+
" outputs = self.model(\n",
|
| 691 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 692 |
+
" attention_mask=attention_mask,\n",
|
| 693 |
+
" labels=labels,\n",
|
| 694 |
+
" )\n",
|
| 695 |
+
"\n",
|
| 696 |
+
" return outputs\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"from transformers import Trainer\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"class GraphTrainer(Trainer):\n",
|
| 701 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 702 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 703 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 704 |
+
" labels = inputs[\"labels\"]\n",
|
| 705 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 706 |
+
"\n",
|
| 707 |
+
" if graph_embedding is not None:\n",
|
| 708 |
+
" outputs = model(\n",
|
| 709 |
+
" input_ids=input_ids,\n",
|
| 710 |
+
" attention_mask=attention_mask,\n",
|
| 711 |
+
" labels=labels,\n",
|
| 712 |
+
" graph_embedding=graph_embedding, \n",
|
| 713 |
+
" )\n",
|
| 714 |
+
" else:\n",
|
| 715 |
+
" outputs = model(\n",
|
| 716 |
+
" input_ids=input_ids,\n",
|
| 717 |
+
" attention_mask=attention_mask,\n",
|
| 718 |
+
" labels=labels,\n",
|
| 719 |
+
" )\n",
|
| 720 |
+
"\n",
|
| 721 |
+
" loss = outputs.loss\n",
|
| 722 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 723 |
+
"\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"from transformers import AutoConfig\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"# 1. 加载模型的配置\n",
|
| 728 |
+
"config = AutoConfig.from_pretrained(MODEL_NAME)\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"# 2. 使用配置创建 GraphAwareLM 实例\n",
|
| 731 |
+
"model = GraphAwareLM.from_config(config) \n",
|
| 732 |
+
"\n",
|
| 733 |
+
"pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 734 |
+
"model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 737 |
+
"# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 738 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"# ✅ 训练参数\n",
|
| 741 |
+
"training_args = TrainingArguments(\n",
|
| 742 |
+
" output_dir=\"./results\",\n",
|
| 743 |
+
" per_device_train_batch_size=7,\n",
|
| 744 |
+
" eval_strategy=\"no\",\n",
|
| 745 |
+
" save_strategy=\"steps\",\n",
|
| 746 |
+
" save_steps=3000,\n",
|
| 747 |
+
" logging_steps=50,\n",
|
| 748 |
+
" bf16=True,\n",
|
| 749 |
+
" optim=\"galore_adamw\",\n",
|
| 750 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 751 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 752 |
+
" warmup_steps=1000,\n",
|
| 753 |
+
" num_train_epochs=3,\n",
|
| 754 |
+
" push_to_hub=True,\n",
|
| 755 |
+
" hub_model_id=HF_NAME,\n",
|
| 756 |
+
" hub_strategy=\"every_save\",\n",
|
| 757 |
+
" run_name = \"experi0304\"\n",
|
| 758 |
+
")\n",
|
| 759 |
+
"\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 762 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 763 |
+
"\n",
|
| 764 |
+
"# ✅ 训练\n",
|
| 765 |
+
"trainer = GraphTrainer(\n",
|
| 766 |
+
" model=model,\n",
|
| 767 |
+
" args=training_args,\n",
|
| 768 |
+
" train_dataset=train_dataset,\n",
|
| 769 |
+
")\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"trainer.train()\n",
|
| 772 |
+
"trainer.save_model(\"/workspace/model\")\n",
|
| 773 |
+
"trainer.push_to_hub()\n",
|
| 774 |
+
"\n",
|
| 775 |
+
"\n"
|
| 776 |
+
]
|
| 777 |
+
},
|
| 778 |
+
{
|
| 779 |
+
"cell_type": "code",
|
| 780 |
+
"execution_count": 5,
|
| 781 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 782 |
+
"metadata": {},
|
| 783 |
+
"outputs": [
|
| 784 |
+
{
|
| 785 |
+
"name": "stdout",
|
| 786 |
+
"output_type": "stream",
|
| 787 |
+
"text": [
|
| 788 |
+
"🚀 处理后数据条数: 12384\n",
|
| 789 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 790 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 791 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 792 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 793 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 794 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 795 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 796 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 797 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 798 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 799 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 800 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 801 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 802 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 803 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 804 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 805 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 806 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 807 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 808 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 809 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 810 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 811 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 812 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 813 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 814 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 815 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 816 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 817 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 818 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 819 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 820 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 821 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 822 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 823 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 824 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 825 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 826 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 827 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 828 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 829 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 830 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 831 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 832 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 833 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 834 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 835 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 836 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 837 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 838 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 839 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 840 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 841 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 842 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 843 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 844 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 845 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 846 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 847 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 848 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 849 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 850 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 851 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 852 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 853 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 854 |
+
]
|
| 855 |
+
}
|
| 856 |
+
],
|
| 857 |
+
"source": [
|
| 858 |
+
"import json\n",
|
| 859 |
+
"import torch\n",
|
| 860 |
+
"from transformers import AutoTokenizer\n",
|
| 861 |
+
"\n",
|
| 862 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 863 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 864 |
+
"\n",
|
| 865 |
+
"json_path = \"final_Graph.json\"\n",
|
| 866 |
+
"with open(json_path, \"r\") as f:\n",
|
| 867 |
+
" data = json.load(f)\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"train_data = []\n",
|
| 870 |
+
"\n",
|
| 871 |
+
"\n",
|
| 872 |
+
"for sample in data:\n",
|
| 873 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 874 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 875 |
+
"\n",
|
| 876 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 877 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 878 |
+
" continue\n",
|
| 879 |
+
"\n",
|
| 880 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 881 |
+
"\n",
|
| 882 |
+
" #拼接所有对话\n",
|
| 883 |
+
" dialogue_text = \"\"\n",
|
| 884 |
+
" for conv in conversations:\n",
|
| 885 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 886 |
+
" content = conv[\"value\"]\n",
|
| 887 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 888 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 889 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 890 |
+
"\n",
|
| 891 |
+
" tokenized = tokenizer(\n",
|
| 892 |
+
" dialogue_text,\n",
|
| 893 |
+
" padding=\"max_length\",\n",
|
| 894 |
+
" truncation=True,\n",
|
| 895 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 896 |
+
" return_tensors=\"pt\",\n",
|
| 897 |
+
" )\n",
|
| 898 |
+
"\n",
|
| 899 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 900 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 901 |
+
"\n",
|
| 902 |
+
" train_data.append({\n",
|
| 903 |
+
" \"input_ids\": input_ids,\n",
|
| 904 |
+
" \"attention_mask\": attention_mask,\n",
|
| 905 |
+
" \"labels\": input_ids.clone(),\n",
|
| 906 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 907 |
+
" })\n",
|
| 908 |
+
"\n",
|
| 909 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 910 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 911 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 912 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 913 |
+
]
|
| 914 |
+
},
|
| 915 |
+
{
|
| 916 |
+
"cell_type": "code",
|
| 917 |
+
"execution_count": 6,
|
| 918 |
+
"id": "a33bffb9-2ff9-4a4d-af2c-b89b30a69f7d",
|
| 919 |
+
"metadata": {
|
| 920 |
+
"scrolled": true
|
| 921 |
+
},
|
| 922 |
+
"outputs": [
|
| 923 |
+
{
|
| 924 |
+
"name": "stdout",
|
| 925 |
+
"output_type": "stream",
|
| 926 |
+
"text": [
|
| 927 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"name": "stderr",
|
| 932 |
+
"output_type": "stream",
|
| 933 |
+
"text": [
|
| 934 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 935 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:49: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 936 |
+
" warnings.warn(\n",
|
| 937 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 938 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 939 |
+
]
|
| 940 |
+
},
|
| 941 |
+
{
|
| 942 |
+
"data": {
|
| 943 |
+
"text/html": [
|
| 944 |
+
"Tracking run with wandb version 0.19.7"
|
| 945 |
+
],
|
| 946 |
+
"text/plain": [
|
| 947 |
+
"<IPython.core.display.HTML object>"
|
| 948 |
+
]
|
| 949 |
+
},
|
| 950 |
+
"metadata": {},
|
| 951 |
+
"output_type": "display_data"
|
| 952 |
+
},
|
| 953 |
+
{
|
| 954 |
+
"data": {
|
| 955 |
+
"text/html": [
|
| 956 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_074031-ofm5zhvd</code>"
|
| 957 |
+
],
|
| 958 |
+
"text/plain": [
|
| 959 |
+
"<IPython.core.display.HTML object>"
|
| 960 |
+
]
|
| 961 |
+
},
|
| 962 |
+
"metadata": {},
|
| 963 |
+
"output_type": "display_data"
|
| 964 |
+
},
|
| 965 |
+
{
|
| 966 |
+
"data": {
|
| 967 |
+
"text/html": [
|
| 968 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">experi0304</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 969 |
+
],
|
| 970 |
+
"text/plain": [
|
| 971 |
+
"<IPython.core.display.HTML object>"
|
| 972 |
+
]
|
| 973 |
+
},
|
| 974 |
+
"metadata": {},
|
| 975 |
+
"output_type": "display_data"
|
| 976 |
+
},
|
| 977 |
+
{
|
| 978 |
+
"data": {
|
| 979 |
+
"text/html": [
|
| 980 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 981 |
+
],
|
| 982 |
+
"text/plain": [
|
| 983 |
+
"<IPython.core.display.HTML object>"
|
| 984 |
+
]
|
| 985 |
+
},
|
| 986 |
+
"metadata": {},
|
| 987 |
+
"output_type": "display_data"
|
| 988 |
+
},
|
| 989 |
+
{
|
| 990 |
+
"data": {
|
| 991 |
+
"text/html": [
|
| 992 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd</a>"
|
| 993 |
+
],
|
| 994 |
+
"text/plain": [
|
| 995 |
+
"<IPython.core.display.HTML object>"
|
| 996 |
+
]
|
| 997 |
+
},
|
| 998 |
+
"metadata": {},
|
| 999 |
+
"output_type": "display_data"
|
| 1000 |
+
},
|
| 1001 |
+
{
|
| 1002 |
+
"data": {
|
| 1003 |
+
"text/html": [
|
| 1004 |
+
"\n",
|
| 1005 |
+
" <div>\n",
|
| 1006 |
+
" \n",
|
| 1007 |
+
" <progress value='89' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 1008 |
+
" [ 89/5310 01:06 < 1:06:24, 1.31 it/s, Epoch 0.05/3]\n",
|
| 1009 |
+
" </div>\n",
|
| 1010 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 1011 |
+
" <thead>\n",
|
| 1012 |
+
" <tr style=\"text-align: left;\">\n",
|
| 1013 |
+
" <th>Step</th>\n",
|
| 1014 |
+
" <th>Training Loss</th>\n",
|
| 1015 |
+
" </tr>\n",
|
| 1016 |
+
" </thead>\n",
|
| 1017 |
+
" <tbody>\n",
|
| 1018 |
+
" <tr>\n",
|
| 1019 |
+
" <td>50</td>\n",
|
| 1020 |
+
" <td>0.000000</td>\n",
|
| 1021 |
+
" </tr>\n",
|
| 1022 |
+
" </tbody>\n",
|
| 1023 |
+
"</table><p>"
|
| 1024 |
+
],
|
| 1025 |
+
"text/plain": [
|
| 1026 |
+
"<IPython.core.display.HTML object>"
|
| 1027 |
+
]
|
| 1028 |
+
},
|
| 1029 |
+
"metadata": {},
|
| 1030 |
+
"output_type": "display_data"
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
"ename": "KeyboardInterrupt",
|
| 1034 |
+
"evalue": "",
|
| 1035 |
+
"output_type": "error",
|
| 1036 |
+
"traceback": [
|
| 1037 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1038 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 1039 |
+
"Cell \u001b[0;32mIn[6], line 150\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# ✅ 训练\u001b[39;00m\n\u001b[1;32m 144\u001b[0m trainer \u001b[38;5;241m=\u001b[39m GraphTrainer(\n\u001b[1;32m 145\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 146\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m 147\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 148\u001b[0m )\n\u001b[0;32m--> 150\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub()\n\u001b[1;32m 152\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/workspace/model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 1040 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2232\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2233\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2234\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2235\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2236\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2238\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
|
| 1041 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2548\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2541\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2542\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2544\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m 2545\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2546\u001b[0m )\n\u001b[1;32m 2547\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2548\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2551\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2554\u001b[0m ):\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2556\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
| 1042 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3740\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED:\n\u001b[1;32m 3738\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscale_wrt_gas\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 3740\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3742\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n",
|
| 1043 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:2325\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 2324\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 2325\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m learning_rate \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_lomo_optimizer:\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n",
|
| 1044 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1045 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1046 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 1047 |
+
]
|
| 1048 |
+
}
|
| 1049 |
+
],
|
| 1050 |
+
"source": [
|
| 1051 |
+
"import json\n",
|
| 1052 |
+
"import torch\n",
|
| 1053 |
+
"import os\n",
|
| 1054 |
+
"from transformers import AutoTokenizer\n",
|
| 1055 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 1056 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 1057 |
+
"if 'train_data' not in globals():\n",
|
| 1058 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 1059 |
+
" \n",
|
| 1060 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 1061 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 1062 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 1063 |
+
" else:\n",
|
| 1064 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 1065 |
+
" exit()\n",
|
| 1066 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 1067 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 1068 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 1069 |
+
"\n",
|
| 1070 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1071 |
+
"\n",
|
| 1072 |
+
"\n",
|
| 1073 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 1074 |
+
"\n",
|
| 1075 |
+
"model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
"\n",
|
| 1078 |
+
"from torch.utils.data import Dataset\n",
|
| 1079 |
+
"\n",
|
| 1080 |
+
"class GraphDataset(Dataset):\n",
|
| 1081 |
+
" def __init__(self, data):\n",
|
| 1082 |
+
" self.data = data\n",
|
| 1083 |
+
"\n",
|
| 1084 |
+
" def __len__(self):\n",
|
| 1085 |
+
" return len(self.data)\n",
|
| 1086 |
+
"\n",
|
| 1087 |
+
" def __getitem__(self, idx):\n",
|
| 1088 |
+
" sample = self.data[idx]\n",
|
| 1089 |
+
" return {\n",
|
| 1090 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 1091 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 1092 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 1093 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 1094 |
+
" }\n",
|
| 1095 |
+
"\n",
|
| 1096 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 1097 |
+
"import torch\n",
|
| 1098 |
+
"import torch.nn as nn\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1101 |
+
" def __init__(self, config):\n",
|
| 1102 |
+
" super().__init__(config)\n",
|
| 1103 |
+
" self.model = AutoModelForCausalLM.from_pretrained(config)\n",
|
| 1104 |
+
" \n",
|
| 1105 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 1106 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 1107 |
+
"\n",
|
| 1108 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1109 |
+
" \"\"\"\n",
|
| 1110 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1111 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1112 |
+
" \"\"\"\n",
|
| 1113 |
+
" # ✅ 获取 token embedding\n",
|
| 1114 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1117 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 1118 |
+
"\n",
|
| 1119 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1120 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1121 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1122 |
+
"\n",
|
| 1123 |
+
" # ✅ 调整 attention mask\n",
|
| 1124 |
+
" if attention_mask is not None:\n",
|
| 1125 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1126 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1127 |
+
"\n",
|
| 1128 |
+
" # ✅ 传入模型\n",
|
| 1129 |
+
" outputs = self.model(\n",
|
| 1130 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1131 |
+
" attention_mask=attention_mask,\n",
|
| 1132 |
+
" labels=labels,\n",
|
| 1133 |
+
" )\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
" return outputs\n",
|
| 1136 |
+
"\n",
|
| 1137 |
+
"from transformers import Trainer\n",
|
| 1138 |
+
"\n",
|
| 1139 |
+
"class GraphTrainer(Trainer):\n",
|
| 1140 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 1141 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 1142 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 1143 |
+
" labels = inputs[\"labels\"]\n",
|
| 1144 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
" if graph_embedding is not None:\n",
|
| 1147 |
+
" outputs = model(\n",
|
| 1148 |
+
" input_ids=input_ids,\n",
|
| 1149 |
+
" attention_mask=attention_mask,\n",
|
| 1150 |
+
" labels=labels,\n",
|
| 1151 |
+
" graph_embedding=graph_embedding, \n",
|
| 1152 |
+
" )\n",
|
| 1153 |
+
" else:\n",
|
| 1154 |
+
" outputs = model(\n",
|
| 1155 |
+
" input_ids=input_ids,\n",
|
| 1156 |
+
" attention_mask=attention_mask,\n",
|
| 1157 |
+
" labels=labels,\n",
|
| 1158 |
+
" )\n",
|
| 1159 |
+
"\n",
|
| 1160 |
+
" loss = outputs.loss\n",
|
| 1161 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 1162 |
+
"\n",
|
| 1163 |
+
"\n",
|
| 1164 |
+
"\n",
|
| 1165 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 1166 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 1167 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 1168 |
+
"\n",
|
| 1169 |
+
"# ✅ 训练参数\n",
|
| 1170 |
+
"training_args = TrainingArguments(\n",
|
| 1171 |
+
" output_dir=\"./results\",\n",
|
| 1172 |
+
" per_device_train_batch_size=7,\n",
|
| 1173 |
+
" eval_strategy=\"no\",\n",
|
| 1174 |
+
" save_strategy=\"steps\",\n",
|
| 1175 |
+
" save_steps=3000,\n",
|
| 1176 |
+
" logging_steps=50,\n",
|
| 1177 |
+
" fp16=True,\n",
|
| 1178 |
+
" optim=\"galore_adamw\",\n",
|
| 1179 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 1180 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 1181 |
+
" warmup_steps=1000,\n",
|
| 1182 |
+
" num_train_epochs=3,\n",
|
| 1183 |
+
" push_to_hub=True,\n",
|
| 1184 |
+
" hub_model_id=HF_NAME,\n",
|
| 1185 |
+
" hub_strategy=\"every_save\",\n",
|
| 1186 |
+
" run_name = \"experi0304\"\n",
|
| 1187 |
+
")\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
"\n",
|
| 1190 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 1191 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 1192 |
+
"\n",
|
| 1193 |
+
"# ✅ 训练\n",
|
| 1194 |
+
"trainer = GraphTrainer(\n",
|
| 1195 |
+
" model=model,\n",
|
| 1196 |
+
" args=training_args,\n",
|
| 1197 |
+
" train_dataset=train_dataset,\n",
|
| 1198 |
+
")\n",
|
| 1199 |
+
"\n",
|
| 1200 |
+
"trainer.train()\n",
|
| 1201 |
+
"trainer.push_to_hub()\n",
|
| 1202 |
+
"trainer.save_model(\"/workspace/model\")\n",
|
| 1203 |
+
"\n"
|
| 1204 |
+
]
|
| 1205 |
+
},
|
| 1206 |
+
{
|
| 1207 |
+
"cell_type": "code",
|
| 1208 |
+
"execution_count": 1,
|
| 1209 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1210 |
+
"metadata": {},
|
| 1211 |
+
"outputs": [],
|
| 1212 |
+
"source": [
|
| 1213 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 1214 |
+
"import torch\n",
|
| 1215 |
+
"import torch.nn as nn\n",
|
| 1216 |
+
"\n",
|
| 1217 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1218 |
+
"\n",
|
| 1219 |
+
" \n",
|
| 1220 |
+
" def __init__(self, config):\n",
|
| 1221 |
+
" super().__init__(config)\n",
|
| 1222 |
+
" self.graph_proj = nn.Linear(512, config.hidden_size)\n",
|
| 1223 |
+
"\n",
|
| 1224 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1225 |
+
" \"\"\"\n",
|
| 1226 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1227 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1228 |
+
" \"\"\"\n",
|
| 1229 |
+
" # ✅ 获取 token embedding\n",
|
| 1230 |
+
" inputs_embeds = self.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1231 |
+
"\n",
|
| 1232 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1233 |
+
" graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n",
|
| 1234 |
+
"\n",
|
| 1235 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1236 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1237 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1238 |
+
"\n",
|
| 1239 |
+
" # ✅ 调整 attention mask\n",
|
| 1240 |
+
" if attention_mask is not None:\n",
|
| 1241 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1242 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1243 |
+
"\n",
|
| 1244 |
+
" # ✅ 传入模型\n",
|
| 1245 |
+
" outputs = self.model(\n",
|
| 1246 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1247 |
+
" attention_mask=attention_mask,\n",
|
| 1248 |
+
" labels=labels,\n",
|
| 1249 |
+
" )\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
" return outputs\n",
|
| 1252 |
+
"\n",
|
| 1253 |
+
" @classmethod\n",
|
| 1254 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1255 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1256 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 1257 |
+
" return model\n"
|
| 1258 |
+
]
|
| 1259 |
+
},
|
| 1260 |
+
{
|
| 1261 |
+
"cell_type": "code",
|
| 1262 |
+
"execution_count": 2,
|
| 1263 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1264 |
+
"metadata": {},
|
| 1265 |
+
"outputs": [],
|
| 1266 |
+
"source": [
|
| 1267 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1268 |
+
]
|
| 1269 |
+
},
|
| 1270 |
+
{
|
| 1271 |
+
"cell_type": "code",
|
| 1272 |
+
"execution_count": 3,
|
| 1273 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1274 |
+
"metadata": {},
|
| 1275 |
+
"outputs": [
|
| 1276 |
+
{
|
| 1277 |
+
"name": "stderr",
|
| 1278 |
+
"output_type": "stream",
|
| 1279 |
+
"text": [
|
| 1280 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
|
| 1281 |
+
]
|
| 1282 |
+
},
|
| 1283 |
+
{
|
| 1284 |
+
"data": {
|
| 1285 |
+
"text/plain": [
|
| 1286 |
+
"Qwen2ForCausalLM(\n",
|
| 1287 |
+
" (model): Qwen2Model(\n",
|
| 1288 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1289 |
+
" (layers): ModuleList(\n",
|
| 1290 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1291 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1292 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1293 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1294 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1295 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1296 |
+
" )\n",
|
| 1297 |
+
" (mlp): Qwen2MLP(\n",
|
| 1298 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1299 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1300 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1301 |
+
" (act_fn): SiLU()\n",
|
| 1302 |
+
" )\n",
|
| 1303 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1304 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1305 |
+
" )\n",
|
| 1306 |
+
" )\n",
|
| 1307 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1308 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1309 |
+
" )\n",
|
| 1310 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1311 |
+
" (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1312 |
+
")"
|
| 1313 |
+
]
|
| 1314 |
+
},
|
| 1315 |
+
"execution_count": 3,
|
| 1316 |
+
"metadata": {},
|
| 1317 |
+
"output_type": "execute_result"
|
| 1318 |
+
}
|
| 1319 |
+
],
|
| 1320 |
+
"source": [
|
| 1321 |
+
"import torch\n",
|
| 1322 |
+
"from transformers import AutoTokenizer\n",
|
| 1323 |
+
"\n",
|
| 1324 |
+
"# 加载 tokenizer\n",
|
| 1325 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1326 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1327 |
+
"\n",
|
| 1328 |
+
"# 加载训练好的模型\n",
|
| 1329 |
+
"model_path = \"/workspace/model\"\n",
|
| 1330 |
+
"model = GraphAwareLM.from_pretrained(model_path).to(device)\n",
|
| 1331 |
+
"model.eval() # 设置为推理模式\n"
|
| 1332 |
+
]
|
| 1333 |
+
},
|
| 1334 |
+
{
|
| 1335 |
+
"cell_type": "code",
|
| 1336 |
+
"execution_count": 8,
|
| 1337 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1338 |
+
"metadata": {},
|
| 1339 |
+
"outputs": [],
|
| 1340 |
+
"source": [
|
| 1341 |
+
"import json\n",
|
| 1342 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1343 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1344 |
+
" data = json.load(f)\n",
|
| 1345 |
+
"\n",
|
| 1346 |
+
"test_data = data[0]\n",
|
| 1347 |
+
"\n",
|
| 1348 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1349 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1350 |
+
"\n",
|
| 1351 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).to(device)\n",
|
| 1352 |
+
"\n",
|
| 1353 |
+
"question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1354 |
+
"\n",
|
| 1355 |
+
"from transformers import AutoTokenizer\n",
|
| 1356 |
+
"\n",
|
| 1357 |
+
"# ✅ 输入文本\n",
|
| 1358 |
+
"ROLE_TOKENS = {\n",
|
| 1359 |
+
" \"human\": \"<|User|>\", \n",
|
| 1360 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1361 |
+
"}\n",
|
| 1362 |
+
"GRAPH_LENGTH = 512\n",
|
| 1363 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1364 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1365 |
+
"\n",
|
| 1366 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1367 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1368 |
+
]
|
| 1369 |
+
},
|
| 1370 |
+
{
|
| 1371 |
+
"cell_type": "code",
|
| 1372 |
+
"execution_count": 5,
|
| 1373 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1374 |
+
"metadata": {},
|
| 1375 |
+
"outputs": [
|
| 1376 |
+
{
|
| 1377 |
+
"data": {
|
| 1378 |
+
"text/plain": [
|
| 1379 |
+
"tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1380 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1381 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1382 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1383 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1384 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1385 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1386 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1387 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1388 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1389 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1390 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1391 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1392 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1393 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1394 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1395 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1396 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1397 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1398 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1399 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1400 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1401 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1402 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1403 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1404 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1405 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1406 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1407 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1408 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1409 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1410 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1411 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1412 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1413 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1414 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1415 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1416 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1417 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1418 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1419 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1420 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1421 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1422 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1423 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1424 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1425 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1426 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1427 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1428 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1429 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1430 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1431 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1432 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1433 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1434 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1435 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1436 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1437 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1438 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1439 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1440 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1441 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1442 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n",
|
| 1443 |
+
" device='cuda:0')"
|
| 1444 |
+
]
|
| 1445 |
+
},
|
| 1446 |
+
"execution_count": 5,
|
| 1447 |
+
"metadata": {},
|
| 1448 |
+
"output_type": "execute_result"
|
| 1449 |
+
}
|
| 1450 |
+
],
|
| 1451 |
+
"source": [
|
| 1452 |
+
"graph_embedding"
|
| 1453 |
+
]
|
| 1454 |
+
},
|
| 1455 |
+
{
|
| 1456 |
+
"cell_type": "code",
|
| 1457 |
+
"execution_count": 15,
|
| 1458 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1459 |
+
"metadata": {},
|
| 1460 |
+
"outputs": [
|
| 1461 |
+
{
|
| 1462 |
+
"name": "stdout",
|
| 1463 |
+
"output_type": "stream",
|
| 1464 |
+
"text": [
|
| 1465 |
+
"模型回复: How\n"
|
| 1466 |
+
]
|
| 1467 |
+
}
|
| 1468 |
+
],
|
| 1469 |
+
"source": [
|
| 1470 |
+
"# ✅ 进行前向传播\n",
|
| 1471 |
+
"with torch.no_grad():\n",
|
| 1472 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1473 |
+
"\n",
|
| 1474 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1475 |
+
"logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1476 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1477 |
+
"\n",
|
| 1478 |
+
"# ✅ 反向编码为文本\n",
|
| 1479 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1480 |
+
"\n",
|
| 1481 |
+
"print(\"模型回复:\", response_text)"
|
| 1482 |
+
]
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"cell_type": "code",
|
| 1486 |
+
"execution_count": 9,
|
| 1487 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1488 |
+
"metadata": {},
|
| 1489 |
+
"outputs": [
|
| 1490 |
+
{
|
| 1491 |
+
"name": "stdout",
|
| 1492 |
+
"output_type": "stream",
|
| 1493 |
+
"text": [
|
| 1494 |
+
"Generated Response: Is there any sequential logic in the module, and if so, how is it handled? What are the module's inputs and outputs?\n",
|
| 1495 |
+
"What are the module's inputs and outputs?\n",
|
| 1496 |
+
"What are the module's inputs and outputs?\n",
|
| 1497 |
+
"What are the module's inputs and outputs?\n",
|
| 1498 |
+
"What is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and module's output, and module's input, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module\n"
|
| 1499 |
+
]
|
| 1500 |
+
}
|
| 1501 |
+
],
|
| 1502 |
+
"source": [
|
| 1503 |
+
"max_new_tokens = 1024\n",
|
| 1504 |
+
"generated_ids = input_ids.clone()\n",
|
| 1505 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1506 |
+
"for _ in range(max_new_tokens):\n",
|
| 1507 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1508 |
+
" with torch.no_grad():\n",
|
| 1509 |
+
" outputs = model(\n",
|
| 1510 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1511 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1512 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1513 |
+
" )\n",
|
| 1514 |
+
"\n",
|
| 1515 |
+
"\n",
|
| 1516 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1517 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1518 |
+
" # print(next_token)\n",
|
| 1519 |
+
"\n",
|
| 1520 |
+
"\n",
|
| 1521 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1522 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1523 |
+
"\n",
|
| 1524 |
+
" # print(generated_ids)\n",
|
| 1525 |
+
"\n",
|
| 1526 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1527 |
+
" break\n",
|
| 1528 |
+
"\n",
|
| 1529 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1530 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1531 |
+
" ) \n",
|
| 1532 |
+
"\n",
|
| 1533 |
+
"# ✅ 解码最终输出\n",
|
| 1534 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1535 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1536 |
+
]
|
| 1537 |
+
},
|
| 1538 |
+
{
|
| 1539 |
+
"cell_type": "code",
|
| 1540 |
+
"execution_count": 10,
|
| 1541 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1542 |
+
"metadata": {},
|
| 1543 |
+
"outputs": [
|
| 1544 |
+
{
|
| 1545 |
+
"data": {
|
| 1546 |
+
"text/plain": [
|
| 1547 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1548 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1549 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1550 |
+
]
|
| 1551 |
+
},
|
| 1552 |
+
"execution_count": 10,
|
| 1553 |
+
"metadata": {},
|
| 1554 |
+
"output_type": "execute_result"
|
| 1555 |
+
}
|
| 1556 |
+
],
|
| 1557 |
+
"source": [
|
| 1558 |
+
"generated_ids"
|
| 1559 |
+
]
|
| 1560 |
+
},
|
| 1561 |
+
{
|
| 1562 |
+
"cell_type": "code",
|
| 1563 |
+
"execution_count": null,
|
| 1564 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1565 |
+
"metadata": {},
|
| 1566 |
+
"outputs": [],
|
| 1567 |
+
"source": []
|
| 1568 |
+
}
|
| 1569 |
+
],
|
| 1570 |
+
"metadata": {
|
| 1571 |
+
"kernelspec": {
|
| 1572 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1573 |
+
"language": "python",
|
| 1574 |
+
"name": "python3"
|
| 1575 |
+
},
|
| 1576 |
+
"language_info": {
|
| 1577 |
+
"codemirror_mode": {
|
| 1578 |
+
"name": "ipython",
|
| 1579 |
+
"version": 3
|
| 1580 |
+
},
|
| 1581 |
+
"file_extension": ".py",
|
| 1582 |
+
"mimetype": "text/x-python",
|
| 1583 |
+
"name": "python",
|
| 1584 |
+
"nbconvert_exporter": "python",
|
| 1585 |
+
"pygments_lexer": "ipython3",
|
| 1586 |
+
"version": "3.10.12"
|
| 1587 |
+
}
|
| 1588 |
+
},
|
| 1589 |
+
"nbformat": 4,
|
| 1590 |
+
"nbformat_minor": 5
|
| 1591 |
+
}
|
graph_train2.ipynb
ADDED
|
@@ -0,0 +1,1506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new2\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 3,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 4,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_111730-i9v1vlu1</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">experi030402</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:34:08, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.319300</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>3.641300</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>1.521800</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>1.027500</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>0.922400</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>0.866900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>0.800500</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>0.721600</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>0.740400</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>0.737000</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.713500</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.747000</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>0.869500</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.473300</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>0.753000</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>0.741300</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>0.751400</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>0.787600</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>0.783200</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>0.780200</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>1.012900</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>1.411700</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>1.536400</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>0.853800</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>0.756500</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>0.750800</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>0.747400</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>0.844400</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>0.858400</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>1.053400</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>1.591600</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>1.498900</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>1.471700</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>1.221100</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>1.802300</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>1.826000</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.857300</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.561800</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>1.398800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>1.398900</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>1.381600</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>0.890300</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>0.763700</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>0.753100</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>0.745500</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>1.186100</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>0.862000</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>1.024600</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>1.028400</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>1.008500</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>0.942800</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>0.849700</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>0.771400</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>0.794100</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>0.819200</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>0.937500</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>1.064500</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>1.189300</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.071100</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.003300</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>1.073900</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.043100</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>1.282600</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.145400</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>1.925800</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>2.005600</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.122600</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.163000</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.046600</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.152200</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.151700</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>5.394900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>4.677800</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>4.122200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>3.710200</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>3.350800</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>3.126300</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.988700</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>2.872000</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.848200</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.823900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.781200</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.735000</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.725900</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>2.644400</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.700000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>2.650100</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.704500</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>2.596700</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>2.510500</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>2.515800</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.498100</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.458900</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.449700</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.425000</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.362300</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>2.232000</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.361500</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>2.302300</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.333900</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.367200</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>2.288300</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.426100</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.344100</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>2.283500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>2.296500</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2/commit/291285a5f2155c79a0da893645d8df9bbca98f63', commit_message='End of training', commit_description='', oid='291285a5f2155c79a0da893645d8df9bbca98f63', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new2'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, pretrained_model_name_or_path):\n",
|
| 662 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 663 |
+
" \n",
|
| 664 |
+
" # ✅ 载入 `MODEL_NAME` 预训练模型\n",
|
| 665 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" \n",
|
| 668 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 669 |
+
" self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
|
| 670 |
+
"\n",
|
| 671 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 672 |
+
" \"\"\"\n",
|
| 673 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 674 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 675 |
+
" \"\"\"\n",
|
| 676 |
+
" # ✅ 获取 token embedding\n",
|
| 677 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 680 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 681 |
+
"\n",
|
| 682 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 683 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 684 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 685 |
+
"\n",
|
| 686 |
+
" # ✅ 调整 attention mask\n",
|
| 687 |
+
" if attention_mask is not None:\n",
|
| 688 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 689 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" # ✅ 传入模型\n",
|
| 692 |
+
" outputs = self.model(\n",
|
| 693 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 694 |
+
" attention_mask=attention_mask,\n",
|
| 695 |
+
" labels=labels,\n",
|
| 696 |
+
" )\n",
|
| 697 |
+
"\n",
|
| 698 |
+
" return outputs\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"from transformers import Trainer\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"class GraphTrainer(Trainer):\n",
|
| 703 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 704 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 705 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 706 |
+
" labels = inputs[\"labels\"]\n",
|
| 707 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 708 |
+
"\n",
|
| 709 |
+
" if graph_embedding is not None:\n",
|
| 710 |
+
" outputs = model(\n",
|
| 711 |
+
" input_ids=input_ids,\n",
|
| 712 |
+
" attention_mask=attention_mask,\n",
|
| 713 |
+
" labels=labels,\n",
|
| 714 |
+
" graph_embedding=graph_embedding, \n",
|
| 715 |
+
" )\n",
|
| 716 |
+
" else:\n",
|
| 717 |
+
" outputs = model(\n",
|
| 718 |
+
" input_ids=input_ids,\n",
|
| 719 |
+
" attention_mask=attention_mask,\n",
|
| 720 |
+
" labels=labels,\n",
|
| 721 |
+
" )\n",
|
| 722 |
+
"\n",
|
| 723 |
+
" loss = outputs.loss\n",
|
| 724 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 725 |
+
"\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"from transformers import AutoConfig\n",
|
| 728 |
+
"\n",
|
| 729 |
+
"# ✅ 载入微调模型\n",
|
| 730 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"# # 1. 加载模型的配置\n",
|
| 733 |
+
"# config = AutoConfig.from_pretrained(MODEL_NAME)\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"# # 2. 使用配置创建 GraphAwareLM 实例\n",
|
| 736 |
+
"# model = GraphAwareLM.from_config(config) \n",
|
| 737 |
+
"\n",
|
| 738 |
+
"# pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 739 |
+
"# model.load_state_dict(pretrained_model.state_dict(), strict=False)\n",
|
| 740 |
+
"\n",
|
| 741 |
+
"# ✅ 载入修改后的 `GraphAwareLM` 模型\n",
|
| 742 |
+
"# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 743 |
+
"# model.config.use_sliding_window_attention = False\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"# ✅ 训练参数\n",
|
| 746 |
+
"training_args = TrainingArguments(\n",
|
| 747 |
+
" output_dir=\"./results2\",\n",
|
| 748 |
+
" per_device_train_batch_size=7,\n",
|
| 749 |
+
" eval_strategy=\"no\",\n",
|
| 750 |
+
" save_strategy=\"steps\",\n",
|
| 751 |
+
" save_steps=3000,\n",
|
| 752 |
+
" logging_steps=50,\n",
|
| 753 |
+
" bf16=True,\n",
|
| 754 |
+
" optim=\"galore_adamw\",\n",
|
| 755 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 756 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 757 |
+
" warmup_steps=1000,\n",
|
| 758 |
+
" num_train_epochs=3,\n",
|
| 759 |
+
" push_to_hub=True,\n",
|
| 760 |
+
" hub_model_id=HF_NAME,\n",
|
| 761 |
+
" hub_strategy=\"every_save\",\n",
|
| 762 |
+
" run_name = \"experi030402\"\n",
|
| 763 |
+
")\n",
|
| 764 |
+
"\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 767 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"# ✅ 训练\n",
|
| 770 |
+
"trainer = GraphTrainer(\n",
|
| 771 |
+
" model=model,\n",
|
| 772 |
+
" args=training_args,\n",
|
| 773 |
+
" train_dataset=train_dataset,\n",
|
| 774 |
+
")\n",
|
| 775 |
+
"\n",
|
| 776 |
+
"trainer.train()\n",
|
| 777 |
+
"trainer.save_model(\"/workspace/model2\")\n",
|
| 778 |
+
"trainer.push_to_hub()\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"\n"
|
| 781 |
+
]
|
| 782 |
+
},
|
| 783 |
+
{
|
| 784 |
+
"cell_type": "code",
|
| 785 |
+
"execution_count": 7,
|
| 786 |
+
"id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
|
| 787 |
+
"metadata": {},
|
| 788 |
+
"outputs": [
|
| 789 |
+
{
|
| 790 |
+
"data": {
|
| 791 |
+
"text/plain": [
|
| 792 |
+
"RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb')"
|
| 793 |
+
]
|
| 794 |
+
},
|
| 795 |
+
"execution_count": 7,
|
| 796 |
+
"metadata": {},
|
| 797 |
+
"output_type": "execute_result"
|
| 798 |
+
}
|
| 799 |
+
],
|
| 800 |
+
"source": [
|
| 801 |
+
"from huggingface_hub import HfApi\n",
|
| 802 |
+
"\n",
|
| 803 |
+
"api = HfApi()\n",
|
| 804 |
+
"repo_name = \"r1q1.5_graph_lora-wandb\" # 你的模型名称\n",
|
| 805 |
+
"api.create_repo(repo_name, exist_ok=True)"
|
| 806 |
+
]
|
| 807 |
+
},
|
| 808 |
+
{
|
| 809 |
+
"cell_type": "code",
|
| 810 |
+
"execution_count": 8,
|
| 811 |
+
"id": "73c434b9-5d58-4819-8526-24aa18ca1010",
|
| 812 |
+
"metadata": {},
|
| 813 |
+
"outputs": [
|
| 814 |
+
{
|
| 815 |
+
"data": {
|
| 816 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 817 |
+
"model_id": "2bf786e437064435b543c4b364404933",
|
| 818 |
+
"version_major": 2,
|
| 819 |
+
"version_minor": 0
|
| 820 |
+
},
|
| 821 |
+
"text/plain": [
|
| 822 |
+
"run-v0v96nik.wandb: 0%| | 0.00/582k [00:00<?, ?B/s]"
|
| 823 |
+
]
|
| 824 |
+
},
|
| 825 |
+
"metadata": {},
|
| 826 |
+
"output_type": "display_data"
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"data": {
|
| 830 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 831 |
+
"model_id": "d8d8867a8978418cbba012ae48c6a461",
|
| 832 |
+
"version_major": 2,
|
| 833 |
+
"version_minor": 0
|
| 834 |
+
},
|
| 835 |
+
"text/plain": [
|
| 836 |
+
"run-i9v1vlu1.wandb: 0%| | 0.00/617k [00:00<?, ?B/s]"
|
| 837 |
+
]
|
| 838 |
+
},
|
| 839 |
+
"metadata": {},
|
| 840 |
+
"output_type": "display_data"
|
| 841 |
+
},
|
| 842 |
+
{
|
| 843 |
+
"data": {
|
| 844 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 845 |
+
"model_id": "aa41d13f7f204554a401f018f535d83a",
|
| 846 |
+
"version_major": 2,
|
| 847 |
+
"version_minor": 0
|
| 848 |
+
},
|
| 849 |
+
"text/plain": [
|
| 850 |
+
"Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]"
|
| 851 |
+
]
|
| 852 |
+
},
|
| 853 |
+
"metadata": {},
|
| 854 |
+
"output_type": "display_data"
|
| 855 |
+
},
|
| 856 |
+
{
|
| 857 |
+
"data": {
|
| 858 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 859 |
+
"model_id": "4f6a00ed3d4e43c9806cb5050b812bf8",
|
| 860 |
+
"version_major": 2,
|
| 861 |
+
"version_minor": 0
|
| 862 |
+
},
|
| 863 |
+
"text/plain": [
|
| 864 |
+
"run-e0v0giuw.wandb: 0%| | 0.00/616k [00:00<?, ?B/s]"
|
| 865 |
+
]
|
| 866 |
+
},
|
| 867 |
+
"metadata": {},
|
| 868 |
+
"output_type": "display_data"
|
| 869 |
+
},
|
| 870 |
+
{
|
| 871 |
+
"data": {
|
| 872 |
+
"text/plain": [
|
| 873 |
+
"CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb/commit/81d72bb1534aa8769166ca2e2dd6f4a657ab3742', commit_message='upload wandb', commit_description='', oid='81d72bb1534aa8769166ca2e2dd6f4a657ab3742', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb'), pr_revision=None, pr_num=None)"
|
| 874 |
+
]
|
| 875 |
+
},
|
| 876 |
+
"execution_count": 8,
|
| 877 |
+
"metadata": {},
|
| 878 |
+
"output_type": "execute_result"
|
| 879 |
+
}
|
| 880 |
+
],
|
| 881 |
+
"source": [
|
| 882 |
+
"from huggingface_hub import upload_folder\n",
|
| 883 |
+
"\n",
|
| 884 |
+
"upload_folder(\n",
|
| 885 |
+
" folder_path = \"/workspace/wandb\",\n",
|
| 886 |
+
" repo_id = \"YiFzhao/r1q1.5_graph_lora-wandb\",\n",
|
| 887 |
+
" commit_message = \"upload wandb\",\n",
|
| 888 |
+
")"
|
| 889 |
+
]
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "code",
|
| 893 |
+
"execution_count": 5,
|
| 894 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 895 |
+
"metadata": {},
|
| 896 |
+
"outputs": [
|
| 897 |
+
{
|
| 898 |
+
"name": "stdout",
|
| 899 |
+
"output_type": "stream",
|
| 900 |
+
"text": [
|
| 901 |
+
"🚀 处理后数据条数: 12384\n",
|
| 902 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 903 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 904 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 905 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 906 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 907 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 908 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 909 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 910 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 911 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 912 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 913 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 914 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 915 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 916 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 917 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 918 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 919 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 920 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 921 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 922 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 923 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 924 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 925 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 926 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 927 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 928 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 929 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 930 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 931 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 932 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 933 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 934 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 935 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 936 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 937 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 938 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 939 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 940 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 941 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 942 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 943 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 944 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 945 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 946 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 947 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 948 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 949 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 950 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 951 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 952 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 953 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 954 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 955 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 956 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 957 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 958 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 959 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 960 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 961 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 962 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 963 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 964 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 965 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 966 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 967 |
+
]
|
| 968 |
+
}
|
| 969 |
+
],
|
| 970 |
+
"source": [
|
| 971 |
+
"import json\n",
|
| 972 |
+
"import torch\n",
|
| 973 |
+
"from transformers import AutoTokenizer\n",
|
| 974 |
+
"\n",
|
| 975 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 976 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 977 |
+
"\n",
|
| 978 |
+
"json_path = \"final_Graph.json\"\n",
|
| 979 |
+
"with open(json_path, \"r\") as f:\n",
|
| 980 |
+
" data = json.load(f)\n",
|
| 981 |
+
"\n",
|
| 982 |
+
"train_data = []\n",
|
| 983 |
+
"\n",
|
| 984 |
+
"\n",
|
| 985 |
+
"for sample in data:\n",
|
| 986 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 987 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 988 |
+
"\n",
|
| 989 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 990 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 991 |
+
" continue\n",
|
| 992 |
+
"\n",
|
| 993 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 994 |
+
"\n",
|
| 995 |
+
" #拼接所有对话\n",
|
| 996 |
+
" dialogue_text = \"\"\n",
|
| 997 |
+
" for conv in conversations:\n",
|
| 998 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 999 |
+
" content = conv[\"value\"]\n",
|
| 1000 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 1001 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 1002 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 1003 |
+
"\n",
|
| 1004 |
+
" tokenized = tokenizer(\n",
|
| 1005 |
+
" dialogue_text,\n",
|
| 1006 |
+
" padding=\"max_length\",\n",
|
| 1007 |
+
" truncation=True,\n",
|
| 1008 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 1009 |
+
" return_tensors=\"pt\",\n",
|
| 1010 |
+
" )\n",
|
| 1011 |
+
"\n",
|
| 1012 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 1013 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 1014 |
+
"\n",
|
| 1015 |
+
" train_data.append({\n",
|
| 1016 |
+
" \"input_ids\": input_ids,\n",
|
| 1017 |
+
" \"attention_mask\": attention_mask,\n",
|
| 1018 |
+
" \"labels\": input_ids.clone(),\n",
|
| 1019 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 1020 |
+
" })\n",
|
| 1021 |
+
"\n",
|
| 1022 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 1023 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 1024 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 1025 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 1026 |
+
]
|
| 1027 |
+
},
|
| 1028 |
+
{
|
| 1029 |
+
"cell_type": "code",
|
| 1030 |
+
"execution_count": 10,
|
| 1031 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1032 |
+
"metadata": {},
|
| 1033 |
+
"outputs": [],
|
| 1034 |
+
"source": [
|
| 1035 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 1036 |
+
"import torch\n",
|
| 1037 |
+
"import torch.nn as nn\n",
|
| 1038 |
+
"\n",
|
| 1039 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1040 |
+
" def __init__(self, pretrained_model_name_or_path):\n",
|
| 1041 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 1042 |
+
" \n",
|
| 1043 |
+
" # ✅ 载入 `MODEL_NAME` 预训练模型\n",
|
| 1044 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 1045 |
+
"\n",
|
| 1046 |
+
" \n",
|
| 1047 |
+
" # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n",
|
| 1048 |
+
" self.graph_proj = nn.Linear(512, self.config.hidden_size)\n",
|
| 1049 |
+
"\n",
|
| 1050 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1051 |
+
" \"\"\"\n",
|
| 1052 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1053 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1054 |
+
" \"\"\"\n",
|
| 1055 |
+
" # ✅ 获取 token embedding\n",
|
| 1056 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1057 |
+
"\n",
|
| 1058 |
+
" # ✅ 变换 graph embedding 到 hidden_size\n",
|
| 1059 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n",
|
| 1060 |
+
"\n",
|
| 1061 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1062 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1063 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1064 |
+
"\n",
|
| 1065 |
+
" # ✅ 调整 attention mask\n",
|
| 1066 |
+
" if attention_mask is not None:\n",
|
| 1067 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1068 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1069 |
+
"\n",
|
| 1070 |
+
" # ✅ 传入模型\n",
|
| 1071 |
+
" outputs = self.model(\n",
|
| 1072 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1073 |
+
" attention_mask=attention_mask,\n",
|
| 1074 |
+
" labels=labels,\n",
|
| 1075 |
+
" )\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
" return outputs\n",
|
| 1078 |
+
"\n",
|
| 1079 |
+
" def generate_with_graph(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
|
| 1080 |
+
" \"\"\"\n",
|
| 1081 |
+
" ✅ 自定义 `generate()`,支持 `graph_embedding`\n",
|
| 1082 |
+
" `input_text`: 需要生成文本的输入\n",
|
| 1083 |
+
" `graph_embedding`: 形状为 (1, 512) 的张量\n",
|
| 1084 |
+
" \"\"\"\n",
|
| 1085 |
+
" # ✅ 2. 处理 `graph_embedding`\n",
|
| 1086 |
+
" graph_embedding_token = self.graph_proj(graph_embedding) # (1, hidden_size)\n",
|
| 1087 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (1, 1, hidden_size)\n",
|
| 1088 |
+
"\n",
|
| 1089 |
+
" # ��� 3. 获取 Token Embeddings 并拼接\n",
|
| 1090 |
+
" inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
|
| 1091 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
" # ✅ 4. 调整 `attention_mask`\n",
|
| 1094 |
+
" if \"attention_mask\" in inputs:\n",
|
| 1095 |
+
" graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
|
| 1096 |
+
" attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
|
| 1097 |
+
" else:\n",
|
| 1098 |
+
" attention_mask = None\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
" # ✅ 5. 进行文本生成\n",
|
| 1101 |
+
" with torch.no_grad():\n",
|
| 1102 |
+
" output_ids = self.model.generate(\n",
|
| 1103 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1104 |
+
" attention_mask=attention_mask,\n",
|
| 1105 |
+
" max_length=max_length,\n",
|
| 1106 |
+
" temperature=temperature,\n",
|
| 1107 |
+
" top_k=top_k,\n",
|
| 1108 |
+
" top_p=top_p,\n",
|
| 1109 |
+
" num_return_sequences=1\n",
|
| 1110 |
+
" )\n",
|
| 1111 |
+
"\n",
|
| 1112 |
+
" # ✅ 6. 解码生成的文本\n",
|
| 1113 |
+
" generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
|
| 1114 |
+
" return generated_text\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
" @classmethod\n",
|
| 1117 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1118 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1119 |
+
" model.graph_proj = nn.Linear(512, model.config.hidden_size)\n",
|
| 1120 |
+
" return model"
|
| 1121 |
+
]
|
| 1122 |
+
},
|
| 1123 |
+
{
|
| 1124 |
+
"cell_type": "code",
|
| 1125 |
+
"execution_count": 11,
|
| 1126 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1127 |
+
"metadata": {},
|
| 1128 |
+
"outputs": [],
|
| 1129 |
+
"source": [
|
| 1130 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1131 |
+
]
|
| 1132 |
+
},
|
| 1133 |
+
{
|
| 1134 |
+
"cell_type": "code",
|
| 1135 |
+
"execution_count": 12,
|
| 1136 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1137 |
+
"metadata": {},
|
| 1138 |
+
"outputs": [
|
| 1139 |
+
{
|
| 1140 |
+
"data": {
|
| 1141 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1142 |
+
"model_id": "7ad289c5523340f39799ad11e3bc1bb5",
|
| 1143 |
+
"version_major": 2,
|
| 1144 |
+
"version_minor": 0
|
| 1145 |
+
},
|
| 1146 |
+
"text/plain": [
|
| 1147 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1148 |
+
]
|
| 1149 |
+
},
|
| 1150 |
+
"metadata": {},
|
| 1151 |
+
"output_type": "display_data"
|
| 1152 |
+
},
|
| 1153 |
+
{
|
| 1154 |
+
"data": {
|
| 1155 |
+
"text/plain": [
|
| 1156 |
+
"Qwen2ForCausalLM(\n",
|
| 1157 |
+
" (model): Qwen2Model(\n",
|
| 1158 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1159 |
+
" (layers): ModuleList(\n",
|
| 1160 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1161 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1162 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1163 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1164 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1165 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1166 |
+
" )\n",
|
| 1167 |
+
" (mlp): Qwen2MLP(\n",
|
| 1168 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1169 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1170 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1171 |
+
" (act_fn): SiLU()\n",
|
| 1172 |
+
" )\n",
|
| 1173 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1174 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1175 |
+
" )\n",
|
| 1176 |
+
" )\n",
|
| 1177 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1178 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1179 |
+
" )\n",
|
| 1180 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1181 |
+
" (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1182 |
+
")"
|
| 1183 |
+
]
|
| 1184 |
+
},
|
| 1185 |
+
"execution_count": 12,
|
| 1186 |
+
"metadata": {},
|
| 1187 |
+
"output_type": "execute_result"
|
| 1188 |
+
}
|
| 1189 |
+
],
|
| 1190 |
+
"source": [
|
| 1191 |
+
"import torch\n",
|
| 1192 |
+
"from transformers import AutoTokenizer\n",
|
| 1193 |
+
"\n",
|
| 1194 |
+
"# 加载 tokenizer\n",
|
| 1195 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1196 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1197 |
+
"\n",
|
| 1198 |
+
"# 加载训练好的模型\n",
|
| 1199 |
+
"model_path = \"/workspace/model2\"\n",
|
| 1200 |
+
"model = GraphAwareLM.from_pretrained(\"/workspace/results2/checkpoint-5310\").to(device)\n",
|
| 1201 |
+
"model.eval() # 设置为推理模式\n"
|
| 1202 |
+
]
|
| 1203 |
+
},
|
| 1204 |
+
{
|
| 1205 |
+
"cell_type": "code",
|
| 1206 |
+
"execution_count": 13,
|
| 1207 |
+
"id": "51995891-8906-4049-9401-2d22e06a84e8",
|
| 1208 |
+
"metadata": {},
|
| 1209 |
+
"outputs": [
|
| 1210 |
+
{
|
| 1211 |
+
"name": "stdout",
|
| 1212 |
+
"output_type": "stream",
|
| 1213 |
+
"text": [
|
| 1214 |
+
"Parameter containing:\n",
|
| 1215 |
+
"tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
|
| 1216 |
+
" [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
|
| 1217 |
+
" [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
|
| 1218 |
+
" ...,\n",
|
| 1219 |
+
" [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
|
| 1220 |
+
" [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
|
| 1221 |
+
" [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
|
| 1222 |
+
" device='cuda:0', requires_grad=True)\n"
|
| 1223 |
+
]
|
| 1224 |
+
}
|
| 1225 |
+
],
|
| 1226 |
+
"source": [
|
| 1227 |
+
"print(model.graph_proj.weight)\n"
|
| 1228 |
+
]
|
| 1229 |
+
},
|
| 1230 |
+
{
|
| 1231 |
+
"cell_type": "code",
|
| 1232 |
+
"execution_count": 14,
|
| 1233 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1234 |
+
"metadata": {},
|
| 1235 |
+
"outputs": [],
|
| 1236 |
+
"source": [
|
| 1237 |
+
"import json\n",
|
| 1238 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1239 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1240 |
+
" data = json.load(f)\n",
|
| 1241 |
+
"\n",
|
| 1242 |
+
"test_data = data[0]\n",
|
| 1243 |
+
"\n",
|
| 1244 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1245 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1246 |
+
"\n",
|
| 1247 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
|
| 1248 |
+
"\n",
|
| 1249 |
+
"question1 = conversations[4][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1250 |
+
"\n",
|
| 1251 |
+
"from transformers import AutoTokenizer\n",
|
| 1252 |
+
"\n",
|
| 1253 |
+
"# ✅ 输入文本\n",
|
| 1254 |
+
"ROLE_TOKENS = {\n",
|
| 1255 |
+
" \"human\": \"<|User|>\", \n",
|
| 1256 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1257 |
+
"}\n",
|
| 1258 |
+
"GRAPH_LENGTH = 512\n",
|
| 1259 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1260 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1261 |
+
"\n",
|
| 1262 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1263 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1264 |
+
]
|
| 1265 |
+
},
|
| 1266 |
+
{
|
| 1267 |
+
"cell_type": "code",
|
| 1268 |
+
"execution_count": 15,
|
| 1269 |
+
"id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
|
| 1270 |
+
"metadata": {},
|
| 1271 |
+
"outputs": [
|
| 1272 |
+
{
|
| 1273 |
+
"ename": "AttributeError",
|
| 1274 |
+
"evalue": "'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'",
|
| 1275 |
+
"output_type": "error",
|
| 1276 |
+
"traceback": [
|
| 1277 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1278 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 1279 |
+
"Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_with_graph\u001b[49m(inputs, graph_embedding)\n",
|
| 1280 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 1281 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'"
|
| 1282 |
+
]
|
| 1283 |
+
}
|
| 1284 |
+
],
|
| 1285 |
+
"source": [
|
| 1286 |
+
"generated_text = model.generate_with_graph(inputs, graph_embedding)"
|
| 1287 |
+
]
|
| 1288 |
+
},
|
| 1289 |
+
{
|
| 1290 |
+
"cell_type": "code",
|
| 1291 |
+
"execution_count": 5,
|
| 1292 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1293 |
+
"metadata": {},
|
| 1294 |
+
"outputs": [
|
| 1295 |
+
{
|
| 1296 |
+
"data": {
|
| 1297 |
+
"text/plain": [
|
| 1298 |
+
"tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1299 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1300 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1301 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1302 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1303 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1304 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1305 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1306 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1307 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1308 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1309 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1310 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1311 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1312 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1313 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1314 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1315 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1316 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1317 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1318 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1319 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1320 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1321 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1322 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1323 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1324 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1325 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1326 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1327 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1328 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1329 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1330 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1331 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1332 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1333 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1334 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1335 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1336 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1337 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1338 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1339 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1340 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1341 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1342 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1343 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1344 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1345 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1346 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1347 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1348 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1349 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1350 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1351 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1352 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1353 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1354 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1355 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1356 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1357 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1358 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1359 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1360 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1361 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
|
| 1362 |
+
" device='cuda:0')"
|
| 1363 |
+
]
|
| 1364 |
+
},
|
| 1365 |
+
"execution_count": 5,
|
| 1366 |
+
"metadata": {},
|
| 1367 |
+
"output_type": "execute_result"
|
| 1368 |
+
}
|
| 1369 |
+
],
|
| 1370 |
+
"source": [
|
| 1371 |
+
"graph_embedding"
|
| 1372 |
+
]
|
| 1373 |
+
},
|
| 1374 |
+
{
|
| 1375 |
+
"cell_type": "code",
|
| 1376 |
+
"execution_count": 15,
|
| 1377 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1378 |
+
"metadata": {},
|
| 1379 |
+
"outputs": [
|
| 1380 |
+
{
|
| 1381 |
+
"name": "stdout",
|
| 1382 |
+
"output_type": "stream",
|
| 1383 |
+
"text": [
|
| 1384 |
+
"模型回复: How\n"
|
| 1385 |
+
]
|
| 1386 |
+
}
|
| 1387 |
+
],
|
| 1388 |
+
"source": [
|
| 1389 |
+
"# ✅ 进行前向传播\n",
|
| 1390 |
+
"with torch.no_grad():\n",
|
| 1391 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1392 |
+
"\n",
|
| 1393 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1394 |
+
"logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1395 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1396 |
+
"\n",
|
| 1397 |
+
"# ✅ 反向编码为文本\n",
|
| 1398 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1399 |
+
"\n",
|
| 1400 |
+
"print(\"模型回复:\", response_text)"
|
| 1401 |
+
]
|
| 1402 |
+
},
|
| 1403 |
+
{
|
| 1404 |
+
"cell_type": "code",
|
| 1405 |
+
"execution_count": 17,
|
| 1406 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1407 |
+
"metadata": {},
|
| 1408 |
+
"outputs": [
|
| 1409 |
+
{
|
| 1410 |
+
"name": "stdout",
|
| 1411 |
+
"output_type": "stream",
|
| 1412 |
+
"text": [
|
| 1413 |
+
"Generated Response: Is there any sequential logic in the module, and if so, how is it handled? `data` is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit data, and the output is the output of the `data` is a 1-bit\n"
|
| 1414 |
+
]
|
| 1415 |
+
}
|
| 1416 |
+
],
|
| 1417 |
+
"source": [
|
| 1418 |
+
"max_new_tokens = 1024\n",
|
| 1419 |
+
"generated_ids = input_ids.clone()\n",
|
| 1420 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1421 |
+
"for _ in range(max_new_tokens):\n",
|
| 1422 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1423 |
+
" with torch.no_grad():\n",
|
| 1424 |
+
" outputs = model(\n",
|
| 1425 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1426 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1427 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1428 |
+
" )\n",
|
| 1429 |
+
"\n",
|
| 1430 |
+
"\n",
|
| 1431 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1432 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1433 |
+
" # print(next_token)\n",
|
| 1434 |
+
"\n",
|
| 1435 |
+
"\n",
|
| 1436 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1437 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1438 |
+
"\n",
|
| 1439 |
+
" # print(generated_ids)\n",
|
| 1440 |
+
"\n",
|
| 1441 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1442 |
+
" break\n",
|
| 1443 |
+
"\n",
|
| 1444 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1445 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1446 |
+
" ) \n",
|
| 1447 |
+
"\n",
|
| 1448 |
+
"# ✅ 解码最终输出\n",
|
| 1449 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1450 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1451 |
+
]
|
| 1452 |
+
},
|
| 1453 |
+
{
|
| 1454 |
+
"cell_type": "code",
|
| 1455 |
+
"execution_count": 10,
|
| 1456 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1457 |
+
"metadata": {},
|
| 1458 |
+
"outputs": [
|
| 1459 |
+
{
|
| 1460 |
+
"data": {
|
| 1461 |
+
"text/plain": [
|
| 1462 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1463 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1464 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1465 |
+
]
|
| 1466 |
+
},
|
| 1467 |
+
"execution_count": 10,
|
| 1468 |
+
"metadata": {},
|
| 1469 |
+
"output_type": "execute_result"
|
| 1470 |
+
}
|
| 1471 |
+
],
|
| 1472 |
+
"source": [
|
| 1473 |
+
"generated_ids"
|
| 1474 |
+
]
|
| 1475 |
+
},
|
| 1476 |
+
{
|
| 1477 |
+
"cell_type": "code",
|
| 1478 |
+
"execution_count": null,
|
| 1479 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1480 |
+
"metadata": {},
|
| 1481 |
+
"outputs": [],
|
| 1482 |
+
"source": []
|
| 1483 |
+
}
|
| 1484 |
+
],
|
| 1485 |
+
"metadata": {
|
| 1486 |
+
"kernelspec": {
|
| 1487 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1488 |
+
"language": "python",
|
| 1489 |
+
"name": "python3"
|
| 1490 |
+
},
|
| 1491 |
+
"language_info": {
|
| 1492 |
+
"codemirror_mode": {
|
| 1493 |
+
"name": "ipython",
|
| 1494 |
+
"version": 3
|
| 1495 |
+
},
|
| 1496 |
+
"file_extension": ".py",
|
| 1497 |
+
"mimetype": "text/x-python",
|
| 1498 |
+
"name": "python",
|
| 1499 |
+
"nbconvert_exporter": "python",
|
| 1500 |
+
"pygments_lexer": "ipython3",
|
| 1501 |
+
"version": "3.10.12"
|
| 1502 |
+
}
|
| 1503 |
+
},
|
| 1504 |
+
"nbformat": 4,
|
| 1505 |
+
"nbformat_minor": 5
|
| 1506 |
+
}
|
graph_train3.ipynb
ADDED
|
@@ -0,0 +1,1588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"ROLE_TOKENS = {\n",
|
| 11 |
+
" \"human\": \"<|User|>\", \n",
|
| 12 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 13 |
+
"}\n",
|
| 14 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n",
|
| 15 |
+
"GRAPH_LENGTH = 512\n",
|
| 16 |
+
"HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new3\""
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 2,
|
| 22 |
+
"id": "bba6e6db-4b79-4461-ba13-75fd41019358",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"CUDA 可用: True\n",
|
| 30 |
+
"GPU 数量: 1\n",
|
| 31 |
+
"当前 GPU: 0\n",
|
| 32 |
+
"GPU 名称: NVIDIA A100 80GB PCIe\n"
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"source": [
|
| 37 |
+
"# !pip install transformers accelerate datasets\n",
|
| 38 |
+
"# !pip install galora\n",
|
| 39 |
+
"# !pip install huggingface_hub\n",
|
| 40 |
+
"import torch\n",
|
| 41 |
+
"print(\"CUDA 可用:\", torch.cuda.is_available())\n",
|
| 42 |
+
"print(\"GPU 数量:\", torch.cuda.device_count())\n",
|
| 43 |
+
"print(\"当前 GPU:\", torch.cuda.current_device())\n",
|
| 44 |
+
"print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 3,
|
| 50 |
+
"id": "ef5551ca-89e2-4488-8e68-1c8d964de039",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
+
"id": "8e283f49-fde4-46e2-9891-dbc304058f0a",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"train_data 重新加载成功,数据量: 12384\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"name": "stderr",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n",
|
| 75 |
+
"/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
| 76 |
+
" warnings.warn(\n",
|
| 77 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
|
| 78 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"text/html": [
|
| 84 |
+
"Tracking run with wandb version 0.19.7"
|
| 85 |
+
],
|
| 86 |
+
"text/plain": [
|
| 87 |
+
"<IPython.core.display.HTML object>"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"output_type": "display_data"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"data": {
|
| 95 |
+
"text/html": [
|
| 96 |
+
"Run data is saved locally in <code>/workspace/wandb/run-20250304_134403-e0v0giuw</code>"
|
| 97 |
+
],
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"<IPython.core.display.HTML object>"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"output_type": "display_data"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"data": {
|
| 107 |
+
"text/html": [
|
| 108 |
+
"Syncing run <strong><a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">experi030403</a></strong> to <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
|
| 109 |
+
],
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"<IPython.core.display.HTML object>"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"data": {
|
| 119 |
+
"text/html": [
|
| 120 |
+
" View project at <a href='https://wandb.ai/yifang_zhao/huggingface' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface</a>"
|
| 121 |
+
],
|
| 122 |
+
"text/plain": [
|
| 123 |
+
"<IPython.core.display.HTML object>"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "display_data"
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"data": {
|
| 131 |
+
"text/html": [
|
| 132 |
+
" View run at <a href='https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw' target=\"_blank\">https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw</a>"
|
| 133 |
+
],
|
| 134 |
+
"text/plain": [
|
| 135 |
+
"<IPython.core.display.HTML object>"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"output_type": "display_data"
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"data": {
|
| 143 |
+
"text/html": [
|
| 144 |
+
"\n",
|
| 145 |
+
" <div>\n",
|
| 146 |
+
" \n",
|
| 147 |
+
" <progress value='5310' max='5310' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
| 148 |
+
" [5310/5310 1:33:59, Epoch 3/3]\n",
|
| 149 |
+
" </div>\n",
|
| 150 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
| 151 |
+
" <thead>\n",
|
| 152 |
+
" <tr style=\"text-align: left;\">\n",
|
| 153 |
+
" <th>Step</th>\n",
|
| 154 |
+
" <th>Training Loss</th>\n",
|
| 155 |
+
" </tr>\n",
|
| 156 |
+
" </thead>\n",
|
| 157 |
+
" <tbody>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <td>50</td>\n",
|
| 160 |
+
" <td>5.319300</td>\n",
|
| 161 |
+
" </tr>\n",
|
| 162 |
+
" <tr>\n",
|
| 163 |
+
" <td>100</td>\n",
|
| 164 |
+
" <td>3.641300</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <td>150</td>\n",
|
| 168 |
+
" <td>1.521800</td>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" <tr>\n",
|
| 171 |
+
" <td>200</td>\n",
|
| 172 |
+
" <td>1.027500</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <td>250</td>\n",
|
| 176 |
+
" <td>0.922400</td>\n",
|
| 177 |
+
" </tr>\n",
|
| 178 |
+
" <tr>\n",
|
| 179 |
+
" <td>300</td>\n",
|
| 180 |
+
" <td>0.866900</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <td>350</td>\n",
|
| 184 |
+
" <td>0.800500</td>\n",
|
| 185 |
+
" </tr>\n",
|
| 186 |
+
" <tr>\n",
|
| 187 |
+
" <td>400</td>\n",
|
| 188 |
+
" <td>0.721600</td>\n",
|
| 189 |
+
" </tr>\n",
|
| 190 |
+
" <tr>\n",
|
| 191 |
+
" <td>450</td>\n",
|
| 192 |
+
" <td>0.740400</td>\n",
|
| 193 |
+
" </tr>\n",
|
| 194 |
+
" <tr>\n",
|
| 195 |
+
" <td>500</td>\n",
|
| 196 |
+
" <td>0.737000</td>\n",
|
| 197 |
+
" </tr>\n",
|
| 198 |
+
" <tr>\n",
|
| 199 |
+
" <td>550</td>\n",
|
| 200 |
+
" <td>0.713500</td>\n",
|
| 201 |
+
" </tr>\n",
|
| 202 |
+
" <tr>\n",
|
| 203 |
+
" <td>600</td>\n",
|
| 204 |
+
" <td>0.747000</td>\n",
|
| 205 |
+
" </tr>\n",
|
| 206 |
+
" <tr>\n",
|
| 207 |
+
" <td>650</td>\n",
|
| 208 |
+
" <td>0.869500</td>\n",
|
| 209 |
+
" </tr>\n",
|
| 210 |
+
" <tr>\n",
|
| 211 |
+
" <td>700</td>\n",
|
| 212 |
+
" <td>1.473300</td>\n",
|
| 213 |
+
" </tr>\n",
|
| 214 |
+
" <tr>\n",
|
| 215 |
+
" <td>750</td>\n",
|
| 216 |
+
" <td>0.753000</td>\n",
|
| 217 |
+
" </tr>\n",
|
| 218 |
+
" <tr>\n",
|
| 219 |
+
" <td>800</td>\n",
|
| 220 |
+
" <td>0.741300</td>\n",
|
| 221 |
+
" </tr>\n",
|
| 222 |
+
" <tr>\n",
|
| 223 |
+
" <td>850</td>\n",
|
| 224 |
+
" <td>0.751400</td>\n",
|
| 225 |
+
" </tr>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <td>900</td>\n",
|
| 228 |
+
" <td>0.787600</td>\n",
|
| 229 |
+
" </tr>\n",
|
| 230 |
+
" <tr>\n",
|
| 231 |
+
" <td>950</td>\n",
|
| 232 |
+
" <td>0.783200</td>\n",
|
| 233 |
+
" </tr>\n",
|
| 234 |
+
" <tr>\n",
|
| 235 |
+
" <td>1000</td>\n",
|
| 236 |
+
" <td>0.780200</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <td>1050</td>\n",
|
| 240 |
+
" <td>1.012900</td>\n",
|
| 241 |
+
" </tr>\n",
|
| 242 |
+
" <tr>\n",
|
| 243 |
+
" <td>1100</td>\n",
|
| 244 |
+
" <td>1.411700</td>\n",
|
| 245 |
+
" </tr>\n",
|
| 246 |
+
" <tr>\n",
|
| 247 |
+
" <td>1150</td>\n",
|
| 248 |
+
" <td>1.536400</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <td>1200</td>\n",
|
| 252 |
+
" <td>0.853800</td>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" <tr>\n",
|
| 255 |
+
" <td>1250</td>\n",
|
| 256 |
+
" <td>0.756500</td>\n",
|
| 257 |
+
" </tr>\n",
|
| 258 |
+
" <tr>\n",
|
| 259 |
+
" <td>1300</td>\n",
|
| 260 |
+
" <td>0.750800</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <td>1350</td>\n",
|
| 264 |
+
" <td>0.747400</td>\n",
|
| 265 |
+
" </tr>\n",
|
| 266 |
+
" <tr>\n",
|
| 267 |
+
" <td>1400</td>\n",
|
| 268 |
+
" <td>0.844400</td>\n",
|
| 269 |
+
" </tr>\n",
|
| 270 |
+
" <tr>\n",
|
| 271 |
+
" <td>1450</td>\n",
|
| 272 |
+
" <td>0.858400</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <td>1500</td>\n",
|
| 276 |
+
" <td>1.053400</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <td>1550</td>\n",
|
| 280 |
+
" <td>1.591600</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <td>1600</td>\n",
|
| 284 |
+
" <td>1.498900</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <td>1650</td>\n",
|
| 288 |
+
" <td>1.471700</td>\n",
|
| 289 |
+
" </tr>\n",
|
| 290 |
+
" <tr>\n",
|
| 291 |
+
" <td>1700</td>\n",
|
| 292 |
+
" <td>1.221100</td>\n",
|
| 293 |
+
" </tr>\n",
|
| 294 |
+
" <tr>\n",
|
| 295 |
+
" <td>1750</td>\n",
|
| 296 |
+
" <td>1.802300</td>\n",
|
| 297 |
+
" </tr>\n",
|
| 298 |
+
" <tr>\n",
|
| 299 |
+
" <td>1800</td>\n",
|
| 300 |
+
" <td>1.826000</td>\n",
|
| 301 |
+
" </tr>\n",
|
| 302 |
+
" <tr>\n",
|
| 303 |
+
" <td>1850</td>\n",
|
| 304 |
+
" <td>1.857300</td>\n",
|
| 305 |
+
" </tr>\n",
|
| 306 |
+
" <tr>\n",
|
| 307 |
+
" <td>1900</td>\n",
|
| 308 |
+
" <td>1.561800</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <td>1950</td>\n",
|
| 312 |
+
" <td>1.398800</td>\n",
|
| 313 |
+
" </tr>\n",
|
| 314 |
+
" <tr>\n",
|
| 315 |
+
" <td>2000</td>\n",
|
| 316 |
+
" <td>1.398900</td>\n",
|
| 317 |
+
" </tr>\n",
|
| 318 |
+
" <tr>\n",
|
| 319 |
+
" <td>2050</td>\n",
|
| 320 |
+
" <td>1.381600</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <td>2100</td>\n",
|
| 324 |
+
" <td>0.890300</td>\n",
|
| 325 |
+
" </tr>\n",
|
| 326 |
+
" <tr>\n",
|
| 327 |
+
" <td>2150</td>\n",
|
| 328 |
+
" <td>0.763700</td>\n",
|
| 329 |
+
" </tr>\n",
|
| 330 |
+
" <tr>\n",
|
| 331 |
+
" <td>2200</td>\n",
|
| 332 |
+
" <td>0.753100</td>\n",
|
| 333 |
+
" </tr>\n",
|
| 334 |
+
" <tr>\n",
|
| 335 |
+
" <td>2250</td>\n",
|
| 336 |
+
" <td>0.745500</td>\n",
|
| 337 |
+
" </tr>\n",
|
| 338 |
+
" <tr>\n",
|
| 339 |
+
" <td>2300</td>\n",
|
| 340 |
+
" <td>1.186100</td>\n",
|
| 341 |
+
" </tr>\n",
|
| 342 |
+
" <tr>\n",
|
| 343 |
+
" <td>2350</td>\n",
|
| 344 |
+
" <td>0.862000</td>\n",
|
| 345 |
+
" </tr>\n",
|
| 346 |
+
" <tr>\n",
|
| 347 |
+
" <td>2400</td>\n",
|
| 348 |
+
" <td>1.024600</td>\n",
|
| 349 |
+
" </tr>\n",
|
| 350 |
+
" <tr>\n",
|
| 351 |
+
" <td>2450</td>\n",
|
| 352 |
+
" <td>1.028400</td>\n",
|
| 353 |
+
" </tr>\n",
|
| 354 |
+
" <tr>\n",
|
| 355 |
+
" <td>2500</td>\n",
|
| 356 |
+
" <td>1.008500</td>\n",
|
| 357 |
+
" </tr>\n",
|
| 358 |
+
" <tr>\n",
|
| 359 |
+
" <td>2550</td>\n",
|
| 360 |
+
" <td>0.942800</td>\n",
|
| 361 |
+
" </tr>\n",
|
| 362 |
+
" <tr>\n",
|
| 363 |
+
" <td>2600</td>\n",
|
| 364 |
+
" <td>0.849700</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <td>2650</td>\n",
|
| 368 |
+
" <td>0.771400</td>\n",
|
| 369 |
+
" </tr>\n",
|
| 370 |
+
" <tr>\n",
|
| 371 |
+
" <td>2700</td>\n",
|
| 372 |
+
" <td>0.794100</td>\n",
|
| 373 |
+
" </tr>\n",
|
| 374 |
+
" <tr>\n",
|
| 375 |
+
" <td>2750</td>\n",
|
| 376 |
+
" <td>0.819200</td>\n",
|
| 377 |
+
" </tr>\n",
|
| 378 |
+
" <tr>\n",
|
| 379 |
+
" <td>2800</td>\n",
|
| 380 |
+
" <td>0.937500</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <td>2850</td>\n",
|
| 384 |
+
" <td>1.064500</td>\n",
|
| 385 |
+
" </tr>\n",
|
| 386 |
+
" <tr>\n",
|
| 387 |
+
" <td>2900</td>\n",
|
| 388 |
+
" <td>1.189300</td>\n",
|
| 389 |
+
" </tr>\n",
|
| 390 |
+
" <tr>\n",
|
| 391 |
+
" <td>2950</td>\n",
|
| 392 |
+
" <td>1.071100</td>\n",
|
| 393 |
+
" </tr>\n",
|
| 394 |
+
" <tr>\n",
|
| 395 |
+
" <td>3000</td>\n",
|
| 396 |
+
" <td>1.003300</td>\n",
|
| 397 |
+
" </tr>\n",
|
| 398 |
+
" <tr>\n",
|
| 399 |
+
" <td>3050</td>\n",
|
| 400 |
+
" <td>1.073900</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <td>3100</td>\n",
|
| 404 |
+
" <td>1.043100</td>\n",
|
| 405 |
+
" </tr>\n",
|
| 406 |
+
" <tr>\n",
|
| 407 |
+
" <td>3150</td>\n",
|
| 408 |
+
" <td>1.282600</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <td>3200</td>\n",
|
| 412 |
+
" <td>2.145400</td>\n",
|
| 413 |
+
" </tr>\n",
|
| 414 |
+
" <tr>\n",
|
| 415 |
+
" <td>3250</td>\n",
|
| 416 |
+
" <td>1.925800</td>\n",
|
| 417 |
+
" </tr>\n",
|
| 418 |
+
" <tr>\n",
|
| 419 |
+
" <td>3300</td>\n",
|
| 420 |
+
" <td>2.005600</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <td>3350</td>\n",
|
| 424 |
+
" <td>2.122600</td>\n",
|
| 425 |
+
" </tr>\n",
|
| 426 |
+
" <tr>\n",
|
| 427 |
+
" <td>3400</td>\n",
|
| 428 |
+
" <td>2.163000</td>\n",
|
| 429 |
+
" </tr>\n",
|
| 430 |
+
" <tr>\n",
|
| 431 |
+
" <td>3450</td>\n",
|
| 432 |
+
" <td>2.046600</td>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" <tr>\n",
|
| 435 |
+
" <td>3500</td>\n",
|
| 436 |
+
" <td>2.152200</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <td>3550</td>\n",
|
| 440 |
+
" <td>2.151700</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <td>3600</td>\n",
|
| 444 |
+
" <td>5.394900</td>\n",
|
| 445 |
+
" </tr>\n",
|
| 446 |
+
" <tr>\n",
|
| 447 |
+
" <td>3650</td>\n",
|
| 448 |
+
" <td>4.677800</td>\n",
|
| 449 |
+
" </tr>\n",
|
| 450 |
+
" <tr>\n",
|
| 451 |
+
" <td>3700</td>\n",
|
| 452 |
+
" <td>4.122200</td>\n",
|
| 453 |
+
" </tr>\n",
|
| 454 |
+
" <tr>\n",
|
| 455 |
+
" <td>3750</td>\n",
|
| 456 |
+
" <td>3.710200</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <td>3800</td>\n",
|
| 460 |
+
" <td>3.350800</td>\n",
|
| 461 |
+
" </tr>\n",
|
| 462 |
+
" <tr>\n",
|
| 463 |
+
" <td>3850</td>\n",
|
| 464 |
+
" <td>3.126300</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <td>3900</td>\n",
|
| 468 |
+
" <td>2.988700</td>\n",
|
| 469 |
+
" </tr>\n",
|
| 470 |
+
" <tr>\n",
|
| 471 |
+
" <td>3950</td>\n",
|
| 472 |
+
" <td>2.872000</td>\n",
|
| 473 |
+
" </tr>\n",
|
| 474 |
+
" <tr>\n",
|
| 475 |
+
" <td>4000</td>\n",
|
| 476 |
+
" <td>2.848200</td>\n",
|
| 477 |
+
" </tr>\n",
|
| 478 |
+
" <tr>\n",
|
| 479 |
+
" <td>4050</td>\n",
|
| 480 |
+
" <td>2.823900</td>\n",
|
| 481 |
+
" </tr>\n",
|
| 482 |
+
" <tr>\n",
|
| 483 |
+
" <td>4100</td>\n",
|
| 484 |
+
" <td>2.781200</td>\n",
|
| 485 |
+
" </tr>\n",
|
| 486 |
+
" <tr>\n",
|
| 487 |
+
" <td>4150</td>\n",
|
| 488 |
+
" <td>2.735000</td>\n",
|
| 489 |
+
" </tr>\n",
|
| 490 |
+
" <tr>\n",
|
| 491 |
+
" <td>4200</td>\n",
|
| 492 |
+
" <td>2.725900</td>\n",
|
| 493 |
+
" </tr>\n",
|
| 494 |
+
" <tr>\n",
|
| 495 |
+
" <td>4250</td>\n",
|
| 496 |
+
" <td>2.644400</td>\n",
|
| 497 |
+
" </tr>\n",
|
| 498 |
+
" <tr>\n",
|
| 499 |
+
" <td>4300</td>\n",
|
| 500 |
+
" <td>2.700000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <td>4350</td>\n",
|
| 504 |
+
" <td>2.650100</td>\n",
|
| 505 |
+
" </tr>\n",
|
| 506 |
+
" <tr>\n",
|
| 507 |
+
" <td>4400</td>\n",
|
| 508 |
+
" <td>2.704500</td>\n",
|
| 509 |
+
" </tr>\n",
|
| 510 |
+
" <tr>\n",
|
| 511 |
+
" <td>4450</td>\n",
|
| 512 |
+
" <td>2.596700</td>\n",
|
| 513 |
+
" </tr>\n",
|
| 514 |
+
" <tr>\n",
|
| 515 |
+
" <td>4500</td>\n",
|
| 516 |
+
" <td>2.510500</td>\n",
|
| 517 |
+
" </tr>\n",
|
| 518 |
+
" <tr>\n",
|
| 519 |
+
" <td>4550</td>\n",
|
| 520 |
+
" <td>2.515800</td>\n",
|
| 521 |
+
" </tr>\n",
|
| 522 |
+
" <tr>\n",
|
| 523 |
+
" <td>4600</td>\n",
|
| 524 |
+
" <td>2.498100</td>\n",
|
| 525 |
+
" </tr>\n",
|
| 526 |
+
" <tr>\n",
|
| 527 |
+
" <td>4650</td>\n",
|
| 528 |
+
" <td>2.458900</td>\n",
|
| 529 |
+
" </tr>\n",
|
| 530 |
+
" <tr>\n",
|
| 531 |
+
" <td>4700</td>\n",
|
| 532 |
+
" <td>2.449700</td>\n",
|
| 533 |
+
" </tr>\n",
|
| 534 |
+
" <tr>\n",
|
| 535 |
+
" <td>4750</td>\n",
|
| 536 |
+
" <td>2.425000</td>\n",
|
| 537 |
+
" </tr>\n",
|
| 538 |
+
" <tr>\n",
|
| 539 |
+
" <td>4800</td>\n",
|
| 540 |
+
" <td>2.362300</td>\n",
|
| 541 |
+
" </tr>\n",
|
| 542 |
+
" <tr>\n",
|
| 543 |
+
" <td>4850</td>\n",
|
| 544 |
+
" <td>2.232000</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" <tr>\n",
|
| 547 |
+
" <td>4900</td>\n",
|
| 548 |
+
" <td>2.361500</td>\n",
|
| 549 |
+
" </tr>\n",
|
| 550 |
+
" <tr>\n",
|
| 551 |
+
" <td>4950</td>\n",
|
| 552 |
+
" <td>2.302300</td>\n",
|
| 553 |
+
" </tr>\n",
|
| 554 |
+
" <tr>\n",
|
| 555 |
+
" <td>5000</td>\n",
|
| 556 |
+
" <td>2.333900</td>\n",
|
| 557 |
+
" </tr>\n",
|
| 558 |
+
" <tr>\n",
|
| 559 |
+
" <td>5050</td>\n",
|
| 560 |
+
" <td>2.367200</td>\n",
|
| 561 |
+
" </tr>\n",
|
| 562 |
+
" <tr>\n",
|
| 563 |
+
" <td>5100</td>\n",
|
| 564 |
+
" <td>2.288300</td>\n",
|
| 565 |
+
" </tr>\n",
|
| 566 |
+
" <tr>\n",
|
| 567 |
+
" <td>5150</td>\n",
|
| 568 |
+
" <td>2.426100</td>\n",
|
| 569 |
+
" </tr>\n",
|
| 570 |
+
" <tr>\n",
|
| 571 |
+
" <td>5200</td>\n",
|
| 572 |
+
" <td>2.344100</td>\n",
|
| 573 |
+
" </tr>\n",
|
| 574 |
+
" <tr>\n",
|
| 575 |
+
" <td>5250</td>\n",
|
| 576 |
+
" <td>2.283500</td>\n",
|
| 577 |
+
" </tr>\n",
|
| 578 |
+
" <tr>\n",
|
| 579 |
+
" <td>5300</td>\n",
|
| 580 |
+
" <td>2.296500</td>\n",
|
| 581 |
+
" </tr>\n",
|
| 582 |
+
" </tbody>\n",
|
| 583 |
+
"</table><p>"
|
| 584 |
+
],
|
| 585 |
+
"text/plain": [
|
| 586 |
+
"<IPython.core.display.HTML object>"
|
| 587 |
+
]
|
| 588 |
+
},
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"output_type": "display_data"
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"name": "stderr",
|
| 594 |
+
"output_type": "stream",
|
| 595 |
+
"text": [
|
| 596 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"data": {
|
| 601 |
+
"text/plain": [
|
| 602 |
+
"CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3/commit/b9472b66316be8654c6f7c173fa4561889bd3446', commit_message='End of training', commit_description='', oid='b9472b66316be8654c6f7c173fa4561889bd3446', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new3'), pr_revision=None, pr_num=None)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
"execution_count": 4,
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"output_type": "execute_result"
|
| 608 |
+
}
|
| 609 |
+
],
|
| 610 |
+
"source": [
|
| 611 |
+
"import json\n",
|
| 612 |
+
"import torch\n",
|
| 613 |
+
"import os\n",
|
| 614 |
+
"from transformers import AutoTokenizer\n",
|
| 615 |
+
"train_data = torch.load(\"train_data.pt\",weights_only=False)\n",
|
| 616 |
+
"print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 617 |
+
"if 'train_data' not in globals():\n",
|
| 618 |
+
" train_data_path = \"train_data.pt\"\n",
|
| 619 |
+
" \n",
|
| 620 |
+
" if os.path.exists(train_data_path): #确保文件存在\n",
|
| 621 |
+
" train_data = torch.load(train_data_path, weights_only=False)\n",
|
| 622 |
+
" print(\"train_data 重新加载成功,数据量:\", len(train_data))\n",
|
| 623 |
+
" else:\n",
|
| 624 |
+
" print(f\"未找到 {train_data_path},请检查路径!\")\n",
|
| 625 |
+
" exit()\n",
|
| 626 |
+
"#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n",
|
| 627 |
+
"if \"MODEL_NAME\" not in globals():\n",
|
| 628 |
+
" MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"from torch.utils.data import Dataset\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"class GraphDataset(Dataset):\n",
|
| 641 |
+
" def __init__(self, data):\n",
|
| 642 |
+
" self.data = data\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" def __len__(self):\n",
|
| 645 |
+
" return len(self.data)\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" def __getitem__(self, idx):\n",
|
| 648 |
+
" sample = self.data[idx]\n",
|
| 649 |
+
" return {\n",
|
| 650 |
+
" \"input_ids\": sample[\"input_ids\"],\n",
|
| 651 |
+
" \"attention_mask\": sample[\"attention_mask\"],\n",
|
| 652 |
+
" \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n",
|
| 653 |
+
" \"labels\": sample[\"labels\"],\n",
|
| 654 |
+
" }\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 657 |
+
"import torch\n",
|
| 658 |
+
"import torch.nn as nn\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 661 |
+
" def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
|
| 662 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 663 |
+
" \n",
|
| 664 |
+
" # ✅ 载入 LLM 预训练模型\n",
|
| 665 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
|
| 668 |
+
" self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
|
| 669 |
+
"\n",
|
| 670 |
+
" # ✅ 2. 多头注意力层\n",
|
| 671 |
+
" self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
|
| 672 |
+
"\n",
|
| 673 |
+
" # ✅ 3. 线性变换\n",
|
| 674 |
+
" self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
|
| 675 |
+
"\n",
|
| 676 |
+
" # ✅ 4. 残差连接 + LayerNorm\n",
|
| 677 |
+
" self.norm = nn.LayerNorm(self.config.hidden_size)\n",
|
| 678 |
+
" \n",
|
| 679 |
+
"\n",
|
| 680 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 681 |
+
" \"\"\"\n",
|
| 682 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 683 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 684 |
+
" \"\"\"\n",
|
| 685 |
+
" # ✅ 获取 token embedding\n",
|
| 686 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 687 |
+
"\n",
|
| 688 |
+
" # ✅ 1. 线性变换 `graph_embedding`\n",
|
| 689 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" # ✅ 2. 多头注意力计算(自注意力机制)\n",
|
| 692 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 693 |
+
" \n",
|
| 694 |
+
" # ✅ 3. 线性层 + 残差连接\n",
|
| 695 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
|
| 696 |
+
"\n",
|
| 697 |
+
" # ✅ 4. 归一化\n",
|
| 698 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 699 |
+
"\n",
|
| 700 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 701 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 702 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" # ✅ 调整 attention mask\n",
|
| 705 |
+
" if attention_mask is not None:\n",
|
| 706 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 707 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 708 |
+
"\n",
|
| 709 |
+
" # ✅ 传入模型\n",
|
| 710 |
+
" outputs = self.model(\n",
|
| 711 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 712 |
+
" attention_mask=attention_mask,\n",
|
| 713 |
+
" labels=labels,\n",
|
| 714 |
+
" )\n",
|
| 715 |
+
"\n",
|
| 716 |
+
" return outputs\n",
|
| 717 |
+
"\n",
|
| 718 |
+
"from transformers import Trainer\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"class GraphTrainer(Trainer):\n",
|
| 721 |
+
" def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n",
|
| 722 |
+
" input_ids = inputs[\"input_ids\"]\n",
|
| 723 |
+
" attention_mask = inputs[\"attention_mask\"]\n",
|
| 724 |
+
" labels = inputs[\"labels\"]\n",
|
| 725 |
+
" graph_embedding = inputs.get(\"graph_embedding\", None) \n",
|
| 726 |
+
"\n",
|
| 727 |
+
" if graph_embedding is not None:\n",
|
| 728 |
+
" outputs = model(\n",
|
| 729 |
+
" input_ids=input_ids,\n",
|
| 730 |
+
" attention_mask=attention_mask,\n",
|
| 731 |
+
" labels=labels,\n",
|
| 732 |
+
" graph_embedding=graph_embedding, \n",
|
| 733 |
+
" )\n",
|
| 734 |
+
" else:\n",
|
| 735 |
+
" outputs = model(\n",
|
| 736 |
+
" input_ids=input_ids,\n",
|
| 737 |
+
" attention_mask=attention_mask,\n",
|
| 738 |
+
" labels=labels,\n",
|
| 739 |
+
" )\n",
|
| 740 |
+
"\n",
|
| 741 |
+
" loss = outputs.loss\n",
|
| 742 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 743 |
+
"\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"from transformers import AutoConfig\n",
|
| 746 |
+
"\n",
|
| 747 |
+
"# ✅ 载入微调模型\n",
|
| 748 |
+
"model = GraphAwareLM.from_pretrained(MODEL_NAME)\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"# ✅ 训练参数\n",
|
| 751 |
+
"training_args = TrainingArguments(\n",
|
| 752 |
+
" output_dir=\"./results3\",\n",
|
| 753 |
+
" per_device_train_batch_size=7,\n",
|
| 754 |
+
" eval_strategy=\"no\",\n",
|
| 755 |
+
" save_strategy=\"steps\",\n",
|
| 756 |
+
" save_steps=3000,\n",
|
| 757 |
+
" logging_steps=50,\n",
|
| 758 |
+
" bf16=True,\n",
|
| 759 |
+
" optim=\"galore_adamw\",\n",
|
| 760 |
+
" optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n",
|
| 761 |
+
" optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n",
|
| 762 |
+
" warmup_steps=1000,\n",
|
| 763 |
+
" num_train_epochs=3,\n",
|
| 764 |
+
" push_to_hub=True,\n",
|
| 765 |
+
" hub_model_id=HF_NAME,\n",
|
| 766 |
+
" hub_strategy=\"every_save\",\n",
|
| 767 |
+
" run_name = \"experi030403\"\n",
|
| 768 |
+
")\n",
|
| 769 |
+
"\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"# ✅ 转换 `train_data` 为 `Dataset`\n",
|
| 772 |
+
"train_dataset = GraphDataset(train_data)\n",
|
| 773 |
+
"\n",
|
| 774 |
+
"# ✅ 训练\n",
|
| 775 |
+
"trainer = GraphTrainer(\n",
|
| 776 |
+
" model=model,\n",
|
| 777 |
+
" args=training_args,\n",
|
| 778 |
+
" train_dataset=train_dataset,\n",
|
| 779 |
+
")\n",
|
| 780 |
+
"\n",
|
| 781 |
+
"trainer.train()\n",
|
| 782 |
+
"trainer.save_model(\"/workspace/model3\")\n",
|
| 783 |
+
"trainer.push_to_hub()\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"\n"
|
| 786 |
+
]
|
| 787 |
+
},
|
| 788 |
+
{
|
| 789 |
+
"cell_type": "code",
|
| 790 |
+
"execution_count": 2,
|
| 791 |
+
"id": "7a72ac3b-561e-41d3-ae93-99f20acf3188",
|
| 792 |
+
"metadata": {},
|
| 793 |
+
"outputs": [
|
| 794 |
+
{
|
| 795 |
+
"data": {
|
| 796 |
+
"text/plain": [
|
| 797 |
+
"RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000')"
|
| 798 |
+
]
|
| 799 |
+
},
|
| 800 |
+
"execution_count": 2,
|
| 801 |
+
"metadata": {},
|
| 802 |
+
"output_type": "execute_result"
|
| 803 |
+
}
|
| 804 |
+
],
|
| 805 |
+
"source": [
|
| 806 |
+
"from huggingface_hub import HfApi\n",
|
| 807 |
+
"\n",
|
| 808 |
+
"api = HfApi()\n",
|
| 809 |
+
"repo_name = \"r1q1.5_graph_lora-results3\" # 你的模型名称\n",
|
| 810 |
+
"api.create_repo(repo_name, exist_ok=True)"
|
| 811 |
+
]
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"cell_type": "code",
|
| 815 |
+
"execution_count": 3,
|
| 816 |
+
"id": "73c434b9-5d58-4819-8526-24aa18ca1010",
|
| 817 |
+
"metadata": {},
|
| 818 |
+
"outputs": [
|
| 819 |
+
{
|
| 820 |
+
"data": {
|
| 821 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 822 |
+
"model_id": "8b896f21685e4086b0b59404b2b1a866",
|
| 823 |
+
"version_major": 2,
|
| 824 |
+
"version_minor": 0
|
| 825 |
+
},
|
| 826 |
+
"text/plain": [
|
| 827 |
+
"model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00<?, ?B/s]"
|
| 828 |
+
]
|
| 829 |
+
},
|
| 830 |
+
"metadata": {},
|
| 831 |
+
"output_type": "display_data"
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"data": {
|
| 835 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 836 |
+
"model_id": "d20bff067ca44c4583378181da817897",
|
| 837 |
+
"version_major": 2,
|
| 838 |
+
"version_minor": 0
|
| 839 |
+
},
|
| 840 |
+
"text/plain": [
|
| 841 |
+
"scheduler.pt: 0%| | 0.00/1.06k [00:00<?, ?B/s]"
|
| 842 |
+
]
|
| 843 |
+
},
|
| 844 |
+
"metadata": {},
|
| 845 |
+
"output_type": "display_data"
|
| 846 |
+
},
|
| 847 |
+
{
|
| 848 |
+
"data": {
|
| 849 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 850 |
+
"model_id": "c4b7114a53b341539a3244f2eea8aacf",
|
| 851 |
+
"version_major": 2,
|
| 852 |
+
"version_minor": 0
|
| 853 |
+
},
|
| 854 |
+
"text/plain": [
|
| 855 |
+
"Upload 6 LFS files: 0%| | 0/6 [00:00<?, ?it/s]"
|
| 856 |
+
]
|
| 857 |
+
},
|
| 858 |
+
"metadata": {},
|
| 859 |
+
"output_type": "display_data"
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"data": {
|
| 863 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 864 |
+
"model_id": "74c6045017b640bdba86fe3ed1bb9c92",
|
| 865 |
+
"version_major": 2,
|
| 866 |
+
"version_minor": 0
|
| 867 |
+
},
|
| 868 |
+
"text/plain": [
|
| 869 |
+
"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
| 870 |
+
]
|
| 871 |
+
},
|
| 872 |
+
"metadata": {},
|
| 873 |
+
"output_type": "display_data"
|
| 874 |
+
},
|
| 875 |
+
{
|
| 876 |
+
"data": {
|
| 877 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 878 |
+
"model_id": "97436b084bc4420f8b273ec462c50e61",
|
| 879 |
+
"version_major": 2,
|
| 880 |
+
"version_minor": 0
|
| 881 |
+
},
|
| 882 |
+
"text/plain": [
|
| 883 |
+
"optimizer.pt: 0%| | 0.00/4.32G [00:00<?, ?B/s]"
|
| 884 |
+
]
|
| 885 |
+
},
|
| 886 |
+
"metadata": {},
|
| 887 |
+
"output_type": "display_data"
|
| 888 |
+
},
|
| 889 |
+
{
|
| 890 |
+
"data": {
|
| 891 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 892 |
+
"model_id": "d7f10ccff3674e6fa8bcb42553c12b19",
|
| 893 |
+
"version_major": 2,
|
| 894 |
+
"version_minor": 0
|
| 895 |
+
},
|
| 896 |
+
"text/plain": [
|
| 897 |
+
"rng_state.pth: 0%| | 0.00/14.2k [00:00<?, ?B/s]"
|
| 898 |
+
]
|
| 899 |
+
},
|
| 900 |
+
"metadata": {},
|
| 901 |
+
"output_type": "display_data"
|
| 902 |
+
},
|
| 903 |
+
{
|
| 904 |
+
"data": {
|
| 905 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 906 |
+
"model_id": "c5b1a010fd0845f9ba9112291afa8f17",
|
| 907 |
+
"version_major": 2,
|
| 908 |
+
"version_minor": 0
|
| 909 |
+
},
|
| 910 |
+
"text/plain": [
|
| 911 |
+
"training_args.bin: 0%| | 0.00/5.37k [00:00<?, ?B/s]"
|
| 912 |
+
]
|
| 913 |
+
},
|
| 914 |
+
"metadata": {},
|
| 915 |
+
"output_type": "display_data"
|
| 916 |
+
},
|
| 917 |
+
{
|
| 918 |
+
"data": {
|
| 919 |
+
"text/plain": [
|
| 920 |
+
"CommitInfo(commit_url='https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000/commit/4088de651a0ce2cc39fcb0c950898e54ce91bdea', commit_message='upload checkpoint-3000', commit_description='', oid='4088de651a0ce2cc39fcb0c950898e54ce91bdea', pr_url=None, repo_url=RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000'), pr_revision=None, pr_num=None)"
|
| 921 |
+
]
|
| 922 |
+
},
|
| 923 |
+
"execution_count": 3,
|
| 924 |
+
"metadata": {},
|
| 925 |
+
"output_type": "execute_result"
|
| 926 |
+
}
|
| 927 |
+
],
|
| 928 |
+
"source": [
|
| 929 |
+
"from huggingface_hub import upload_folder\n",
|
| 930 |
+
"\n",
|
| 931 |
+
"upload_folder(\n",
|
| 932 |
+
" folder_path = \"/workspace/results3\",\n",
|
| 933 |
+
" repo_id = \"YiFzhao/r1q1.5_graph_lora-results3\",\n",
|
| 934 |
+
" commit_message = \"upload results2\",\n",
|
| 935 |
+
")"
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "code",
|
| 940 |
+
"execution_count": 5,
|
| 941 |
+
"id": "8d2ebf87-402e-444d-8599-96c313f1b7fa",
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"outputs": [
|
| 944 |
+
{
|
| 945 |
+
"name": "stdout",
|
| 946 |
+
"output_type": "stream",
|
| 947 |
+
"text": [
|
| 948 |
+
"🚀 处理后数据条数: 12384\n",
|
| 949 |
+
"✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 950 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 951 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 952 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 953 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 954 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 955 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 956 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 957 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 958 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 959 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 960 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 961 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 962 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 963 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 964 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 965 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 966 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 967 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 968 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 969 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 970 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 971 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 972 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 973 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 974 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 975 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 976 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 977 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 978 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 979 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 980 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 981 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 982 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 983 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 984 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 985 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 986 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 987 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 988 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 989 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 990 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 991 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 992 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 993 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 994 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 995 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 996 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 997 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 998 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 999 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1000 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1001 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1002 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1003 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1004 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1005 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1006 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1007 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1008 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1009 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1010 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1011 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1012 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n",
|
| 1013 |
+
"✅ train_data 已保存到 train_data.pt\n"
|
| 1014 |
+
]
|
| 1015 |
+
}
|
| 1016 |
+
],
|
| 1017 |
+
"source": [
|
| 1018 |
+
"import json\n",
|
| 1019 |
+
"import torch\n",
|
| 1020 |
+
"from transformers import AutoTokenizer\n",
|
| 1021 |
+
"\n",
|
| 1022 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1023 |
+
"tokenizer.pad_token = tokenizer.eos_token \n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1026 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1027 |
+
" data = json.load(f)\n",
|
| 1028 |
+
"\n",
|
| 1029 |
+
"train_data = []\n",
|
| 1030 |
+
"\n",
|
| 1031 |
+
"\n",
|
| 1032 |
+
"for sample in data:\n",
|
| 1033 |
+
" conversations = sample.get(\"conversations\", [])\n",
|
| 1034 |
+
" embeddings = sample.get(\"embedding\", []) \n",
|
| 1035 |
+
"\n",
|
| 1036 |
+
" if not isinstance(embeddings, list) or len(embeddings) == 0:\n",
|
| 1037 |
+
" print(f\"无效的 embedding,跳过样本:{sample}\")\n",
|
| 1038 |
+
" continue\n",
|
| 1039 |
+
"\n",
|
| 1040 |
+
" graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" #拼接所有对话\n",
|
| 1043 |
+
" dialogue_text = \"\"\n",
|
| 1044 |
+
" for conv in conversations:\n",
|
| 1045 |
+
" role = conv[\"from\"] # \"human\" 或 \"gpt\"\n",
|
| 1046 |
+
" content = conv[\"value\"]\n",
|
| 1047 |
+
" content = content.replace(\"<image>\", \"\") #去掉 <image>\n",
|
| 1048 |
+
" role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n",
|
| 1049 |
+
" dialogue_text += f\"{role_token} {content}\\n\"\n",
|
| 1050 |
+
"\n",
|
| 1051 |
+
" tokenized = tokenizer(\n",
|
| 1052 |
+
" dialogue_text,\n",
|
| 1053 |
+
" padding=\"max_length\",\n",
|
| 1054 |
+
" truncation=True,\n",
|
| 1055 |
+
" max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n",
|
| 1056 |
+
" return_tensors=\"pt\",\n",
|
| 1057 |
+
" )\n",
|
| 1058 |
+
"\n",
|
| 1059 |
+
" input_ids = tokenized[\"input_ids\"].squeeze(0)\n",
|
| 1060 |
+
" attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n",
|
| 1061 |
+
"\n",
|
| 1062 |
+
" train_data.append({\n",
|
| 1063 |
+
" \"input_ids\": input_ids,\n",
|
| 1064 |
+
" \"attention_mask\": attention_mask,\n",
|
| 1065 |
+
" \"labels\": input_ids.clone(),\n",
|
| 1066 |
+
" \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n",
|
| 1067 |
+
" })\n",
|
| 1068 |
+
"\n",
|
| 1069 |
+
"print(\"🚀 处理后数据条数:\", len(train_data))\n",
|
| 1070 |
+
"print(\"✅ 示例数据:\", train_data[0])\n",
|
| 1071 |
+
"torch.save(train_data, \"train_data.pt\")\n",
|
| 1072 |
+
"print(\"✅ train_data 已保存到 train_data.pt\")\n"
|
| 1073 |
+
]
|
| 1074 |
+
},
|
| 1075 |
+
{
|
| 1076 |
+
"cell_type": "code",
|
| 1077 |
+
"execution_count": 6,
|
| 1078 |
+
"id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d",
|
| 1079 |
+
"metadata": {},
|
| 1080 |
+
"outputs": [],
|
| 1081 |
+
"source": [
|
| 1082 |
+
"from transformers import AutoModelForCausalLM, AutoConfig\n",
|
| 1083 |
+
"import torch\n",
|
| 1084 |
+
"import torch.nn as nn\n",
|
| 1085 |
+
"\n",
|
| 1086 |
+
"class GraphAwareLM(AutoModelForCausalLM):\n",
|
| 1087 |
+
" def __init__(self, pretrained_model_name_or_path, num_heads=8):\n",
|
| 1088 |
+
" super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n",
|
| 1089 |
+
" \n",
|
| 1090 |
+
" # ✅ 载入 LLM 预训练模型\n",
|
| 1091 |
+
" self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
" # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n",
|
| 1094 |
+
" self.linear1 = nn.Linear(512, self.config.hidden_size)\n",
|
| 1095 |
+
"\n",
|
| 1096 |
+
" # ✅ 2. 多头注意力层\n",
|
| 1097 |
+
" self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n",
|
| 1098 |
+
"\n",
|
| 1099 |
+
" # ✅ 3. 线性变换\n",
|
| 1100 |
+
" self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n",
|
| 1101 |
+
"\n",
|
| 1102 |
+
" # ✅ 4. 残差连接 + LayerNorm\n",
|
| 1103 |
+
" self.norm = nn.LayerNorm(self.config.hidden_size)\n",
|
| 1104 |
+
" \n",
|
| 1105 |
+
"\n",
|
| 1106 |
+
" def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n",
|
| 1107 |
+
" \"\"\"\n",
|
| 1108 |
+
" `graph_embedding` 形状: (batch_size, 512)\n",
|
| 1109 |
+
" `input_ids` 形状: (batch_size, seq_len)\n",
|
| 1110 |
+
" \"\"\"\n",
|
| 1111 |
+
" # ✅ 获取 token embedding\n",
|
| 1112 |
+
" inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
" # ✅ 1. 线性变换 `graph_embedding`\n",
|
| 1115 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n",
|
| 1116 |
+
"\n",
|
| 1117 |
+
" # ✅ 2. 多头注意力计算(自注意力机制)\n",
|
| 1118 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 1119 |
+
" \n",
|
| 1120 |
+
" # ✅ 3. 线性层 + 残差连接\n",
|
| 1121 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n",
|
| 1122 |
+
"\n",
|
| 1123 |
+
" # ✅ 4. 归一化\n",
|
| 1124 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 1125 |
+
"\n",
|
| 1126 |
+
" # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n",
|
| 1127 |
+
" graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n",
|
| 1128 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n",
|
| 1129 |
+
"\n",
|
| 1130 |
+
" # ✅ 调整 attention mask\n",
|
| 1131 |
+
" if attention_mask is not None:\n",
|
| 1132 |
+
" graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n",
|
| 1133 |
+
" attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n",
|
| 1134 |
+
"\n",
|
| 1135 |
+
" # ✅ 传入模型\n",
|
| 1136 |
+
" outputs = self.model(\n",
|
| 1137 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1138 |
+
" attention_mask=attention_mask,\n",
|
| 1139 |
+
" labels=labels,\n",
|
| 1140 |
+
" )\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
" return outputs\n",
|
| 1143 |
+
"\n",
|
| 1144 |
+
" def generate(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n",
|
| 1145 |
+
" \"\"\"\n",
|
| 1146 |
+
" ✅ 自定义 `generate()` 方法,支持 `graph_embedding`\n",
|
| 1147 |
+
" `input_text`: 需要生成文本的输入\n",
|
| 1148 |
+
" `graph_embedding`: 形状为 (1, 512) 的张量\n",
|
| 1149 |
+
" \"\"\"\n",
|
| 1150 |
+
"\n",
|
| 1151 |
+
" # ✅ 2. 处理 `graph_embedding`\n",
|
| 1152 |
+
" graph_embedding_token = self.linear1(graph_embedding) # (1, 1, hidden_size)\n",
|
| 1153 |
+
" attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n",
|
| 1154 |
+
" graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (1, 1, hidden_size)\n",
|
| 1155 |
+
" graph_embedding_token = self.norm(graph_embedding_token)\n",
|
| 1156 |
+
"\n",
|
| 1157 |
+
" # ✅ 3. 获取 Token Embeddings 并拼接\n",
|
| 1158 |
+
" inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n",
|
| 1159 |
+
" inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n",
|
| 1160 |
+
"\n",
|
| 1161 |
+
" # ✅ 4. 调整 `attention_mask`\n",
|
| 1162 |
+
" if \"attention_mask\" in inputs:\n",
|
| 1163 |
+
" graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n",
|
| 1164 |
+
" attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n",
|
| 1165 |
+
" else:\n",
|
| 1166 |
+
" attention_mask = None\n",
|
| 1167 |
+
"\n",
|
| 1168 |
+
" # ✅ 5. 进行文本生成\n",
|
| 1169 |
+
" with torch.no_grad():\n",
|
| 1170 |
+
" output_ids = self.model.generate(\n",
|
| 1171 |
+
" inputs_embeds=inputs_embeds,\n",
|
| 1172 |
+
" attention_mask=attention_mask,\n",
|
| 1173 |
+
" max_length=max_length,\n",
|
| 1174 |
+
" temperature=temperature,\n",
|
| 1175 |
+
" top_k=top_k,\n",
|
| 1176 |
+
" top_p=top_p,\n",
|
| 1177 |
+
" num_return_sequences=1\n",
|
| 1178 |
+
" )\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
" # ✅ 6. 解码输出\n",
|
| 1181 |
+
" generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
|
| 1182 |
+
" return generated_text\n",
|
| 1183 |
+
"\n",
|
| 1184 |
+
" @classmethod\n",
|
| 1185 |
+
" def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
|
| 1186 |
+
" # ✅ 1. 调用 `super().from_pretrained()` 加载 LLM\n",
|
| 1187 |
+
" model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
" # ✅ 2. 初始化 `MLP + MultiheadAttention` 结构\n",
|
| 1190 |
+
" model.linear1 = nn.Linear(512, model.config.hidden_size)\n",
|
| 1191 |
+
" model.multihead_attn = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=8, batch_first=True)\n",
|
| 1192 |
+
" model.linear2 = nn.Linear(model.config.hidden_size, model.config.hidden_size)\n",
|
| 1193 |
+
" model.norm = nn.LayerNorm(model.config.hidden_size)\n",
|
| 1194 |
+
"\n",
|
| 1195 |
+
" return model"
|
| 1196 |
+
]
|
| 1197 |
+
},
|
| 1198 |
+
{
|
| 1199 |
+
"cell_type": "code",
|
| 1200 |
+
"execution_count": 2,
|
| 1201 |
+
"id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984",
|
| 1202 |
+
"metadata": {},
|
| 1203 |
+
"outputs": [],
|
| 1204 |
+
"source": [
|
| 1205 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
| 1206 |
+
]
|
| 1207 |
+
},
|
| 1208 |
+
{
|
| 1209 |
+
"cell_type": "code",
|
| 1210 |
+
"execution_count": 7,
|
| 1211 |
+
"id": "21c8df04-0dc2-436c-aaaf-74a885f734d9",
|
| 1212 |
+
"metadata": {},
|
| 1213 |
+
"outputs": [
|
| 1214 |
+
{
|
| 1215 |
+
"data": {
|
| 1216 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1217 |
+
"model_id": "0b50f0cd6d784f598cc64a40cff40f38",
|
| 1218 |
+
"version_major": 2,
|
| 1219 |
+
"version_minor": 0
|
| 1220 |
+
},
|
| 1221 |
+
"text/plain": [
|
| 1222 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1223 |
+
]
|
| 1224 |
+
},
|
| 1225 |
+
"metadata": {},
|
| 1226 |
+
"output_type": "display_data"
|
| 1227 |
+
},
|
| 1228 |
+
{
|
| 1229 |
+
"data": {
|
| 1230 |
+
"text/plain": [
|
| 1231 |
+
"Qwen2ForCausalLM(\n",
|
| 1232 |
+
" (model): Qwen2Model(\n",
|
| 1233 |
+
" (embed_tokens): Embedding(151936, 1536)\n",
|
| 1234 |
+
" (layers): ModuleList(\n",
|
| 1235 |
+
" (0-27): 28 x Qwen2DecoderLayer(\n",
|
| 1236 |
+
" (self_attn): Qwen2Attention(\n",
|
| 1237 |
+
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1238 |
+
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1239 |
+
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
|
| 1240 |
+
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
|
| 1241 |
+
" )\n",
|
| 1242 |
+
" (mlp): Qwen2MLP(\n",
|
| 1243 |
+
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1244 |
+
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
|
| 1245 |
+
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
|
| 1246 |
+
" (act_fn): SiLU()\n",
|
| 1247 |
+
" )\n",
|
| 1248 |
+
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1249 |
+
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1250 |
+
" )\n",
|
| 1251 |
+
" )\n",
|
| 1252 |
+
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
|
| 1253 |
+
" (rotary_emb): Qwen2RotaryEmbedding()\n",
|
| 1254 |
+
" )\n",
|
| 1255 |
+
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
|
| 1256 |
+
" (linear1): Linear(in_features=512, out_features=1536, bias=True)\n",
|
| 1257 |
+
" (multihead_attn): MultiheadAttention(\n",
|
| 1258 |
+
" (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1259 |
+
" )\n",
|
| 1260 |
+
" (linear2): Linear(in_features=1536, out_features=1536, bias=True)\n",
|
| 1261 |
+
" (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
|
| 1262 |
+
")"
|
| 1263 |
+
]
|
| 1264 |
+
},
|
| 1265 |
+
"execution_count": 7,
|
| 1266 |
+
"metadata": {},
|
| 1267 |
+
"output_type": "execute_result"
|
| 1268 |
+
}
|
| 1269 |
+
],
|
| 1270 |
+
"source": [
|
| 1271 |
+
"import torch\n",
|
| 1272 |
+
"from transformers import AutoTokenizer\n",
|
| 1273 |
+
"\n",
|
| 1274 |
+
"# 加载 tokenizer\n",
|
| 1275 |
+
"MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
|
| 1276 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 1277 |
+
"\n",
|
| 1278 |
+
"# 加载训练好的模型\n",
|
| 1279 |
+
"model_path = \"/workspace/model2\"\n",
|
| 1280 |
+
"model = GraphAwareLM.from_pretrained(\"/workspace/results3/checkpoint-3000\").to(device)\n",
|
| 1281 |
+
"model.eval() # 设置为推理模式\n"
|
| 1282 |
+
]
|
| 1283 |
+
},
|
| 1284 |
+
{
|
| 1285 |
+
"cell_type": "code",
|
| 1286 |
+
"execution_count": 13,
|
| 1287 |
+
"id": "51995891-8906-4049-9401-2d22e06a84e8",
|
| 1288 |
+
"metadata": {},
|
| 1289 |
+
"outputs": [
|
| 1290 |
+
{
|
| 1291 |
+
"name": "stdout",
|
| 1292 |
+
"output_type": "stream",
|
| 1293 |
+
"text": [
|
| 1294 |
+
"Parameter containing:\n",
|
| 1295 |
+
"tensor([[-0.0380, -0.0350, -0.0423, ..., 0.0213, 0.0148, -0.0047],\n",
|
| 1296 |
+
" [ 0.0131, 0.0388, -0.0378, ..., 0.0399, -0.0309, -0.0342],\n",
|
| 1297 |
+
" [ 0.0084, -0.0116, 0.0259, ..., 0.0344, 0.0268, -0.0062],\n",
|
| 1298 |
+
" ...,\n",
|
| 1299 |
+
" [ 0.0080, -0.0073, -0.0023, ..., -0.0120, 0.0387, 0.0209],\n",
|
| 1300 |
+
" [ 0.0277, 0.0326, 0.0270, ..., 0.0124, -0.0348, 0.0389],\n",
|
| 1301 |
+
" [ 0.0184, -0.0410, -0.0415, ..., 0.0255, -0.0429, -0.0386]],\n",
|
| 1302 |
+
" device='cuda:0', requires_grad=True)\n"
|
| 1303 |
+
]
|
| 1304 |
+
}
|
| 1305 |
+
],
|
| 1306 |
+
"source": [
|
| 1307 |
+
"print(model.graph_proj.weight)\n"
|
| 1308 |
+
]
|
| 1309 |
+
},
|
| 1310 |
+
{
|
| 1311 |
+
"cell_type": "code",
|
| 1312 |
+
"execution_count": 4,
|
| 1313 |
+
"id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9",
|
| 1314 |
+
"metadata": {},
|
| 1315 |
+
"outputs": [],
|
| 1316 |
+
"source": [
|
| 1317 |
+
"import json\n",
|
| 1318 |
+
"json_path = \"final_Graph.json\"\n",
|
| 1319 |
+
"with open(json_path, \"r\") as f:\n",
|
| 1320 |
+
" data = json.load(f)\n",
|
| 1321 |
+
"\n",
|
| 1322 |
+
"test_data = data[0]\n",
|
| 1323 |
+
"\n",
|
| 1324 |
+
"conversations = test_data.get(\"conversations\")\n",
|
| 1325 |
+
"embeddings = test_data.get(\"embedding\") \n",
|
| 1326 |
+
"\n",
|
| 1327 |
+
"graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0).to(device)\n",
|
| 1328 |
+
"\n",
|
| 1329 |
+
"question1 = conversations[0][\"value\"].replace(\"<image>\", \"\").strip()\n",
|
| 1330 |
+
"\n",
|
| 1331 |
+
"from transformers import AutoTokenizer\n",
|
| 1332 |
+
"\n",
|
| 1333 |
+
"# ✅ 输入文本\n",
|
| 1334 |
+
"ROLE_TOKENS = {\n",
|
| 1335 |
+
" \"human\": \"<|User|>\", \n",
|
| 1336 |
+
" \"gpt\": \"<|Assistant|>\", \n",
|
| 1337 |
+
"}\n",
|
| 1338 |
+
"GRAPH_LENGTH = 512\n",
|
| 1339 |
+
"max_seq_length = 1100 + GRAPH_LENGTH\n",
|
| 1340 |
+
"inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n",
|
| 1341 |
+
"\n",
|
| 1342 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 1343 |
+
"attention_mask = inputs[\"attention_mask\"]\n"
|
| 1344 |
+
]
|
| 1345 |
+
},
|
| 1346 |
+
{
|
| 1347 |
+
"cell_type": "code",
|
| 1348 |
+
"execution_count": 8,
|
| 1349 |
+
"id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc",
|
| 1350 |
+
"metadata": {},
|
| 1351 |
+
"outputs": [
|
| 1352 |
+
{
|
| 1353 |
+
"ename": "AttributeError",
|
| 1354 |
+
"evalue": "'Tensor' object has no attribute 'update'",
|
| 1355 |
+
"output_type": "error",
|
| 1356 |
+
"traceback": [
|
| 1357 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1358 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 1359 |
+
"Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1360 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 1361 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1982\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1979\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1980\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n\u001b[0;32m-> 1982\u001b[0m generation_config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_generation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1983\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_kwargs(model_kwargs\u001b[38;5;241m.\u001b[39mcopy())\n\u001b[1;32m 1984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_assistant(assistant_model, tokenizer, assistant_tokenizer)\n",
|
| 1362 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1549\u001b[0m, in \u001b[0;36mGenerationMixin._prepare_generation_config\u001b[0;34m(self, generation_config, **kwargs)\u001b[0m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 1548\u001b[0m generation_config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(generation_config)\n\u001b[0;32m-> 1549\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1550\u001b[0m \u001b[38;5;66;03m# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model\u001b[39;00m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_model_generation_config:\n",
|
| 1363 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'update'"
|
| 1364 |
+
]
|
| 1365 |
+
}
|
| 1366 |
+
],
|
| 1367 |
+
"source": [
|
| 1368 |
+
"generated_text = model.generate(inputs, graph_embedding)"
|
| 1369 |
+
]
|
| 1370 |
+
},
|
| 1371 |
+
{
|
| 1372 |
+
"cell_type": "code",
|
| 1373 |
+
"execution_count": 5,
|
| 1374 |
+
"id": "62f40327-f102-4259-80a5-8761d5d7d3c6",
|
| 1375 |
+
"metadata": {},
|
| 1376 |
+
"outputs": [
|
| 1377 |
+
{
|
| 1378 |
+
"data": {
|
| 1379 |
+
"text/plain": [
|
| 1380 |
+
"tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n",
|
| 1381 |
+
" -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n",
|
| 1382 |
+
" -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n",
|
| 1383 |
+
" -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n",
|
| 1384 |
+
" -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n",
|
| 1385 |
+
" 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n",
|
| 1386 |
+
" -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n",
|
| 1387 |
+
" -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n",
|
| 1388 |
+
" -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n",
|
| 1389 |
+
" -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n",
|
| 1390 |
+
" 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n",
|
| 1391 |
+
" 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n",
|
| 1392 |
+
" -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n",
|
| 1393 |
+
" 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n",
|
| 1394 |
+
" -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n",
|
| 1395 |
+
" -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n",
|
| 1396 |
+
" -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n",
|
| 1397 |
+
" -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n",
|
| 1398 |
+
" -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n",
|
| 1399 |
+
" 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n",
|
| 1400 |
+
" 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n",
|
| 1401 |
+
" -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n",
|
| 1402 |
+
" 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n",
|
| 1403 |
+
" 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n",
|
| 1404 |
+
" -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n",
|
| 1405 |
+
" -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n",
|
| 1406 |
+
" -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n",
|
| 1407 |
+
" -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n",
|
| 1408 |
+
" -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n",
|
| 1409 |
+
" 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n",
|
| 1410 |
+
" -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n",
|
| 1411 |
+
" -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n",
|
| 1412 |
+
" -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n",
|
| 1413 |
+
" 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n",
|
| 1414 |
+
" -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n",
|
| 1415 |
+
" -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n",
|
| 1416 |
+
" 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n",
|
| 1417 |
+
" -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n",
|
| 1418 |
+
" -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n",
|
| 1419 |
+
" -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n",
|
| 1420 |
+
" -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n",
|
| 1421 |
+
" -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n",
|
| 1422 |
+
" 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n",
|
| 1423 |
+
" 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n",
|
| 1424 |
+
" -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n",
|
| 1425 |
+
" 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n",
|
| 1426 |
+
" -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n",
|
| 1427 |
+
" 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n",
|
| 1428 |
+
" -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n",
|
| 1429 |
+
" -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n",
|
| 1430 |
+
" 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n",
|
| 1431 |
+
" -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n",
|
| 1432 |
+
" 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n",
|
| 1433 |
+
" -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n",
|
| 1434 |
+
" -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n",
|
| 1435 |
+
" 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n",
|
| 1436 |
+
" 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n",
|
| 1437 |
+
" -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n",
|
| 1438 |
+
" 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n",
|
| 1439 |
+
" -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n",
|
| 1440 |
+
" -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n",
|
| 1441 |
+
" -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n",
|
| 1442 |
+
" -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n",
|
| 1443 |
+
" 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n",
|
| 1444 |
+
" device='cuda:0')"
|
| 1445 |
+
]
|
| 1446 |
+
},
|
| 1447 |
+
"execution_count": 5,
|
| 1448 |
+
"metadata": {},
|
| 1449 |
+
"output_type": "execute_result"
|
| 1450 |
+
}
|
| 1451 |
+
],
|
| 1452 |
+
"source": [
|
| 1453 |
+
"graph_embedding"
|
| 1454 |
+
]
|
| 1455 |
+
},
|
| 1456 |
+
{
|
| 1457 |
+
"cell_type": "code",
|
| 1458 |
+
"execution_count": 15,
|
| 1459 |
+
"id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b",
|
| 1460 |
+
"metadata": {},
|
| 1461 |
+
"outputs": [
|
| 1462 |
+
{
|
| 1463 |
+
"name": "stdout",
|
| 1464 |
+
"output_type": "stream",
|
| 1465 |
+
"text": [
|
| 1466 |
+
"模型回复: How\n"
|
| 1467 |
+
]
|
| 1468 |
+
}
|
| 1469 |
+
],
|
| 1470 |
+
"source": [
|
| 1471 |
+
"# ✅ 进行前向传播\n",
|
| 1472 |
+
"with torch.no_grad():\n",
|
| 1473 |
+
" outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n",
|
| 1474 |
+
"\n",
|
| 1475 |
+
"# ✅ 提取 logits 并进行贪心解码\n",
|
| 1476 |
+
"logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1477 |
+
"predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n",
|
| 1478 |
+
"\n",
|
| 1479 |
+
"# ✅ 反向编码为文本\n",
|
| 1480 |
+
"response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n",
|
| 1481 |
+
"\n",
|
| 1482 |
+
"print(\"模型回复:\", response_text)"
|
| 1483 |
+
]
|
| 1484 |
+
},
|
| 1485 |
+
{
|
| 1486 |
+
"cell_type": "code",
|
| 1487 |
+
"execution_count": 9,
|
| 1488 |
+
"id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef",
|
| 1489 |
+
"metadata": {},
|
| 1490 |
+
"outputs": [
|
| 1491 |
+
{
|
| 1492 |
+
"name": "stdout",
|
| 1493 |
+
"output_type": "stream",
|
| 1494 |
+
"text": [
|
| 1495 |
+
"Generated Response: What are the signal definitions in the Verilog code for the calculator module, and what are their purposes? The Verilog code defines the inputs A, B, and C, and the output Y. A and B are the operands, C is the carry-in, and Y is the result. The purpose of the module is to perform a 2-bit adder, which adds two 2-bit numbers, and the output is the sum. The inputs A and B are the operands, C is the carry-in, and Y is the result. The module is designed to handle the addition operation of two 2-bit numbers, with a carry-in, and a 3-bit output. The implementation involves using logic gates to perform the addition operation, with the sum output connected to the gates. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, involving basic gates and an adder circuit. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with\n"
|
| 1496 |
+
]
|
| 1497 |
+
}
|
| 1498 |
+
],
|
| 1499 |
+
"source": [
|
| 1500 |
+
"max_new_tokens = 500\n",
|
| 1501 |
+
"generated_ids = input_ids.clone()\n",
|
| 1502 |
+
"generated_attention_mask = attention_mask.clone()\n",
|
| 1503 |
+
"for _ in range(max_new_tokens):\n",
|
| 1504 |
+
" # ✅ 计算 logits 并进行生成\n",
|
| 1505 |
+
" with torch.no_grad():\n",
|
| 1506 |
+
" outputs = model(\n",
|
| 1507 |
+
" input_ids=generated_ids, # (batch_size, seq_len)\n",
|
| 1508 |
+
" attention_mask=generated_attention_mask, # (batch_size, seq_len)\n",
|
| 1509 |
+
" graph_embedding=graph_embedding, # (batch_size, 512)\n",
|
| 1510 |
+
" )\n",
|
| 1511 |
+
"\n",
|
| 1512 |
+
"\n",
|
| 1513 |
+
" logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n",
|
| 1514 |
+
" next_token = torch.argmax(logits, dim=-1) # 贪心解码\n",
|
| 1515 |
+
" # print(next_token)\n",
|
| 1516 |
+
"\n",
|
| 1517 |
+
"\n",
|
| 1518 |
+
" # ✅ **拼接到已生成序列**\n",
|
| 1519 |
+
" generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n",
|
| 1520 |
+
"\n",
|
| 1521 |
+
" # print(generated_ids)\n",
|
| 1522 |
+
"\n",
|
| 1523 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 1524 |
+
" break\n",
|
| 1525 |
+
"\n",
|
| 1526 |
+
" generated_attention_mask = torch.cat(\n",
|
| 1527 |
+
" [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n",
|
| 1528 |
+
" ) \n",
|
| 1529 |
+
"\n",
|
| 1530 |
+
"# ✅ 解码最终输出\n",
|
| 1531 |
+
"generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
|
| 1532 |
+
"print(\"Generated Response:\", generated_text)"
|
| 1533 |
+
]
|
| 1534 |
+
},
|
| 1535 |
+
{
|
| 1536 |
+
"cell_type": "code",
|
| 1537 |
+
"execution_count": 10,
|
| 1538 |
+
"id": "803f41fe-f504-4c2a-96b4-afc2cd437d01",
|
| 1539 |
+
"metadata": {},
|
| 1540 |
+
"outputs": [
|
| 1541 |
+
{
|
| 1542 |
+
"data": {
|
| 1543 |
+
"text/plain": [
|
| 1544 |
+
"tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n",
|
| 1545 |
+
" 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n",
|
| 1546 |
+
" 525, 862, 9895, 30]], device='cuda:0')"
|
| 1547 |
+
]
|
| 1548 |
+
},
|
| 1549 |
+
"execution_count": 10,
|
| 1550 |
+
"metadata": {},
|
| 1551 |
+
"output_type": "execute_result"
|
| 1552 |
+
}
|
| 1553 |
+
],
|
| 1554 |
+
"source": [
|
| 1555 |
+
"generated_ids"
|
| 1556 |
+
]
|
| 1557 |
+
},
|
| 1558 |
+
{
|
| 1559 |
+
"cell_type": "code",
|
| 1560 |
+
"execution_count": null,
|
| 1561 |
+
"id": "87d1396b-4d20-4a76-a092-b26a587a76ac",
|
| 1562 |
+
"metadata": {},
|
| 1563 |
+
"outputs": [],
|
| 1564 |
+
"source": []
|
| 1565 |
+
}
|
| 1566 |
+
],
|
| 1567 |
+
"metadata": {
|
| 1568 |
+
"kernelspec": {
|
| 1569 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1570 |
+
"language": "python",
|
| 1571 |
+
"name": "python3"
|
| 1572 |
+
},
|
| 1573 |
+
"language_info": {
|
| 1574 |
+
"codemirror_mode": {
|
| 1575 |
+
"name": "ipython",
|
| 1576 |
+
"version": 3
|
| 1577 |
+
},
|
| 1578 |
+
"file_extension": ".py",
|
| 1579 |
+
"mimetype": "text/x-python",
|
| 1580 |
+
"name": "python",
|
| 1581 |
+
"nbconvert_exporter": "python",
|
| 1582 |
+
"pygments_lexer": "ipython3",
|
| 1583 |
+
"version": "3.10.12"
|
| 1584 |
+
}
|
| 1585 |
+
},
|
| 1586 |
+
"nbformat": 4,
|
| 1587 |
+
"nbformat_minor": 5
|
| 1588 |
+
}
|
train_data.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c09c4c4be57cf268061c0f20f6e6d877359dd683fbff17f4d3a6b7cffee3dae
|
| 3 |
+
size 364711686
|