Upload fine-tuning-gpt-oss-derma.ipynb
Browse files- fine-tuning-gpt-oss-derma.ipynb +1277 -0
fine-tuning-gpt-oss-derma.ipynb
ADDED
@@ -0,0 +1,1277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Setting Up"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"%%capture\n",
|
17 |
+
"%pip install -U accelerate \n",
|
18 |
+
"%pip install -U peft \n",
|
19 |
+
"%pip install -U trl \n",
|
20 |
+
"%pip install -U bitsandbytes\n",
|
21 |
+
"%pip install -U transformers\n",
|
22 |
+
"%pip install -U tensorboard\n",
|
23 |
+
"%pip install -U openai-harmony\n",
|
24 |
+
"%pip install -U tiktoken\n",
|
25 |
+
"%pip install -U pyctcdecode"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 2,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"name": "stderr",
|
35 |
+
"output_type": "stream",
|
36 |
+
"text": [
|
37 |
+
"Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n"
|
38 |
+
]
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"from huggingface_hub import login\n",
|
43 |
+
"import os\n",
|
44 |
+
"HF_TOKEN= os.getenv(\"HF_TOKEN\")\n",
|
45 |
+
"\n",
|
46 |
+
"login(HF_TOKEN)"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "markdown",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"## Configs"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 3,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"BASE_MODEL_ID = \"openai/gpt-oss-20b\"\n",
|
63 |
+
"SAVED_MODEL_ID = \"gpt-oss-20b-dermatology-qa\"\n",
|
64 |
+
"DATASET_NAME = \"kingabzpro/dermatology-qa-firecrawl-dataset\""
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "markdown",
|
69 |
+
"metadata": {},
|
70 |
+
"source": [
|
71 |
+
"## Loading the model and tokenizer"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 4,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"application/vnd.jupyter.widget-view+json": {
|
82 |
+
"model_id": "1a21fc931ff7479896898c2bfaefaaa3",
|
83 |
+
"version_major": 2,
|
84 |
+
"version_minor": 0
|
85 |
+
},
|
86 |
+
"text/plain": [
|
87 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"metadata": {},
|
91 |
+
"output_type": "display_data"
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"source": [
|
95 |
+
"import torch\n",
|
96 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config\n",
|
97 |
+
"\n",
|
98 |
+
"quantization_config = Mxfp4Config(dequantize=True)\n",
|
99 |
+
"model_kwargs = dict(\n",
|
100 |
+
" attn_implementation=\"eager\",\n",
|
101 |
+
" torch_dtype=torch.bfloat16,\n",
|
102 |
+
" quantization_config=quantization_config,\n",
|
103 |
+
" use_cache=False,\n",
|
104 |
+
" device_map=\"auto\",\n",
|
105 |
+
")\n",
|
106 |
+
"\n",
|
107 |
+
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **model_kwargs)\n",
|
108 |
+
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "markdown",
|
113 |
+
"metadata": {},
|
114 |
+
"source": [
|
115 |
+
"## Loading and processing the dataset"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 5,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"from openai_harmony import (\n",
|
125 |
+
" Conversation,\n",
|
126 |
+
" DeveloperContent,\n",
|
127 |
+
" HarmonyEncodingName,\n",
|
128 |
+
" Message,\n",
|
129 |
+
" Role,\n",
|
130 |
+
" load_harmony_encoding,\n",
|
131 |
+
")\n",
|
132 |
+
"\n",
|
133 |
+
"# Load the Harmony encoder once\n",
|
134 |
+
"enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 6,
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"DERM_DEV_INSTRUCTIONS = (\n",
|
144 |
+
" \"You are a board-certified dermatologist answering various dermatology questions.\"\n",
|
145 |
+
" \" Answer clearly in 1–3 sentences. No speculation.\"\n",
|
146 |
+
")\n",
|
147 |
+
"\n",
|
148 |
+
"\n",
|
149 |
+
"def render_pair_harmony(question: str, answer: str) -> str:\n",
|
150 |
+
" \"\"\"Harmony-formatted prompt for training.\"\"\"\n",
|
151 |
+
" convo = Conversation.from_messages(\n",
|
152 |
+
" [\n",
|
153 |
+
" Message.from_role_and_content(\n",
|
154 |
+
" Role.DEVELOPER,\n",
|
155 |
+
" DeveloperContent.new().with_instructions(DERM_DEV_INSTRUCTIONS),\n",
|
156 |
+
" ),\n",
|
157 |
+
" Message.from_role_and_content(Role.USER, question.strip()),\n",
|
158 |
+
" Message.from_role_and_content(Role.ASSISTANT, answer.strip()),\n",
|
159 |
+
" ]\n",
|
160 |
+
" )\n",
|
161 |
+
" tokens = enc.render_conversation(convo)\n",
|
162 |
+
" return enc.decode(tokens)\n"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 7,
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [
|
170 |
+
{
|
171 |
+
"name": "stderr",
|
172 |
+
"output_type": "stream",
|
173 |
+
"text": [
|
174 |
+
"Parameter 'function'=<function to_harmony_batch at 0x7aa95e0afd80> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"data": {
|
179 |
+
"application/vnd.jupyter.widget-view+json": {
|
180 |
+
"model_id": "afb999e8af5849d4955eaf0a26fbe17f",
|
181 |
+
"version_major": 2,
|
182 |
+
"version_minor": 0
|
183 |
+
},
|
184 |
+
"text/plain": [
|
185 |
+
"Map: 0%| | 0/1001 [00:00<?, ? examples/s]"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
"metadata": {},
|
189 |
+
"output_type": "display_data"
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"name": "stdout",
|
193 |
+
"output_type": "stream",
|
194 |
+
"text": [
|
195 |
+
"DatasetDict({\n",
|
196 |
+
" train: Dataset({\n",
|
197 |
+
" features: ['question', 'answer', 'condition', 'difficulty', 'source_url', 'text'],\n",
|
198 |
+
" num_rows: 900\n",
|
199 |
+
" })\n",
|
200 |
+
" test: Dataset({\n",
|
201 |
+
" features: ['question', 'answer', 'condition', 'difficulty', 'source_url', 'text'],\n",
|
202 |
+
" num_rows: 101\n",
|
203 |
+
" })\n",
|
204 |
+
"})\n",
|
205 |
+
"<|start|>developer<|message|># Instructions\n",
|
206 |
+
"\n",
|
207 |
+
"You are a board-certified dermatologist answering various dermatology questions. Answer clearly in 1–3 sentences. No speculation.<|end|><|start|>user<|message|>What type of skin changes accompany the pustules in GPP?<|end|><|start|>assistant<|message|>The skin surrounding the pustules becomes erythematous, which means it appears red and inflamed. The affected skin is also painful. These changes occur during the recurrent flares of the disease.<|end|>\n"
|
208 |
+
]
|
209 |
+
}
|
210 |
+
],
|
211 |
+
"source": [
|
212 |
+
"from datasets import load_dataset\n",
|
213 |
+
"\n",
|
214 |
+
"# Load dataset\n",
|
215 |
+
"dataset = load_dataset(\"kingabzpro/dermatology-qa-firecrawl-dataset\", split=\"train\")\n",
|
216 |
+
"\n",
|
217 |
+
"\n",
|
218 |
+
"def to_harmony_batch(examples: dict) -> dict:\n",
|
219 |
+
" \"\"\"Convert batch of dermatology Q&A pairs to harmony format.\"\"\"\n",
|
220 |
+
" questions = examples[\"question\"]\n",
|
221 |
+
" answers = examples[\"answer\"]\n",
|
222 |
+
"\n",
|
223 |
+
" formatted_texts = []\n",
|
224 |
+
" for question, answer in zip(questions, answers):\n",
|
225 |
+
" formatted_text = render_pair_harmony(question.strip(), answer.strip())\n",
|
226 |
+
" formatted_texts.append(formatted_text)\n",
|
227 |
+
"\n",
|
228 |
+
" return {\"text\": formatted_texts}\n",
|
229 |
+
"\n",
|
230 |
+
"\n",
|
231 |
+
"# Process dataset\n",
|
232 |
+
"dataset = dataset.map(to_harmony_batch, batched=True)\n",
|
233 |
+
"dataset = dataset.train_test_split(test_size=0.1, seed=42)\n",
|
234 |
+
"\n",
|
235 |
+
"print(dataset)\n",
|
236 |
+
"\n",
|
237 |
+
"print(dataset[\"train\"][0][\"text\"])\n"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "markdown",
|
242 |
+
"metadata": {},
|
243 |
+
"source": [
|
244 |
+
"## Model inference before fine-tuning"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": 8,
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"def render_inference_harmony(question: str) -> str:\n",
|
254 |
+
" \"\"\"Harmony-formatted prompt for inference.\"\"\"\n",
|
255 |
+
" convo = Conversation.from_messages(\n",
|
256 |
+
" [\n",
|
257 |
+
" Message.from_role_and_content(\n",
|
258 |
+
" Role.DEVELOPER,\n",
|
259 |
+
" DeveloperContent.new().with_instructions(DERM_DEV_INSTRUCTIONS),\n",
|
260 |
+
" ),\n",
|
261 |
+
" Message.from_role_and_content(Role.USER, question.strip()),\n",
|
262 |
+
" ]\n",
|
263 |
+
" )\n",
|
264 |
+
" tokens = enc.render_conversation_for_completion(convo, Role.ASSISTANT)\n",
|
265 |
+
" return enc.decode(tokens)\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": 9,
|
271 |
+
"metadata": {},
|
272 |
+
"outputs": [
|
273 |
+
{
|
274 |
+
"name": "stdout",
|
275 |
+
"output_type": "stream",
|
276 |
+
"text": [
|
277 |
+
"<|start|>developer<|message|># Instructions\n",
|
278 |
+
"\n",
|
279 |
+
"You are a board-certified dermatologist answering various dermatology questions. Answer clearly in 1–3 sentences. No speculation.<|end|><|start|>user<|message|>Why might winter be a problematic season for some people with eczema?<|end|><|start|>assistant<|channel|>analysis<|message|>They ask: \"Why might winter be a problematic season for some people with eczema?\" A dermatologist must answer succinctly. We'll provide reasons: cold, dry air, indoor heating increases dryness, reduces skin barrier, triggers flare-ups. Also less humidity helps dryness, exposure to indoor allergens, etc. Provide 1-3 sentences. Must be no speculation. Provide factual explanation. Must answer clearly.<|end|><|start|>assistant<|channel|>final<|message|>Winter can trigger eczema flare‑ups because cold, dry air and indoor heating strip the skin of moisture, compromising the skin barrier and making it more prone to irritation and infection. Lower humidity also increases scratching and can worsen inflammation, so maintaining skin hydration is particularly important during the colder months.<|return|>\n"
|
280 |
+
]
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"source": [
|
284 |
+
"question = dataset[\"test\"][20][\"question\"]\n",
|
285 |
+
"\n",
|
286 |
+
"text = render_inference_harmony(question)\n",
|
287 |
+
"\n",
|
288 |
+
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
|
289 |
+
"outputs = model.generate(\n",
|
290 |
+
" input_ids=inputs.input_ids,\n",
|
291 |
+
" attention_mask=inputs.attention_mask,\n",
|
292 |
+
" max_new_tokens=200,\n",
|
293 |
+
" eos_token_id=tokenizer.eos_token_id,\n",
|
294 |
+
" use_cache=True,\n",
|
295 |
+
")\n",
|
296 |
+
"response = tokenizer.batch_decode(outputs)\n",
|
297 |
+
"print(response[0])\n"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"execution_count": 10,
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [
|
305 |
+
{
|
306 |
+
"name": "stdout",
|
307 |
+
"output_type": "stream",
|
308 |
+
"text": [
|
309 |
+
"Winter can trigger eczema flare‑ups because cold, dry air and indoor heating strip the skin of moisture, compromising the skin barrier and making it more prone to irritation and infection. Lower humidity also increases scratching and can worsen inflammation, so maintaining skin hydration is particularly important during the colder months.\n"
|
310 |
+
]
|
311 |
+
}
|
312 |
+
],
|
313 |
+
"source": [
|
314 |
+
"start_idx = response[0].find(\"<|start|>assistant<|channel|>final<|message|>\") + \\\n",
|
315 |
+
" len(\"<|start|>assistant<|channel|>final<|message|>\")\n",
|
316 |
+
"end_idx = response[0].rfind(\"<|return|>\") if \"<|return|>\" in response[0] else len(response[0])\n",
|
317 |
+
"final_answer = response[0][start_idx:end_idx].strip()\n",
|
318 |
+
"print(final_answer)\n"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": 11,
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [
|
326 |
+
{
|
327 |
+
"data": {
|
328 |
+
"text/plain": [
|
329 |
+
"'During winter, indoor air tends to be dry, which can trigger eczema flare‑ups for some individuals. The dryness of indoor environments in winter is a known trigger for these patients.'"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"execution_count": 11,
|
333 |
+
"metadata": {},
|
334 |
+
"output_type": "execute_result"
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"source": [
|
338 |
+
"dataset[\"test\"][20][\"answer\"]"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"cell_type": "markdown",
|
343 |
+
"metadata": {},
|
344 |
+
"source": [
|
345 |
+
"## Setting up the model"
|
346 |
+
]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"cell_type": "code",
|
350 |
+
"execution_count": 12,
|
351 |
+
"metadata": {},
|
352 |
+
"outputs": [
|
353 |
+
{
|
354 |
+
"name": "stdout",
|
355 |
+
"output_type": "stream",
|
356 |
+
"text": [
|
357 |
+
"trainable params: 15,040,512 || all params: 20,929,797,696 || trainable%: 0.0719\n"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"name": "stderr",
|
362 |
+
"output_type": "stream",
|
363 |
+
"text": [
|
364 |
+
"/usr/local/lib/python3.11/dist-packages/peft/tuners/lora/layer.py:159: UserWarning: Unsupported layer type '<class 'transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts'>' encountered, proceed at your own risk.\n",
|
365 |
+
" warnings.warn(\n"
|
366 |
+
]
|
367 |
+
}
|
368 |
+
],
|
369 |
+
"source": [
|
370 |
+
"from peft import LoraConfig, get_peft_model\n",
|
371 |
+
"\n",
|
372 |
+
"peft_config = LoraConfig(\n",
|
373 |
+
" r=8,\n",
|
374 |
+
" lora_alpha=16,\n",
|
375 |
+
" target_modules=\"all-linear\",\n",
|
376 |
+
" target_parameters=[\n",
|
377 |
+
" \"7.mlp.experts.gate_up_proj\",\n",
|
378 |
+
" \"7.mlp.experts.down_proj\",\n",
|
379 |
+
" \"15.mlp.experts.gate_up_proj\",\n",
|
380 |
+
" \"15.mlp.experts.down_proj\",\n",
|
381 |
+
" \"23.mlp.experts.gate_up_proj\",\n",
|
382 |
+
" \"23.mlp.experts.down_proj\",\n",
|
383 |
+
" ],\n",
|
384 |
+
")\n",
|
385 |
+
"peft_model = get_peft_model(model, peft_config)\n",
|
386 |
+
"peft_model.print_trainable_parameters()"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": 13,
|
392 |
+
"metadata": {},
|
393 |
+
"outputs": [],
|
394 |
+
"source": [
|
395 |
+
"from trl import SFTConfig\n",
|
396 |
+
"\n",
|
397 |
+
"training_args = SFTConfig(\n",
|
398 |
+
" learning_rate=2e-4,\n",
|
399 |
+
" gradient_checkpointing=True,\n",
|
400 |
+
" num_train_epochs=1,\n",
|
401 |
+
" logging_steps=10,\n",
|
402 |
+
" bf16=True,\n",
|
403 |
+
" per_device_train_batch_size=8,\n",
|
404 |
+
" per_device_eval_batch_size=8,\n",
|
405 |
+
" gradient_accumulation_steps=2,\n",
|
406 |
+
" max_length=2048,\n",
|
407 |
+
" warmup_ratio=0.03,\n",
|
408 |
+
" eval_strategy=\"steps\",\n",
|
409 |
+
" eval_steps=10,\n",
|
410 |
+
" lr_scheduler_type=\"cosine_with_min_lr\",\n",
|
411 |
+
" lr_scheduler_kwargs={\"min_lr_rate\": 0.1},\n",
|
412 |
+
" output_dir=SAVED_MODEL_ID,\n",
|
413 |
+
" report_to=\"tensorboard\",\n",
|
414 |
+
" push_to_hub=True,\n",
|
415 |
+
")"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"cell_type": "markdown",
|
420 |
+
"metadata": {},
|
421 |
+
"source": [
|
422 |
+
"## Model Training"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": 14,
|
428 |
+
"metadata": {},
|
429 |
+
"outputs": [
|
430 |
+
{
|
431 |
+
"data": {
|
432 |
+
"application/vnd.jupyter.widget-view+json": {
|
433 |
+
"model_id": "78071bc85662408ebcc727e8de0f9ec7",
|
434 |
+
"version_major": 2,
|
435 |
+
"version_minor": 0
|
436 |
+
},
|
437 |
+
"text/plain": [
|
438 |
+
"Adding EOS to train dataset: 0%| | 0/900 [00:00<?, ? examples/s]"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
"metadata": {},
|
442 |
+
"output_type": "display_data"
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"data": {
|
446 |
+
"application/vnd.jupyter.widget-view+json": {
|
447 |
+
"model_id": "91c71a17c98b4465b5573711b65aaca3",
|
448 |
+
"version_major": 2,
|
449 |
+
"version_minor": 0
|
450 |
+
},
|
451 |
+
"text/plain": [
|
452 |
+
"Tokenizing train dataset: 0%| | 0/900 [00:00<?, ? examples/s]"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
"metadata": {},
|
456 |
+
"output_type": "display_data"
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"data": {
|
460 |
+
"application/vnd.jupyter.widget-view+json": {
|
461 |
+
"model_id": "c1de460d7e7b4fbc96c446cfa20ec591",
|
462 |
+
"version_major": 2,
|
463 |
+
"version_minor": 0
|
464 |
+
},
|
465 |
+
"text/plain": [
|
466 |
+
"Truncating train dataset: 0%| | 0/900 [00:00<?, ? examples/s]"
|
467 |
+
]
|
468 |
+
},
|
469 |
+
"metadata": {},
|
470 |
+
"output_type": "display_data"
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"data": {
|
474 |
+
"application/vnd.jupyter.widget-view+json": {
|
475 |
+
"model_id": "e5f57bfd84064385b63a4df24da2c93d",
|
476 |
+
"version_major": 2,
|
477 |
+
"version_minor": 0
|
478 |
+
},
|
479 |
+
"text/plain": [
|
480 |
+
"Adding EOS to eval dataset: 0%| | 0/101 [00:00<?, ? examples/s]"
|
481 |
+
]
|
482 |
+
},
|
483 |
+
"metadata": {},
|
484 |
+
"output_type": "display_data"
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"data": {
|
488 |
+
"application/vnd.jupyter.widget-view+json": {
|
489 |
+
"model_id": "cbea58e54f7949c39deea56ceb637266",
|
490 |
+
"version_major": 2,
|
491 |
+
"version_minor": 0
|
492 |
+
},
|
493 |
+
"text/plain": [
|
494 |
+
"Tokenizing eval dataset: 0%| | 0/101 [00:00<?, ? examples/s]"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
"metadata": {},
|
498 |
+
"output_type": "display_data"
|
499 |
+
},
|
500 |
+
{
|
501 |
+
"data": {
|
502 |
+
"application/vnd.jupyter.widget-view+json": {
|
503 |
+
"model_id": "3d3f75508723415f870bdbac32c88ea3",
|
504 |
+
"version_major": 2,
|
505 |
+
"version_minor": 0
|
506 |
+
},
|
507 |
+
"text/plain": [
|
508 |
+
"Truncating eval dataset: 0%| | 0/101 [00:00<?, ? examples/s]"
|
509 |
+
]
|
510 |
+
},
|
511 |
+
"metadata": {},
|
512 |
+
"output_type": "display_data"
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"data": {
|
516 |
+
"text/html": [
|
517 |
+
"\n",
|
518 |
+
" <div>\n",
|
519 |
+
" \n",
|
520 |
+
" <progress value='57' max='57' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
521 |
+
" [57/57 06:35, Epoch 1/1]\n",
|
522 |
+
" </div>\n",
|
523 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
524 |
+
" <thead>\n",
|
525 |
+
" <tr style=\"text-align: left;\">\n",
|
526 |
+
" <th>Step</th>\n",
|
527 |
+
" <th>Training Loss</th>\n",
|
528 |
+
" <th>Validation Loss</th>\n",
|
529 |
+
" </tr>\n",
|
530 |
+
" </thead>\n",
|
531 |
+
" <tbody>\n",
|
532 |
+
" <tr>\n",
|
533 |
+
" <td>10</td>\n",
|
534 |
+
" <td>4.970200</td>\n",
|
535 |
+
" <td>2.089929</td>\n",
|
536 |
+
" </tr>\n",
|
537 |
+
" <tr>\n",
|
538 |
+
" <td>20</td>\n",
|
539 |
+
" <td>1.454900</td>\n",
|
540 |
+
" <td>0.981939</td>\n",
|
541 |
+
" </tr>\n",
|
542 |
+
" <tr>\n",
|
543 |
+
" <td>30</td>\n",
|
544 |
+
" <td>0.871900</td>\n",
|
545 |
+
" <td>0.867009</td>\n",
|
546 |
+
" </tr>\n",
|
547 |
+
" <tr>\n",
|
548 |
+
" <td>40</td>\n",
|
549 |
+
" <td>0.830900</td>\n",
|
550 |
+
" <td>0.836862</td>\n",
|
551 |
+
" </tr>\n",
|
552 |
+
" <tr>\n",
|
553 |
+
" <td>50</td>\n",
|
554 |
+
" <td>0.845000</td>\n",
|
555 |
+
" <td>0.823363</td>\n",
|
556 |
+
" </tr>\n",
|
557 |
+
" </tbody>\n",
|
558 |
+
"</table><p>"
|
559 |
+
],
|
560 |
+
"text/plain": [
|
561 |
+
"<IPython.core.display.HTML object>"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
"metadata": {},
|
565 |
+
"output_type": "display_data"
|
566 |
+
},
|
567 |
+
{
|
568 |
+
"data": {
|
569 |
+
"text/plain": [
|
570 |
+
"TrainOutput(global_step=57, training_loss=1.6722588371812253, metrics={'train_runtime': 402.1596, 'train_samples_per_second': 2.238, 'train_steps_per_second': 0.142, 'total_flos': 1.1781569356397568e+16, 'train_loss': 1.6722588371812253})"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
"execution_count": 14,
|
574 |
+
"metadata": {},
|
575 |
+
"output_type": "execute_result"
|
576 |
+
}
|
577 |
+
],
|
578 |
+
"source": [
|
579 |
+
"from trl import SFTTrainer\n",
|
580 |
+
"\n",
|
581 |
+
"trainer = SFTTrainer(\n",
|
582 |
+
" model=peft_model,\n",
|
583 |
+
" args=training_args,\n",
|
584 |
+
" train_dataset=dataset[\"train\"],\n",
|
585 |
+
" eval_dataset=dataset[\"test\"],\n",
|
586 |
+
" processing_class=tokenizer,\n",
|
587 |
+
")\n",
|
588 |
+
"trainer.train()"
|
589 |
+
]
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"cell_type": "code",
|
593 |
+
"execution_count": 15,
|
594 |
+
"metadata": {},
|
595 |
+
"outputs": [
|
596 |
+
{
|
597 |
+
"data": {
|
598 |
+
"application/vnd.jupyter.widget-view+json": {
|
599 |
+
"model_id": "217e9aaff6574d159a1a988ff583ff3c",
|
600 |
+
"version_major": 2,
|
601 |
+
"version_minor": 0
|
602 |
+
},
|
603 |
+
"text/plain": [
|
604 |
+
"Processing Files (0 / 0) : | | 0.00B / 0.00B "
|
605 |
+
]
|
606 |
+
},
|
607 |
+
"metadata": {},
|
608 |
+
"output_type": "display_data"
|
609 |
+
},
|
610 |
+
{
|
611 |
+
"data": {
|
612 |
+
"application/vnd.jupyter.widget-view+json": {
|
613 |
+
"model_id": "9a751f3aff01463499df49d9975252a8",
|
614 |
+
"version_major": 2,
|
615 |
+
"version_minor": 0
|
616 |
+
},
|
617 |
+
"text/plain": [
|
618 |
+
"New Data Upload : | | 0.00B / 0.00B "
|
619 |
+
]
|
620 |
+
},
|
621 |
+
"metadata": {},
|
622 |
+
"output_type": "display_data"
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"data": {
|
626 |
+
"application/vnd.jupyter.widget-view+json": {
|
627 |
+
"model_id": "a37515d0d5444c4194d40b60517f89ac",
|
628 |
+
"version_major": 2,
|
629 |
+
"version_minor": 0
|
630 |
+
},
|
631 |
+
"text/plain": [
|
632 |
+
" ...events.1756035182.51336aae7b3d.93.0: 100%|##########| 6.77kB / 6.77kB "
|
633 |
+
]
|
634 |
+
},
|
635 |
+
"metadata": {},
|
636 |
+
"output_type": "display_data"
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"data": {
|
640 |
+
"application/vnd.jupyter.widget-view+json": {
|
641 |
+
"model_id": "08b221ee56ce42c7ae1e4b7c97e54b08",
|
642 |
+
"version_major": 2,
|
643 |
+
"version_minor": 0
|
644 |
+
},
|
645 |
+
"text/plain": [
|
646 |
+
" ...events.1756035572.51336aae7b3d.93.1: 100%|##########| 10.3kB / 10.3kB "
|
647 |
+
]
|
648 |
+
},
|
649 |
+
"metadata": {},
|
650 |
+
"output_type": "display_data"
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"data": {
|
654 |
+
"application/vnd.jupyter.widget-view+json": {
|
655 |
+
"model_id": "a1dd80827c8240cca48b2f62624678ec",
|
656 |
+
"version_major": 2,
|
657 |
+
"version_minor": 0
|
658 |
+
},
|
659 |
+
"text/plain": [
|
660 |
+
" ...ents.1756037571.51336aae7b3d.2438.0: 100%|##########| 6.45kB / 6.45kB "
|
661 |
+
]
|
662 |
+
},
|
663 |
+
"metadata": {},
|
664 |
+
"output_type": "display_data"
|
665 |
+
},
|
666 |
+
{
|
667 |
+
"data": {
|
668 |
+
"application/vnd.jupyter.widget-view+json": {
|
669 |
+
"model_id": "be6333ce8ef3415db35a755e930aa85c",
|
670 |
+
"version_major": 2,
|
671 |
+
"version_minor": 0
|
672 |
+
},
|
673 |
+
"text/plain": [
|
674 |
+
" ...ents.1756037458.51336aae7b3d.2174.0: 100%|##########| 6.45kB / 6.45kB "
|
675 |
+
]
|
676 |
+
},
|
677 |
+
"metadata": {},
|
678 |
+
"output_type": "display_data"
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"data": {
|
682 |
+
"application/vnd.jupyter.widget-view+json": {
|
683 |
+
"model_id": "e37affcd1ca64ee0b4e206cbe244a117",
|
684 |
+
"version_major": 2,
|
685 |
+
"version_minor": 0
|
686 |
+
},
|
687 |
+
"text/plain": [
|
688 |
+
" ...ents.1756037652.51336aae7b3d.2641.0: 100%|##########| 10.4kB / 10.4kB "
|
689 |
+
]
|
690 |
+
},
|
691 |
+
"metadata": {},
|
692 |
+
"output_type": "display_data"
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"data": {
|
696 |
+
"application/vnd.jupyter.widget-view+json": {
|
697 |
+
"model_id": "4062cbde5d2a46c1bc3b55556c10bcf2",
|
698 |
+
"version_major": 2,
|
699 |
+
"version_minor": 0
|
700 |
+
},
|
701 |
+
"text/plain": [
|
702 |
+
" ...0b-dermatology-qa/training_args.bin: 100%|##########| 6.16kB / 6.16kB "
|
703 |
+
]
|
704 |
+
},
|
705 |
+
"metadata": {},
|
706 |
+
"output_type": "display_data"
|
707 |
+
},
|
708 |
+
{
|
709 |
+
"data": {
|
710 |
+
"application/vnd.jupyter.widget-view+json": {
|
711 |
+
"model_id": "5bbb5e92548e430c98d9fde415d36c82",
|
712 |
+
"version_major": 2,
|
713 |
+
"version_minor": 0
|
714 |
+
},
|
715 |
+
"text/plain": [
|
716 |
+
" ...s-20b-dermatology-qa/tokenizer.json: 100%|##########| 27.9MB / 27.9MB "
|
717 |
+
]
|
718 |
+
},
|
719 |
+
"metadata": {},
|
720 |
+
"output_type": "display_data"
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"data": {
|
724 |
+
"application/vnd.jupyter.widget-view+json": {
|
725 |
+
"model_id": "53c63643ee9643f4af1d8016c8dbc32f",
|
726 |
+
"version_major": 2,
|
727 |
+
"version_minor": 0
|
728 |
+
},
|
729 |
+
"text/plain": [
|
730 |
+
" ...tology-qa/adapter_model.safetensors: 70%|######9 | 41.8MB / 60.2MB "
|
731 |
+
]
|
732 |
+
},
|
733 |
+
"metadata": {},
|
734 |
+
"output_type": "display_data"
|
735 |
+
},
|
736 |
+
{
|
737 |
+
"name": "stderr",
|
738 |
+
"output_type": "stream",
|
739 |
+
"text": [
|
740 |
+
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
|
741 |
+
]
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"data": {
|
745 |
+
"application/vnd.jupyter.widget-view+json": {
|
746 |
+
"model_id": "29576445acd942f499a65f3687e25eaf",
|
747 |
+
"version_major": 2,
|
748 |
+
"version_minor": 0
|
749 |
+
},
|
750 |
+
"text/plain": [
|
751 |
+
"Processing Files (0 / 0) : | | 0.00B / 0.00B "
|
752 |
+
]
|
753 |
+
},
|
754 |
+
"metadata": {},
|
755 |
+
"output_type": "display_data"
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"data": {
|
759 |
+
"application/vnd.jupyter.widget-view+json": {
|
760 |
+
"model_id": "7c133dd9961643bd9ed042dfe7f67a12",
|
761 |
+
"version_major": 2,
|
762 |
+
"version_minor": 0
|
763 |
+
},
|
764 |
+
"text/plain": [
|
765 |
+
"New Data Upload : | | 0.00B / 0.00B "
|
766 |
+
]
|
767 |
+
},
|
768 |
+
"metadata": {},
|
769 |
+
"output_type": "display_data"
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"data": {
|
773 |
+
"application/vnd.jupyter.widget-view+json": {
|
774 |
+
"model_id": "016b81a696b3490589db33a5a92dd7eb",
|
775 |
+
"version_major": 2,
|
776 |
+
"version_minor": 0
|
777 |
+
},
|
778 |
+
"text/plain": [
|
779 |
+
" ...events.1756035182.51336aae7b3d.93.0: 100%|##########| 6.77kB / 6.77kB "
|
780 |
+
]
|
781 |
+
},
|
782 |
+
"metadata": {},
|
783 |
+
"output_type": "display_data"
|
784 |
+
},
|
785 |
+
{
|
786 |
+
"data": {
|
787 |
+
"application/vnd.jupyter.widget-view+json": {
|
788 |
+
"model_id": "60e93471be1144f0a056dee1f2dbdcbe",
|
789 |
+
"version_major": 2,
|
790 |
+
"version_minor": 0
|
791 |
+
},
|
792 |
+
"text/plain": [
|
793 |
+
" ...events.1756035572.51336aae7b3d.93.1: 100%|##########| 10.3kB / 10.3kB "
|
794 |
+
]
|
795 |
+
},
|
796 |
+
"metadata": {},
|
797 |
+
"output_type": "display_data"
|
798 |
+
},
|
799 |
+
{
|
800 |
+
"data": {
|
801 |
+
"application/vnd.jupyter.widget-view+json": {
|
802 |
+
"model_id": "e7bbd3b9fc604f9289257cbe876ccda6",
|
803 |
+
"version_major": 2,
|
804 |
+
"version_minor": 0
|
805 |
+
},
|
806 |
+
"text/plain": [
|
807 |
+
" ...ents.1756037458.51336aae7b3d.2174.0: 100%|##########| 6.45kB / 6.45kB "
|
808 |
+
]
|
809 |
+
},
|
810 |
+
"metadata": {},
|
811 |
+
"output_type": "display_data"
|
812 |
+
},
|
813 |
+
{
|
814 |
+
"data": {
|
815 |
+
"application/vnd.jupyter.widget-view+json": {
|
816 |
+
"model_id": "ee7018041db44537b4d9e16909768f5f",
|
817 |
+
"version_major": 2,
|
818 |
+
"version_minor": 0
|
819 |
+
},
|
820 |
+
"text/plain": [
|
821 |
+
" ...ents.1756037571.51336aae7b3d.2438.0: 100%|##########| 6.45kB / 6.45kB "
|
822 |
+
]
|
823 |
+
},
|
824 |
+
"metadata": {},
|
825 |
+
"output_type": "display_data"
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"data": {
|
829 |
+
"application/vnd.jupyter.widget-view+json": {
|
830 |
+
"model_id": "9375111063c740509aa2046131404269",
|
831 |
+
"version_major": 2,
|
832 |
+
"version_minor": 0
|
833 |
+
},
|
834 |
+
"text/plain": [
|
835 |
+
" ...ents.1756037652.51336aae7b3d.2641.0: 100%|##########| 10.4kB / 10.4kB "
|
836 |
+
]
|
837 |
+
},
|
838 |
+
"metadata": {},
|
839 |
+
"output_type": "display_data"
|
840 |
+
},
|
841 |
+
{
|
842 |
+
"data": {
|
843 |
+
"application/vnd.jupyter.widget-view+json": {
|
844 |
+
"model_id": "cfe5771010214e9197c21f990e2840da",
|
845 |
+
"version_major": 2,
|
846 |
+
"version_minor": 0
|
847 |
+
},
|
848 |
+
"text/plain": [
|
849 |
+
" ...0b-dermatology-qa/training_args.bin: 100%|##########| 6.16kB / 6.16kB "
|
850 |
+
]
|
851 |
+
},
|
852 |
+
"metadata": {},
|
853 |
+
"output_type": "display_data"
|
854 |
+
},
|
855 |
+
{
|
856 |
+
"data": {
|
857 |
+
"application/vnd.jupyter.widget-view+json": {
|
858 |
+
"model_id": "da6a258f9d464998a5b1ae264121f97a",
|
859 |
+
"version_major": 2,
|
860 |
+
"version_minor": 0
|
861 |
+
},
|
862 |
+
"text/plain": [
|
863 |
+
" ...tology-qa/adapter_model.safetensors: 56%|#####5 | 33.6MB / 60.2MB "
|
864 |
+
]
|
865 |
+
},
|
866 |
+
"metadata": {},
|
867 |
+
"output_type": "display_data"
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"data": {
|
871 |
+
"application/vnd.jupyter.widget-view+json": {
|
872 |
+
"model_id": "931788c4fdfb4a47ac8b8cefe36cd54b",
|
873 |
+
"version_major": 2,
|
874 |
+
"version_minor": 0
|
875 |
+
},
|
876 |
+
"text/plain": [
|
877 |
+
" ...s-20b-dermatology-qa/tokenizer.json: 100%|##########| 27.9MB / 27.9MB "
|
878 |
+
]
|
879 |
+
},
|
880 |
+
"metadata": {},
|
881 |
+
"output_type": "display_data"
|
882 |
+
},
|
883 |
+
{
|
884 |
+
"data": {
|
885 |
+
"text/plain": [
|
886 |
+
"CommitInfo(commit_url='https://huggingface.co/kingabzpro/gpt-oss-20b-dermatology-qa/commit/b1706fde1cbc1942ccf763061fa31b22e5b61cc6', commit_message='End of training', commit_description='', oid='b1706fde1cbc1942ccf763061fa31b22e5b61cc6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/kingabzpro/gpt-oss-20b-dermatology-qa', endpoint='https://huggingface.co', repo_type='model', repo_id='kingabzpro/gpt-oss-20b-dermatology-qa'), pr_revision=None, pr_num=None)"
|
887 |
+
]
|
888 |
+
},
|
889 |
+
"execution_count": 15,
|
890 |
+
"metadata": {},
|
891 |
+
"output_type": "execute_result"
|
892 |
+
}
|
893 |
+
],
|
894 |
+
"source": [
|
895 |
+
"trainer.save_model(SAVED_MODEL_ID)\n",
|
896 |
+
"trainer.push_to_hub(dataset_name=SAVED_MODEL_ID)"
|
897 |
+
]
|
898 |
+
},
|
899 |
+
{
|
900 |
+
"cell_type": "markdown",
|
901 |
+
"metadata": {},
|
902 |
+
"source": [
|
903 |
+
"## Model inference after fine-tuning"
|
904 |
+
]
|
905 |
+
},
|
906 |
+
{
|
907 |
+
"cell_type": "code",
|
908 |
+
"execution_count": 1,
|
909 |
+
"metadata": {},
|
910 |
+
"outputs": [
|
911 |
+
{
|
912 |
+
"name": "stderr",
|
913 |
+
"output_type": "stream",
|
914 |
+
"text": [
|
915 |
+
"MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16\n"
|
916 |
+
]
|
917 |
+
},
|
918 |
+
{
|
919 |
+
"data": {
|
920 |
+
"application/vnd.jupyter.widget-view+json": {
|
921 |
+
"model_id": "dc0cfff4f43a4675b78186cdf5625377",
|
922 |
+
"version_major": 2,
|
923 |
+
"version_minor": 0
|
924 |
+
},
|
925 |
+
"text/plain": [
|
926 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
927 |
+
]
|
928 |
+
},
|
929 |
+
"metadata": {},
|
930 |
+
"output_type": "display_data"
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"name": "stderr",
|
934 |
+
"output_type": "stream",
|
935 |
+
"text": [
|
936 |
+
"/usr/local/lib/python3.11/dist-packages/peft/tuners/lora/layer.py:159: UserWarning: Unsupported layer type '<class 'transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts'>' encountered, proceed at your own risk.\n",
|
937 |
+
" warnings.warn(\n"
|
938 |
+
]
|
939 |
+
}
|
940 |
+
],
|
941 |
+
"source": [
|
942 |
+
"from peft import PeftModel\n",
|
943 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
944 |
+
"\n",
|
945 |
+
"BASE_MODEL_ID = \"openai/gpt-oss-20b\"\n",
|
946 |
+
"SAVED_LORA_MODEL_ID = \"gpt-oss-20b-dermatology-qa\"\n",
|
947 |
+
"\n",
|
948 |
+
"# Load the tokenizer\n",
|
949 |
+
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)\n",
|
950 |
+
"\n",
|
951 |
+
"# Load the original model first\n",
|
952 |
+
"model_kwargs = dict(\n",
|
953 |
+
" attn_implementation=\"eager\", torch_dtype=\"auto\", use_cache=True, device_map=\"cuda\"\n",
|
954 |
+
")\n",
|
955 |
+
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
956 |
+
" BASE_MODEL_ID, **model_kwargs\n",
|
957 |
+
")\n",
|
958 |
+
"\n",
|
959 |
+
"# Merge fine-tuned weights with the base model\n",
|
960 |
+
"model = PeftModel.from_pretrained(base_model, SAVED_LORA_MODEL_ID)\n",
|
961 |
+
"model = model.merge_and_unload()"
|
962 |
+
]
|
963 |
+
},
|
964 |
+
{
|
965 |
+
"cell_type": "code",
|
966 |
+
"execution_count": 2,
|
967 |
+
"metadata": {},
|
968 |
+
"outputs": [],
|
969 |
+
"source": [
|
970 |
+
"from datasets import load_dataset\n",
|
971 |
+
"\n",
|
972 |
+
"# Load dataset\n",
|
973 |
+
"dataset = load_dataset(\"kingabzpro/dermatology-qa-firecrawl-dataset\", split=\"train\")\n",
|
974 |
+
"dataset = dataset.train_test_split(test_size=0.1, seed=42)"
|
975 |
+
]
|
976 |
+
},
|
977 |
+
{
|
978 |
+
"cell_type": "code",
|
979 |
+
"execution_count": 9,
|
980 |
+
"metadata": {},
|
981 |
+
"outputs": [],
|
982 |
+
"source": [
|
983 |
+
"from openai_harmony import (\n",
|
984 |
+
" Conversation,\n",
|
985 |
+
" DeveloperContent,\n",
|
986 |
+
" HarmonyEncodingName,\n",
|
987 |
+
" Message,\n",
|
988 |
+
" Role,\n",
|
989 |
+
" load_harmony_encoding,\n",
|
990 |
+
")\n",
|
991 |
+
"\n",
|
992 |
+
"# Load the Harmony encoder once\n",
|
993 |
+
"enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n",
|
994 |
+
"\n",
|
995 |
+
"DERM_DEV_INSTRUCTIONS = (\n",
|
996 |
+
" \"You are a board-certified dermatologist answering various dermatology questions.\"\n",
|
997 |
+
" \" Answer clearly in 1–3 sentences. No speculation.\"\n",
|
998 |
+
")\n",
|
999 |
+
"\n",
|
1000 |
+
"def render_inference_harmony(question: str) -> str:\n",
|
1001 |
+
" \"\"\"Harmony-formatted prompt for inference.\"\"\"\n",
|
1002 |
+
" convo = Conversation.from_messages(\n",
|
1003 |
+
" [\n",
|
1004 |
+
" Message.from_role_and_content(\n",
|
1005 |
+
" Role.DEVELOPER,\n",
|
1006 |
+
" DeveloperContent.new().with_instructions(DERM_DEV_INSTRUCTIONS),\n",
|
1007 |
+
" ),\n",
|
1008 |
+
" Message.from_role_and_content(Role.USER, question.strip()),\n",
|
1009 |
+
" ]\n",
|
1010 |
+
" )\n",
|
1011 |
+
" tokens = enc.render_conversation_for_completion(convo, Role.ASSISTANT)\n",
|
1012 |
+
" return enc.decode(tokens)\n",
|
1013 |
+
"\n",
|
1014 |
+
"def extract_final_answer(text):\n",
|
1015 |
+
" # Find the start of the assistant's final message\n",
|
1016 |
+
" start_marker = \"<|start|>assistant<|message|>\"\n",
|
1017 |
+
" start_idx = text.find(start_marker)\n",
|
1018 |
+
" \n",
|
1019 |
+
" if start_idx == -1:\n",
|
1020 |
+
" return \"No answer found in the text\"\n",
|
1021 |
+
" \n",
|
1022 |
+
" # Move to the beginning of the actual answer\n",
|
1023 |
+
" start_idx += len(start_marker)\n",
|
1024 |
+
" \n",
|
1025 |
+
" # Find the end of the answer (either next tag or end of text)\n",
|
1026 |
+
" end_idx = text.find(\"<|end|>\", start_idx)\n",
|
1027 |
+
" if end_idx == -1:\n",
|
1028 |
+
" end_idx = len(text)\n",
|
1029 |
+
" \n",
|
1030 |
+
" # Extract and clean the answer\n",
|
1031 |
+
" answer = text[start_idx:end_idx].strip()\n",
|
1032 |
+
" \n",
|
1033 |
+
" return answer"
|
1034 |
+
]
|
1035 |
+
},
|
1036 |
+
{
|
1037 |
+
"cell_type": "code",
|
1038 |
+
"execution_count": 10,
|
1039 |
+
"metadata": {},
|
1040 |
+
"outputs": [
|
1041 |
+
{
|
1042 |
+
"name": "stdout",
|
1043 |
+
"output_type": "stream",
|
1044 |
+
"text": [
|
1045 |
+
"During winter, the dry skin that is typical in eczema can become even more dry, which may worsen eczema symptoms. The lack of moisture in the air can increase skin dryness.\n"
|
1046 |
+
]
|
1047 |
+
}
|
1048 |
+
],
|
1049 |
+
"source": [
|
1050 |
+
"question = dataset[\"test\"][20][\"question\"]\n",
|
1051 |
+
"\n",
|
1052 |
+
"text = render_inference_harmony(question)\n",
|
1053 |
+
"\n",
|
1054 |
+
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
|
1055 |
+
"outputs = model.generate(\n",
|
1056 |
+
" input_ids=inputs.input_ids,\n",
|
1057 |
+
" attention_mask=inputs.attention_mask,\n",
|
1058 |
+
" max_new_tokens=200,\n",
|
1059 |
+
" eos_token_id=tokenizer.eos_token_id,\n",
|
1060 |
+
" use_cache=True,\n",
|
1061 |
+
")\n",
|
1062 |
+
"response = tokenizer.batch_decode(outputs)\n",
|
1063 |
+
"final_answer = extract_final_answer(response[0])\n",
|
1064 |
+
"print(final_answer)"
|
1065 |
+
]
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"cell_type": "code",
|
1069 |
+
"execution_count": 14,
|
1070 |
+
"metadata": {},
|
1071 |
+
"outputs": [
|
1072 |
+
{
|
1073 |
+
"data": {
|
1074 |
+
"text/plain": [
|
1075 |
+
"'<|start|>developer<|message|># Instructions\\n\\nYou are a board-certified dermatologist answering various dermatology questions. Answer clearly in 1–3 sentences. No speculation.<|end|><|start|>user<|message|>How does the source suggest clinicians approach the diagnosis of rosacea?<|end|><|start|>assistant'"
|
1076 |
+
]
|
1077 |
+
},
|
1078 |
+
"execution_count": 14,
|
1079 |
+
"metadata": {},
|
1080 |
+
"output_type": "execute_result"
|
1081 |
+
}
|
1082 |
+
],
|
1083 |
+
"source": [
|
1084 |
+
"text"
|
1085 |
+
]
|
1086 |
+
},
|
1087 |
+
{
|
1088 |
+
"cell_type": "code",
|
1089 |
+
"execution_count": 11,
|
1090 |
+
"metadata": {},
|
1091 |
+
"outputs": [
|
1092 |
+
{
|
1093 |
+
"data": {
|
1094 |
+
"text/plain": [
|
1095 |
+
"'During winter, indoor air tends to be dry, which can trigger eczema flare‑ups for some individuals. The dryness of indoor environments in winter is a known trigger for these patients.'"
|
1096 |
+
]
|
1097 |
+
},
|
1098 |
+
"execution_count": 11,
|
1099 |
+
"metadata": {},
|
1100 |
+
"output_type": "execute_result"
|
1101 |
+
}
|
1102 |
+
],
|
1103 |
+
"source": [
|
1104 |
+
"dataset[\"test\"][20][\"answer\"]"
|
1105 |
+
]
|
1106 |
+
},
|
1107 |
+
{
|
1108 |
+
"cell_type": "code",
|
1109 |
+
"execution_count": 12,
|
1110 |
+
"metadata": {},
|
1111 |
+
"outputs": [
|
1112 |
+
{
|
1113 |
+
"name": "stdout",
|
1114 |
+
"output_type": "stream",
|
1115 |
+
"text": [
|
1116 |
+
"The source suggests that clinicians approach the diagnosis of rosacea as a dynamic process that requires ongoing reassessment as patients develop new symptoms or present new presentations, rather than a one-time determination. This approach is intended to address changing manifestations within patients and evolving presentations in the population.\n"
|
1117 |
+
]
|
1118 |
+
}
|
1119 |
+
],
|
1120 |
+
"source": [
|
1121 |
+
"question = dataset[\"test\"][50][\"question\"]\n",
|
1122 |
+
"\n",
|
1123 |
+
"text = render_inference_harmony(question)\n",
|
1124 |
+
"\n",
|
1125 |
+
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
|
1126 |
+
"outputs = model.generate(\n",
|
1127 |
+
" input_ids=inputs.input_ids,\n",
|
1128 |
+
" attention_mask=inputs.attention_mask,\n",
|
1129 |
+
" max_new_tokens=200,\n",
|
1130 |
+
" eos_token_id=tokenizer.eos_token_id,\n",
|
1131 |
+
" use_cache=True,\n",
|
1132 |
+
")\n",
|
1133 |
+
"response = tokenizer.batch_decode(outputs)\n",
|
1134 |
+
"final_answer = extract_final_answer(response[0])\n",
|
1135 |
+
"print(final_answer)"
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "code",
|
1140 |
+
"execution_count": 13,
|
1141 |
+
"metadata": {},
|
1142 |
+
"outputs": [
|
1143 |
+
{
|
1144 |
+
"data": {
|
1145 |
+
"text/plain": [
|
1146 |
+
"'The source suggests using a stepped approach, which typically involves evaluating the patient and then progressing through treatment options as needed. The text also mentions a differential diagnosis list to aid in distinguishing rosacea from other similar conditions.'"
|
1147 |
+
]
|
1148 |
+
},
|
1149 |
+
"execution_count": 13,
|
1150 |
+
"metadata": {},
|
1151 |
+
"output_type": "execute_result"
|
1152 |
+
}
|
1153 |
+
],
|
1154 |
+
"source": [
|
1155 |
+
"dataset[\"test\"][50][\"answer\"]"
|
1156 |
+
]
|
1157 |
+
},
|
1158 |
+
{
|
1159 |
+
"cell_type": "code",
|
1160 |
+
"execution_count": 4,
|
1161 |
+
"metadata": {},
|
1162 |
+
"outputs": [
|
1163 |
+
{
|
1164 |
+
"data": {
|
1165 |
+
"application/vnd.jupyter.widget-view+json": {
|
1166 |
+
"model_id": "1ea55abf1ec6411f88a9e1c7bcf90446",
|
1167 |
+
"version_major": 2,
|
1168 |
+
"version_minor": 0
|
1169 |
+
},
|
1170 |
+
"text/plain": [
|
1171 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
1172 |
+
]
|
1173 |
+
},
|
1174 |
+
"metadata": {},
|
1175 |
+
"output_type": "display_data"
|
1176 |
+
},
|
1177 |
+
{
|
1178 |
+
"name": "stderr",
|
1179 |
+
"output_type": "stream",
|
1180 |
+
"text": [
|
1181 |
+
"Device set to use cuda\n"
|
1182 |
+
]
|
1183 |
+
}
|
1184 |
+
],
|
1185 |
+
"source": [
|
1186 |
+
"from transformers import pipeline\n",
|
1187 |
+
"\n",
|
1188 |
+
"# Load pipeline\n",
|
1189 |
+
"generator = pipeline(\n",
|
1190 |
+
" \"text-generation\",\n",
|
1191 |
+
" model=\"kingabzpro/gpt-oss-20b-dermatology-qa\",\n",
|
1192 |
+
" device=\"cuda\" # or device=0\n",
|
1193 |
+
")\n"
|
1194 |
+
]
|
1195 |
+
},
|
1196 |
+
{
|
1197 |
+
"cell_type": "code",
|
1198 |
+
"execution_count": 5,
|
1199 |
+
"metadata": {},
|
1200 |
+
"outputs": [
|
1201 |
+
{
|
1202 |
+
"name": "stdout",
|
1203 |
+
"output_type": "stream",
|
1204 |
+
"text": [
|
1205 |
+
"The source advises that clinicians should not rely solely on clinical presentation to diagnose rosacea. Instead, they should use a standardized, validated diagnostic tool such as the 2016 International Rosacea Consensus (IRC) criteria to confirm the diagnosis. This approach ensures a consistent and evidence‑based assessment rather than a subjective interpretation of symptoms.\n"
|
1206 |
+
]
|
1207 |
+
}
|
1208 |
+
],
|
1209 |
+
"source": [
|
1210 |
+
"question = \"How does the source suggest clinicians approach the diagnosis of rosacea?\"\n",
|
1211 |
+
"\n",
|
1212 |
+
"output = generator(\n",
|
1213 |
+
" [{\"role\": \"user\", \"content\": question}],\n",
|
1214 |
+
" max_new_tokens=200,\n",
|
1215 |
+
" return_full_text=False\n",
|
1216 |
+
")[0]\n",
|
1217 |
+
"\n",
|
1218 |
+
"print(output[\"generated_text\"])"
|
1219 |
+
]
|
1220 |
+
},
|
1221 |
+
{
|
1222 |
+
"cell_type": "code",
|
1223 |
+
"execution_count": 6,
|
1224 |
+
"metadata": {},
|
1225 |
+
"outputs": [
|
1226 |
+
{
|
1227 |
+
"name": "stdout",
|
1228 |
+
"output_type": "stream",
|
1229 |
+
"text": [
|
1230 |
+
"The source indicates that clinicians should consider rosacea when patients present with erythematous facial skin and may need to differentiate it from other conditions such as acne. Recognizing these features helps in identifying rosacea.\n"
|
1231 |
+
]
|
1232 |
+
}
|
1233 |
+
],
|
1234 |
+
"source": [
|
1235 |
+
"prompt = \"<|start|>developer<|message|># Instructions\\n\\nYou are a board-certified dermatologist answering various dermatology questions. Answer clearly in 1–3 sentences. No speculation.<|end|><|start|>user<|message|>How does the source suggest clinicians approach the diagnosis of rosacea?<|end|><|start|>assistant\"\n",
|
1236 |
+
"\n",
|
1237 |
+
"output = generator(\n",
|
1238 |
+
" prompt,\n",
|
1239 |
+
" max_new_tokens=200,\n",
|
1240 |
+
" return_full_text=False\n",
|
1241 |
+
")[0]\n",
|
1242 |
+
"\n",
|
1243 |
+
"print(output[\"generated_text\"])"
|
1244 |
+
]
|
1245 |
+
}
|
1246 |
+
],
|
1247 |
+
"metadata": {
|
1248 |
+
"kaggle": {
|
1249 |
+
"accelerator": "nvidiaTeslaT4",
|
1250 |
+
"dataSources": [],
|
1251 |
+
"dockerImageVersionId": 31011,
|
1252 |
+
"isGpuEnabled": true,
|
1253 |
+
"isInternetEnabled": true,
|
1254 |
+
"language": "python",
|
1255 |
+
"sourceType": "notebook"
|
1256 |
+
},
|
1257 |
+
"kernelspec": {
|
1258 |
+
"display_name": "Python 3 (ipykernel)",
|
1259 |
+
"language": "python",
|
1260 |
+
"name": "python3"
|
1261 |
+
},
|
1262 |
+
"language_info": {
|
1263 |
+
"codemirror_mode": {
|
1264 |
+
"name": "ipython",
|
1265 |
+
"version": 3
|
1266 |
+
},
|
1267 |
+
"file_extension": ".py",
|
1268 |
+
"mimetype": "text/x-python",
|
1269 |
+
"name": "python",
|
1270 |
+
"nbconvert_exporter": "python",
|
1271 |
+
"pygments_lexer": "ipython3",
|
1272 |
+
"version": "3.11.11"
|
1273 |
+
}
|
1274 |
+
},
|
1275 |
+
"nbformat": 4,
|
1276 |
+
"nbformat_minor": 4
|
1277 |
+
}
|