Spaces:
Runtime error
Runtime error
Update oryx/model/builder.py
Browse files- oryx/model/builder.py +5 -5
oryx/model/builder.py
CHANGED
|
@@ -75,11 +75,11 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 75 |
print("Loading vision tower...")
|
| 76 |
if not vision_tower.is_loaded:
|
| 77 |
vision_tower.load_model(device_map=device_map)
|
| 78 |
-
if device_map != "auto":
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
else:
|
| 82 |
-
|
| 83 |
image_processor = vision_tower.image_processor
|
| 84 |
print("Loading vision tower succeeded.")
|
| 85 |
if hasattr(model.config, "max_sequence_length"):
|
|
|
|
| 75 |
print("Loading vision tower...")
|
| 76 |
if not vision_tower.is_loaded:
|
| 77 |
vision_tower.load_model(device_map=device_map)
|
| 78 |
+
# if device_map != "auto":
|
| 79 |
+
# vision_tower = vision_tower.bfloat16()
|
| 80 |
+
# vision_tower = vision_tower.to("cuda")
|
| 81 |
+
# else:
|
| 82 |
+
# vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
| 83 |
image_processor = vision_tower.image_processor
|
| 84 |
print("Loading vision tower succeeded.")
|
| 85 |
if hasattr(model.config, "max_sequence_length"):
|