Spaces:
Runtime error
Runtime error
Update trellis/pipelines/trellis_text_to_3d.py
Browse files
trellis/pipelines/trellis_text_to_3d.py
CHANGED
|
@@ -9,7 +9,6 @@ from .base import Pipeline
|
|
| 9 |
from . import samplers
|
| 10 |
from ..modules import sparse as sp
|
| 11 |
|
| 12 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
|
| 14 |
class TrellisTextTo3DPipeline(Pipeline):
|
| 15 |
"""
|
|
@@ -70,9 +69,10 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
| 70 |
Initialize the text conditioning model.
|
| 71 |
"""
|
| 72 |
# load model
|
| 73 |
-
model = CLIPTextModel.from_pretrained(name)
|
| 74 |
-
tokenizer = AutoTokenizer.from_pretrained(name)
|
| 75 |
model.eval()
|
|
|
|
| 76 |
self.text_cond_model = {
|
| 77 |
'model': model,
|
| 78 |
'tokenizer': tokenizer,
|
|
@@ -86,7 +86,7 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
| 86 |
"""
|
| 87 |
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
|
| 88 |
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 89 |
-
tokens = encoding['input_ids'].
|
| 90 |
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
| 91 |
|
| 92 |
return embeddings
|
|
@@ -225,4 +225,4 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
| 225 |
torch.manual_seed(seed)
|
| 226 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
| 227 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
| 228 |
-
return self.decode_slat(slat, formats)
|
|
|
|
| 9 |
from . import samplers
|
| 10 |
from ..modules import sparse as sp
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
class TrellisTextTo3DPipeline(Pipeline):
|
| 14 |
"""
|
|
|
|
| 69 |
Initialize the text conditioning model.
|
| 70 |
"""
|
| 71 |
# load model
|
| 72 |
+
model = CLIPTextModel.from_pretrained(name)
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
| 74 |
model.eval()
|
| 75 |
+
model = model.cuda()
|
| 76 |
self.text_cond_model = {
|
| 77 |
'model': model,
|
| 78 |
'tokenizer': tokenizer,
|
|
|
|
| 86 |
"""
|
| 87 |
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
|
| 88 |
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
| 89 |
+
tokens = encoding['input_ids'].cuda()
|
| 90 |
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
| 91 |
|
| 92 |
return embeddings
|
|
|
|
| 225 |
torch.manual_seed(seed)
|
| 226 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
| 227 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
| 228 |
+
return self.decode_slat(slat, formats)
|