Common: Support postgreSQL database as the metadata db. (#2357)
Browse fileshttps://github.com/infiniflow/ragflow/issues/2356
### What problem does this PR solve?
As title
### Type of change
- [X] New Feature (non-breaking change which adds functionality)
- api/db/db_models.py +72 -10
- api/db/db_utils.py +6 -1
- api/settings.py +2 -1
- conf/service_conf.yaml +8 -0
- requirements.txt +1 -0
    	
        api/db/db_models.py
    CHANGED
    
    | @@ -18,18 +18,19 @@ import os | |
| 18 | 
             
            import sys
         | 
| 19 | 
             
            import typing
         | 
| 20 | 
             
            import operator
         | 
|  | |
| 21 | 
             
            from functools import wraps
         | 
| 22 | 
             
            from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
         | 
| 23 | 
             
            from flask_login import UserMixin
         | 
| 24 | 
            -
            from playhouse.migrate import MySQLMigrator, migrate
         | 
| 25 | 
             
            from peewee import (
         | 
| 26 | 
             
                BigIntegerField, BooleanField, CharField,
         | 
| 27 | 
             
                CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
         | 
| 28 | 
             
                Field, Model, Metadata
         | 
| 29 | 
             
            )
         | 
| 30 | 
            -
            from playhouse.pool import PooledMySQLDatabase
         | 
| 31 | 
             
            from api.db import SerializedType, ParserType
         | 
| 32 | 
            -
            from api.settings import DATABASE, stat_logger, SECRET_KEY
         | 
| 33 | 
             
            from api.utils.log_utils import getLogger
         | 
| 34 | 
             
            from api import utils
         | 
| 35 |  | 
| @@ -58,8 +59,13 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = { | |
| 58 | 
             
                "write_access"}
         | 
| 59 |  | 
| 60 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 | 
             
            class LongTextField(TextField):
         | 
| 62 | 
            -
                field_type =  | 
| 63 |  | 
| 64 |  | 
| 65 | 
             
            class JSONField(LongTextField):
         | 
| @@ -266,18 +272,69 @@ class JsonSerializedField(SerializedField): | |
| 266 | 
             
                    super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
         | 
| 267 | 
             
                                                              object_pairs_hook=object_pairs_hook, **kwargs)
         | 
| 268 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 269 |  | 
| 270 | 
             
            @singleton
         | 
| 271 | 
             
            class BaseDataBase:
         | 
| 272 | 
             
                def __init__(self):
         | 
| 273 | 
             
                    database_config = DATABASE.copy()
         | 
| 274 | 
             
                    db_name = database_config.pop("name")
         | 
| 275 | 
            -
                    self.database_connection =  | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 278 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 279 |  | 
| 280 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 281 | 
             
                def __init__(self, lock_name, timeout=10, db=None):
         | 
| 282 | 
             
                    self.lock_name = lock_name
         | 
| 283 | 
             
                    self.timeout = int(timeout)
         | 
| @@ -325,8 +382,13 @@ class DatabaseLock: | |
| 325 | 
             
                    return magic
         | 
| 326 |  | 
| 327 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 328 | 
             
            DB = BaseDataBase().database_connection
         | 
| 329 | 
            -
            DB.lock = DatabaseLock
         | 
| 330 |  | 
| 331 |  | 
| 332 | 
             
            def close_connection():
         | 
| @@ -918,7 +980,7 @@ class CanvasTemplate(DataBaseModel): | |
| 918 |  | 
| 919 | 
             
            def migrate_db():
         | 
| 920 | 
             
                with DB.transaction():
         | 
| 921 | 
            -
                    migrator =  | 
| 922 | 
             
                    try:
         | 
| 923 | 
             
                        migrate(
         | 
| 924 | 
             
                            migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
         | 
|  | |
| 18 | 
             
            import sys
         | 
| 19 | 
             
            import typing
         | 
| 20 | 
             
            import operator
         | 
| 21 | 
            +
            from enum import Enum
         | 
| 22 | 
             
            from functools import wraps
         | 
| 23 | 
             
            from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
         | 
| 24 | 
             
            from flask_login import UserMixin
         | 
| 25 | 
            +
            from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
         | 
| 26 | 
             
            from peewee import (
         | 
| 27 | 
             
                BigIntegerField, BooleanField, CharField,
         | 
| 28 | 
             
                CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
         | 
| 29 | 
             
                Field, Model, Metadata
         | 
| 30 | 
             
            )
         | 
| 31 | 
            +
            from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
         | 
| 32 | 
             
            from api.db import SerializedType, ParserType
         | 
| 33 | 
            +
            from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
         | 
| 34 | 
             
            from api.utils.log_utils import getLogger
         | 
| 35 | 
             
            from api import utils
         | 
| 36 |  | 
|  | |
| 59 | 
             
                "write_access"}
         | 
| 60 |  | 
| 61 |  | 
| 62 | 
            +
            class TextFieldType(Enum):
         | 
| 63 | 
            +
                MYSQL = 'LONGTEXT'
         | 
| 64 | 
            +
                POSTGRES = 'TEXT'
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
             
            class LongTextField(TextField):
         | 
