cavargas10 commited on
Commit
04b25c9
·
verified ·
1 Parent(s): 68bd713

Update trellis/pipelines/trellis_text_to_3d.py

Browse files
trellis/pipelines/trellis_text_to_3d.py CHANGED
@@ -9,7 +9,6 @@ from .base import Pipeline
9
  from . import samplers
10
  from ..modules import sparse as sp
11
 
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  class TrellisTextTo3DPipeline(Pipeline):
15
  """
@@ -70,9 +69,10 @@ class TrellisTextTo3DPipeline(Pipeline):
70
  Initialize the text conditioning model.
71
  """
72
  # load model
73
- model = CLIPTextModel.from_pretrained(name).to(self.device)
74
- tokenizer = AutoTokenizer.from_pretrained(name).to(self.device)
75
  model.eval()
 
76
  self.text_cond_model = {
77
  'model': model,
78
  'tokenizer': tokenizer,
@@ -86,7 +86,7 @@ class TrellisTextTo3DPipeline(Pipeline):
86
  """
87
  assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
88
  encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
89
- tokens = encoding['input_ids'].to(self.device)
90
  embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
91
 
92
  return embeddings
@@ -225,4 +225,4 @@ class TrellisTextTo3DPipeline(Pipeline):
225
  torch.manual_seed(seed)
226
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
227
  slat = self.sample_slat(cond, coords, slat_sampler_params)
228
- return self.decode_slat(slat, formats)
 
9
  from . import samplers
10
  from ..modules import sparse as sp
11
 
 
12
 
13
  class TrellisTextTo3DPipeline(Pipeline):
14
  """
 
69
  Initialize the text conditioning model.
70
  """
71
  # load model
72
+ model = CLIPTextModel.from_pretrained(name)
73
+ tokenizer = AutoTokenizer.from_pretrained(name)
74
  model.eval()
75
+ model = model.cuda()
76
  self.text_cond_model = {
77
  'model': model,
78
  'tokenizer': tokenizer,
 
86
  """
87
  assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
88
  encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
89
+ tokens = encoding['input_ids'].cuda()
90
  embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
91
 
92
  return embeddings
 
225
  torch.manual_seed(seed)
226
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
227
  slat = self.sample_slat(cond, coords, slat_sampler_params)
228
+ return self.decode_slat(slat, formats)