import os import json from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL from typing import Optional # we don't init the database here db = SqliteDatabase(None) class _TranslationCache(Model): id = AutoField() translate_engine = CharField(max_length=20) translate_engine_params = TextField() original_text = TextField() translation = TextField() class Meta: database = db constraints = [ SQL( """ UNIQUE ( translate_engine, translate_engine_params, original_text ) ON CONFLICT REPLACE """ ) ] class TranslationCache: @staticmethod def _sort_dict_recursively(obj): if isinstance(obj, dict): return { k: TranslationCache._sort_dict_recursively(v) for k in sorted(obj.keys()) for v in [obj[k]] } elif isinstance(obj, list): return [TranslationCache._sort_dict_recursively(item) for item in obj] return obj def __init__(self, translate_engine: str, translate_engine_params: dict = None): assert ( len(translate_engine) < 20 ), "current cache require translate engine name less than 20 characters" self.translate_engine = translate_engine self.replace_params(translate_engine_params) # The program typically starts multi-threaded translation # only after cache parameters are fully configured, # so thread safety doesn't need to be considered here. def replace_params(self, params: dict = None): if params is None: params = {} self.params = params params = self._sort_dict_recursively(params) self.translate_engine_params = json.dumps(params) def update_params(self, params: dict = None): if params is None: params = {} self.params.update(params) self.replace_params(self.params) def add_params(self, k: str, v): self.params[k] = v self.replace_params(self.params) # Since peewee and the underlying sqlite are thread-safe, # get and set operations don't need locks. def get(self, original_text: str) -> Optional[str]: result = _TranslationCache.get_or_none( translate_engine=self.translate_engine, translate_engine_params=self.translate_engine_params, original_text=original_text, ) return result.translation if result else None def set(self, original_text: str, translation: str): _TranslationCache.create( translate_engine=self.translate_engine, translate_engine_params=self.translate_engine_params, original_text=original_text, translation=translation, ) def init_db(remove_exists=False): cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh") os.makedirs(cache_folder, exist_ok=True) # The current version does not support database migration, so add the version number to the file name. cache_db_path = os.path.join(cache_folder, "cache.v1.db") if remove_exists and os.path.exists(cache_db_path): os.remove(cache_db_path) db.init( cache_db_path, pragmas={ "journal_mode": "wal", "busy_timeout": 1000, }, ) db.create_tables([_TranslationCache], safe=True) def init_test_db(): import tempfile cache_db_path = tempfile.mktemp(suffix=".db") test_db = SqliteDatabase( cache_db_path, pragmas={ "journal_mode": "wal", "busy_timeout": 1000, }, ) test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False) test_db.connect() test_db.create_tables([_TranslationCache], safe=True) return test_db def clean_test_db(test_db): test_db.drop_tables([_TranslationCache]) test_db.close() db_path = test_db.database if os.path.exists(db_path): os.remove(test_db.database) wal_path = db_path + "-wal" if os.path.exists(wal_path): os.remove(wal_path) shm_path = db_path + "-shm" if os.path.exists(shm_path): os.remove(shm_path) init_db()