| 68 | 
            +
                field_type = TextFieldType[DATABASE_TYPE.upper()].value
         | 
| 69 |  | 
| 70 |  | 
| 71 | 
             
            class JSONField(LongTextField):
         | 
|  | |
| 272 | 
             
                    super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
         | 
| 273 | 
             
                                                              object_pairs_hook=object_pairs_hook, **kwargs)
         | 
| 274 |  | 
| 275 | 
            +
            class PooledDatabase(Enum):
         | 
| 276 | 
            +
                MYSQL = PooledMySQLDatabase
         | 
| 277 | 
            +
                POSTGRES = PooledPostgresqlDatabase
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            class DatabaseMigrator(Enum):
         | 
| 281 | 
            +
                MYSQL = MySQLMigrator
         | 
| 282 | 
            +
                POSTGRES = PostgresqlMigrator
         | 
| 283 | 
            +
             | 
| 284 |  | 
| 285 | 
             
            @singleton
         | 
| 286 | 
             
            class BaseDataBase:
         | 
| 287 | 
             
                def __init__(self):
         | 
| 288 | 
             
                    database_config = DATABASE.copy()
         | 
| 289 | 
             
                    db_name = database_config.pop("name")
         | 
| 290 | 
            +
                    self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
         | 
| 291 | 
            +
                    stat_logger.info('init database on cluster mode successfully')
         | 
| 292 | 
            +
             | 
| 293 | 
            +
            class PostgresDatabaseLock:
         | 
| 294 | 
            +
                def __init__(self, lock_name, timeout=10, db=None):
         | 
| 295 | 
            +
                    self.lock_name = lock_name
         | 
| 296 | 
            +
                    self.timeout = int(timeout)
         | 
| 297 | 
            +
                    self.db = db if db else DB
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                def lock(self):
         | 
| 300 | 
            +
                    cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
         | 
| 301 | 
            +
                    ret = cursor.fetchone()
         | 
| 302 | 
            +
                    if ret[0] == 0:
         | 
| 303 | 
            +
                        raise Exception(f'acquire postgres lock {self.lock_name} timeout')
         | 
| 304 | 
            +
                    elif ret[0] == 1:
         | 
| 305 | 
            +
                        return True
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        raise Exception(f'failed to acquire lock {self.lock_name}')
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def unlock(self):
         | 
| 310 | 
            +
                    cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
         | 
| 311 | 
            +
                    ret = cursor.fetchone()
         | 
| 312 | 
            +
                    if ret[0] == 0:
         | 
| 313 | 
            +
                        raise Exception(
         | 
| 314 | 
            +
                            f'postgres lock {self.lock_name} was not established by this thread')
         | 
| 315 | 
            +
                    elif ret[0] == 1:
         | 
| 316 | 
            +
                        return True
         | 
| 317 | 
            +
                    else:
         | 
| 318 | 
            +
                        raise Exception(f'postgres lock {self.lock_name} does not exist')
         | 
| 319 |  | 
| 320 | 
            +
                def __enter__(self):
         | 
| 321 | 
            +
                    if isinstance(self.db, PostgresDatabaseLock):
         | 
| 322 | 
            +
                        self.lock()
         | 
| 323 | 
            +
                    return self
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                def __exit__(self, exc_type, exc_val, exc_tb):
         | 
| 326 | 
            +
                    if isinstance(self.db, PostgresDatabaseLock):
         | 
| 327 | 
            +
                        self.unlock()
         | 
| 328 |  | 
| 329 | 
            +
                def __call__(self, func):
         | 
| 330 | 
            +
                    @wraps(func)
         | 
| 331 | 
            +
                    def magic(*args, **kwargs):
         | 
| 332 | 
            +
                        with self:
         | 
| 333 | 
            +
                            return func(*args, **kwargs)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    return magic
         | 
| 336 | 
            +
             | 
| 337 | 
            +
            class MysqlDatabaseLock:
         | 
| 338 | 
             
                def __init__(self, lock_name, timeout=10, db=None):
         | 
| 339 | 
             
                    self.lock_name = lock_name
         | 
| 340 | 
             
                    self.timeout = int(timeout)
         | 
|  | |
| 382 | 
             
                    return magic
         | 
| 383 |  | 
| 384 |  | 
| 385 | 
            +
            class DatabaseLock(Enum):
         | 
| 386 | 
            +
                MYSQL = MysqlDatabaseLock
         | 
| 387 | 
            +
                POSTGRES = PostgresDatabaseLock
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
             
            DB = BaseDataBase().database_connection
         | 
| 391 | 
            +
            DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
         | 
| 392 |  | 
| 393 |  | 
| 394 | 
             
            def close_connection():
         | 
|  | |
| 980 |  | 
| 981 | 
             
            def migrate_db():
         | 
| 982 | 
             
                with DB.transaction():
         | 
| 983 | 
            +
                    migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
         | 
| 984 | 
             
                    try:
         | 
