LiuHua root Kevin Hu commited on
Commit
0874636
·
1 Parent(s): 11e3284

create and update dataset (#2110)

Browse files

### What problem does this PR solve?

Added the ability to create and update dataset for SDK

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: root <root@xwg>
Co-authored-by: Kevin Hu <[email protected]>

api/apps/sdk/dataset.py CHANGED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from flask import request
17
+
18
+ from api.db import StatusEnum
19
+ from api.db.db_models import APIToken
20
+ from api.db.services.knowledgebase_service import KnowledgebaseService
21
+ from api.db.services.user_service import TenantService
22
+ from api.settings import RetCode
23
+ from api.utils import get_uuid
24
+ from api.utils.api_utils import get_data_error_result
25
+ from api.utils.api_utils import get_json_result
26
+
27
+
28
+ @manager.route('/save', methods=['POST'])
29
+ def save():
30
+ req = request.json
31
+ token = request.headers.get('Authorization').split()[1]
32
+ objs = APIToken.query(token=token)
33
+ if not objs:
34
+ return get_json_result(
35
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
36
+ tenant_id = objs[0].tenant_id
37
+ e, t = TenantService.get_by_id(tenant_id)
38
+ if not e:
39
+ return get_data_error_result(retmsg="Tenant not found.")
40
+ if "id" not in req:
41
+ req['id'] = get_uuid()
42
+ req["name"] = req["name"].strip()
43
+ if req["name"] == "":
44
+ return get_data_error_result(
45
+ retmsg="Name is not empty")
46
+ if KnowledgebaseService.query(name=req["name"]):
47
+ return get_data_error_result(
48
+ retmsg="Duplicated knowledgebase name")
49
+ req["tenant_id"] = tenant_id
50
+ req['created_by'] = tenant_id
51
+ req['embd_id'] = t.embd_id
52
+ if not KnowledgebaseService.save(**req):
53
+ return get_data_error_result(retmsg="Data saving error")
54
+ req.pop('created_by')
55
+ keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
56
+ 'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
57
+ for old_key,new_key in keys_to_rename.items():
58
+ if old_key in req:
59
+ req[new_key]=req.pop(old_key)
60
+ return get_json_result(data=req)
61
+ else:
62
+ if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
63
+ return get_data_error_result(
64
+ retmsg="Can't change tenant_id or embedding_model")
65
+
66
+ e, kb = KnowledgebaseService.get_by_id(req["id"])
67
+ if not e:
68
+ return get_data_error_result(
69
+ retmsg="Can't find this knowledgebase!")
70
+
71
+ if not KnowledgebaseService.query(
72
+ created_by=tenant_id, id=req["id"]):
73
+ return get_json_result(
74
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
75
+ retcode=RetCode.OPERATING_ERROR)
76
+
77
+ if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
78
+ return get_data_error_result(
79
+ retmsg="Can't change document_count or chunk_count ")
80
+
81
+ if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
82
+ return get_data_error_result(
83
+ retmsg="if chunk count is not 0, parser method is not changable. ")
84
+
85
+
86
+ if req["name"].lower() != kb.name.lower() \
87
+ and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
88
+ status=StatusEnum.VALID.value)) > 0:
89
+ return get_data_error_result(
90
+ retmsg="Duplicated knowledgebase name.")
91
+
92
+ del req["id"]
93
+ req['created_by'] = tenant_id
94
+ if not KnowledgebaseService.update_by_id(kb.id, req):
95
+ return get_data_error_result(retmsg="Data update error ")
96
+ return get_json_result(data=True)
sdk/python/ragflow/__init__.py CHANGED
@@ -3,3 +3,4 @@ import importlib.metadata
3
  __version__ = importlib.metadata.version("ragflow")
4
 
5
  from .ragflow import RAGFlow
 
 
3
  __version__ = importlib.metadata.version("ragflow")
4
 
5
  from .ragflow import RAGFlow
6
+ from .modules.dataset import DataSet
sdk/python/ragflow/modules/dataset.py CHANGED
@@ -2,7 +2,7 @@ from .base import Base
2
 
3
 
4
  class DataSet(Base):
5
- class ParseConfig(Base):
6
  def __init__(self, rag, res_dict):
7
  self.chunk_token_count = 128
