Spaces:
Running
Running
Upload 3 files
Browse files- app.py +1 -0
- prompt_generator.py +3 -3
app.py
CHANGED
@@ -48,6 +48,7 @@ logger = logging.getLogger(__name__)
|
|
48 |
|
49 |
# Constants
|
50 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
|
|
51 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
52 |
|
53 |
# PyTorch settings for better performance and determinism
|
|
|
48 |
|
49 |
# Constants
|
50 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
51 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
52 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
53 |
|
54 |
# PyTorch settings for better performance and determinism
|
prompt_generator.py
CHANGED
@@ -119,8 +119,8 @@ def load_model():
|
|
119 |
_model = AutoModelForCausalLM.from_pretrained(
|
120 |
model_path,
|
121 |
torch_dtype=torch_dtype,
|
122 |
-
|
123 |
-
|
124 |
low_cpu_mem_usage=True,
|
125 |
)
|
126 |
|
@@ -277,7 +277,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
|
|
277 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
278 |
|
279 |
# 生成
|
280 |
-
logger.info("before
|
281 |
with torch.no_grad():
|
282 |
generated_ids = model.generate(
|
283 |
input_ids=inputs,
|
|
|
119 |
_model = AutoModelForCausalLM.from_pretrained(
|
120 |
model_path,
|
121 |
torch_dtype=torch_dtype,
|
122 |
+
device_map=device_map,
|
123 |
+
use_cache=True,
|
124 |
low_cpu_mem_usage=True,
|
125 |
)
|
126 |
|
|
|
277 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
278 |
|
279 |
# 生成
|
280 |
+
logger.info("before torch.no_grad")
|
281 |
with torch.no_grad():
|
282 |
generated_ids = model.generate(
|
283 |
input_ids=inputs,
|