rahul7star commited on
Commit
ac8e9ac
·
verified ·
1 Parent(s): 5f734d7

Lora testing

Browse files
Files changed (1) hide show
  1. app.py +80 -113
app.py CHANGED
@@ -13,7 +13,8 @@ from lycoris import create_lycoris_from_weights
13
  # Define model options
14
  MODEL_OPTIONS = {
15
  "Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
16
- "Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
17
  }
18
 
19
  # Define scheduler options
@@ -23,10 +24,6 @@ SCHEDULER_OPTIONS = {
23
  }
24
 
25
  def download_adapter(repo_id, weight_name=None):
26
- """
27
- Download the adapter file from the Hugging Face Hub.
28
- If weight_name is not provided, attempts to use pytorch_lora_weights.safetensors
29
- """
30
  adapter_filename = weight_name if weight_name else "pytorch_lora_weights.safetensors"
31
  cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models'))
32
  cleaned_adapter_path = repo_id.replace("/", "_").replace("\\", "_").replace(":", "_")
@@ -41,7 +38,6 @@ def download_adapter(repo_id, weight_name=None):
41
  )
42
  return path_to_adapter_file
43
  except Exception as e:
44
- # If specific file not found, try to get a list of available safetensors files
45
  if weight_name is None:
46
  raise ValueError(f"Could not download default adapter file: {str(e)}\nPlease specify the exact weight file name.")
47
  else:
@@ -65,55 +61,41 @@ def generate_video(
65
  output_fps,
66
  seed
67
  ):
68
- # Get model ID from selection
69
  model_id = MODEL_OPTIONS[model_choice]
70
 
71
- # Set seed for reproducibility
72
  if seed == -1 or seed is None or seed == "":
73
  seed = random.randint(0, 2147483647)
74
  else:
75
  seed = int(seed)
76
 
77
- # Set the seed
78
  torch.manual_seed(seed)
79
 
80
- # Load model
81
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
82
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16)
83
 
84
- # Set scheduler
85
  if scheduler_type == "UniPCMultistepScheduler":
86
- pipe.scheduler = UniPCMultistepScheduler.from_config(
87
- pipe.scheduler.config,
88
- flow_shift=flow_shift
89
- )
90
  else:
91
  pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
92
 
93
- # Move to GPU
94
  pipe.to("cuda")
95
 
96
- # Load LyCORIS weights if provided
97
  if lycoris_id and lycoris_id.strip():
98
  try:
99
- # Download the adapter file
100
- adapter_file_path = download_adapter(repo_id=lycoris_id, weight_name=lycoris_weight_name if lycoris_weight_name and lycoris_weight_name.strip() else None)
101
-
102
- # Apply LyCORIS adapter
103
  wrapper, *_ = create_lycoris_from_weights(lycoris_scale, adapter_file_path, pipe.transformer)
104
  wrapper.merge_to()
105
-
106
  except ValueError as e:
107
- # Return informative error if there are issues loading the adapter
108
  if "more than one weights file" in str(e) or "Could not download default adapter file" in str(e):
109
- return f"Error: The repository '{lycoris_id}' may contain multiple weight files. Please specify a weight name using the 'LyCORIS Weight Name' field.", seed
110
  else:
111
  return f"Error loading LyCORIS weights: {str(e)}", seed
112
-
113
- # Enable CPU offload for low VRAM
114
  pipe.enable_model_cpu_offload()
115
 
