|
from datetime import datetime |
|
from typing import Any, Optional |
|
|
|
import nltk |
|
|
|
import pyinflect |
|
import spacy |
|
from fastapi import HTTPException |
|
from nltk.corpus.reader import Synset |
|
|
|
from my_ghost_writer.constants import ELIGIBLE_POS, NLTK_DATA, SPACY_MODEL_NAME, app_logger |
|
from my_ghost_writer.custom_synonym_handler import CustomSynonymHandler |
|
from my_ghost_writer.thesaurus import wn |
|
from my_ghost_writer.type_hints import ContextInfo, RelatedWordGroup, RelatedWordOption, RelatedWordWordResult, \ |
|
TermRelationships |
|
|
|
|
|
custom_synonym_handler = CustomSynonymHandler() |
|
|
|
nlp = None |
|
try: |
|
nlp = spacy.load(SPACY_MODEL_NAME) |
|
app_logger.info(f"spacy model {SPACY_MODEL_NAME} has type:'{type(nlp)}'") |
|
except (OSError, IOError) as io_ex: |
|
app_logger.error(io_ex) |
|
app_logger.error( |
|
f"spaCy model '{SPACY_MODEL_NAME}' not found. Please install it with: 'python -m spacy download {SPACY_MODEL_NAME}'" |
|
) |
|
|
|
|
|
try: |
|
app_logger.info(f"Downloading NLTK data to the folder:'{NLTK_DATA}'") |
|
nltk.download('punkt_tab', quiet=False, download_dir=NLTK_DATA) |
|
nltk.download('wordnet', quiet=False, download_dir=NLTK_DATA) |
|
nltk.download('wordnet31', quiet=False, download_dir=NLTK_DATA) |
|
except Exception as e: |
|
app_logger.error(f"Failed to download NLTK data: {e}") |
|
|
|
|
|
def is_nlp_available() -> bool: |
|
"""Check if spaCy model is available""" |
|
return nlp is not None |
|
|
|
|
|
def find_synonyms_for_phrase(text: str, start_idx: int, end_idx: int) -> list[RelatedWordWordResult]: |
|
""" |
|
Finds related words for all eligible words within a selected text span. |
|
It analyzes the span, filters for meaningful words (nouns, verbs, etc.), |
|
and returns a list of related word results for each. |
|
Raises: HTTPException: If the spaCy model is unavailable. |
|
|
|
Args: |
|
text: The input text (str). |
|
start_idx: The start index of the phrase within the text (int). |
|
end_idx: The end index of the phrase within the text (int). |
|
|
|
Returns: |
|
A list of RelatedWordWordResult objects, representing the related words for each eligible word (list[RelatedWordWordResult]). |
|
""" |
|
if nlp is None: |
|
app_logger.error( |
|
f"spaCy model '{SPACY_MODEL_NAME}' not found. Please install it with: 'python -m spacy download {SPACY_MODEL_NAME}'" |
|
) |
|
raise HTTPException(status_code=503, detail="NLP service is unavailable") |
|
|
|
doc = nlp(text) |
|
|
|
span = doc.char_span(start_idx, end_idx, alignment_mode="expand") |
|
|
|
if span is None: |
|
app_logger.warning(f"Could not create a valid token span from indices {start_idx}-{end_idx}.") |
|
|
|
return [] |
|
|
|
|
|
results: list[RelatedWordWordResult] = [] |
|
|
|
for token in span: |
|
|
|
if token.pos_ in ELIGIBLE_POS and not token.is_stop and not token.is_punct: |
|
try: |
|
|
|
context_info_dict = extract_contextual_info_by_indices( |
|
text, token.idx, token.idx + len(token.text), token.text |
|
) |
|
|
|
|
|
related_word_groups_list = process_synonym_groups(context_info_dict["lemma"], context_info_dict) |
|
|
|
|
|
if related_word_groups_list: |
|
|
|
context_info_model = ContextInfo( |
|
pos=context_info_dict["pos"], |
|
sentence=context_info_dict["context_sentence"], |
|
grammatical_form=context_info_dict["tag"], |
|
context_words=context_info_dict["context_words"], |
|
dependency=context_info_dict["dependency"], |
|
) |
|
local_start_idx = token.idx - start_idx |
|
local_end_idx = local_start_idx + len(token.text) |
|
sliced_sentence = text[start_idx:end_idx] |
|
sliced_word = sliced_sentence[local_start_idx:local_end_idx] |
|
assert sliced_word == token.text, (f"Mismatch! sliced_word ({sliced_word}) != token.text ({token.text}), but these substrings should be equal.\n" |
|
f" start_idx:{start_idx}, End_word:{end_idx}. local_start_idx:{local_start_idx}, local_end_idx:{local_end_idx}.") |
|
word_result = RelatedWordWordResult( |
|
original_word=token.text, |
|
original_indices={"start": local_start_idx, "end": local_end_idx}, |
|
context_info=context_info_model, |
|
related_word_groups=related_word_groups_list, |
|
debug_info={ |
|
"spacy_token_indices": { |
|
"start": context_info_dict["char_start"], |
|
"end": context_info_dict["char_end"], |
|
}, |
|
"lemma": context_info_dict["lemma"] |
|
} |
|
) |
|
results.append(word_result) |
|
|
|
except HTTPException as http_ex: |
|
app_logger.warning(f"Could not process token '{token.text}': '{http_ex.detail}'") |
|
except Exception as synonym_ex: |
|
app_logger.error(f"Unexpected error processing token '{token.text}': '{synonym_ex}'", exc_info=True) |
|
|
|
return results |
|
|
|
|
|
def extract_contextual_info_by_indices(text: str, start_idx: int, end_idx: int, target_word: str) -> dict[str, Any]: |
|
""" |
|
Extract grammatical and contextual information using character indices. |
|
Raises: HTTPException: If the spaCy model is unavailable or if the indices are invalid. |
|
|
|
Args: |
|
text: The input text (str). |
|
start_idx: The start index of the word within the text (int). |
|
end_idx: The end index of the word within the text (int). |
|
target_word: The target word (str). |
|
|
|
Returns: |
|
A dictionary containing contextual information about the word (dict[str, Any). |
|
""" |
|
if nlp is None: |
|
raise HTTPException(status_code=500, detail="spaCy model not available") |
|
|
|
|
|
if start_idx < 0 or end_idx > len(text) or start_idx >= end_idx: |
|
raise HTTPException(status_code=400, detail="Invalid start/end indices") |
|
|
|
try: |
|
doc = nlp(text) |
|
|
|
|
|
target_token = None |
|
for token in doc: |
|
|
|
if (token.idx <= start_idx < token.idx + len(token.text) or |
|
start_idx <= token.idx < end_idx): |
|
target_token = token |
|
break |
|
|
|
|
|
|
|
|
|
if target_token is None or str(target_token) != target_word: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Could not find token for word '{target_word}' at indices {start_idx}-{end_idx}" |
|
) |
|
|
|
|
|
sentence_tokens = [t for t in target_token.sent if not t.is_space] |
|
target_position_in_sentence = None |
|
for i, token in enumerate(sentence_tokens): |
|
if token == target_token: |
|
target_position_in_sentence = i |
|
break |
|
|
|
|
|
context_start = max(0, target_position_in_sentence - 5) if target_position_in_sentence else 0 |
|
context_end = min(len(sentence_tokens), |
|
target_position_in_sentence + 6) if target_position_in_sentence else len(sentence_tokens) |
|
context_words = [t.text for t in sentence_tokens[context_start:context_end]] |
|
|
|
return { |
|
"word": target_token.text, |
|
"lemma": target_token.lemma_, |
|
"pos": target_token.pos_, |
|
"tag": target_token.tag_, |
|
"is_title": target_token.is_title, |
|
"is_upper": target_token.is_upper, |
|
"is_lower": target_token.is_lower, |
|
"dependency": target_token.dep_, |
|
"context_sentence": target_token.sent.text, |
|
"context_words": context_words, |
|
"sentence_position": target_position_in_sentence, |
|
"char_start": target_token.idx, |
|
"char_end": target_token.idx + len(target_token.text), |
|
"original_indices": {"start": start_idx, "end": end_idx}, |
|
} |
|
|
|
except Exception as indices_ex: |
|
app_logger.error(f"Error in contextual analysis: {indices_ex}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Error analyzing context: {str(indices_ex)}") |
|
|
|
|
|
def get_wordnet_synonyms(word: str, pos_tag: Optional[str] = None) -> list[dict[str, Any]]: |
|
""" |
|
Gets related words from WordNet and custom synonym handler, |
|
returning a list of dictionaries containing the raw data, grouped by relation type. |
|
|
|
Args: |
|
word: The word to get related words for (str). |
|
pos_tag: An optional part-of-speech tag to filter WordNet results (Optional[str]). |
|
|
|
Returns: |
|
A list of dictionaries, where each dictionary represents a group of related words (list[dict[str, Any]]). |
|
""" |
|
related_word_groups_raw: list[dict[str, Any]] = [] |
|
word_lower = word.lower() |
|
|
|
|
|
_extract_related_word_groups_custom(related_word_groups_raw, word_lower) |
|
|
|
try: |
|
|
|
pos_map = { |
|
"NOUN": wn.NOUN, |
|
"VERB": wn.VERB, |
|
"ADJ": wn.ADJ, |
|
"ADV": wn.ADV, |
|
} |
|
|
|
|
|
synsets = wn.synsets(word) |
|
|
|
|
|
if pos_tag and pos_tag in pos_map: |
|
synsets = [s for s in synsets if s.pos() == pos_map[pos_tag]] |
|
|
|
|
|
for synset in synsets: |
|
result = _get_related_words(synset, TermRelationships.SYNONYM, word_lower) |
|
related_word_groups_raw.append(result) |
|
for lemma in synset.lemmas(): |
|
result = _get_related_words(lemma, TermRelationships.ANTONYM, word_lower) |
|
related_word_groups_raw.append(result) |
|
for rel_type in [ |
|
TermRelationships.HYPERNYM, TermRelationships.HYPONYM, TermRelationships.MERONYM, |
|
TermRelationships.HOLONYM, TermRelationships.ALSO_SEE, TermRelationships.CAUSE, |
|
|
|
|
|
|
|
|
|
TermRelationships.SIMILAR_TO |
|
]: |
|
result = _get_related_words(synset, rel_type, word_lower) |
|
related_word_groups_raw.append(result) |
|
|
|
except Exception as ex1: |
|
app_logger.error(f"Error getting wn synonyms: '{ex1}' with: word:{type(word)}, '{word}', pos_tag: {type(pos_tag)}, '{pos_tag}'") |
|
raise HTTPException(status_code=500, detail=f"Error retrieving related words: '{str(ex1)}'") |
|
|
|
return [related_words for related_words in related_word_groups_raw if related_words is not None] |
|
|
|
|
|
def _extract_related_word_groups_custom(related_word_groups_raw, word_lower): |
|
for rel_type in TermRelationships: |
|
custom_groups = custom_synonym_handler.get_related(word_lower, rel_type) |
|
if custom_groups: |
|
for related in custom_groups: |
|
words = related["words"] |
|
definition = related.get("definition", "") |
|
related_word_options = [] |
|
for word_from_related_words in words: |
|
related_word_options.append({ |
|
"base_form": word_from_related_words, |
|
"is_custom": True, |
|
"definition": definition, |
|
}) |
|
related_word_groups_raw.append({ |
|
"relation_type": rel_type, |
|
"source": "custom", |
|
"definition": definition, |
|
"examples": [], |
|
"wordnet_pos": None, |
|
"related_words": related_word_options, |
|
}) |
|
|
|
|
|
def _get_base_form_by_synset_type(local_lemma: str, inner_word_lower: str, related_words: list[dict]) -> list[dict]: |
|
lemma_name = local_lemma.replace("_", " ") |
|
if lemma_name.lower() != inner_word_lower: |
|
related_words.append({ |
|
"base_form": lemma_name |
|
}) |
|
return related_words |
|
|
|
|
|
def _get_related_words(related_object, relation_type: TermRelationships, inner_word_lower: str) -> dict|None: |
|
related_words = [] |
|
|
|
if relation_type == TermRelationships.SYNONYM: |
|
|
|
for local_lemma in related_object.lemmas(): |
|
_get_base_form_by_synset_type(local_lemma.name(), inner_word_lower, related_words) |
|
elif relation_type == TermRelationships.ANTONYM: |
|
|
|
for ant in related_object.antonyms(): |
|
_get_base_form_by_synset_type(ant.name(), inner_word_lower, related_words) |
|
else: |
|
|
|
|
|
relation_methods = { |
|
TermRelationships.HYPERNYM: related_object.hypernyms, |
|
TermRelationships.HYPONYM: related_object.hyponyms, |
|
TermRelationships.MERONYM: lambda: related_object.member_meronyms() + related_object.substance_meronyms() + related_object.part_meronyms(), |
|
TermRelationships.HOLONYM: lambda: related_object.member_holonyms() + related_object.substance_holonyms() + related_object.part_holonyms(), |
|
TermRelationships.ALSO_SEE: related_object.also_sees, |
|
TermRelationships.CAUSE: related_object.causes, |
|
|
|
|
|
|
|
TermRelationships.SIMILAR_TO: related_object.similar_tos, |
|
} |
|
get_words_fn = relation_methods.get(relation_type) |
|
if get_words_fn: |
|
for related_synset in get_words_fn(): |
|
_extract_lemmas_or_names_from_synset(inner_word_lower, related_synset, related_words) |
|
if related_words: |
|
return { |
|
"relation_type": relation_type, |
|
"source": "wordnet", |
|
"definition": _get_related_object_definition(related_object), |
|
"examples": _get_related_object_examples(related_object), |
|
"wordnet_pos": _get_related_wordnet_pos(related_object), |
|
"related_words": related_words, |
|
} |
|
return None |
|
|
|
|
|
def _extract_lemmas_or_names_from_synset(inner_word_lower, related_synset, related_words): |
|
|
|
if hasattr(related_synset, "lemmas"): |
|
for local_lemma in related_synset.lemmas(): |
|
_get_base_form_by_synset_type(local_lemma.name(), inner_word_lower, related_words) |
|
elif hasattr(related_synset, "name"): |
|
_get_base_form_by_synset_type(related_synset.name(), inner_word_lower, related_words) |
|
|
|
|
|
def _get_related_wordnet_pos(related_object: Synset): |
|
return related_object.pos() if hasattr(related_object, "pos") else None |
|
|
|
|
|
def _get_related_object_examples(related_object: Synset, n: int = 2) -> list[str]: |
|
return related_object.examples()[:n] if hasattr(related_object, "examples") else [] |
|
|
|
|
|
def _get_related_object_definition(related_object: Synset) -> str: |
|
return related_object.definition() if hasattr(related_object, "definition") else "" |
|
|
|
|
|
def inflect_synonym(synonym: str, original_token_info: dict[str, Any]) -> str: |
|
"""Adapt the input synonym arg to match the original word's grammatical form""" |
|
|
|
if nlp is None: |
|
return synonym |
|
|
|
pos = original_token_info.get("pos") |
|
tag = original_token_info.get("tag") |
|
|
|
|
|
if original_token_info.get("is_title"): |
|
synonym = synonym.title() |
|
elif original_token_info.get("is_upper"): |
|
synonym = synonym.upper() |
|
elif original_token_info.get("is_lower", True): |
|
synonym = synonym.lower() |
|
|
|
|
|
try: |
|
|
|
inflection_tags = { |
|
"NOUN": ["NNS", "NNPS"], |
|
"VERB": ["VBD", "VBN", "VBZ", "VBG"], |
|
"ADJ": ["JJR", "JJS"], |
|
} |
|
|
|
|
|
if pos in inflection_tags and tag in inflection_tags.get(pos, []): |
|
doc = nlp(synonym) |
|
if doc and len(doc) > 0: |
|
inflected = doc[0]._.inflect(tag) |
|
if inflected: |
|
|
|
return inflected + synonym[len(doc[0].text):] |
|
return synonym |
|
|
|
except Exception as ex2: |
|
app_logger.warning(f"Inflection error for '{synonym}': '{ex2}'") |
|
|
|
|
|
return synonym |
|
|
|
|
|
def process_synonym_groups(word: str, context_info: dict[str, Any]) -> list[RelatedWordGroup]: |
|
"""Process given related word groups with inflection matching |
|
|
|
Args: |
|
word (str): the word |
|
context_info (dict[str, Any]): the original form of data |
|
|
|
Returns: |
|
list[RelatedWordGroup]: List of the processed related words |
|
""" |
|
|
|
t0 = datetime.now() |
|
|
|
related_words_raw = get_wordnet_synonyms(context_info["lemma"], context_info["pos"]) |
|
t1 = datetime.now() |
|
duration = (t1 - t0).total_seconds() |
|
app_logger.info(f"# 1/Got get_wordnet_synonyms result with '{word}' word in {duration:.3f}s.") |
|
|
|
if not related_words_raw: |
|
return [] |
|
|
|
|
|
processed_groups: list[RelatedWordGroup] = [] |
|
for related_group in related_words_raw: |
|
app_logger.info(f"related_group:'{related_group}'") |
|
relation_type = related_group["relation_type"] |
|
definition = related_group.get("definition", "") |
|
examples = related_group.get("examples", []) |
|
wordnet_pos = related_group.get("wordnet_pos") |
|
related_words = related_group["related_words"] |
|
processed_options: list[RelatedWordOption] = [] |
|
|
|
for related_word in related_words: |
|
base_form = related_word["base_form"] |
|
inflected_form = inflect_synonym(base_form, context_info) |
|
|
|
related_word_option = RelatedWordOption( |
|
base_form=base_form, |
|
inflected_form=inflected_form, |
|
matches_context=inflected_form.lower() != base_form.lower() |
|
) |
|
if "is_custom" in related_word: |
|
related_word_option.is_custom = related_word["is_custom"] |
|
processed_options.append(related_word_option) |
|
app_logger.info(f"wordnet_pos:{type(wordnet_pos)}, '{wordnet_pos}'") |
|
processed_groups.append( |
|
RelatedWordGroup( |
|
relation_type=relation_type, |
|
definition=definition, |
|
examples=examples, |
|
related_words=processed_options, |
|
wordnet_pos=wordnet_pos |
|
) |
|
) |
|
return processed_groups |
|
|