8
  self.layout_recognize = True
@@ -21,13 +21,18 @@ class DataSet(Base):
21
  self.permission = "me"
22
  self.document_count = 0
23
  self.chunk_count = 0
24
- self.parse_method = 0
25
  self.parser_config = None
26
  super().__init__(rag, res_dict)
27
 
28
- def delete(self):
29
- try:
30
- self.post("/rm", {"kb_id": self.id})
31
- return True
32
- except Exception:
33
- return False
 
 
 
 
 
 
2
 
3
 
4
  class DataSet(Base):
5
+ class ParserConfig(Base):
6
  def __init__(self, rag, res_dict):
7
  self.chunk_token_count = 128
8
  self.layout_recognize = True
 
21
  self.permission = "me"
22
  self.document_count = 0
23
  self.chunk_count = 0
24
+ self.parser_method = "naive"
25
  self.parser_config = None
26
  super().__init__(rag, res_dict)
27
 
28
+ def save(self):
29
+ res = self.post('/dataset/save',
30
+ {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
31
+ "description": self.description, "language": self.language, "embd_id": self.embedding_model,
32
+ "permission": self.permission,
33
+ "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
34
+ "parser_config": self.parser_config.to_json()
35
+ })
36
+ res = res.json()
37
+ if not res.get("retmsg"): return True
38
+ raise Exception(res["retmsg"])
sdk/python/ragflow/ragflow.py CHANGED
@@ -21,180 +21,34 @@ from .modules.dataset import DataSet
21
  class RAGFlow:
22
  def __init__(self, user_key, base_url, version='v1'):
23
  """
24
- api_url: http://<host_address>/v1
25
- dataset_url: http://<host_address>/v1/kb
26
- document_url: http://<host_address>/v1/dataset/{dataset_id}/documents
27
  """
28
  self.user_key = user_key
29
- self.api_url = f"{base_url}/{version}"
30
- self.dataset_url = f"{self.api_url}/kb"
31
- self.authorization_header = {"Authorization": "{}".format(self.user_key)}
32
- self.base_url = base_url
33
 
34
  def post(self, path, param):
35
- res = requests.post(url=self.dataset_url + path, json=param, headers=self.authorization_header)
36
  return res
37
 
38
  def get(self, path, params=''):
39
- res = requests.get(self.dataset_url + path, params=params, headers=self.authorization_header)
40
  return res
41
 
