Commit 
							
							·
						
						004a144
	
1
								Parent(s):
							
							5fe8039
								
Update base/app.py
Browse files- base/app.py +6 -0
    	
        base/app.py
    CHANGED
    
    | @@ -8,9 +8,15 @@ import cv2 | |
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
             
            import torchvision
         | 
| 10 | 
             
            import random
         | 
|  | |
|  | |
| 11 | 
             
            config_path = "./base/configs/sample.yaml"
         | 
| 12 | 
             
            args = OmegaConf.load("./base/configs/sample.yaml")
         | 
| 13 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
|  | |
|  | |
|  | |
|  | |
| 14 | 
             
            # ------- get model ---------------
         | 
| 15 | 
             
            model_t2V = model_t2v_fun(args)
         | 
| 16 | 
             
            model_t2V.to(device)
         | 
|  | |
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
             
            import torchvision
         | 
| 10 | 
             
            import random
         | 
| 11 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 12 | 
            +
             | 
| 13 | 
             
            config_path = "./base/configs/sample.yaml"
         | 
| 14 | 
             
            args = OmegaConf.load("./base/configs/sample.yaml")
         | 
| 15 | 
             
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 16 | 
            +
            ### download models
         | 
| 17 | 
            +
            snapshot_download('Vchitect/LaVie',cache_dir='./pretrained_models')
         | 
| 18 | 
            +
            snapshot_download('CompVis/stable-diffusion-v1-4',cache_dir='./pretrained_models')
         | 
| 19 | 
            +
             | 
| 20 | 
             
            # ------- get model ---------------
         | 
| 21 | 
             
            model_t2V = model_t2v_fun(args)
         | 
| 22 | 
             
            model_t2V.to(device)
         | 
 
			
