dahara1 commited on
Commit
53a3cc2
·
verified ·
1 Parent(s): 334a452

Upload 2 files

Browse files
Files changed (1) hide show
  1. utils.py +12 -14
utils.py CHANGED
@@ -7,7 +7,6 @@ import torch
7
  import uuid
8
  from PIL import Image, PngImagePlugin
9
  from datetime import datetime
10
- from dataclasses import dataclass
11
  from typing import Callable, Dict, Optional, Tuple, Any, List
12
  from diffusers import (
13
  DDIMScheduler,
@@ -22,18 +21,8 @@ import logging
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
 
25
-
26
- @dataclass
27
- class StyleConfig:
28
- prompt: str
29
- negative_prompt: str
30
-
31
-
32
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
33
- if randomize_seed:
34
- seed = random.randint(0, MAX_SEED)
35
- return seed
36
-
37
 
38
  def seed_everything(seed: int) -> torch.Generator:
39
  torch.manual_seed(seed)
@@ -217,10 +206,19 @@ def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str]
217
  add_watermarker=False
218
  )
219
 
220
- pipe.to(device)
 
 
 
 
 
 
 
 
221
  logging.info("Pipeline loaded successfully!")
222
  return pipe
223
  except Exception as e:
224
  logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
225
  raise
226
 
 
 
7
  import uuid
8
  from PIL import Image, PngImagePlugin
9
  from datetime import datetime
 
10
  from typing import Callable, Dict, Optional, Tuple, Any, List
11
  from diffusers import (
12
  DDIMScheduler,
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
 
24
+ def is_space_environment():
25
+ return "SPACE_ID" in os.environ and os.environ.get("SYSTEM") == "spaces"
 
 
 
 
 
 
 
 
 
 
26
 
27
  def seed_everything(seed: int) -> torch.Generator:
28
  torch.manual_seed(seed)
 
206
  add_watermarker=False
207
  )
208
 
209
+ # デバイス移動の部分を修正
210
+ if "SPACE_ID" in os.environ and os.environ.get("SYSTEM") == "spaces":
211
+ # Stateless GPU環境ではデバイス移動を特別に扱う
212
+ return pipe
213
+ else:
214
+ # 通常の環境では以前のコードを使用
215
+ pipe.to(device)
216
+ return pipe
217
+
218
  logging.info("Pipeline loaded successfully!")
219
  return pipe
220
  except Exception as e:
221
  logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
222
  raise
223
 
224
+