jinhai-2012 commited on
Commit
4359d15
·
1 Parent(s): 036a97a

API: create dataset (#1106)

Browse files

### What problem does this PR solve?

This PR have finished 'create dataset' of both HTTP API and Python SDK.
HTTP API:
```
curl --request POST --url http://<HOST_ADDRESS>/api/v1/dataset --header 'Content-Type: application/json' --header 'Authorization: <ACCESS_KEY>' --data-binary '{
"name": "<DATASET_NAME>"
}'
```

Python SDK:
```
from ragflow.ragflow import RAGFLow
ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380')
ragflow.create_dataset("dataset1")

```

TODO:
- ACCESS_KEY is the login_token when user login RAGFlow, currently.
RAGFlow should have the function that user can add/delete access_key.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update

---------

Signed-off-by: Jin Hai <[email protected]>

api/apps/__init__.py CHANGED
@@ -63,12 +63,17 @@ login_manager.init_app(app)
63
 
64
 
65
  def search_pages_path(pages_dir):
66
- return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
 
 
 
67
 
68
 
69
  def register_page(page_path):
70
- page_name = page_path.stem.rstrip('_app')
71
- module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, ))
 
 
72
 
73
  spec = spec_from_file_location(module_name, page_path)
74
  page = module_from_spec(spec)
@@ -76,17 +81,17 @@ def register_page(page_path):
76
  page.manager = Blueprint(page_name, module_name)
77
  sys.modules[module_name] = page
78
  spec.loader.exec_module(page)
79
-
80
  page_name = getattr(page, 'page_name', page_name)
81
- url_prefix = f'/{API_VERSION}/{page_name}'
82
 
83
  app.register_blueprint(page.manager, url_prefix=url_prefix)
 
84
  return url_prefix
85
 
86
 
87
  pages_dir = [
88
  Path(__file__).parent,
89
- Path(__file__).parent.parent / 'api' / 'apps',
90
  ]
91
 
92
  client_urls_prefix = [
 
63
 
64
 
65
  def search_pages_path(pages_dir):
66
+ app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
67
+ api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')]
68
+ app_path_list.extend(api_path_list)
69
+ return app_path_list
70
 
71
 
72
  def register_page(page_path):
73
+ path = f'{page_path}'
74
+
75
+ page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app')
76
+ module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,))
77
 
78
  spec = spec_from_file_location(module_name, page_path)
79
  page = module_from_spec(spec)
 
81
  page.manager = Blueprint(page_name, module_name)
82
  sys.modules[module_name] = page
83
  spec.loader.exec_module(page)
 
84
  page_name = getattr(page, 'page_name', page_name)
85
+ url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
86
 
87
  app.register_blueprint(page.manager, url_prefix=url_prefix)
88
+ print(f'API file: {page_path}, URL: {url_prefix}')
89
  return url_prefix
90
 
91
 
92
  pages_dir = [
93
  Path(__file__).parent,
94
+ Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove?
95
  ]
96
 
