zhijianma commited on
Commit
d637696
·
verified ·
1 Parent(s): 4938d99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -11,6 +11,7 @@ from datasets import Dataset
11
 
12
  from data_juicer.ops.base_op import OPERATORS
13
  from data_juicer.utils.constant import Fields
 
14
 
15
  demo_path = os.path.dirname(os.path.abspath(__file__))
16
  project_path = os.path.dirname(os.path.dirname(demo_path))
@@ -156,15 +157,23 @@ def copy_func(file):
156
 
157
  def encode_sample(input_text, input_image, input_video, input_audio):
158
  sample = dict()
159
- sample[text_key]=input_text
160
  sample[image_key]= [input_image] if input_image else []
161
  sample[video_key]=[input_video] if input_video else []
162
  sample[audio_key]=[input_audio] if input_audio else []
 
 
 
 
 
 
 
 
 
163
  return sample
164
 
165
 
166
  def decode_sample(output_sample):
167
- output_text = output_sample[text_key]
168
  output_image = output_sample[image_key][0] if output_sample[image_key] else None
169
  output_video = output_sample[video_key][0] if output_sample[video_key] else None
170
  output_audio = output_sample[audio_key][0] if output_sample[audio_key] else None
 
11
 
12
  from data_juicer.ops.base_op import OPERATORS
13
  from data_juicer.utils.constant import Fields
14
+ from data_juicer.utils.mm_utils import SpecialTokens, remove_special_tokens
15
 
16
  demo_path = os.path.dirname(os.path.abspath(__file__))
17
  project_path = os.path.dirname(os.path.dirname(demo_path))
 
157
 
158
  def encode_sample(input_text, input_image, input_video, input_audio):
159
  sample = dict()
 
160
  sample[image_key]= [input_image] if input_image else []
161
  sample[video_key]=[input_video] if input_video else []
162
  sample[audio_key]=[input_audio] if input_audio else []
163
+
164
+ if input_image:
165
+ input_text += SpecialTokens.image
166
+ if input_video:
167
+ input_text += SpecialTokens.video
168
+ if input_audio:
169
+ input_text += SpecialTokens.audio
170
+ sample[text_key]=input_text
171
+
172
  return sample
173
 
174
 
175
  def decode_sample(output_sample):
176
+ output_text = remove_special_tokens(output_sample[text_key])
177
  output_image = output_sample[image_key][0] if output_sample[image_key] else None
178
  output_video = output_sample[video_key][0] if output_sample[video_key] else None
179
  output_audio = output_sample[audio_key][0] if output_sample[audio_key] else None