Spaces:
Running
Running
Update script.py
Browse files
script.py
CHANGED
@@ -4,8 +4,8 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
|
4 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
5 |
|
6 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
7 |
-
|
8 |
-
HF_DATASET = "DevWild/autotrain-pr0b0rk"
|
9 |
repo_id = os.environ.get("MODEL_REPO_ID")
|
10 |
|
11 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
@@ -126,7 +126,7 @@ def run_training(hf_dataset_path: str):
|
|
126 |
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
|
127 |
subprocess.run(commands, shell=True)
|
128 |
|
129 |
-
patch_ai_toolkit_typing()
|
130 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
131 |
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
|
132 |
|
@@ -137,19 +137,19 @@ def run_training(hf_dataset_path: str):
|
|
137 |
|
138 |
return process, dataset_dir
|
139 |
|
140 |
-
def patch_ai_toolkit_typing():
|
141 |
-
config_path = "ai-toolkit/toolkit/config_modules.py"
|
142 |
-
if os.path.exists(config_path):
|
143 |
-
with open(config_path, "r") as f:
|
144 |
-
content = f.read()
|
145 |
|
146 |
-
content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
|
147 |
|
148 |
-
with open(config_path, "w") as f:
|
149 |
-
f.write(content)
|
150 |
-
print("✅ Patched ai-toolkit typing for torch.Tensor | None → Optional[torch.Tensor]")
|
151 |
-
else:
|
152 |
-
print("⚠️ Could not patch config_modules.py — file not found")
|
153 |
|
154 |
|
155 |
if __name__ == "__main__":
|
|
|
4 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
5 |
|
6 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
7 |
+
HF_DATASET = os.environ.get("HF_DATASET")
|
8 |
+
#HF_DATASET = "DevWild/autotrain-pr0b0rk"
|
9 |
repo_id = os.environ.get("MODEL_REPO_ID")
|
10 |
|
11 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
|
|
126 |
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
|
127 |
subprocess.run(commands, shell=True)
|
128 |
|
129 |
+
# patch_ai_toolkit_typing()
|
130 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
131 |
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
|
132 |
|
|
|
137 |
|
138 |
return process, dataset_dir
|
139 |
|
140 |
+
#def patch_ai_toolkit_typing():
|
141 |
+
# config_path = "ai-toolkit/toolkit/config_modules.py"
|
142 |
+
# if os.path.exists(config_path):
|
143 |
+
# with open(config_path, "r") as f:
|
144 |
+
# content = f.read()
|
145 |
|
146 |
+
# content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
|
147 |
|
148 |
+
# with open(config_path, "w") as f:
|
149 |
+
# f.write(content)
|
150 |
+
# print("✅ Patched ai-toolkit typing for torch.Tensor | None → Optional[torch.Tensor]")
|
151 |
+
# else:
|
152 |
+
# print("⚠️ Could not patch config_modules.py — file not found")
|
153 |
|
154 |
|
155 |
if __name__ == "__main__":
|