116
- # Generate video
117
  output = pipe(
118
  prompt=prompt,
119
  negative_prompt=negative_prompt,
@@ -125,7 +107,6 @@ def generate_video(
125
  generator=torch.Generator("cuda").manual_seed(seed)
126
  ).frames[0]
127
 
128
- # Export to video
129
  temp_file = "output.mp4"
130
  export_to_video(output, temp_file, fps=output_fps)
131
 
@@ -134,116 +115,103 @@ def generate_video(
134
  # Create the Gradio interface
135
  with gr.Blocks() as demo:
136
 
137
- gr.Markdown("# Wan 2.1 T2V")
138
-
139
  with gr.Row():
140
  with gr.Column(scale=1):
141
  model_choice = gr.Dropdown(
142
  choices=list(MODEL_OPTIONS.keys()),
143
- value="Wan2.1-T2V-1.3B",
144
  label="Model"
145
  )
146
 
147
- prompt = gr.Textbox(
148
- label="Prompt",
149
- value="",
150
- lines=3
151
- )
152
-
153
  negative_prompt = gr.Textbox(
154
  label="Negative Prompt",
155
  value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体��灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿",
156
  lines=3
157
  )
158
 
159
- with gr.Row():
160
- lycoris_id = gr.Textbox(
161
- label="Adapter Repo (e.g., markury/wan-st)",
162
- value="markury/wan-st"
163
- )
164
 
165
  with gr.Row():
166
  lycoris_weight_name = gr.Textbox(
167
- label="Adapter Path in Repo",
168
- value="pytorch_lora_weights.safetensors",
169
- info="Specify for repos with multiple .safetensors files, e.g.: adapter_model.safetensors, pytorch_lora_weights.safetensors, etc."
170
  )
171
  lycoris_scale = gr.Slider(
172
  label="Adapter Scale",
173
  minimum=0.0,
174
  maximum=2.0,
175
- value=1.00,
176
  step=0.05
177
  )
178
 
179
- with gr.Row():
180
- scheduler_type = gr.Dropdown(
181
- choices=list(SCHEDULER_OPTIONS.keys()),
182
- value="UniPCMultistepScheduler",
183
- label="Scheduler"
184
- )
185
- flow_shift = gr.Slider(
186
- label="Flow Shift",
187
- minimum=1.0,
188
- maximum=12.0,
189
- value=3.0,
190
- step=0.5,
191
- info="2.0-5.0 for smaller videos, 7.0-12.0 for larger videos"
192
- )
193
 
194
- with gr.Row():
195
- height = gr.Slider(
196
- label="Height",
197
- minimum=256,
198
- maximum=1024,
199
- value=832,
200
- step=32
201
- )
202
- width = gr.Slider(
203
- label="Width",
204
- minimum=256,
205
- maximum=1792,
206
- value=480,
207
- step=30
208
- )
209
 
210
- with gr.Row():
211
- num_frames = gr.Slider(
212
- label="Number of Frames (4k+1 is recommended, e.g. 33)",
213
- minimum=17,
214
- maximum=129,
215
- value=33,
216
- step=4
217
- )
218
- output_fps = gr.Slider(
219
- label="Output FPS",
220
- minimum=8,
221
- maximum=30,
222
- value=16,
223
- step=1
224
- )
225
 
226
- with gr.Row():
227
- guidance_scale = gr.Slider(
228
- label="Guidance Scale (CFG)",
229
- minimum=1.0,
230
- maximum=15.0,
231
- value=4.0,
232
- step=0.5
233
- )
234
- num_inference_steps = gr.Slider(
235
- label="Inference Steps",
236
- minimum=10,
237
- maximum=100,
238
- value=20,
239
- step=1
240
- )
241
 
242
  seed = gr.Number(
243
  label="Seed (-1 for random)",
244
  value=-1,
245
- precision=0,
246
- info="Set a specific seed for deterministic results"
247
  )
248
 
249
  generate_btn = gr.Button("Generate Video")
@@ -276,11 +244,10 @@ with gr.Blocks() as demo:
276
 
277
  gr.Markdown("""
278
  ## Tips for best results:
279
- - For smaller resolution videos, try lower values of flow shift (2.0-5.0)
280
- - For larger resolution videos, try higher values of flow shift (7.0-12.0)
281
- - Number of frames should be of the form 4k+1 (e.g., 33, 81)
282
- - Stick to lower frame counts. Even at 480p, an 81 frame sequence at 30 steps will nearly time out the request in this ZeroGPU space.
283
-
284
  """)
285
 
286
- demo.launch()
 
13
  # Define model options
14
  MODEL_OPTIONS = {
15
  "Wan2.1-T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
16
+ "Wan2.1-T2V-14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
17
+ "Wan2.1-Fun-Reward-1.3B": "alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
18
  }
19
 
20
  # Define scheduler options
 
24
  }
25
 
26
  def download_adapter(repo_id, weight_name=None):
 
 
 
 
27
  adapter_filename = weight_name if weight_name else "pytorch_lora_weights.safetensors"
28
  cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models'))
29
  cleaned_adapter_path = repo_id.replace("/", "_").replace("\\", "_").replace(":", "_")
 
38
  )
39
  return path_to_adapter_file
40
  except Exception as e:
 
41
  if weight_name is None:
42
  raise ValueError(f"Could not download default adapter file: {str(e)}\nPlease specify the exact weight file name.")
43
  else:
 
61
  output_fps,
62
  seed
63
  ):
 
64
  model_id = MODEL_OPTIONS[model_choice]
65
 
 
66
  if seed == -1 or seed is None or seed == "":
67
  seed = random.randint(0, 2147483647)
68
  else:
69
  seed = int(seed)
70
 
 
71
  torch.manual_seed(seed)
72
 
 
73
  vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
74
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16)
75
 
 
76
  if scheduler_type == "UniPCMultistepScheduler":
77
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
 
 
 
78
  else:
79
  pipe.scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
80
 
 
81
  pipe.to("cuda")
82
 
 
83
  if lycoris_id and lycoris_id.strip():
