blanchon commited on
Commit
763c766
·
1 Parent(s): ad8bea2

Add spaces

Browse files
Files changed (1) hide show
  1. demo/app_januspro.py +101 -83
demo/app_januspro.py CHANGED
@@ -1,24 +1,19 @@
1
  import gradio as gr
 
 
2
  import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM
4
- from janus.models import MultiModalityCausalLM, VLChatProcessor
5
- from janus.utils.io import load_pil_images
6
  from PIL import Image
7
-
8
- import numpy as np
9
- import os
10
- import time
11
- # import spaces # Import spaces for ZeroGPU compatibility
12
-
13
 
14
  # Load model and processor
15
  model_path = "deepseek-ai/Janus-Pro-7B"
16
  config = AutoConfig.from_pretrained(model_path)
17
  language_config = config.language_config
18
- language_config._attn_implementation = 'eager'
19
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
20
- language_config=language_config,
21
- trust_remote_code=True)
22
  if torch.cuda.is_available():
23
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
24
  else:
@@ -26,20 +21,21 @@ else:
26
 
27
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
28
  tokenizer = vl_chat_processor.tokenizer
29
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
30
 
31
  @torch.inference_mode()
32
- # @spaces.GPU(duration=120)
33
  # Multimodal Understanding function
34
  def multimodal_understanding(image, question, seed, top_p, temperature):
35
  # Clear CUDA cache before generating
36
  torch.cuda.empty_cache()
37
-
38
  # set seed
39
  torch.manual_seed(seed)
40
  np.random.seed(seed)
41
  torch.cuda.manual_seed(seed)
42
-
43
  conversation = [
44
  {
45
  "role": "<|User|>",
@@ -48,15 +44,17 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
48
  },
49
  {"role": "<|Assistant|>", "content": ""},
50
  ]
51
-
52
  pil_images = [Image.fromarray(image)]
53
  prepare_inputs = vl_chat_processor(
54
  conversations=conversation, images=pil_images, force_batchify=True
55
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
56
-
57
-
 
 
58
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
59
-
60
  outputs = vl_gpt.language_model.generate(
61
  inputs_embeds=inputs_embeds,
62
  attention_mask=prepare_inputs.attention_mask,
@@ -69,36 +67,42 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
69
  temperature=temperature,
70
  top_p=top_p,
71
  )
72
-
73
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
74
  return answer
75
 
76
 
77
- def generate(input_ids,
78
- width,
79
- height,
80
- temperature: float = 1,
81
- parallel_size: int = 5,
82
- cfg_weight: float = 5,
83
- image_token_num_per_image: int = 576,
84
- patch_size: int = 16):
 
 
85
  # Clear CUDA cache before generating
86
  torch.cuda.empty_cache()
87
-
88
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
 
 
89
  for i in range(parallel_size * 2):
90
  tokens[i, :] = input_ids
91
  if i % 2 != 0:
92
  tokens[i, 1:-1] = vl_chat_processor.pad_id
93
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
94
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
 
 
95
 
96
  pkv = None
97
  for i in range(image_token_num_per_image):
98
  with torch.no_grad():
99
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
100
- use_cache=True,
101
- past_key_values=pkv)
102
  pkv = outputs.past_key_values
103
  hidden_states = outputs.last_hidden_state
104
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -108,18 +112,21 @@ def generate(input_ids,
108
  probs = torch.softmax(logits / temperature, dim=-1)
109
  next_token = torch.multinomial(probs, num_samples=1)
110
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
111
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
 
 
112
 
113
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
114
  inputs_embeds = img_embeds.unsqueeze(dim=1)
115
 
116
-
117
-
118
- patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
119
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
120
 
121
  return generated_tokens.to(dtype=torch.int), patches
122
 
 
123
  def unpack(dec, width, height, parallel_size=5):
124
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
125
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
@@ -130,13 +137,9 @@ def unpack(dec, width, height, parallel_size=5):
130
  return visual_img
131
 
132
 
133
-
134
  @torch.inference_mode()
135
- # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
136
- def generate_image(prompt,
137
- seed=None,
138
- guidance=5,
139
- t2i_temperature=1.0):
140
  # Clear CUDA cache and avoid tracking gradients
141
  torch.cuda.empty_cache()
142
  # Set the seed for reproducible results
@@ -147,29 +150,37 @@ def generate_image(prompt,
147
  width = 384
148
  height = 384
149
  parallel_size = 5
150
-
151
  with torch.no_grad():
152
- messages = [{'role': '<|User|>', 'content': prompt},
153
- {'role': '<|Assistant|>', 'content': ''}]
154
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
155
- sft_format=vl_chat_processor.sft_format,
156
- system_prompt='')
 
 
 
 
157
  text = text + vl_chat_processor.image_start_tag
158
-
159
  input_ids = torch.LongTensor(tokenizer.encode(text))
160
- output, patches = generate(input_ids,
161
- width // 16 * 16,
162
- height // 16 * 16,
163
- cfg_weight=guidance,
164
- parallel_size=parallel_size,
165
- temperature=t2i_temperature)
166
- images = unpack(patches,
167
- width // 16 * 16,
168
- height // 16 * 16,
169
- parallel_size=parallel_size)
170
-
171
- return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
172
-
 
 
 
 
173
 
174
  # Gradio interface
175
  with gr.Blocks() as demo:
@@ -179,9 +190,13 @@ with gr.Blocks() as demo:
179
  with gr.Column():
180
  question_input = gr.Textbox(label="Question")
181
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
182
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
183
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
184
-
 
 
 
 
185
  understanding_button = gr.Button("Chat")
186
  understanding_output = gr.Textbox(label="Response")
187
 
@@ -199,17 +214,20 @@ with gr.Blocks() as demo:
199
  ],
200
  inputs=[question_input, image_input],
201
  )
202
-
203
-
204
  gr.Markdown(value="# Text-to-Image Generation")
205
 
206
-
207
-
208
  with gr.Row():
209
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
210
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
 
 
 
 
211
 
212
- prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
 
 
213
  seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
214
 
215
  generation_button = gr.Button("Generate Images")
@@ -228,18 +246,18 @@ with gr.Blocks() as demo:
228
  ],
229
  inputs=prompt_input,
230
  )
231
-
232
  understanding_button.click(
233
  multimodal_understanding,
234
  inputs=[image_input, question_input, und_seed_input, top_p, temperature],
235
- outputs=understanding_output
236
  )
237
-
238
  generation_button.click(
239
  fn=generate_image,
240
  inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
241
- outputs=image_output
242
  )
243
 
244
  demo.launch(share=True)
245
- # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import spaces # Import spaces for ZeroGPU compatibility
4
  import torch
5
+ from janus.models import VLChatProcessor
 
 
6
  from PIL import Image
7
+ from transformers import AutoConfig, AutoModelForCausalLM
 
 
 
 
 
8
 
9
  # Load model and processor
10
  model_path = "deepseek-ai/Janus-Pro-7B"
11
  config = AutoConfig.from_pretrained(model_path)
12
  language_config = config.language_config
13
+ language_config._attn_implementation = "eager"
14
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
15
+ model_path, language_config=language_config, trust_remote_code=True
16
+ )
17
  if torch.cuda.is_available():
18
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
19
  else:
 
21
 
22
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
23
  tokenizer = vl_chat_processor.tokenizer
24
+ cuda_device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
 
27
  @torch.inference_mode()
28
+ @spaces.GPU(duration=120)
29
  # Multimodal Understanding function
30
  def multimodal_understanding(image, question, seed, top_p, temperature):
31
  # Clear CUDA cache before generating
32
  torch.cuda.empty_cache()
33
+
34
  # set seed
35
  torch.manual_seed(seed)
36
  np.random.seed(seed)
37
  torch.cuda.manual_seed(seed)
38
+
39
  conversation = [
40
  {
41
  "role": "<|User|>",
 
44
  },
45
  {"role": "<|Assistant|>", "content": ""},
46
  ]
47
+
48
  pil_images = [Image.fromarray(image)]
49
  prepare_inputs = vl_chat_processor(
50
  conversations=conversation, images=pil_images, force_batchify=True
51
+ ).to(
52
+ cuda_device,
53
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
54
+ )
55
+
56
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
57
+
58
  outputs = vl_gpt.language_model.generate(
59
  inputs_embeds=inputs_embeds,
60
  attention_mask=prepare_inputs.attention_mask,
 
67
  temperature=temperature,
68
  top_p=top_p,
69
  )
70
+
71
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
72
  return answer
73
 
74
 
75
+ def generate(
76
+ input_ids,
77
+ width,
78
+ height,
79
+ temperature: float = 1,
80
+ parallel_size: int = 5,
81
+ cfg_weight: float = 5,
82
+ image_token_num_per_image: int = 576,
83
+ patch_size: int = 16,
84
+ ):
85
  # Clear CUDA cache before generating
86
  torch.cuda.empty_cache()