42
- def create_dataset(self, dataset_name):
43
- """
44
- name: dataset name
45
- """
46
- res_create = self.post("/create", {"name": dataset_name})
47
- res_create_data = res_create.json()['data']
48
- res_detail = self.get("/detail", {"kb_id": res_create_data["kb_id"]})
49
- res_detail_data = res_detail.json()['data']
50
- result = {}
51
- result['id'] = res_detail_data['id']
52
- result['name'] = res_detail_data['name']
53
- result['avatar'] = res_detail_data['avatar']
54
- result['description'] = res_detail_data['description']
55
- result['language'] = res_detail_data['language']
56
- result['embedding_model'] = res_detail_data['embd_id']
57
- result['permission'] = res_detail_data['permission']
58
- result['document_count'] = res_detail_data['doc_num']
59
- result['chunk_count'] = res_detail_data['chunk_num']
60
- result['parser_config'] = res_detail_data['parser_config']
61
- dataset = DataSet(self, result)
62
- return dataset
63
-
64
- """
65
- def delete_dataset(self, dataset_name):
66
- dataset_id = self.find_dataset_id_by_name(dataset_name)
67
-
68
- endpoint = f"{self.dataset_url}/{dataset_id}"
69
- res = requests.delete(endpoint, headers=self.authorization_header)
70
- return res.json()
71
-
72
- def find_dataset_id_by_name(self, dataset_name):
73
- res = requests.get(self.dataset_url, headers=self.authorization_header)
74
- for dataset in res.json()["data"]:
75
- if dataset["name"] == dataset_name:
76
- return dataset["id"]
77
- return None
78
-
79
- def get_dataset(self, dataset_name):
80
- dataset_id = self.find_dataset_id_by_name(dataset_name)
81
- endpoint = f"{self.dataset_url}/{dataset_id}"
82
- response = requests.get(endpoint, headers=self.authorization_header)
83
- return response.json()
84
-
85
- def update_dataset(self, dataset_name, **params):
86
- dataset_id = self.find_dataset_id_by_name(dataset_name)
87
-
88
- endpoint = f"{self.dataset_url}/{dataset_id}"
89
- response = requests.put(endpoint, json=params, headers=self.authorization_header)
90
- return response.json()
91
-
92
- # ------------------------------- CONTENT MANAGEMENT -----------------------------------------------------
93
-
94
- # ----------------------------upload local files-----------------------------------------------------
95
- def upload_local_file(self, dataset_id, file_paths):
96
- files = []
97
-
98
- for file_path in file_paths:
99
- if not isinstance(file_path, str):
100
- return {"code": RetCode.ARGUMENT_ERROR, "message": f"{file_path} is not string."}
101
- if "http" in file_path:
102
- return {"code": RetCode.ARGUMENT_ERROR, "message": "Remote files have not unsupported."}
103
- if os.path.isfile(file_path):
104
- files.append(("file", open(file_path, "rb")))
105
- else:
106
- return {"code": RetCode.DATA_ERROR, "message": f"The file {file_path} does not exist"}
107
-
108
- res = requests.request("POST", url=f"{self.dataset_url}/{dataset_id}/documents", files=files,
109
- headers=self.authorization_header)
110
-
111
- result_dict = json.loads(res.text)
112
- return result_dict
113
-
114
- # ----------------------------delete a file-----------------------------------------------------
115
- def delete_files(self, document_id, dataset_id):
116
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
117
- res = requests.delete(endpoint, headers=self.authorization_header)
118
- return res.json()
119
-
120
- # ----------------------------list files-----------------------------------------------------
121
- def list_files(self, dataset_id, offset=0, count=-1, order_by="create_time", descend=True, keywords=""):
122
- params = {
123
- "offset": offset,
124
- "count": count,
125
- "order_by": order_by,
126
- "descend": descend,
127
- "keywords": keywords
128
- }
129
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/"
130
- res = requests.get(endpoint, params=params, headers=self.authorization_header)
131
- return res.json()
132
-
133
- # ----------------------------update files: enable, rename, template_type-------------------------------------------
134
- def update_file(self, dataset_id, document_id, **params):
135
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
136
- response = requests.put(endpoint, json=params, headers=self.authorization_header)
137
- return response.json()
138
-
139
- # ----------------------------download a file-----------------------------------------------------
140
- def download_file(self, dataset_id, document_id):
141
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}"
142
- res = requests.get(endpoint, headers=self.authorization_header)
143
-
144
- content = res.content # binary data
145
- # decode the binary data
146
- try:
147
- decoded_content = content.decode("utf-8")
148
- json_data = json.loads(decoded_content)
149
- return json_data # message
150
- except json.JSONDecodeError: # binary data
151
- _, document = DocumentService.get_by_id(document_id)
152
- file_path = os.path.join(os.getcwd(), document.name)
153
- with open(file_path, "wb") as file:
154
- file.write(content)
155
- return {"code": RetCode.SUCCESS, "data": content}
156
-
157
- # ----------------------------start parsing-----------------------------------------------------
158
- def start_parsing_document(self, dataset_id, document_id):
159
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
160
- res = requests.post(endpoint, headers=self.authorization_header)
161
-
162
- return res.json()
163
-
164
- def start_parsing_documents(self, dataset_id, doc_ids=None):
165
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/status"
166
- res = requests.post(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids})
167
-
168
- return res.json()
169
-
170
- # ----------------------------stop parsing-----------------------------------------------------
171
- def stop_parsing_document(self, dataset_id, document_id):
172
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
173
- res = requests.delete(endpoint, headers=self.authorization_header)
174
-
175
- return res.json()
176
-
177
- def stop_parsing_documents(self, dataset_id, doc_ids=None):
178
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/status"
179
- res = requests.delete(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids})
180
-
181
- return res.json()
182
-
183
- # ----------------------------show the status of the file-----------------------------------------------------
184
- def show_parsing_status(self, dataset_id, document_id):
185
- endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status"
186
- res = requests.get(endpoint, headers=self.authorization_header)
187
-
188
- return res.json()
189
- # ----------------------------list the chunks of the file-----------------------------------------------------
190
-
191
- # ----------------------------delete the chunk-----------------------------------------------------
192
-
193
- # ----------------------------edit the status of the chunk-----------------------------------------------------
194
-
195
- # ----------------------------insert a new chunk-----------------------------------------------------
196
 
197
- # ----------------------------get a specific chunk-----------------------------------------------------
198
 
199
- # ----------------------------retrieval test-----------------------------------------------------
200
- """
 
