DevWild commited on
Commit
f905585
·
verified ·
1 Parent(s): b1ec5b4

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +14 -14
script.py CHANGED
@@ -4,8 +4,8 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
4
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
5
 
6
  HF_TOKEN = os.environ.get("HF_TOKEN")
7
- #HF_DATASET = os.environ.get("HF_DATASET")
8
- HF_DATASET = "DevWild/autotrain-pr0b0rk"
9
  repo_id = os.environ.get("MODEL_REPO_ID")
10
 
11
  from huggingface_hub import snapshot_download, delete_repo, metadata_update
@@ -126,7 +126,7 @@ def run_training(hf_dataset_path: str):
126
  commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
127
  subprocess.run(commands, shell=True)
128
 
129
- patch_ai_toolkit_typing()
130
  commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
131
  process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
132
 
@@ -137,19 +137,19 @@ def run_training(hf_dataset_path: str):
137
 
138
  return process, dataset_dir
139
 
140
- def patch_ai_toolkit_typing():
141
- config_path = "ai-toolkit/toolkit/config_modules.py"
142
- if os.path.exists(config_path):
143
- with open(config_path, "r") as f:
144
- content = f.read()
145
 
146
- content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
147
 
148
- with open(config_path, "w") as f:
149
- f.write(content)
150
- print("✅ Patched ai-toolkit typing for torch.Tensor | None → Optional[torch.Tensor]")
151
- else:
152
- print("⚠️ Could not patch config_modules.py — file not found")
153
 
154
 
155
  if __name__ == "__main__":
 
4
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
5
 
6
  HF_TOKEN = os.environ.get("HF_TOKEN")
7
+ HF_DATASET = os.environ.get("HF_DATASET")
8
+ #HF_DATASET = "DevWild/autotrain-pr0b0rk"
9
  repo_id = os.environ.get("MODEL_REPO_ID")
10
 
11
  from huggingface_hub import snapshot_download, delete_repo, metadata_update
 
126
  commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
127
  subprocess.run(commands, shell=True)
128
 
129
+ # patch_ai_toolkit_typing()
130
  commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
131
  process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
132
 
 
137
 
138
  return process, dataset_dir
139
 
140
+ #def patch_ai_toolkit_typing():
141
+ # config_path = "ai-toolkit/toolkit/config_modules.py"
142
+ # if os.path.exists(config_path):
143
+ # with open(config_path, "r") as f:
144
+ # content = f.read()
145
 
146
+ # content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
147
 
148
+ # with open(config_path, "w") as f:
149
+ # f.write(content)
150
+ # print("✅ Patched ai-toolkit typing for torch.Tensor | None → Optional[torch.Tensor]")
151
+ # else:
152
+ # print("⚠️ Could not patch config_modules.py — file not found")
153
 
154
 
155
  if __name__ == "__main__":