aopstudio commited on
Commit
37e8b27
·
1 Parent(s): 4825b73

Add pdf support for QA parser (#1155)

Browse files

### What problem does this PR solve?

Support extracting questions and answers from PDF files

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

rag/app/qa.py CHANGED
@@ -13,13 +13,13 @@
13
  import re
14
  from copy import deepcopy
15
  from io import BytesIO
 
16
  from nltk import word_tokenize
17
  from openpyxl import load_workbook
18
- from rag.nlp import is_english, random_choices, find_codec
19
- from rag.nlp import rag_tokenizer
20
- from deepdoc.parser import ExcelParser
21
-
22
-
23
  class Excel(ExcelParser):
24
  def __call__(self, fnm, binary=None, callback=None):
25
  if not binary:
@@ -62,12 +62,80 @@ class Excel(ExcelParser):
62
  [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
63
  return res
64
 
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def rmPrefix(txt):
67
  return re.sub(
68
  r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  def beAdoc(d, q, a, eng):
72
  qprefix = "Question: " if eng else "问题:"
73
  aprefix = "Answer: " if eng else "回答:"
@@ -145,6 +213,19 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
145
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
146
 
147
  return res
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  raise NotImplementedError(
150
  "Excel and csv(txt) format files are supported.")
@@ -153,6 +234,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
153
  if __name__ == "__main__":
154
  import sys
155
 
156
- def dummy(a, b):
157
  pass
158
- chunk(sys.argv[1], callback=dummy)
 
 
 
13
  import re
14
  from copy import deepcopy
15
  from io import BytesIO
16
+ from timeit import default_timer as timer
17
  from nltk import word_tokenize
18
  from openpyxl import load_workbook
19
+ from rag.nlp import is_english, random_choices, find_codec, qbullets_category, add_positions, has_qbullet
20
+ from rag.nlp import rag_tokenizer, tokenize_table
21
+ from rag.settings import cron_logger
22
+ from deepdoc.parser import PdfParser, ExcelParser
 
23
  class Excel(ExcelParser):
24
  def __call__(self, fnm, binary=None, callback=None):
25
  if not binary:
 
62
  [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
63
  return res
64
 
65
+ class Pdf(PdfParser):
66
+ def __call__(self, filename, binary=None, from_page=0,
67
+ to_page=100000, zoomin=3, callback=None):
68
+ start = timer()
69
+ callback(msg="OCR is running...")
70
+ self.__images__(
71
+ filename if not binary else binary,
72
+ zoomin,
73
+ from_page,
74
+ to_page,
75
+ callback
76
+ )
77
+ callback(msg="OCR finished")
78
+ cron_logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))
79
+ start = timer()
80
+ self._layouts_rec(zoomin, drop=False)
81
+ callback(0.63, "Layout analysis finished.")
82
+ self._table_transformer_job(zoomin)
83
+ callback(0.65, "Table analysis finished.")
84
+ self._text_merge()
85
+ callback(0.67, "Text merging finished")
86
+ tbls = self._extract_table_figure(True, zoomin, True, True)
87
+ #self._naive_vertical_merge()
88
+ # self._concat_downward()
89
+ #self._filter_forpages()
90
+ cron_logger.info("layouts: {}".format(timer() - start))
91
+ sections = [b["text"] for b in self.boxes]
92
+ bull_x0_list = []
93
+ q_bull, reg = qbullets_category(sections)
94
+ if q_bull == -1:
95
+ raise ValueError("Unable to recognize Q&A structure.")
96
+ qai_list = []
97
+ last_q, last_a, last_tag = '', '', ''
98
+ last_index = -1
99
+ last_box = {'text':''}
100
+ last_bull = None
101
+ for box in self.boxes:
102
+ section, line_tag = box['text'], self._line_tag(box, zoomin)
103
+ has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
104
+ last_box, last_index, last_bull = box, index, has_bull
105
+ if not has_bull: # No question bullet
106
+ if not last_q:
107
+ continue
108
+ else:
109
+ last_a = f'{last_a}{section}'
110
+ last_tag = f'{last_tag}{line_tag}'
111
+ else:
112
+ if last_q:
113
+ qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
114
+ last_q, last_a, last_tag = '', '', ''
115
+ last_q = has_bull.group()
116
+ _, end = has_bull.span()
117
+ last_a = section[end:]
118
+ last_tag = line_tag
119
+ if last_q:
120
+ qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
121
+ return qai_list, tbls
122
+
123
  def rmPrefix(txt):
124
  return re.sub(
125
  r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
126
 
127
 
128
+ def beAdocPdf(d, q, a, eng, image, poss):
129
+ qprefix = "Question: " if eng else "问题:"
130
+ aprefix = "Answer: " if eng else "回答:"
131
+ d["content_with_weight"] = "\t".join(
132
+ [qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
133
+ d["content_ltks"] = rag_tokenizer.tokenize(q)
134
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
135
+ d["image"] = image
136
+ add_positions(d, poss)
137
+ return d
138
+
139
  def beAdoc(d, q, a, eng):
140
  qprefix = "Question: " if eng else "问题:"
141
  aprefix = "Answer: " if eng else "回答:"
 
213
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
214
 
215
  return res
216
+ elif re.search(r"\.pdf$", filename, re.IGNORECASE):
217
+ pdf_parser = Pdf()
218
+ count = 0
219
+ qai_list, tbls = pdf_parser(filename if not binary else binary,
220
+ from_page=0, to_page=10000, callback=callback)
221
+
222
+ res = tokenize_table(tbls, doc, eng)
223
+
224
+ for q, a, image, poss in qai_list:
225
+ count += 1
226
+ res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
227
+ return res
228
+
229
 
230
  raise NotImplementedError(
231
  "Excel and csv(txt) format files are supported.")
 
234
  if __name__ == "__main__":
235
  import sys
236
 
237
+ def dummy(prog=None, msg=""):
238
  pass
239
+ import json
240
+
241
+ chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
rag/nlp/__init__.py CHANGED
@@ -21,6 +21,9 @@ from rag.utils import num_tokens_from_string
21
  from . import rag_tokenizer
22
  import re
23
  import copy
 
 
 
24
 
25
  all_codecs = [
26
  'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@@ -57,6 +60,95 @@ def find_codec(blob):
57
 
58
  return "utf-8"
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  BULLET_PATTERN = [[
62
  r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
 
21
  from . import rag_tokenizer
22
  import re
23
  import copy
24
+ import roman_numbers as r
25
+ from word2number import w2n
26
+ from cn2an import cn2an
27
 
28
  all_codecs = [
29
  'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
 
60
 
61
  return "utf-8"
62
 
63
+ QUESTION_PATTERN = [
64
+ r"第([零一二三四五六七八九十百0-9]+)问",
65
+ r"第([零一二三四五六七八九十百0-9]+)条",
66
+ r"[\((]([零一二三四五六七八九十百]+)[\))]",
67
+ r"第([0-9]+)问",
68
+ r"第([0-9]+)条",
69
+ r"([0-9]{1,2})[\. 、]",
70
+ r"([零一二三四五六七八九十百]+)[ 、]",
71
+ r"[\((]([0-9]{1,2})[\))]",
72
+ r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
73
+ r"QUESTION (I+V?|VI*|XI|IX|X)",
74
+ r"QUESTION ([0-9]+)",
75
+ ]
76
+
77
+ def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
78
+ section, last_section = box['text'], last_box['text']
79
+ q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+'
80
+ full_reg = reg + q_reg
81
+ has_bull = re.match(full_reg, section)
82
+ index_str = None
83
+ if has_bull:
84
+ if 'x0' not in last_box:
85
+ last_box['x0'] = box['x0']
86
+ if 'top' not in last_box:
87
+ last_box['top'] = box['top']
88
+ if last_bull and box['x0']-last_box['x0']>10:
89
+ return None, last_index
90
+ if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
91
+ return None, last_index
92
+ avg_bull_x0 = 0
93
+ if bull_x0_list:
94
+ avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
95
+ else:
96
+ avg_bull_x0 = box['x0']
97
+ if box['x0'] - avg_bull_x0 > 10:
98
+ return None, last_index
99
+ index_str = has_bull.group(1)
100
+ index = index_int(index_str)
101
+ if last_section[-1] == ':' or last_section[-1] == ':':
102
+ return None, last_index
103
+ if not last_index or index >= last_index:
104
+ bull_x0_list.append(box['x0'])
105
+ return has_bull, index
106
+ if section[-1] == '?' or section[-1] == '?':
107
+ bull_x0_list.append(box['x0'])
108
+ return has_bull, index
109
+ if box['layout_type'] == 'title':
110
+ bull_x0_list.append(box['x0'])
111
+ return has_bull, index
112
+ pure_section = section.lstrip(re.match(reg, section).group()).lower()
113
+ ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
114
+ if re.match(ask_reg, pure_section):
115
+ bull_x0_list.append(box['x0'])
116
+ return has_bull, index
117
+ return None, last_index
118
+
119
+ def index_int(index_str):
120
+ res = -1
121
+ try:
122
+ res=int(index_str)
123
+ except ValueError:
124
+ try:
125
+ res=w2n.word_to_num(index_str)
126
+ except ValueError:
127
+ try:
128
+ res = cn2an(index_str)
129
+ except ValueError:
130
+ try:
131
+ res = r.number(index_str)
132
+ except ValueError:
133
+ return -1
134
+ return res
135
+
136
+ def qbullets_category(sections):
137
+ global QUESTION_PATTERN
138
+ hits = [0] * len(QUESTION_PATTERN)
139
+ for i, pro in enumerate(QUESTION_PATTERN):
140
+ for sec in sections:
141
+ if re.match(pro, sec) and not not_bullet(sec):
142
+ hits[i] += 1
143
+ break
144
+ maxium = 0
145
+ res = -1
146
+ for i, h in enumerate(hits):
147
+ if h <= maxium:
148
+ continue
149
+ res = i
150
+ maxium = h
151
+ return res, QUESTION_PATTERN[res]
152
 
153
  BULLET_PATTERN = [[
154
  r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
requirements.txt CHANGED
@@ -141,3 +141,6 @@ readability-lxml==0.8.1
141
  html_text==0.6.2
142
  selenium==4.21.0
143
  webdriver-manager==4.0.1
 
 
 
 
141
  html_text==0.6.2
142
  selenium==4.21.0
143
  webdriver-manager==4.0.1
144
+ cn2an==0.5.22
145
+ roman-numbers==1.0.2
146
+ word2number==1.1
requirements_arm.txt CHANGED
@@ -139,4 +139,7 @@ fasttext==0.9.2
139
  volcengine==1.0.141
140
  opencv-python-headless==4.9.0.80
141
  readability-lxml==0.8.1
142
- html_text==0.6.2
 
 
 
 
139
  volcengine==1.0.141
140
  opencv-python-headless==4.9.0.80
141
  readability-lxml==0.8.1
142
+ html_text==0.6.2
143
+ cn2an==0.5.22
144
+ roman-numbers==1.0.2
145
+ word2number==1.1
requirements_dev.txt CHANGED
@@ -126,4 +126,7 @@ fasttext==0.9.2
126
  umap-learn
127
  volcengine==1.0.141
128
  readability-lxml==0.8.1
129
- html_text==0.6.2
 
 
 
 
126
  umap-learn
127
  volcengine==1.0.141
128
  readability-lxml==0.8.1
129
+ html_text==0.6.2
130
+ cn2an==0.5.22
131
+ roman-numbers==1.0.2
132
+ word2number==1.1