zhijianma commited on
Commit
d135ca6
·
verified ·
1 Parent(s): d637696

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -155,7 +155,7 @@ def copy_func(file):
155
  return cache_file
156
 
157
 
158
- def encode_sample(input_text, input_image, input_video, input_audio):
159
  sample = dict()
160
  sample[image_key]= [input_image] if input_image else []
161
  sample[video_key]=[input_video] if input_video else []
@@ -169,10 +169,17 @@ def encode_sample(input_text, input_image, input_video, input_audio):
169
  input_text += SpecialTokens.audio
170
  sample[text_key]=input_text
171
 
 
 
 
172
  return sample
173
 
174
 
175
- def decode_sample(output_sample):
 
 
 
 
176
  output_text = remove_special_tokens(output_sample[text_key])
177
  output_image = output_sample[image_key][0] if output_sample[image_key] else None
178
  output_video = output_sample[video_key][0] if output_sample[video_key] else None
@@ -199,7 +206,7 @@ def create_tab_layout(op_tab, op_type, run_op, has_stats=False):
199
  with gr.Group('Inputs'):
200
  gr.Markdown(" **Inputs**")
201
  with gr.Row():
202
- input_text = gr.TextArea(label="Text",interactive=True,)
203
  input_image = gr.Image(label='Image', type='filepath', visible=multimodal_visible)
204
  input_video = gr.Video(label='Video', visible=multimodal_visible)
205
  input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal_visible)
@@ -207,11 +214,11 @@ def create_tab_layout(op_tab, op_type, run_op, has_stats=False):
207
  with gr.Group('Outputs'):
208
  gr.Markdown(" **Outputs**")
209
  with gr.Row():
210
- output_text = gr.TextArea(label="Text",interactive=False,)
211
  output_image = gr.Image(label='Image', type='filepath', visible=multimodal_visible)
212
- output_video = gr.Video(label='Video', visible=multimodal_visible)
213
  output_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal_visible)
214
-
215
  with gr.Row():
216
  if has_stats:
217
  output_stats = gr.Json(label='Stats')
@@ -254,9 +261,10 @@ def create_mapper_tab(op_type, op_tab):
254
  def run_op(input_text, input_image, input_video, input_audio, op_name, op_params):
255
  op_class = OPERATORS.modules[op_name]
256
  op = op_class(**op_params)
257
- sample = encode_sample(input_text, input_image, input_video, input_audio)
 
258
  output_sample = op.process(copy.deepcopy(sample))
259
- return decode_sample(output_sample)
260
  create_tab_layout(op_tab, op_type, run_op)
261
 
262
 
 
155
  return cache_file
156
 
157
 
158
+ def encode_sample(input_text, input_image, input_video, input_audio, is_batched_op=False):
159
  sample = dict()
160
  sample[image_key]= [input_image] if input_image else []
161
  sample[video_key]=[input_video] if input_video else []
 
169
  input_text += SpecialTokens.audio
170
  sample[text_key]=input_text
171
 
172
+ if is_batched_op:
173
+ for k, v in sample.items():
174
+ sample[k] = [v]
175
  return sample
176
 
177
 
178
+ def decode_sample(output_sample, is_batched_op=False):
179
+ if is_batched_op:
180
+ for k, v in output_sample.items():
181
+ output_sample[k] = v[-1]
182
+
183
  output_text = remove_special_tokens(output_sample[text_key])
184
  output_image = output_sample[image_key][0] if output_sample[image_key] else None
185
  output_video = output_sample[video_key][0] if output_sample[video_key] else None
 
206
  with gr.Group('Inputs'):
207
  gr.Markdown(" **Inputs**")
208
  with gr.Row():
209
+ input_text = gr.TextArea(label="Text",interactive=True,scale=2)
210
  input_image = gr.Image(label='Image', type='filepath', visible=multimodal_visible)
211
  input_video = gr.Video(label='Video', visible=multimodal_visible)
212
  input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal_visible)
 
214
  with gr.Group('Outputs'):
215
  gr.Markdown(" **Outputs**")
216
  with gr.Row():
217
+ output_text = gr.TextArea(label="Text",interactive=False,scale=2)
218
  output_image = gr.Image(label='Image', type='filepath', visible=multimodal_visible)
219
+ output_video = gr.Video(label='Video', visible=multimodal_visible,)
220
  output_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal_visible)
221
+
222
  with gr.Row():
223
  if has_stats:
224
  output_stats = gr.Json(label='Stats')
 
261
  def run_op(input_text, input_image, input_video, input_audio, op_name, op_params):
262
  op_class = OPERATORS.modules[op_name]
263
  op = op_class(**op_params)
264
+ is_batched_op = op.is_batched_op()
265
+ sample = encode_sample(input_text, input_image, input_video, input_audio, is_batched_op)
266
  output_sample = op.process(copy.deepcopy(sample))
267
+ return decode_sample(output_sample, is_batched_op)
268
  create_tab_layout(op_tab, op_type, run_op)
269
 
270