97
  client_urls_prefix = [
api/apps/dataset_api.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import json
18
+ import os
19
+ import re
20
+ from datetime import datetime, timedelta
21
+ from flask import request, Response
22
+ from flask_login import login_required, current_user
23
+
24
+ from api.db import FileType, ParserType, FileSource, StatusEnum
25
+ from api.db.db_models import APIToken, API4Conversation, Task, File
26
+ from api.db.services import duplicate_name
27
+ from api.db.services.api_service import APITokenService, API4ConversationService
28
+ from api.db.services.dialog_service import DialogService, chat
29
+ from api.db.services.document_service import DocumentService
30
+ from api.db.services.file2document_service import File2DocumentService
31
+ from api.db.services.file_service import FileService
32
+ from api.db.services.knowledgebase_service import KnowledgebaseService
33
+ from api.db.services.task_service import queue_tasks, TaskService
34
+ from api.db.services.user_service import UserTenantService, TenantService
35
+ from api.settings import RetCode, retrievaler
36
+ from api.utils import get_uuid, current_timestamp, datetime_format
37
+ # from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
38
+ from itsdangerous import URLSafeTimedSerializer
39
+
40
+ from api.utils.file_utils import filename_type, thumbnail
41
+ from rag.utils.minio_conn import MINIO
42
+
43
+ # import library
44
+ from api.utils.api_utils import construct_json_result, construct_result, construct_error_response, validate_request
45
+ from api.contants import NAME_LENGTH_LIMIT
46
+
47
+ # ------------------------------ create a dataset ---------------------------------------
48
+ @manager.route('/', methods=['POST'])
49
+ @login_required # use login
50
+ @validate_request("name") # check name key
51
+ def create_dataset():
52
+ # Check if Authorization header is present
53
+ authorization_token = request.headers.get('Authorization')
54
+ if not authorization_token:
55
+ return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Authorization header is missing.")
56
+
57
+ # TODO: Login or API key
58
+ # objs = APIToken.query(token=authorization_token)
59
+ #
60
+ # # Authorization error
61
+ # if not objs:
62
+ # return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Token is invalid.")
63
+ #
64
+ # tenant_id = objs[0].tenant_id
65
+
66
+ tenant_id = current_user.id
67
+ request_body = request.json
68
+
69
+ # In case that there's no name
70
+ if "name" not in request_body:
71
+ return construct_json_result(code=RetCode.DATA_ERROR, message="Expected 'name' field in request body")
72
+
73
+ dataset_name = request_body["name"]
74
+
75
+ # empty dataset_name
76
+ if not dataset_name:
77
+ return construct_json_result(code=RetCode.DATA_ERROR, message="Empty dataset name")
78
+
79
+ # In case that there's space in the head or the tail
80
+ dataset_name = dataset_name.strip()
81
+
82
+ # In case that the length of the name exceeds the limit
83
+ dataset_name_length = len(dataset_name)
84
+ if dataset_name_length > NAME_LENGTH_LIMIT:
85
+ return construct_json_result(
86
+ message=f"Dataset name: {dataset_name} with length {dataset_name_length} exceeds {NAME_LENGTH_LIMIT}!")
87
+
88
+ # In case that there are other fields in the data-binary
89
+ if len(request_body.keys()) > 1:
90
+ name_list = []
91
+ for key_name in request_body.keys():
92
+ if key_name != 'name':
93
+ name_list.append(key_name)
94
+ return construct_json_result(code=RetCode.DATA_ERROR,
95
+ message=f"fields: {name_list}, are not allowed in request body.")
96
+
97
+ # If there is a duplicate name, it will modify it to make it unique
98
+ request_body["name"] = duplicate_name(
99
+ KnowledgebaseService.query,
100
+ name=dataset_name,
101
+ tenant_id=tenant_id,
102
+ status=StatusEnum.VALID.value)
103
+ try:
104
+ request_body["id"] = get_uuid()
105
+ request_body["tenant_id"] = tenant_id
106
+ request_body["created_by"] = tenant_id
107
+ e, t = TenantService.get_by_id(tenant_id)
108
+ if not e:
109
+ return construct_result(code=RetCode.AUTHENTICATION_ERROR, message="Tenant not found.")
110
+ request_body["embd_id"] = t.embd_id
111
+ if not KnowledgebaseService.save(**request_body):
112
+ # failed to create new dataset
113
+ return construct_result()
114
+ return construct_json_result(data={"dataset_id": request_body["id"]})
115
+ except Exception as e:
116
+ return construct_error_response(e)
117
+
118
+
119
+ @manager.route('/<dataset_id>', methods=['DELETE'])
120
+ @login_required
121
+ def remove_dataset(dataset_id):
122
+ return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to remove dataset: {dataset_id}")
123
+
124
+
125
+ @manager.route('/<dataset_id>', methods=['PUT'])
126
+ @login_required
127
+ @validate_request("name")
128
+ def update_dataset(dataset_id):
129
+ return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to update dataset: {dataset_id}")
130
+
131
+
132
+ @manager.route('/<dataset_id>', methods=['GET'])
133
+ @login_required
134
+ def get_dataset(dataset_id):
135
+ return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
136
+
137
+
138
+ @manager.route('/', methods=['GET'])
139
+ @login_required
140
+ def list_datasets():
141
+ return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
142
+
api/contants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ NAME_LENGTH_LIMIT = 2 ** 10
api/settings.py CHANGED
@@ -239,4 +239,5 @@ class RetCode(IntEnum, CustomEnum):
239
  RUNNING = 106
240
  PERMISSION_ERROR = 108
241
  AUTHENTICATION_ERROR = 109
 
242
  SERVER_ERROR = 500
 
239
  RUNNING = 106
240
  PERMISSION_ERROR = 108
241
  AUTHENTICATION_ERROR = 109
242
+ UNAUTHORIZED = 401
243
  SERVER_ERROR = 500
api/utils/api_utils.py CHANGED
@@ -38,7 +38,6 @@ from base64 import b64encode
38
  from hmac import HMAC
39
  from urllib.parse import quote, urlencode
40
 
41
-
42
  requests.models.complexjson.dumps = functools.partial(
43
  json.dumps, cls=CustomJSONEncoder)
44
 
@@ -235,3 +234,35 @@ def cors_reponse(retcode=RetCode.SUCCESS,
235
  response.headers["Access-Control-Allow-Headers"] = "*"
236
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
237
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  from hmac import HMAC
39
  from urllib.parse import quote, urlencode
40
 
 
41
  requests.models.complexjson.dumps = functools.partial(
42
  json.dumps, cls=CustomJSONEncoder)
43
 
 
234
  response.headers["Access-Control-Allow-Headers"] = "*"
235
  response.headers["Access-Control-Expose-Headers"] = "Authorization"
236
  return response
237
+
238
+ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
239
+ import re
240
+ result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
241
+ response = {}
242
+ for key, value in result_dict.items():
243
+ if value is None and key != "code":
244
+ continue
245
+ else:
246
+ response[key] = value
247
+ return jsonify(response)
248
+
249
+
250
+ def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
251
+ if data == None:
252
+ return jsonify({"code": code, "message": message})
253
+ else:
254
+ return jsonify({"code": code, "message": message, "data": data})
255
+
256
+ def construct_error_response(e):
257
+ stat_logger.exception(e)
258
+ try:
259
+ if e.code == 401:
260
+ return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
261
+ except BaseException:
262
+ pass
263
+ if len(e.args) > 1:
264
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
265
+ if repr(e).find("index_not_found_exception") >=0:
266
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
267
+
268
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
sdk/python/README.md CHANGED
@@ -1 +1,41 @@
1
- # ragflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python-ragflow
2
+
3
+ # update python client
4
+
5
+ - Update "version" field of [project] chapter
6
+ - build new python SDK
7
+ - upload to pypi.org
8
+ - install new python SDK
9
+
10
+ # build python SDK
11
+
12
+ ```shell
13
+ rm -f dist/* && python setup.py sdist bdist_wheel
14
+ ```
15
+
16
+ # install python SDK
17
+ ```shell
18
+ pip uninstall -y ragflow && pip install dist/*.whl
19
+ ```
20
+
21
+ This will install ragflow-sdk and its dependencies.
22
+
23
+ # upload to pypi.org
24
+ ```shell
25
+ twine upload dist/*.whl
26
+ ```
27
+
28
+ Enter your pypi API token according to the prompt.
29
+
30
+ Note that pypi allow a version of a package [be uploaded only once](https://pypi.org/help/#file-name-reuse). You need to change the `version` inside the `pyproject.toml` before build and upload.
31
+
32
+ # using
33
+
34
+ ```python
35
+
36
+ ```
37
+
38
+ # For developer
39
+ ```shell
40
+ pip install -e .
41
+ ```
sdk/python/ragflow/dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ class DataSet:
17
+ def __init__(self, user_key, dataset_url, uuid, name):
18
+ self.user_key = user_key
19
+ self.dataset_url = dataset_url
20
+ self.uuid = uuid
21
+ self.name = name
sdk/python/ragflow/ragflow.py CHANGED
@@ -12,33 +12,43 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- #
16
  import os
17
- from abc import ABC
18
  import requests
 
19
 
20
-
21
- class RAGFLow(ABC):
22
- def __init__(self, user_key, base_url):
 
 
 
23
  self.user_key = user_key
24
- self.base_url = base_url
 
 
25
 
26
- def create_dataset(self, name):
27
- return name
 
 
 
 
 
28
 
29
- def delete_dataset(self, name):
30
- return name
31
 
32
  def list_dataset(self):
33
- endpoint = f"{self.base_url}/api/v1/dataset"
34
- response = requests.get(endpoint)
35
  if response.status_code == 200:
36
  return response.json()['datasets']
37
  else:
38
  return None
39
 
40
  def get_dataset(self, dataset_id):
41
- endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}"
42
  response = requests.get(endpoint)
43
  if response.status_code == 200:
44
  return response.json()
@@ -46,7 +56,7 @@ class RAGFLow(ABC):
46
  return None
47
 
48
  def update_dataset(self, dataset_id, params):
49
- endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}"
50
  response = requests.put(endpoint, json=params)
51
  if response.status_code == 200:
52
  return True
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+
16
  import os
 
17
  import requests
18
+ import json
19
 
20
+ class RAGFLow:
21
+ def __init__(self, user_key, base_url, version = 'v1'):
22
+ '''
23
+ api_url: http://<host_address>/api/v1
24
+ dataset_url: http://<host_address>/api/v1/dataset
25
+ '''
26
  self.user_key = user_key
27
+ self.api_url = f"{base_url}/api/{version}"
28
+ self.dataset_url = f"{self.api_url}/dataset"
29
+ self.authorization_header = {"Authorization": "{}".format(self.user_key)}
30
 
31
+ def create_dataset(self, dataset_name):
32
+ """
33
+ name: dataset name
34
+ """
35
+ res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header)
36
+ result_dict = json.loads(res.text)
37
+ return result_dict
38
 
