Spaces:
Runtime error
Runtime error
Update oryx/model/builder.py
Browse files- oryx/model/builder.py +2 -1
oryx/model/builder.py
CHANGED
|
@@ -76,7 +76,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
| 76 |
if not vision_tower.is_loaded:
|
| 77 |
vision_tower.load_model(device_map=device_map)
|
| 78 |
if device_map != "auto":
|
| 79 |
-
vision_tower
|
|
|
|
| 80 |
else:
|
| 81 |
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
| 82 |
image_processor = vision_tower.image_processor
|
|
|
|
| 76 |
if not vision_tower.is_loaded:
|
| 77 |
vision_tower.load_model(device_map=device_map)
|
| 78 |
if device_map != "auto":
|
| 79 |
+
vision_tower = vision_tower.bfloat16()
|
| 80 |
+
vision_tower = vision_tower.to("cuda")
|
| 81 |
else:
|
| 82 |
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
| 83 |
image_processor = vision_tower.image_processor
|