import unittest from pdf2zh import cache import threading import multiprocessing import random import string class TestCache(unittest.TestCase): def setUp(self): self.test_db = cache.init_test_db() def tearDown(self): # Clean up cache.clean_test_db(self.test_db) def test_basic_set_get(self): """Test basic set and get operations""" cache_instance = cache.TranslationCache("test_engine") # Test get with non-existent entry result = cache_instance.get("hello") self.assertIsNone(result) # Test set and get cache_instance.set("hello", "你好") result = cache_instance.get("hello") self.assertEqual(result, "你好") def test_cache_overwrite(self): """Test that cache entries can be overwritten""" cache_instance = cache.TranslationCache("test_engine") # Set initial translation cache_instance.set("hello", "你好") # Overwrite with new translation cache_instance.set("hello", "您好") # Verify the new translation is returned result = cache_instance.get("hello") self.assertEqual(result, "您好") def test_non_string_params(self): """Test that non-string parameters are automatically converted to JSON""" params = {"model": "gpt-3.5", "temperature": 0.7} cache_instance = cache.TranslationCache("test_engine", params) # Test that params are converted to JSON string internally cache_instance.set("hello", "你好") result = cache_instance.get("hello") self.assertEqual(result, "你好") # Test with different param types array_params = ["param1", "param2"] cache_instance2 = cache.TranslationCache("test_engine", array_params) cache_instance2.set("hello", "你好2") self.assertEqual(cache_instance2.get("hello"), "你好2") # Test with nested structures nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}} cache_instance3 = cache.TranslationCache("test_engine", nested_params) cache_instance3.set("hello", "你好3") self.assertEqual(cache_instance3.get("hello"), "你好3") def test_engine_distinction(self): """Test that cache distinguishes between different translation engines""" cache1 = cache.TranslationCache("engine1") cache2 = cache.TranslationCache("engine2") # Set same text with different engines cache1.set("hello", "你好 1") cache2.set("hello", "你好 2") # Verify each engine gets its own translation self.assertEqual(cache1.get("hello"), "你好 1") self.assertEqual(cache2.get("hello"), "你好 2") def test_params_distinction(self): """Test that cache distinguishes between different engine parameters""" params1 = {"param": "value1"} params2 = {"param": "value2"} cache1 = cache.TranslationCache("test_engine", params1) cache2 = cache.TranslationCache("test_engine", params2) # Set same text with different parameters cache1.set("hello", "你好 1") cache2.set("hello", "你好 2") # Verify each parameter set gets its own translation self.assertEqual(cache1.get("hello"), "你好 1") self.assertEqual(cache2.get("hello"), "你好 2") def test_consistent_param_serialization(self): """Test that dictionary parameters are consistently serialized regardless of key order""" # Test simple dictionary params1 = {"b": 1, "a": 2} params2 = {"a": 2, "b": 1} cache1 = cache.TranslationCache("test_engine", params1) cache2 = cache.TranslationCache("test_engine", params2) self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) # Test nested dictionary params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3} params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}} cache1 = cache.TranslationCache("test_engine", params1) cache2 = cache.TranslationCache("test_engine", params2) self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) # Test dictionary with list of dictionaries params1 = {"b": [{"y": 1, "x": 2}], "a": 3} params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} cache1 = cache.TranslationCache("test_engine", params1) cache2 = cache.TranslationCache("test_engine", params2) self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params) # Test that different values still produce different results params1 = {"a": 1, "b": 2} params2 = {"a": 2, "b": 1} cache1 = cache.TranslationCache("test_engine", params1) cache2 = cache.TranslationCache("test_engine", params2) self.assertNotEqual( cache1.translate_engine_params, cache2.translate_engine_params ) def test_cache_with_sorted_params(self): """Test that cache works correctly with sorted parameters""" params1 = {"b": [{"y": 1, "x": 2}], "a": 3} params2 = {"a": 3, "b": [{"x": 2, "y": 1}]} # Both caches should work with the same key cache1 = cache.TranslationCache("test_engine", params1) cache1.set("hello", "你好") cache2 = cache.TranslationCache("test_engine", params2) self.assertEqual(cache2.get("hello"), "你好") def test_append_params(self): """Test the append_params method""" cache_instance = cache.TranslationCache("test_engine", {"initial": "value"}) # Test appending new parameter cache_instance.add_params("new_param", "new_value") self.assertEqual( cache_instance.params, {"initial": "value", "new_param": "new_value"} ) # Test that cache with appended params works correctly cache_instance.set("hello", "你好") self.assertEqual(cache_instance.get("hello"), "你好") # Test overwriting existing parameter cache_instance.add_params("initial", "new_value") self.assertEqual( cache_instance.params, {"initial": "new_value", "new_param": "new_value"} ) # Cache should work with updated params cache_instance.set("hello2", "你好2") self.assertEqual(cache_instance.get("hello2"), "你好2") def test_thread_safety(self): """Test thread safety of cache operations""" cache_instance = cache.TranslationCache("test_engine") lock = threading.Lock() results = [] num_threads = multiprocessing.cpu_count() items_per_thread = 100 def generate_random_text(length=10): return "".join( random.choices(string.ascii_letters + string.digits, k=length) ) def worker(): thread_results = [] # 线程本地存储结果 for _ in range(items_per_thread): text = generate_random_text() translation = f"翻译_{text}" # Write operation cache_instance.set(text, translation) # Read operation - verify our own write result = cache_instance.get(text) thread_results.append((text, result)) # 所有操作完成后,一次性加锁并追加结果 with lock: results.extend(thread_results) # Create threads equal to CPU core count threads = [] for _ in range(num_threads): thread = threading.Thread(target=worker) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Verify all operations were successful expected_total = num_threads * items_per_thread self.assertEqual(len(results), expected_total) # Verify each thread got its correct value for text, result in results: expected = f"翻译_{text}" self.assertEqual(result, expected) if __name__ == "__main__": unittest.main()