fredbai commited on
Commit
da0bc38
·
1 Parent(s): 5b3c777

Common: Support postgreSQL database as the metadata db. (#2357)

Browse files

https://github.com/infiniflow/ragflow/issues/2356

### What problem does this PR solve?

As title

### Type of change

- [X] New Feature (non-breaking change which adds functionality)

api/db/db_models.py CHANGED
@@ -18,18 +18,19 @@ import os
18
  import sys
19
  import typing
20
  import operator
 
21
  from functools import wraps
22
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
23
  from flask_login import UserMixin
24
- from playhouse.migrate import MySQLMigrator, migrate
25
  from peewee import (
26
  BigIntegerField, BooleanField, CharField,
27
  CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
28
  Field, Model, Metadata
29
  )
30
- from playhouse.pool import PooledMySQLDatabase
31
  from api.db import SerializedType, ParserType
32
- from api.settings import DATABASE, stat_logger, SECRET_KEY
33
  from api.utils.log_utils import getLogger
34
  from api import utils
35
 
@@ -58,8 +59,13 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
58
  "write_access"}
59
 
60
 
 
 
 
 
 
61
  class LongTextField(TextField):
62
- field_type = 'LONGTEXT'
63
 
64
 
65
  class JSONField(LongTextField):
@@ -266,18 +272,69 @@ class JsonSerializedField(SerializedField):
266
  super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
267
  object_pairs_hook=object_pairs_hook, **kwargs)
268
 
 
 
 
 
 
 
 
 
 
269
 
270
  @singleton
271
  class BaseDataBase:
272
  def __init__(self):
273
  database_config = DATABASE.copy()
274
  db_name = database_config.pop("name")
275
- self.database_connection = PooledMySQLDatabase(
276
- db_name, **database_config)
277
- stat_logger.info('init mysql database on cluster mode successfully')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
 
 
 
 
 
 
 
 
279
 
280
- class DatabaseLock:
 
 
 
 
 
 
 
 
281
  def __init__(self, lock_name, timeout=10, db=None):
282
  self.lock_name = lock_name
283
  self.timeout = int(timeout)
@@ -325,8 +382,13 @@ class DatabaseLock:
325
  return magic
326
 
327
 
 
 
 
 
 
328
  DB = BaseDataBase().database_connection
329
- DB.lock = DatabaseLock
330
 
331
 
332
  def close_connection():
@@ -918,7 +980,7 @@ class CanvasTemplate(DataBaseModel):
918
 
919
  def migrate_db():
920
  with DB.transaction():
921
- migrator = MySQLMigrator(DB)
922
  try:
923
  migrate(
924
  migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
 
18
  import sys
19
  import typing
20
  import operator
21
+ from enum import Enum
22
  from functools import wraps
23
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
24
  from flask_login import UserMixin
25
+ from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
26
  from peewee import (
27
  BigIntegerField, BooleanField, CharField,
28
  CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
29
  Field, Model, Metadata
30
  )
31
+ from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
32
  from api.db import SerializedType, ParserType
33
+ from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
34
  from api.utils.log_utils import getLogger
35
  from api import utils
36
 
 
59
  "write_access"}
60
 
61
 
62
+ class TextFieldType(Enum):
63
+ MYSQL = 'LONGTEXT'
64
+ POSTGRES = 'TEXT'
65
+
66
+
67
  class LongTextField(TextField):
68
+ field_type = TextFieldType[DATABASE_TYPE.upper()].value
69
 
70
 
71
  class JSONField(LongTextField):
 
272
  super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
273
  object_pairs_hook=object_pairs_hook, **kwargs)
274
 
275
+ class PooledDatabase(Enum):
276
+ MYSQL = PooledMySQLDatabase
277
+ POSTGRES = PooledPostgresqlDatabase
278
+
279
+
280
+ class DatabaseMigrator(Enum):
281
+ MYSQL = MySQLMigrator
282
+ POSTGRES = PostgresqlMigrator
283
+
284
 
285
  @singleton
286
  class BaseDataBase:
287
  def __init__(self):
288
  database_config = DATABASE.copy()
289
  db_name = database_config.pop("name")
290
+ self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
291
+ stat_logger.info('init database on cluster mode successfully')
292
+
293
+ class PostgresDatabaseLock:
294
+ def __init__(self, lock_name, timeout=10, db=None):
295
+ self.lock_name = lock_name
296
+ self.timeout = int(timeout)
297
+ self.db = db if db else DB
298
+
299
+ def lock(self):
300
+ cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
301
+ ret = cursor.fetchone()
302
+ if ret[0] == 0:
303
+ raise Exception(f'acquire postgres lock {self.lock_name} timeout')
304
+ elif ret[0] == 1:
305
+ return True
306
+ else:
307
+ raise Exception(f'failed to acquire lock {self.lock_name}')
308
+
309
+ def unlock(self):
310
+ cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
311
+ ret = cursor.fetchone()
312
+ if ret[0] == 0:
313
+ raise Exception(
314
+ f'postgres lock {self.lock_name} was not established by this thread')
315
+ elif ret[0] == 1:
316
+ return True
317
+ else:
318
+ raise Exception(f'postgres lock {self.lock_name} does not exist')
319
 