87
+
88
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(
89
+ cuda_device
90
+ )
91
  for i in range(parallel_size * 2):
92
  tokens[i, :] = input_ids
93
  if i % 2 != 0:
94
  tokens[i, 1:-1] = vl_chat_processor.pad_id
95
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
96
+ generated_tokens = torch.zeros(
97
+ (parallel_size, image_token_num_per_image), dtype=torch.int
98
+ ).to(cuda_device)
99
 
100
  pkv = None
101
  for i in range(image_token_num_per_image):
102
  with torch.no_grad():
103
+ outputs = vl_gpt.language_model.model(
104
+ inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv
105
+ )
106
  pkv = outputs.past_key_values
107
  hidden_states = outputs.last_hidden_state
108
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
112
  probs = torch.softmax(logits / temperature, dim=-1)
113
  next_token = torch.multinomial(probs, num_samples=1)
114
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
115
+ next_token = torch.cat(
116
+ [next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
117
+ ).view(-1)
118
 
119
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
120
  inputs_embeds = img_embeds.unsqueeze(dim=1)
121
 
122
+ patches = vl_gpt.gen_vision_model.decode_code(
123
+ generated_tokens.to(dtype=torch.int),
124
+ shape=[parallel_size, 8, width // patch_size, height // patch_size],
125
+ )
126
 
127
  return generated_tokens.to(dtype=torch.int), patches
128
 
129
+
130
  def unpack(dec, width, height, parallel_size=5):
131
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
132
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
 
137
  return visual_img
138
 
139
 
 
140
  @torch.inference_mode()
141
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
142
+ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
 
 
 
143
  # Clear CUDA cache and avoid tracking gradients
144
  torch.cuda.empty_cache()
145
  # Set the seed for reproducible results
 
150
  width = 384
151
  height = 384
152
  parallel_size = 5
153
+
154
  with torch.no_grad():
155
+ messages = [
156
+ {"role": "<|User|>", "content": prompt},
157
+ {"role": "<|Assistant|>", "content": ""},
158
+ ]
159
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
160
+ conversations=messages,
161
+ sft_format=vl_chat_processor.sft_format,
162
+ system_prompt="",
163
+ )
164
  text = text + vl_chat_processor.image_start_tag
165
+
166
  input_ids = torch.LongTensor(tokenizer.encode(text))
167
+ output, patches = generate(
168
+ input_ids,
169
+ width // 16 * 16,
170
+ height // 16 * 16,
171
+ cfg_weight=guidance,
172
+ parallel_size=parallel_size,
173
+ temperature=t2i_temperature,
174
+ )
175
+ images = unpack(
176
+ patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size
177
+ )
178
+
179
+ return [
180
+ Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS)
181
+ for i in range(parallel_size)
182
+ ]
183
+
184
 
185
  # Gradio interface
186
  with gr.Blocks() as demo:
 
190
  with gr.Column():
191
  question_input = gr.Textbox(label="Question")
192
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
193
+ top_p = gr.Slider(
194
+ minimum=0, maximum=1, value=0.95, step=0.05, label="top_p"
195
+ )
196
+ temperature = gr.Slider(
197
+ minimum=0, maximum=1, value=0.1, step=0.05, label="temperature"
198
+ )
199
+
200
  understanding_button = gr.Button("Chat")
201
  understanding_output = gr.Textbox(label="Response")
202
 
 
214
  ],
215
  inputs=[question_input, image_input],
216
  )
217
+
 
218
  gr.Markdown(value="# Text-to-Image Generation")
219
 
 
 
220
  with gr.Row():
221
+ cfg_weight_input = gr.Slider(
222
+ minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight"
223
+ )
224
+ t2i_temperature = gr.Slider(
225
+ minimum=0, maximum=1, value=1.0, step=0.05, label="temperature"
226
+ )
227
 
228
+ prompt_input = gr.Textbox(
229
+ label="Prompt. (Prompt in more detail can help produce better images!)"
230
+ )
231
  seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
232
 
233
  generation_button = gr.Button("Generate Images")
 
246
  ],
247
  inputs=prompt_input,
248
  )
249
+
250
  understanding_button.click(
251
  multimodal_understanding,
252
  inputs=[image_input, question_input, und_seed_input, top_p, temperature],
253
+ outputs=understanding_output,
254
  )
255
+
256
  generation_button.click(
257
  fn=generate_image,
258
  inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
259
+ outputs=image_output,
260
  )
261
 
262
  demo.launch(share=True)
263
+ # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")