quin-general / models /sparse_retriever.py
anabmaulana's picture
init
592f71e
raw
history blame
7.27 kB
import hashlib
import logging
import math
import re
from multiprocessing import Pool
from nltk import WordNetLemmatizer, pos_tag
from nltk.corpus import wordnet, stopwords
from typing import List
from tqdm import tqdm
lemmatizer = WordNetLemmatizer()
stopwords = set(stopwords.words('english'))
def get_ngrams(text_tokens: List[str], min_length=1, max_length=4) -> List[str]:
"""
Gets word-level ngrams from text
:param text_tokens: the string used to generate ngrams
:param min_length: the minimum length og the generated ngrams in words
:param max_length: the maximum length og the generated ngrams in words
:return: list of ngrams (strings)
"""
max_length = min(max_length, len(text_tokens))
all_ngrams = []
for n in range(min_length - 1, max_length):
ngrams = [" ".join(ngram) for ngram in zip(*[text_tokens[i:] for i in range(n + 1)])]
for ngram in ngrams:
if '#' not in ngram:
all_ngrams.append(ngram)
return all_ngrams
def clean_text(text):
"""
Removes non-alphanumeric symbols from text
:param text:
:return: clean text
"""
text = text.replace('-', ' ')
text = re.sub('[^a-zA-Z0-9 ]+', '', text)
text = re.sub(' +', ' ', text)
return text
def string_hash(string):
"""
Returns a static hash value for a string
:param string:
:return:
"""
return int(hashlib.md5(string.encode('utf8')).hexdigest(), 16)
def tokenize(text, lemmatize=True, ngrams_length=2):
"""
:param text:
:param stopwords: set of stopwords to exclude
:param lemmatize:
:param ngrams_length: the maximum number of tokens per ngram
:return:
"""
tokens = clean_text(text).lower().split(' ')
tokens = [t for t in tokens if t != '']
if lemmatize:
lemmatized_words = []
pos_labels = pos_tag(tokens)
pos_labels = [pos[1][0].lower() for pos in pos_labels]
for i, word in enumerate(tokens):
if word in stopwords:
word = '#'
if pos_labels[i] == 'j':
pos_labels[i] = 'a' # 'j' <--> 'a' reassignment
if pos_labels[i] in ['r']: # For adverbs it's a bit different
try:
lemma = wordnet.synset(word + '.r.1').lemmas()[0].pertainyms()[0].name()
except:
lemma = word
lemmatized_words.append(lemma)
elif pos_labels[i] in ['a', 's', 'v']: # For adjectives and verbs
lemmatized_words.append(lemmatizer.lemmatize(word, pos=pos_labels[i]))
else: # For nouns and everything else as it is the default kwarg
lemmatized_words.append(lemmatizer.lemmatize(word))
tokens = lemmatized_words
ngrams = get_ngrams(tokens, max_length=ngrams_length)
return ngrams
class SparseRetriever:
def __init__(self, ngram_buckets=16777216, k1=1.5, b=0.75, epsilon=0.25,
max_relative_freq=0.5, workers=4):
self.ngram_buckets = ngram_buckets
self.k1 = k1
self.b = b
self.epsilon = epsilon
self.max_relative_freq = max_relative_freq
self.corpus_size = 0
self.avgdl = 0
self.inverted_index = {}
self.idf = {}
self.doc_len = []
self.workers = workers
def index_documents(self, documents):
with Pool(self.workers) as p:
tokenized_documents = list(tqdm(p.imap(tokenize, documents), total=len(documents), desc='tokenized'))
logging.info('Building inverted index...')
self.corpus_size = len(tokenized_documents)
nd = self._create_inverted_index(tokenized_documents)
self._calc_idf(nd)
logging.info('Built inverted index')
def _create_inverted_index(self, documents):
nd = {} # word -> number of documents with word
num_doc = 0
for doc_id, document in enumerate(tqdm(documents, desc='indexed')):
self.doc_len.append(len(document))
num_doc += len(document)
frequencies = {}
for word in document:
if word not in frequencies:
frequencies[word] = 0
frequencies[word] += 1
for word, freq in frequencies.items():
hashed_word = string_hash(word) % self.ngram_buckets
if hashed_word not in self.inverted_index:
self.inverted_index[hashed_word] = [(doc_id, freq)]
else:
self.inverted_index[hashed_word].append((doc_id, freq))
for word, freq in frequencies.items():
hashed_word = string_hash(word) % self.ngram_buckets
if hashed_word not in nd:
nd[hashed_word] = 0
nd[hashed_word] += 1
self.avgdl = num_doc / self.corpus_size
return nd
def _calc_idf(self, nd):
"""
Calculates frequencies of terms in documents and in corpus.
This algorithm sets a floor on the idf values to eps * average_idf
"""
# collect idf sum to calculate an average idf for epsilon value
idf_sum = 0
# collect words with negative idf to set them a special epsilon value.
# idf can be negative if word is contained in more than half of documents
negative_idfs = []
for word, freq in nd.items():
idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
r_freq = float(freq) / self.corpus_size
if r_freq > self.max_relative_freq:
continue
self.idf[word] = idf
idf_sum += idf
if idf < 0:
negative_idfs.append(word)
self.average_idf = idf_sum / len(self.idf)
eps = self.epsilon * self.average_idf
for word in negative_idfs:
self.idf[word] = eps
def _get_scores(self, query):
"""
The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores,
this algorithm also adds a floor to the idf value of epsilon.
See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info
:param query:
:return:
"""
query = tokenize(query)
scores = {}
for q in query:
hashed_word = string_hash(q) % self.ngram_buckets
idf = self.idf.get(hashed_word)
if idf:
doc_freqs = self.inverted_index[hashed_word]
for doc_id, freq in doc_freqs:
score = idf * (freq * (self.k1 + 1) /
(freq + self.k1 * 1 - self.b + (self.b * self.doc_len[doc_id] / self.avgdl)))
if doc_id in scores:
scores[doc_id] += score
else:
scores[doc_id] = score
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
def search(self, queries, topk=100):
results = [self._get_scores(q) for q in tqdm(queries, desc='searched')]
results = [r[:topk] for r in results]
logging.info('Done searching')
return results