320
+ def __enter__(self):
321
+ if isinstance(self.db, PostgresDatabaseLock):
322
+ self.lock()
323
+ return self
324
+
325
+ def __exit__(self, exc_type, exc_val, exc_tb):
326
+ if isinstance(self.db, PostgresDatabaseLock):
327
+ self.unlock()
328
 
329
+ def __call__(self, func):
330
+ @wraps(func)
331
+ def magic(*args, **kwargs):
332
+ with self:
333
+ return func(*args, **kwargs)
334
+
335
+ return magic
336
+
337
+ class MysqlDatabaseLock:
338
  def __init__(self, lock_name, timeout=10, db=None):
339
  self.lock_name = lock_name
340
  self.timeout = int(timeout)
 
382
  return magic
383
 
384
 
385
+ class DatabaseLock(Enum):
386
+ MYSQL = MysqlDatabaseLock
387
+ POSTGRES = PostgresDatabaseLock
388
+
389
+
390
  DB = BaseDataBase().database_connection
391
+ DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
392
 
393
 
394
  def close_connection():
 
980
 
981
  def migrate_db():
982
  with DB.transaction():
983
+ migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
984
  try:
985
  migrate(
986
  migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
api/db/db_utils.py CHANGED
@@ -17,6 +17,8 @@ import operator
17
  from functools import reduce
18
  from typing import Dict, Type, Union
19
 
 
 
20
  from api.utils import current_timestamp, timestamp_to_date
21
 
22
  from api.db.db_models import DB, DataBaseModel
@@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
49
  with DB.atomic():
50
  query = model.insert_many(data_source[i:i + batch_size])
51
  if replace_on_conflict:
52
- query = query.on_conflict(preserve=preserve)
 
 
 
53
  query.execute()
54
 
55
 
 
17
  from functools import reduce
18
  from typing import Dict, Type, Union
19
 
20
+ from playhouse.pool import PooledMySQLDatabase
21
+
22
  from api.utils import current_timestamp, timestamp_to_date
23
 
24
  from api.db.db_models import DB, DataBaseModel
 
51
  with DB.atomic():
52
  query = model.insert_many(data_source[i:i + batch_size])
53
  if replace_on_conflict:
54
+ if isinstance(DB, PooledMySQLDatabase):
55
+ query = query.on_conflict(preserve=preserve)
56
+ else:
57
+ query = query.on_conflict(conflict_target="id", preserve=preserve)
58
  query.execute()
59
 
60
 
api/settings.py CHANGED
@@ -164,7 +164,8 @@ RANDOM_INSTANCE_ID = get_base_config(
164
  PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
165
  PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
166
 
167
- DATABASE = decrypt_database_config(name="mysql")
 
168
 
169
  # Switch
170
  # upload
 
164
  PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
165
  PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
166
 
167
+ DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
168
+ DATABASE = decrypt_database_config(name=DATABASE_TYPE)
169
 
170
  # Switch
171
  # upload
conf/service_conf.yaml CHANGED
@@ -9,6 +9,14 @@ mysql:
9
  port: 3306
10
  max_connections: 100
11
  stale_timeout: 30
 
 
 
 
 
 
 
 
12
  minio:
13
  user: 'rag_flow'
14
  password: 'infini_rag_flow'
 
9
  port: 3306
10
  max_connections: 100
11
  stale_timeout: 30
12
+ postgres:
13
+ name: 'rag_flow'
14
+ user: 'rag_flow'
15
+ password: 'infini_rag_flow'
16
+ host: 'postgres'
17
+ port: 5432
18
+ max_connections: 100
19
+ stale_timeout: 30
20
  minio:
21
  user: 'rag_flow'
22
  password: 'infini_rag_flow'
requirements.txt CHANGED
@@ -31,6 +31,7 @@ Flask==3.0.3
31
  Flask_Cors==5.0.0
32
  Flask_Login==0.6.3
33
  flask_session==0.8.0
 
34
  google_search_results==2.4.2
35
  groq==0.9.0
36
  hanziconv==0.3.2
 
31
  Flask_Cors==5.0.0
32
  Flask_Login==0.6.3
33
  flask_session==0.8.0
34
+ psycopg2==2.9.9
35
  google_search_results==2.4.2
36
  groq==0.9.0
37
  hanziconv==0.3.2