|
import unittest |
|
from pdf2zh import cache |
|
import threading |
|
import multiprocessing |
|
import random |
|
import string |
|
|
|
|
|
class TestCache(unittest.TestCase): |
|
def setUp(self): |
|
self.test_db = cache.init_test_db() |
|
|
|
def tearDown(self): |
|
|
|
cache.clean_test_db(self.test_db) |
|
|
|
def test_basic_set_get(self): |
|
"""Test basic set and get operations""" |
|
cache_instance = cache.TranslationCache("test_engine") |
|
|
|
|
|
result = cache_instance.get("hello") |
|
self.assertIsNone(result) |
|
|
|
|
|
cache_instance.set("hello", "你好") |
|
result = cache_instance.get("hello") |
|
self.assertEqual(result, "你好") |
|
|
|
def test_cache_overwrite(self): |
|
"""Test that cache entries can be overwritten""" |
|
cache_instance = cache.TranslationCache("test_engine") |
|
|
|
|
|
cache_instance.set("hello", "你好") |
|
|
|
|
|
cache_instance.set("hello", "您好") |
|
|
|
|
|
result = cache_instance.get("hello") |
|
self.assertEqual(result, "您好") |
|
|
|
def test_non_string_params(self): |
|
"""Test that non-string parameters are automatically converted to JSON""" |
|
params = {"model": "gpt-3.5", "temperature": 0.7} |
|
cache_instance = cache.TranslationCache("test_engine", params) |
|
|
|
|
|
cache_instance.set("hello", "你好") |
|
result = cache_instance.get("hello") |
|
self.assertEqual(result, "你好") |
|
|
|
|
|
array_params = ["param1", "param2"] |
|
cache_instance2 = cache.TranslationCache("test_engine", array_params) |
|
cache_instance2.set("hello", "你好2") |
|
self.assertEqual(cache_instance2.get("hello"), "你好2") |
|
|
|
|
|
nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}} |
|
cache_instance3 = cache.TranslationCache("test_engine", nested_params) |
|
cache_instance3.set("hello", "你好3") |
|
self.assertEqual(cache_instance3.get("hello"), "你好3") |
|
|
|
def test_engine_distinction(self): |
|
"""Test that cache distinguishes between different translation engines""" |
|
cache1 = cache.TranslationCache("engine1") |
|
cache2 = cache.TranslationCache("engine2") |
|
|
|
|
|
cache1.set("hello", "你好 1") |
|
cache2.set("hello", "你好 2") |
|
|
|
|
|
self.assertEqual(cache1.get("hello"), "你好 1") |
|
self.assertEqual(cache2.get("hello"), "你好 2") |
|
|
|
def test_params_distinction(self): |
|
"""Test that cache distinguishes between different engine parameters""" |
|
params1 = {"param": "value1"} |
|
params2 = {"param": "value2"} |
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
|
|
|
|
cache1.set("hello", "你好 1") |
|
cache2.set("hello", "你好 2") |
|
|
|
|
|
self.assertEqual(cache1.get("hello"), "你好 1") |
|
self.assertEqual(cache2.get("hello"), "你好 2") |
|
|
|
def test_consistent_param_serialization(self): |
|
"""Test that dictionary parameters are consistently serialized regardless of key order""" |
|
|
|
params1 = {"b": 1, "a": 2} |
|
params2 = {"a": 2, "b": 1} |
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) |
|
|
|
|
|
params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3} |
|
params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}} |
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) |
|
|
|
|
|
params1 = {"b": [{"y": 1, "x": 2}], "a": 3} |
|
params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} |
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) |
|
|
|
|
|
params1 = {"a": 1, "b": 2} |
|
params2 = {"a": 2, "b": 1} |
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
self.assertNotEqual( |
|
cache1.translate_engine_params, cache2.translate_engine_params |
|
) |
|
|
|
def test_cache_with_sorted_params(self): |
|
"""Test that cache works correctly with sorted parameters""" |
|
params1 = {"b": [{"y": 1, "x": 2}], "a": 3} |
|
params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} |
|
|
|
|
|
cache1 = cache.TranslationCache("test_engine", params1) |
|
cache1.set("hello", "你好") |
|
|
|
cache2 = cache.TranslationCache("test_engine", params2) |
|
self.assertEqual(cache2.get("hello"), "你好") |
|
|
|
def test_append_params(self): |
|
"""Test the append_params method""" |
|
cache_instance = cache.TranslationCache("test_engine", {"initial": "value"}) |
|
|
|
|
|
cache_instance.add_params("new_param", "new_value") |
|
self.assertEqual( |
|
cache_instance.params, {"initial": "value", "new_param": "new_value"} |
|
) |
|
|
|
|
|
cache_instance.set("hello", "你好") |
|
self.assertEqual(cache_instance.get("hello"), "你好") |
|
|
|
|
|
cache_instance.add_params("initial", "new_value") |
|
self.assertEqual( |
|
cache_instance.params, {"initial": "new_value", "new_param": "new_value"} |
|
) |
|
|
|
|
|
cache_instance.set("hello2", "你好2") |
|
self.assertEqual(cache_instance.get("hello2"), "你好2") |
|
|
|
def test_thread_safety(self): |
|
"""Test thread safety of cache operations""" |
|
cache_instance = cache.TranslationCache("test_engine") |
|
lock = threading.Lock() |
|
results = [] |
|
num_threads = multiprocessing.cpu_count() |
|
items_per_thread = 100 |
|
|
|
def generate_random_text(length=10): |
|
return "".join( |
|
random.choices(string.ascii_letters + string.digits, k=length) |
|
) |
|
|
|
def worker(): |
|
thread_results = [] |
|
for _ in range(items_per_thread): |
|
text = generate_random_text() |
|
translation = f"翻译_{text}" |
|
|
|
|
|
cache_instance.set(text, translation) |
|
|
|
|
|
result = cache_instance.get(text) |
|
thread_results.append((text, result)) |
|
|
|
|
|
with lock: |
|
results.extend(thread_results) |
|
|
|
|
|
threads = [] |
|
for _ in range(num_threads): |
|
thread = threading.Thread(target=worker) |
|
threads.append(thread) |
|
thread.start() |
|
|
|
|
|
for thread in threads: |
|
thread.join() |
|
|
|
|
|
expected_total = num_threads * items_per_thread |
|
self.assertEqual(len(results), expected_total) |
|
|
|
|
|
for text, result in results: |
|
expected = f"翻译_{text}" |
|
self.assertEqual(result, expected) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|