| 985 | 
             
                        migrate(
         | 
| 986 | 
             
                            migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
         | 
    	
        api/db/db_utils.py
    CHANGED
    
    | @@ -17,6 +17,8 @@ import operator | |
| 17 | 
             
            from functools import reduce
         | 
| 18 | 
             
            from typing import Dict, Type, Union
         | 
| 19 |  | 
|  | |
|  | |
| 20 | 
             
            from api.utils import current_timestamp, timestamp_to_date
         | 
| 21 |  | 
| 22 | 
             
            from api.db.db_models import DB, DataBaseModel
         | 
| @@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): | |
| 49 | 
             
                    with DB.atomic():
         | 
| 50 | 
             
                        query = model.insert_many(data_source[i:i + batch_size])
         | 
| 51 | 
             
                        if replace_on_conflict:
         | 
| 52 | 
            -
                             | 
|  | |
|  | |
|  | |
| 53 | 
             
                        query.execute()
         | 
| 54 |  | 
| 55 |  | 
|  | |
| 17 | 
             
            from functools import reduce
         | 
| 18 | 
             
            from typing import Dict, Type, Union
         | 
| 19 |  | 
| 20 | 
            +
            from playhouse.pool import PooledMySQLDatabase
         | 
| 21 | 
            +
             | 
| 22 | 
             
            from api.utils import current_timestamp, timestamp_to_date
         | 
| 23 |  | 
| 24 | 
             
            from api.db.db_models import DB, DataBaseModel
         | 
|  | |
| 51 | 
             
                    with DB.atomic():
         | 
| 52 | 
             
                        query = model.insert_many(data_source[i:i + batch_size])
         | 
| 53 | 
             
                        if replace_on_conflict:
         | 
| 54 | 
            +
                            if isinstance(DB, PooledMySQLDatabase):
         | 
| 55 | 
            +
                                query = query.on_conflict(preserve=preserve)
         | 
| 56 | 
            +
                            else:
         | 
| 57 | 
            +
                                query = query.on_conflict(conflict_target="id", preserve=preserve)
         | 
| 58 | 
             
                        query.execute()
         | 
| 59 |  | 
| 60 |  | 
    	
        api/settings.py
    CHANGED
    
    | @@ -164,7 +164,8 @@ RANDOM_INSTANCE_ID = get_base_config( | |
| 164 | 
             
            PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
         | 
| 165 | 
             
            PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
         | 
| 166 |  | 
| 167 | 
            -
             | 
|  | |
| 168 |  | 
| 169 | 
             
            # Switch
         | 
| 170 | 
             
            # upload
         | 
|  | |
| 164 | 
             
            PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
         | 
| 165 | 
             
            PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
         | 
| 166 |  | 
| 167 | 
            +
            DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
         | 
| 168 | 
            +
            DATABASE = decrypt_database_config(name=DATABASE_TYPE)
         | 
| 169 |  | 
| 170 | 
             
            # Switch
         | 
| 171 | 
             
            # upload
         | 
    	
        conf/service_conf.yaml
    CHANGED
    
    | @@ -9,6 +9,14 @@ mysql: | |
| 9 | 
             
              port: 3306
         | 
| 10 | 
             
              max_connections: 100
         | 
| 11 | 
             
              stale_timeout: 30
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 | 
             
            minio:
         | 
| 13 | 
             
              user: 'rag_flow'
         | 
| 14 | 
             
              password: 'infini_rag_flow'
         | 
|  | |
| 9 | 
             
              port: 3306
         | 
| 10 | 
             
              max_connections: 100
         | 
| 11 | 
             
              stale_timeout: 30
         | 
| 12 | 
            +
            postgres:
         | 
| 13 | 
            +
              name: 'rag_flow'
         | 
| 14 | 
            +
              user: 'rag_flow'
         | 
| 15 | 
            +
              password: 'infini_rag_flow'
         | 
| 16 | 
            +
              host: 'postgres'
         | 
| 17 | 
            +
              port: 5432
         | 
| 18 | 
            +
              max_connections: 100
         | 
| 19 | 
            +
              stale_timeout: 30
         | 
| 20 | 
             
            minio:
         | 
| 21 | 
             
              user: 'rag_flow'
         | 
| 22 | 
             
              password: 'infini_rag_flow'
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -31,6 +31,7 @@ Flask==3.0.3 | |
| 31 | 
             
            Flask_Cors==5.0.0
         | 
| 32 | 
             
            Flask_Login==0.6.3
         | 
| 33 | 
             
            flask_session==0.8.0
         | 
|  | |
| 34 | 
             
            google_search_results==2.4.2
         | 
| 35 | 
             
            groq==0.9.0
         | 
| 36 | 
             
            hanziconv==0.3.2
         | 
|  | |
| 31 | 
             
            Flask_Cors==5.0.0
         | 
| 32 | 
             
            Flask_Login==0.6.3
         | 
| 33 | 
             
            flask_session==0.8.0
         | 
| 34 | 
            +
            psycopg2==2.9.9
         | 
| 35 | 
             
            google_search_results==2.4.2
         | 
| 36 | 
             
            groq==0.9.0
         | 
| 37 | 
             
            hanziconv==0.3.2
         |