Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Float8DynamicActivation quantization
Browse files- optimization.py +4 -0
- requirements.txt +1 -0
    	
        optimization.py
    CHANGED
    
    | @@ -8,6 +8,8 @@ from typing import ParamSpec | |
| 8 | 
             
            import spaces
         | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            from torch.utils._pytree import tree_map_only
         | 
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            from optimization_utils import capture_component_call
         | 
| 13 | 
             
            from optimization_utils import aoti_compile
         | 
| @@ -46,6 +48,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw | |
| 46 |  | 
| 47 | 
             
                    pipeline.transformer.fuse_qkv_projections()
         | 
| 48 |  | 
|  | |
|  | |
| 49 | 
             
                    exported = torch.export.export(
         | 
| 50 | 
             
                        mod=pipeline.transformer,
         | 
| 51 | 
             
                        args=call.args,
         | 
|  | |
| 8 | 
             
            import spaces
         | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            from torch.utils._pytree import tree_map_only
         | 
| 11 | 
            +
            from torchao.quantization import quantize_
         | 
| 12 | 
            +
            from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
         | 
| 13 |  | 
| 14 | 
             
            from optimization_utils import capture_component_call
         | 
| 15 | 
             
            from optimization_utils import aoti_compile
         | 
|  | |
| 48 |  | 
| 49 | 
             
                    pipeline.transformer.fuse_qkv_projections()
         | 
| 50 |  | 
| 51 | 
            +
                    quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
         | 
| 52 | 
            +
                    
         | 
| 53 | 
             
                    exported = torch.export.export(
         | 
| 54 | 
             
                        mod=pipeline.transformer,
         | 
| 55 | 
             
                        args=call.args,
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
|  | |
| 1 | 
             
            transformers
         | 
| 2 | 
             
            git+https://github.com/huggingface/diffusers.git
         | 
| 3 | 
             
            accelerate
         | 
|  | |
| 1 | 
            +
            torchao
         | 
| 2 | 
             
            transformers
         | 
| 3 | 
             
            git+https://github.com/huggingface/diffusers.git
         | 
| 4 | 
             
            accelerate
         | 
 
			
