dahara1 commited on
Commit
6b0e3c1
·
verified ·
1 Parent(s): 29e82b9

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. prompt_generator.py +3 -3
app.py CHANGED
@@ -48,6 +48,7 @@ logger = logging.getLogger(__name__)
48
 
49
  # Constants
50
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
 
51
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
52
 
53
  # PyTorch settings for better performance and determinism
 
48
 
49
  # Constants
50
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
51
+ HF_TOKEN = os.getenv("HF_TOKEN")
52
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
53
 
54
  # PyTorch settings for better performance and determinism
prompt_generator.py CHANGED
@@ -119,8 +119,8 @@ def load_model():
119
  _model = AutoModelForCausalLM.from_pretrained(
120
  model_path,
121
  torch_dtype=torch_dtype,
122
- # device_map=device_map,
123
- # use_cache=True,
124
  low_cpu_mem_usage=True,
125
  )
126
 
@@ -277,7 +277,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
277
  logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
278
 
279
  # 生成
280
- logger.info("before ttorch.no_grad")
281
  with torch.no_grad():
282
  generated_ids = model.generate(
283
  input_ids=inputs,
 
119
  _model = AutoModelForCausalLM.from_pretrained(
120
  model_path,
121
  torch_dtype=torch_dtype,
122
+ device_map=device_map,
123
+ use_cache=True,
124
  low_cpu_mem_usage=True,
125
  )
126
 
 
277
  logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
278
 
279
  # 生成
280
+ logger.info("before torch.no_grad")
281
  with torch.no_grad():
282
  generated_ids = model.generate(
283
  input_ids=inputs,