Spaces:
Running
on
Zero
Running
on
Zero
Update trellis/pipelines/trellis_text_to_3d.py
Browse files
trellis/pipelines/trellis_text_to_3d.py
CHANGED
@@ -9,7 +9,6 @@ from .base import Pipeline
|
|
9 |
from . import samplers
|
10 |
from ..modules import sparse as sp
|
11 |
|
12 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
class TrellisTextTo3DPipeline(Pipeline):
|
15 |
"""
|
@@ -70,9 +69,10 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
70 |
Initialize the text conditioning model.
|
71 |
"""
|
72 |
# load model
|
73 |
-
model = CLIPTextModel.from_pretrained(name)
|
74 |
-
tokenizer = AutoTokenizer.from_pretrained(name)
|
75 |
model.eval()
|
|
|
76 |
self.text_cond_model = {
|
77 |
'model': model,
|
78 |
'tokenizer': tokenizer,
|
@@ -86,7 +86,7 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
86 |
"""
|
87 |
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
|
88 |
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
89 |
-
tokens = encoding['input_ids'].
|
90 |
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
91 |
|
92 |
return embeddings
|
@@ -225,4 +225,4 @@ class TrellisTextTo3DPipeline(Pipeline):
|
|
225 |
torch.manual_seed(seed)
|
226 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
227 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
228 |
-
return self.decode_slat(slat, formats)
|
|
|
9 |
from . import samplers
|
10 |
from ..modules import sparse as sp
|
11 |
|
|
|
12 |
|
13 |
class TrellisTextTo3DPipeline(Pipeline):
|
14 |
"""
|
|
|
69 |
Initialize the text conditioning model.
|
70 |
"""
|
71 |
# load model
|
72 |
+
model = CLIPTextModel.from_pretrained(name)
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
74 |
model.eval()
|
75 |
+
model = model.cuda()
|
76 |
self.text_cond_model = {
|
77 |
'model': model,
|
78 |
'tokenizer': tokenizer,
|
|
|
86 |
"""
|
87 |
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
|
88 |
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
89 |
+
tokens = encoding['input_ids'].cuda()
|
90 |
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
91 |
|
92 |
return embeddings
|
|
|
225 |
torch.manual_seed(seed)
|
226 |
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
227 |
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
228 |
+
return self.decode_slat(slat, formats)
|