Kevin Hu commited on
Commit
da58b16
·
1 Parent(s): 0a9da14

accelerate tokenize (#3244)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement

Files changed (1) hide show
  1. rag/nlp/rag_tokenizer.py +40 -25
rag/nlp/rag_tokenizer.py CHANGED
@@ -281,34 +281,49 @@ class RagTokenizer:
281
  print("[FW]", tks, s)
282
  print("[BW]", tks1, s1)
283
 
284
- diff = [0 for _ in range(max(len(tks1), len(tks)))]
285
- for i in range(min(len(tks1), len(tks))):
286
- if tks[i] != tks1[i]:
287
- diff[i] = 1
288
-
289
- if s1 > s:
290
- tks = tks1
291
-
292
- i = 0
293
- while i < len(tks):
294
- s = i
295
- while s < len(tks) and diff[s] == 0:
296
- s += 1
297
- if s == len(tks):
298
- res.append(" ".join(tks[i:]))
299
- break
300
- if s > i:
301
- res.append(" ".join(tks[i:s]))
302
-
303
- e = s
304
- while e < len(tks) and e - s < 5 and diff[e] == 1:
305
- e += 1
306
-
 
307
  tkslist = []
308
- self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
309
  res.append(" ".join(self.sortTks_(tkslist)[0][0]))
310
 
311
- i = e + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  res = " ".join(self.english_normalize_(res))
314
  if self.DEBUG:
 
281
  print("[FW]", tks, s)
282
  print("[BW]", tks1, s1)
283
 
284
+ i, j, _i, _j = 0, 0, 0, 0
285
+ same = 0
286
+ while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
287
+ same += 1
288
+ if same > 0: res.append(" ".join(tks[j: j + same]))
289
+ _i = i + same
290
+ _j = j + same
291
+ j = _j + 1
292
+ i = _i + 1
293
+
294
+ while i < len(tks1) and j < len(tks):
295
+ tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
296
+ if tk1 != tk:
297
+ if len(tk1) > len(tk):
298
+ j += 1
299
+ else:
300
+ i += 1
301
+ continue
302
+
303
+ if tks1[i] != tks[j]:
304
+ i += 1
305
+ j += 1
306
+ continue
307
+ # backward tokens from_i to i are different from forward tokens from _j to j.
308
  tkslist = []
309
+ self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
310
  res.append(" ".join(self.sortTks_(tkslist)[0][0]))
311
 
312
+ same = 1
313
+ while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
314
+ same += 1
315
+ res.append(" ".join(tks[j: j + same]))
316
+ _i = i + same
317
+ _j = j + same
318
+ j = _j + 1
319
+ i = _i + 1
320
+
321
+ if _i < len(tks1):
322
+ assert _j < len(tks)
323
+ assert "".join(tks1[_i:]) == "".join(tks[_j:])
324
+ tkslist = []
325
+ self.dfs_("".join(tks[_j:]), 0, [], tkslist)
326
+ res.append(" ".join(self.sortTks_(tkslist)[0][0]))
327
 
328
  res = " ".join(self.english_normalize_(res))
329
  if self.DEBUG: