Commit
·
241b23e
1
Parent(s):
0a519de
Refactor trie load and construct (#4083)
Browse files### What problem does this PR solve?
1. Fix initial build and load trie
2. Update comment
### Type of change
- [x] Refactoring
Signed-off-by: jinhai <[email protected]>
- rag/nlp/rag_tokenizer.py +23 -10
rag/nlp/rag_tokenizer.py
CHANGED
|
@@ -36,7 +36,7 @@ class RagTokenizer:
|
|
| 36 |
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
| 37 |
|
| 38 |
def loadDict_(self, fnm):
|
| 39 |
-
logging.info(f"[HUQIE]:Build trie {fnm}")
|
| 40 |
try:
|
| 41 |
of = open(fnm, "r", encoding='utf-8')
|
| 42 |
while True:
|
|
@@ -50,7 +50,10 @@ class RagTokenizer:
|
|
| 50 |
if k not in self.trie_ or self.trie_[k][0] < F:
|
| 51 |
self.trie_[self.key_(line[0])] = (F, line[2])
|
| 52 |
self.trie_[self.rkey_(line[0])] = 1
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
of.close()
|
| 55 |
except Exception:
|
| 56 |
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
|
@@ -58,20 +61,30 @@ class RagTokenizer:
|
|
| 58 |
def __init__(self, debug=False):
|
| 59 |
self.DEBUG = debug
|
| 60 |
self.DENOMINATOR = 1000000
|
| 61 |
-
self.trie_ = datrie.Trie(string.printable)
|
| 62 |
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
| 63 |
|
| 64 |
self.stemmer = PorterStemmer()
|
| 65 |
self.lemmatizer = WordNetLemmatizer()
|
| 66 |
|
| 67 |
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
self.trie_ = datrie.Trie(string.printable)
|
| 74 |
|
|
|
|
| 75 |
self.loadDict_(self.DIR_ + ".txt")
|
| 76 |
|
| 77 |
def loadUserDict(self, fnm):
|
|
@@ -86,7 +99,7 @@ class RagTokenizer:
|
|
| 86 |
self.loadDict_(fnm)
|
| 87 |
|
| 88 |
def _strQ2B(self, ustring):
|
| 89 |
-
"""
|
| 90 |
rstring = ""
|
| 91 |
for uchar in ustring:
|
| 92 |
inside_code = ord(uchar)
|
|
@@ -94,7 +107,7 @@ class RagTokenizer:
|
|
| 94 |
inside_code = 0x0020
|
| 95 |
else:
|
| 96 |
inside_code -= 0xfee0
|
| 97 |
-
if inside_code < 0x0020 or inside_code > 0x7e: #
|
| 98 |
rstring += uchar
|
| 99 |
else:
|
| 100 |
rstring += chr(inside_code)
|
|
|
|
| 36 |
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
| 37 |
|
| 38 |
def loadDict_(self, fnm):
|
| 39 |
+
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
| 40 |
try:
|
| 41 |
of = open(fnm, "r", encoding='utf-8')
|
| 42 |
while True:
|
|
|
|
| 50 |
if k not in self.trie_ or self.trie_[k][0] < F:
|
| 51 |
self.trie_[self.key_(line[0])] = (F, line[2])
|
| 52 |
self.trie_[self.rkey_(line[0])] = 1
|
| 53 |
+
|
| 54 |
+
dict_file_cache = fnm + ".trie"
|
| 55 |
+
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
| 56 |
+
self.trie_.save(dict_file_cache)
|
| 57 |
of.close()
|
| 58 |
except Exception:
|
| 59 |
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
|
|
|
| 61 |
def __init__(self, debug=False):
|
| 62 |
self.DEBUG = debug
|
| 63 |
self.DENOMINATOR = 1000000
|
|
|
|
| 64 |
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
| 65 |
|
| 66 |
self.stemmer = PorterStemmer()
|
| 67 |
self.lemmatizer = WordNetLemmatizer()
|
| 68 |
|
| 69 |
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
| 70 |
+
|
| 71 |
+
trie_file_name = self.DIR_ + ".txt.trie"
|
| 72 |
+
# check if trie file existence
|
| 73 |
+
if os.path.exists(trie_file_name):
|
| 74 |
+
try:
|
| 75 |
+
# load trie from file
|
| 76 |
+
self.trie_ = datrie.Trie.load(trie_file_name)
|
| 77 |
+
return
|
| 78 |
+
except Exception:
|
| 79 |
+
# fail to load trie from file, build default trie
|
| 80 |
+
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
| 81 |
+
self.trie_ = datrie.Trie(string.printable)
|
| 82 |
+
else:
|
| 83 |
+
# file not exist, build default trie
|
| 84 |
+
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
| 85 |
self.trie_ = datrie.Trie(string.printable)
|
| 86 |
|
| 87 |
+
# load data from dict file and save to trie file
|
| 88 |
self.loadDict_(self.DIR_ + ".txt")
|
| 89 |
|
| 90 |
def loadUserDict(self, fnm):
|
|
|
|
| 99 |
self.loadDict_(fnm)
|
| 100 |
|
| 101 |
def _strQ2B(self, ustring):
|
| 102 |
+
"""Convert full-width characters to half-width characters"""
|
| 103 |
rstring = ""
|
| 104 |
for uchar in ustring:
|
| 105 |
inside_code = ord(uchar)
|
|
|
|
| 107 |
inside_code = 0x0020
|
| 108 |
else:
|
| 109 |
inside_code -= 0xfee0
|
| 110 |
+
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
|
| 111 |
rstring += uchar
|
| 112 |
else:
|
| 113 |
rstring += chr(inside_code)
|