39
+ def delete_dataset(self, dataset_name = None, dataset_id = None):
40
+ return dataset_name
41
 
42
  def list_dataset(self):
43
+ response = requests.get(self.dataset_url)
44
+ print(response)
45
  if response.status_code == 200:
46
  return response.json()['datasets']
47
  else:
48
  return None
49
 
50
  def get_dataset(self, dataset_id):
51
+ endpoint = f"{self.dataset_url}/{dataset_id}"
52
  response = requests.get(endpoint)
53
  if response.status_code == 200:
54
  return response.json()
 
56
  return None
57
 
58
  def update_dataset(self, dataset_id, params):
59
+ endpoint = f"{self.dataset_url}/{dataset_id}"
60
  response = requests.put(endpoint, json=params)
61
  if response.status_code == 200:
62
  return True
sdk/python/test/common.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE'
4
+ HOST_ADDRESS = 'http://127.0.0.1:9380'
sdk/python/test/test_basic.py CHANGED
@@ -3,49 +3,46 @@ import ragflow
3
  from ragflow.ragflow import RAGFLow
4
  import pytest
5
  from unittest.mock import MagicMock
 
6
 
7
 
8
- class TestCase(TestSdk):
9
-
10
- @pytest.fixture
11
- def ragflow_instance(self):
12
- # Here we create a mock instance of RAGFlow for testing
13
- return ragflow.ragflow.RAGFLow('123', 'url')
14
 