21
  class RAGFlow:
22
  def __init__(self, user_key, base_url, version='v1'):
23
  """
24
+ api_url: http://<host_address>/api/v1
 
 
25
  """
26
  self.user_key = user_key
27
+ self.api_url = f"{base_url}/api/{version}"
28
+ self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}
 
 
29
 
30
  def post(self, path, param):
31
+ res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
32
  return res
33
 
34
  def get(self, path, params=''):
35
+ res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
36
  return res
37
 
38
+ def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
39
+ document_count:int=0,chunk_count:int=0,parser_method:str="naive",
40
+ parser_config:DataSet.ParserConfig=None):
41
+ if parser_config is None:
42
+ parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
43
+ parser_config=parser_config.to_json()
44
+ res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
45
+ "doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
46
+ "parser_config":parser_config
47
+ }
48
+ )
49
+ res = res.json()
50
+ if not res.get("retmsg"):
51
+ return DataSet(self, res["data"])
52
+ raise Exception(res["retmsg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
54
 
 
 
sdk/python/test/common.py CHANGED
@@ -1,4 +1,4 @@
1
 
2
 
3
- API_KEY = 'IjUxNGM0MmM4NWY5MzExZWY5MDhhMDI0MmFjMTIwMDA2Ig.ZsWebA.mV1NKdSPPllgowiH-7vz36tMWyI'
4
  HOST_ADDRESS = 'http://127.0.0.1:9380'
 
1
 
2
 
3
+ API_KEY = 'ragflow-k0N2I1MzQwNjNhMzExZWY5ODg1MDI0Mm'
4
  HOST_ADDRESS = 'http://127.0.0.1:9380'
sdk/python/test/t_dataset.py CHANGED
@@ -1,4 +1,4 @@
1
- from ragflow import RAGFlow
2
 
3
  from common import API_KEY, HOST_ADDRESS
4
  from test_sdkbase import TestSdk
@@ -6,18 +6,27 @@ from test_sdkbase import TestSdk
6
 
7
  class TestDataset(TestSdk):
8
  def test_create_dataset_with_success(self):
 
 
 
9
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
10
  ds = rag.create_dataset("God")
11
- assert ds is not None, "The dataset creation failed, returned None."
12
- assert ds.name == "God", "Dataset name does not match."
 
 
13
 
14
- def test_delete_one_file(self):
15
  """
16
- Test deleting one file with success.
17
  """
18
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
19
  ds = rag.create_dataset("ABC")
20
- assert ds is not None, "Failed to create dataset"
21
- assert ds.name == "ABC", "Dataset name mismatch"
22
- delete_result = ds.delete()
23
- assert delete_result is True, "Failed to delete dataset"
 
 
 
 
 
1
+ from ragflow import RAGFlow, DataSet
2
 
3
  from common import API_KEY, HOST_ADDRESS
4
  from test_sdkbase import TestSdk
 
6
 
7
  class TestDataset(TestSdk):
8
  def test_create_dataset_with_success(self):
9
+ """
10
+ Test creating dataset with success
11
+ """
12
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
13
  ds = rag.create_dataset("God")
14
+ if isinstance(ds, DataSet):
15
+ assert ds.name == "God", "Name does not match."
16
+ else:
17
+ assert False, f"Failed to create dataset, error: {ds}"
18
 
19
+ def test_update_dataset_with_success(self):
20
  """
21
+ Test updating dataset with success.
22
  """
23
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
24
  ds = rag.create_dataset("ABC")
25
+ if isinstance(ds, DataSet):
26
+ assert ds.name == "ABC", "Name does not match."
27
+ ds.name = 'DEF'
28
+ res = ds.save()
29
+ assert res is True, f"Failed to update dataset, error: {res}"
30
+
31
+ else:
32
+ assert False, f"Failed to create dataset, error: {ds}"