84
  try:
85
+ adapter_file_path = download_adapter(
86
+ repo_id=lycoris_id,
87
+ weight_name=lycoris_weight_name if lycoris_weight_name and lycoris_weight_name.strip() else None
88
+ )
89
  wrapper, *_ = create_lycoris_from_weights(lycoris_scale, adapter_file_path, pipe.transformer)
90
  wrapper.merge_to()
 
91
  except ValueError as e:
 
92
  if "more than one weights file" in str(e) or "Could not download default adapter file" in str(e):
93
+ return f"Error: The repository '{lycoris_id}' may contain multiple weight files. Please specify a weight name.", seed
94
  else:
95
  return f"Error loading LyCORIS weights: {str(e)}", seed
96
+
 
97
  pipe.enable_model_cpu_offload()
98
 
 
99
  output = pipe(
100
  prompt=prompt,
101
  negative_prompt=negative_prompt,
 
107
  generator=torch.Generator("cuda").manual_seed(seed)
108
  ).frames[0]
109
 
 
110
  temp_file = "output.mp4"
111
  export_to_video(output, temp_file, fps=output_fps)
112
 
 
115
  # Create the Gradio interface
116
  with gr.Blocks() as demo:
117
 
118
+ gr.Markdown("# Wan 2.1 T2V with Custom LoRA")
119
+
120
  with gr.Row():
121
  with gr.Column(scale=1):
122
  model_choice = gr.Dropdown(
123
  choices=list(MODEL_OPTIONS.keys()),
124
+ value="Wan2.1-Fun-Reward-1.3B",
125
  label="Model"
126
  )
127
 
128
+ prompt = gr.Textbox(label="Prompt", value="", lines=3)
 
 
 
 
 
129
  negative_prompt = gr.Textbox(
130
  label="Negative Prompt",
131
  value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体��灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿",
132
  lines=3
133
  )
134
 
135
+ lycoris_id = gr.Textbox(
136
+ label="Adapter Repo",
137
+ value="alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
138
+ )
 
139
 
140
  with gr.Row():
141
  lycoris_weight_name = gr.Textbox(
142
+ label="Adapter File Name",
143
+ value="Wan2.1-Fun-1.3B-InP-HPS2.1.safetensors"
 
144
  )
145
  lycoris_scale = gr.Slider(
146
  label="Adapter Scale",
147
  minimum=0.0,
148
  maximum=2.0,
149
+ value=1.0,
150
  step=0.05
151
  )
152
 
153
+ scheduler_type = gr.Dropdown(
154
+ choices=list(SCHEDULER_OPTIONS.keys()),
155
+ value="UniPCMultistepScheduler",
156
+ label="Scheduler"
157
+ )
158
+ flow_shift = gr.Slider(
159
+ label="Flow Shift",
160
+ minimum=1.0,
161
+ maximum=12.0,
162
+ value=3.0,
163
+ step=0.5
164
+ )
 
 
165
 
166
+ height = gr.Slider(
167
+ label="Height",
168
+ minimum=256,
169
+ maximum=1024,
170
+ value=832,
171
+ step=32
172
+ )
173
+ width = gr.Slider(
174
+ label="Width",
175
+ minimum=256,
176
+ maximum=1792,
177
+ value=480,
178
+ step=30
179
+ )
 
180
 
181
+ num_frames = gr.Slider(
182
+ label="Number of Frames",
183
+ minimum=17,
184
+ maximum=129,
185
+ value=33,
186
+ step=4
187
+ )
188
+ output_fps = gr.Slider(
189
+ label="Output FPS",
190
+ minimum=8,
191
+ maximum=30,
192
+ value=16,
193
+ step=1
194
+ )
 
195
 
196
+ guidance_scale = gr.Slider(
197
+ label="Guidance Scale (CFG)",
198
+ minimum=1.0,
199
+ maximum=15.0,
200
+ value=4.0,
201
+ step=0.5
202
+ )
203
+ num_inference_steps = gr.Slider(
204
+ label="Inference Steps",
205
+ minimum=10,
206
+ maximum=100,
207
+ value=20,
208
+ step=1
209
+ )
 
210
 
211
  seed = gr.Number(
212
  label="Seed (-1 for random)",
213
  value=-1,
214
+ precision=0
 
215
  )
216
 
217
  generate_btn = gr.Button("Generate Video")
 
244
 
245
  gr.Markdown("""
246
  ## Tips for best results:
247
+ - Smaller videos: Flow shift 2.05.0
248
+ - Larger videos: Flow shift 7.012.0
249
+ - Use frame count in 4k+1 form (e.g., 33, 65)
250
+ - Limit frame count and resolution to avoid timeout
 
251
  """)
252
 
253
+ demo.launch()