15
  def test_version(self):
16
  print(ragflow.__version__)
17
 
18
- def test_create_dataset(self):
19
- assert ragflow.ragflow.RAGFLow('123', 'url').create_dataset('abc') == 'abc'
20
-
21
- def test_delete_dataset(self):
22
- assert ragflow.ragflow.RAGFLow('123', 'url').delete_dataset('abc') == 'abc'
23
-
24
- def test_list_dataset_success(self, ragflow_instance, monkeypatch):
25
- # Mocking the response of requests.get method
26
- mock_response = MagicMock()
27
- mock_response.status_code = 200
28
- mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]}
29
-
30
- # Patching requests.get to return the mock_response
31
- monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response))
32
-
33
- # Call the method under test
34
- result = ragflow_instance.list_dataset()
35
-
36
- # Assertion
37
- assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]
38
-
39
- def test_list_dataset_failure(self, ragflow_instance, monkeypatch):
40
- # Mocking the response of requests.get method
41
- mock_response = MagicMock()
42
- mock_response.status_code = 404 # Simulating a failed request
43
-
44
- # Patching requests.get to return the mock_response
45
- monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response))
46
-
47
- # Call the method under test
48
- result = ragflow_instance.list_dataset()
49
-
50
- # Assertion
51
- assert result is None
 
 
3
  from ragflow.ragflow import RAGFLow
