cavargas10 commited on
Commit
b3c1677
·
verified ·
1 Parent(s): e1088fe

Update trellis/models/__init__.py

Browse files
Files changed (1) hide show
  1. trellis/models/__init__.py +96 -70
trellis/models/__init__.py CHANGED
@@ -1,70 +1,96 @@
1
- import importlib
2
-
3
- __attributes = {
4
- 'SparseStructureEncoder': 'sparse_structure_vae',
5
- 'SparseStructureDecoder': 'sparse_structure_vae',
6
- 'SparseStructureFlowModel': 'sparse_structure_flow',
7
- 'SLatEncoder': 'structured_latent_vae',
8
- 'SLatGaussianDecoder': 'structured_latent_vae',
9
- 'SLatRadianceFieldDecoder': 'structured_latent_vae',
10
- 'SLatMeshDecoder': 'structured_latent_vae',
11
- 'SLatFlowModel': 'structured_latent_flow',
12
- }
13
-
14
- __submodules = []
15
-
16
- __all__ = list(__attributes.keys()) + __submodules
17
-
18
- def __getattr__(name):
19
- if name not in globals():
20
- if name in __attributes:
21
- module_name = __attributes[name]
22
- module = importlib.import_module(f".{module_name}", __name__)
23
- globals()[name] = getattr(module, name)
24
- elif name in __submodules:
25
- module = importlib.import_module(f".{name}", __name__)
26
- globals()[name] = module
27
- else:
28
- raise AttributeError(f"module {__name__} has no attribute {name}")
29
- return globals()[name]
30
-
31
-
32
- def from_pretrained(path: str, **kwargs):
33
- """
34
- Load a model from a pretrained checkpoint.
35
-
36
- Args:
37
- path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
38
- NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
39
- **kwargs: Additional arguments for the model constructor.
40
- """
41
- import os
42
- import json
43
- from safetensors.torch import load_file
44
- is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
45
-
46
- if is_local:
47
- config_file = f"{path}.json"
48
- model_file = f"{path}.safetensors"
49
- else:
50
- from huggingface_hub import hf_hub_download
51
- path_parts = path.split('/')
52
- repo_id = f'{path_parts[0]}/{path_parts[1]}'
53
- model_name = '/'.join(path_parts[2:])
54
- config_file = hf_hub_download(repo_id, f"{model_name}.json")
55
- model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
56
-
57
- with open(config_file, 'r') as f:
58
- config = json.load(f)
59
- model = __getattr__(config['name'])(**config['args'], **kwargs)
60
- model.load_state_dict(load_file(model_file))
61
-
62
- return model
63
-
64
-
65
- # For Pylance
66
- if __name__ == '__main__':
67
- from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
68
- from .sparse_structure_flow import SparseStructureFlowModel
69
- from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
70
- from .structured_latent_flow import SLatFlowModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructureEncoder': 'sparse_structure_vae',
5
+ 'SparseStructureDecoder': 'sparse_structure_vae',
6
+
7
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
8
+
9
+ 'SLatEncoder': 'structured_latent_vae',
10
+ 'SLatGaussianDecoder': 'structured_latent_vae',
11
+ 'SLatRadianceFieldDecoder': 'structured_latent_vae',
12
+ 'SLatMeshDecoder': 'structured_latent_vae',
13
+ 'ElasticSLatEncoder': 'structured_latent_vae',
14
+ 'ElasticSLatGaussianDecoder': 'structured_latent_vae',
15
+ 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
16
+ 'ElasticSLatMeshDecoder': 'structured_latent_vae',
17
+
18
+ 'SLatFlowModel': 'structured_latent_flow',
19
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
20
+ }
21
+
22
+ __submodules = []
23
+
24
+ __all__ = list(__attributes.keys()) + __submodules
25
+
26
+ def __getattr__(name):
27
+ if name not in globals():
28
+ if name in __attributes:
29
+ module_name = __attributes[name]
30
+ module = importlib.import_module(f".{module_name}", __name__)
31
+ globals()[name] = getattr(module, name)
32
+ elif name in __submodules:
33
+ module = importlib.import_module(f".{name}", __name__)
34
+ globals()[name] = module
35
+ else:
36
+ raise AttributeError(f"module {__name__} has no attribute {name}")
37
+ return globals()[name]
38
+
39
+
40
+ def from_pretrained(path: str, **kwargs):
41
+ """
42
+ Load a model from a pretrained checkpoint.
43
+
44
+ Args:
45
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
46
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
47
+ **kwargs: Additional arguments for the model constructor.
48
+ """
49
+ import os
50
+ import json
51
+ from safetensors.torch import load_file
52
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
53
+
54
+ if is_local:
55
+ config_file = f"{path}.json"
56
+ model_file = f"{path}.safetensors"
57
+ else:
58
+ from huggingface_hub import hf_hub_download
59
+ path_parts = path.split('/')
60
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
61
+ model_name = '/'.join(path_parts[2:])
62
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
63
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
64
+
65
+ with open(config_file, 'r') as f:
66
+ config = json.load(f)
67
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
68
+ model.load_state_dict(load_file(model_file))
69
+
70
+ return model
71
+
72
+
73
+ # For Pylance
74
+ if __name__ == '__main__':
75
+ from .sparse_structure_vae import (
76
+ SparseStructureEncoder,
77
+ SparseStructureDecoder,
78
+ )
79
+
80
+ from .sparse_structure_flow import SparseStructureFlowModel
81
+
82
+ from .structured_latent_vae import (
83
+ SLatEncoder,
84
+ SLatGaussianDecoder,
85
+ SLatRadianceFieldDecoder,
86
+ SLatMeshDecoder,
87
+ ElasticSLatEncoder,
88
+ ElasticSLatGaussianDecoder,
89
+ ElasticSLatRadianceFieldDecoder,
90
+ ElasticSLatMeshDecoder,
91
+ )
92
+
93
+ from .structured_latent_flow import (
94
+ SLatFlowModel,
95
+ ElasticSLatFlowModel,
96
+ )