4
  import pytest
5
  from unittest.mock import MagicMock
6
+ from common import API_KEY, HOST_ADDRESS
7
 
8
 
9
+ class TestBasic(TestSdk):
 
 
 
 
 
10
 
11
  def test_version(self):
12
  print(ragflow.__version__)
13
 
14
+ # def test_create_dataset(self):
15
+ # res = RAGFLow(API_KEY, HOST_ADDRESS).create_dataset('abc')
16
+ # print(res)
17
+ #
18
+ # def test_delete_dataset(self):
19
+ # assert RAGFLow('123', 'url').delete_dataset('abc') == 'abc'
20
+ #
21
+ # def test_list_dataset_success(self, ragflow_instance, monkeypatch):
22
+ # # Mocking the response of requests.get method
23
+ # mock_response = MagicMock()
24
+ # mock_response.status_code = 200
25
+ # mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]}
26
+ #
27
+ # # Patching requests.get to return the mock_response
28
+ # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response))
29
+ #
30
+ # # Call the method under test
31
+ # result = ragflow_instance.list_dataset()
32
+ #
33
+ # # Assertion
34
+ # assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]
35
+ #
36
+ # def test_list_dataset_failure(self, ragflow_instance, monkeypatch):
37
+ # # Mocking the response of requests.get method
38
+ # mock_response = MagicMock()
39
+ # mock_response.status_code = 404 # Simulating a failed request
40
+ #
41
+ # # Patching requests.get to return the mock_response
42
+ # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response))
43
+ #
44
+ # # Call the method under test
45
+ # result = ragflow_instance.list_dataset()
46
+ #
47
+ # # Assertion
48
+ # assert result is None
sdk/python/test/test_dataset.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from test_sdkbase import TestSdk
2
+ import ragflow
3
+ from ragflow.ragflow import RAGFLow
4
+ import pytest
5
+ from unittest.mock import MagicMock
6
+ from common import API_KEY, HOST_ADDRESS
7
+
8
+ class TestDataset(TestSdk):
9
+
10
+ def test_create_dataset(self):
11
+ '''
12
+ 1. create a kb
13
+ 2. list the kb
14
+ 3. get the detail info according to the kb id
15
+ 4. update the kb
16
+ 5. delete the kb
17
+ '''
18
+ ragflow = RAGFLow(API_KEY, HOST_ADDRESS)
19
+
20
+ # create a kb
21
+ res = ragflow.create_dataset("kb1")
22
+ assert res['code'] == 0 and res['message'] == 'success'
23
+ dataset_id = res['data']['dataset_id']
24
+ print(dataset_id)
25
+
26
+ # TODO: list the kb