Spaces:
Running
Running
Commit
·
287a0bc
1
Parent(s):
fbbc97b
feat: chroma initial deploy
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +10 -0
- .gitattributes +1 -35
- .github/ISSUE_TEMPLATE/bug_report.yaml +43 -0
- .github/ISSUE_TEMPLATE/config.yml +5 -0
- .github/ISSUE_TEMPLATE/feature_request.yaml +46 -0
- .github/ISSUE_TEMPLATE/installation_trouble.yaml +41 -0
- .github/actions/bandit-scan/Dockerfile +7 -0
- .github/actions/bandit-scan/action.yaml +26 -0
- .github/actions/bandit-scan/entrypoint.sh +13 -0
- .github/workflows/chroma-client-integration-test.yml +31 -0
- .github/workflows/chroma-cluster-test.yml +42 -0
- .github/workflows/chroma-coordinator-test.yaml +23 -0
- .github/workflows/chroma-integration-test.yml +40 -0
- .github/workflows/chroma-js-release.yml +42 -0
- .github/workflows/chroma-release-python-client.yml +58 -0
- .github/workflows/chroma-release.yml +179 -0
- .github/workflows/chroma-test.yml +65 -0
- .github/workflows/chroma-worker-test.yml +36 -0
- .github/workflows/pr-review-checklist.yml +37 -0
- .github/workflows/python-vuln.yaml +28 -0
- .gitignore +34 -0
- .pre-commit-config.yaml +36 -0
- .vscode/settings.json +131 -0
- Cargo.lock +0 -0
- Cargo.toml +5 -0
- DEVELOP.md +111 -0
- Dockerfile +39 -0
- LICENSE +201 -0
- README.md +106 -11
- RELEASE_PROCESS.md +22 -0
- Tiltfile +30 -0
- bandit.yaml +4 -0
- bin/cluster-test.sh +62 -0
- bin/docker_entrypoint.sh +15 -0
- bin/generate_cloudformation.py +198 -0
- bin/integration-test +75 -0
- bin/reset.sh +13 -0
- bin/templates/docker-compose.yml +21 -0
- bin/test-package.sh +24 -0
- bin/test-remote +16 -0
- bin/test.py +7 -0
- bin/version +8 -0
- bin/windows_upgrade_sqlite.py +20 -0
- chromadb/__init__.py +257 -0
- chromadb/api/__init__.py +596 -0
- chromadb/api/client.py +496 -0
- chromadb/api/fastapi.py +654 -0
- chromadb/api/models/Collection.py +633 -0
- chromadb/api/segment.py +914 -0
- chromadb/api/types.py +509 -0
.dockerignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
.conda
|
3 |
+
.git
|
4 |
+
examples
|
5 |
+
clients
|
6 |
+
.hypothesis
|
7 |
+
__pycache__
|
8 |
+
.vscode
|
9 |
+
*.egg-info
|
10 |
+
.pytest_cache
|
.gitattributes
CHANGED
@@ -1,35 +1 @@
|
|
1 |
-
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*_pb2.py* linguist-generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/ISSUE_TEMPLATE/bug_report.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 🐛 Bug Report
|
2 |
+
description: File a bug report to help us improve Chroma
|
3 |
+
title: "[Bug]: "
|
4 |
+
labels: ["bug", "triage"]
|
5 |
+
# assignees:
|
6 |
+
# - octocat
|
7 |
+
body:
|
8 |
+
- type: markdown
|
9 |
+
attributes:
|
10 |
+
value: |
|
11 |
+
Thanks for taking the time to fill out this bug report!
|
12 |
+
- type: textarea
|
13 |
+
id: what-happened
|
14 |
+
attributes:
|
15 |
+
label: What happened?
|
16 |
+
description: Also tell us, what did you expect to happen?
|
17 |
+
placeholder: Tell us what you see!
|
18 |
+
# value: "A bug happened!"
|
19 |
+
validations:
|
20 |
+
required: true
|
21 |
+
- type: textarea
|
22 |
+
id: versions
|
23 |
+
attributes:
|
24 |
+
label: Versions
|
25 |
+
description: Your Chroma, Python, and OS versions, as well as whatever else you think relevant. Check that you have [the latest Chroma](https://github.com/chroma-core/chroma/pkgs/container/chroma) as we are a fast moving pre v1.0 project.
|
26 |
+
placeholder: Chroma v0.3.22, Python 3.9.6, MacOS 12.5
|
27 |
+
# value: "A bug happened!"
|
28 |
+
validations:
|
29 |
+
required: true
|
30 |
+
- type: textarea
|
31 |
+
id: logs
|
32 |
+
attributes:
|
33 |
+
label: Relevant log output
|
34 |
+
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
|
35 |
+
render: shell
|
36 |
+
# - type: checkboxes
|
37 |
+
# id: terms
|
38 |
+
# attributes:
|
39 |
+
# label: Code of Conduct
|
40 |
+
# description: By submitting this issue, you agree to follow our [Code of Conduct](https://example.com)
|
41 |
+
# options:
|
42 |
+
# - label: I agree to follow this project's Code of Conduct
|
43 |
+
# required: true
|
.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
blank_issues_enabled: true
|
2 |
+
contact_links:
|
3 |
+
- name: 🤷🏻♀️ Questions
|
4 |
+
url: https://discord.com/invite/MMeYNTmh3x
|
5 |
+
about: Interact with the Chroma community here by asking for help, discussing and more!
|
.github/ISSUE_TEMPLATE/feature_request.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 🚀 Feature request
|
2 |
+
description: Suggest an idea for Chroma
|
3 |
+
title: "[Feature Request]: "
|
4 |
+
labels: ["enhancement"]
|
5 |
+
body:
|
6 |
+
- type: markdown
|
7 |
+
attributes:
|
8 |
+
value: |
|
9 |
+
Thanks for taking the time to request this feature!
|
10 |
+
- type: textarea
|
11 |
+
id: problem
|
12 |
+
attributes:
|
13 |
+
label: Describe the problem
|
14 |
+
description: Please provide a clear and concise description the problem this feature would solve. The more information you can provide here, the better.
|
15 |
+
placeholder: I prefer if...
|
16 |
+
validations:
|
17 |
+
required: true
|
18 |
+
- type: textarea
|
19 |
+
id: solution
|
20 |
+
attributes:
|
21 |
+
label: Describe the proposed solution
|
22 |
+
description: Please provide a clear and concise description of what you would like to happen.
|
23 |
+
placeholder: I would like to see...
|
24 |
+
validations:
|
25 |
+
required: true
|
26 |
+
- type: textarea
|
27 |
+
id: alternatives
|
28 |
+
attributes:
|
29 |
+
label: Alternatives considered
|
30 |
+
description: "Please provide a clear and concise description of any alternative solutions or features you've considered."
|
31 |
+
- type: dropdown
|
32 |
+
id: importance
|
33 |
+
attributes:
|
34 |
+
label: Importance
|
35 |
+
description: How important is this feature to you?
|
36 |
+
options:
|
37 |
+
- nice to have
|
38 |
+
- would make my life easier
|
39 |
+
- i cannot use Chroma without it
|
40 |
+
validations:
|
41 |
+
required: true
|
42 |
+
- type: textarea
|
43 |
+
id: additional-context
|
44 |
+
attributes:
|
45 |
+
label: Additional Information
|
46 |
+
description: Add any other context or screenshots about the feature request here.
|
.github/ISSUE_TEMPLATE/installation_trouble.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Installation Issue
|
2 |
+
description: Request for install help with Chroma
|
3 |
+
title: "[Install issue]: "
|
4 |
+
labels: ["installation trouble"]
|
5 |
+
body:
|
6 |
+
- type: markdown
|
7 |
+
attributes:
|
8 |
+
value: |
|
9 |
+
Thanks for taking the time to fill out this issue report!
|
10 |
+
- type: textarea
|
11 |
+
id: what-happened
|
12 |
+
attributes:
|
13 |
+
label: What happened?
|
14 |
+
description: Also tell us, what did you expect to happen?
|
15 |
+
placeholder: Tell us what you see!
|
16 |
+
# value: "A bug happened!"
|
17 |
+
validations:
|
18 |
+
required: true
|
19 |
+
- type: textarea
|
20 |
+
id: versions
|
21 |
+
attributes:
|
22 |
+
label: Versions
|
23 |
+
description: We need your Chroma, Python, and OS versions, as well as whatever else you think relevant.
|
24 |
+
placeholder: Chroma v0.3.14, Python 3.9.6, MacOS 12.5
|
25 |
+
# value: "A bug happened!"
|
26 |
+
validations:
|
27 |
+
required: true
|
28 |
+
- type: textarea
|
29 |
+
id: logs
|
30 |
+
attributes:
|
31 |
+
label: Relevant log output
|
32 |
+
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
|
33 |
+
render: shell
|
34 |
+
# - type: checkboxes
|
35 |
+
# id: terms
|
36 |
+
# attributes:
|
37 |
+
# label: Code of Conduct
|
38 |
+
# description: By submitting this issue, you agree to follow our [Code of Conduct](https://example.com)
|
39 |
+
# options:
|
40 |
+
# - label: I agree to follow this project's Code of Conduct
|
41 |
+
# required: true
|
.github/actions/bandit-scan/Dockerfile
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-alpine AS base-action
|
2 |
+
|
3 |
+
RUN pip3 install -U setuptools pip bandit
|
4 |
+
|
5 |
+
COPY entrypoint.sh /entrypoint.sh
|
6 |
+
RUN chmod +x /entrypoint.sh
|
7 |
+
ENTRYPOINT ["sh","/entrypoint.sh"]
|
.github/actions/bandit-scan/action.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 'Bandit Scan'
|
2 |
+
description: 'This action performs a security vulnerability scan of python code using bandit library.'
|
3 |
+
inputs:
|
4 |
+
bandit-config:
|
5 |
+
description: 'Bandit configuration file'
|
6 |
+
required: false
|
7 |
+
input-dir:
|
8 |
+
description: 'Directory to scan'
|
9 |
+
required: false
|
10 |
+
default: '.'
|
11 |
+
format:
|
12 |
+
description: 'Output format (txt, csv, json, xml, yaml). Default: json'
|
13 |
+
required: false
|
14 |
+
default: 'json'
|
15 |
+
output-file:
|
16 |
+
description: "The report file to produce. Make sure to align your format with the file extension to avoid confusion."
|
17 |
+
required: false
|
18 |
+
default: "bandit-scan.json"
|
19 |
+
runs:
|
20 |
+
using: 'docker'
|
21 |
+
image: 'Dockerfile'
|
22 |
+
args:
|
23 |
+
- ${{ inputs.format }}
|
24 |
+
- ${{ inputs.bandit-config }}
|
25 |
+
- ${{ inputs.input-dir }}
|
26 |
+
- ${{ inputs.output-file }}
|
.github/actions/bandit-scan/entrypoint.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
CFG="-c $2"
|
3 |
+
if [ -z "$1" ]; then
|
4 |
+
echo "No path to scan provided"
|
5 |
+
exit 1
|
6 |
+
fi
|
7 |
+
|
8 |
+
if [ -z "$2" ]; then
|
9 |
+
CFG = ""
|
10 |
+
fi
|
11 |
+
|
12 |
+
bandit -f "$1" ${CFG} -r "$3" -o "$4"
|
13 |
+
exit 0 #we want to ignore the exit code of bandit (for now)
|
.github/workflows/chroma-client-integration-test.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Client Integration Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- '**'
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
test:
|
15 |
+
timeout-minutes: 90
|
16 |
+
strategy:
|
17 |
+
matrix:
|
18 |
+
python: ['3.8', '3.9', '3.10', '3.11']
|
19 |
+
platform: [ubuntu-latest, windows-latest]
|
20 |
+
runs-on: ${{ matrix.platform }}
|
21 |
+
steps:
|
22 |
+
- name: Checkout
|
23 |
+
uses: actions/checkout@v3
|
24 |
+
- name: Set up Python ${{ matrix.python }}
|
25 |
+
uses: actions/setup-python@v4
|
26 |
+
with:
|
27 |
+
python-version: ${{ matrix.python }}
|
28 |
+
- name: Install test dependencies
|
29 |
+
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
30 |
+
- name: Test
|
31 |
+
run: clients/python/integration-test.sh
|
.github/workflows/chroma-cluster-test.yml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Cluster Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- '**'
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
test:
|
15 |
+
strategy:
|
16 |
+
matrix:
|
17 |
+
python: ['3.8']
|
18 |
+
platform: ['16core-64gb-ubuntu-latest']
|
19 |
+
testfile: ["chromadb/test/ingest/test_producer_consumer.py",
|
20 |
+
"chromadb/test/db/test_system.py",
|
21 |
+
"chromadb/test/segment/distributed/test_memberlist_provider.py",]
|
22 |
+
runs-on: ${{ matrix.platform }}
|
23 |
+
steps:
|
24 |
+
- name: Checkout
|
25 |
+
uses: actions/checkout@v3
|
26 |
+
- name: Set up Python ${{ matrix.python }}
|
27 |
+
uses: actions/setup-python@v4
|
28 |
+
with:
|
29 |
+
python-version: ${{ matrix.python }}
|
30 |
+
- name: Install test dependencies
|
31 |
+
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
32 |
+
- name: Start minikube
|
33 |
+
id: minikube
|
34 |
+
uses: medyagh/setup-minikube@latest
|
35 |
+
with:
|
36 |
+
minikube-version: latest
|
37 |
+
kubernetes-version: latest
|
38 |
+
driver: docker
|
39 |
+
addons: ingress, ingress-dns
|
40 |
+
start-args: '--profile chroma-test'
|
41 |
+
- name: Integration Test
|
42 |
+
run: bin/cluster-test.sh ${{ matrix.testfile }}
|
.github/workflows/chroma-coordinator-test.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Coordinator Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- '**'
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
test:
|
15 |
+
strategy:
|
16 |
+
matrix:
|
17 |
+
platform: [ubuntu-latest]
|
18 |
+
runs-on: ${{ matrix.platform }}
|
19 |
+
steps:
|
20 |
+
- name: Checkout
|
21 |
+
uses: actions/checkout@v3
|
22 |
+
- name: Build and test
|
23 |
+
run: cd go/coordinator && make test
|
.github/workflows/chroma-integration-test.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Integration Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- team/hypothesis-tests
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
- '**'
|
12 |
+
workflow_dispatch:
|
13 |
+
|
14 |
+
jobs:
|
15 |
+
test:
|
16 |
+
strategy:
|
17 |
+
matrix:
|
18 |
+
python: ['3.8']
|
19 |
+
platform: [ubuntu-latest, windows-latest]
|
20 |
+
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
|
21 |
+
"chromadb/test/property/test_add.py",
|
22 |
+
"chromadb/test/test_cli.py",
|
23 |
+
"chromadb/test/auth/test_simple_rbac_authz.py",
|
24 |
+
"chromadb/test/property/test_collections.py",
|
25 |
+
"chromadb/test/property/test_cross_version_persist.py",
|
26 |
+
"chromadb/test/property/test_embeddings.py",
|
27 |
+
"chromadb/test/property/test_filtering.py",
|
28 |
+
"chromadb/test/property/test_persist.py"]
|
29 |
+
runs-on: ${{ matrix.platform }}
|
30 |
+
steps:
|
31 |
+
- name: Checkout
|
32 |
+
uses: actions/checkout@v3
|
33 |
+
- name: Set up Python ${{ matrix.python }}
|
34 |
+
uses: actions/setup-python@v4
|
35 |
+
with:
|
36 |
+
python-version: ${{ matrix.python }}
|
37 |
+
- name: Install test dependencies
|
38 |
+
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
39 |
+
- name: Integration Test
|
40 |
+
run: bin/integration-test ${{ matrix.testfile }}
|
.github/workflows/chroma-js-release.yml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Release JS Client
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- 'js_release_*.*.*' # Match tags in the form js_release_X.Y.Z
|
7 |
+
- 'js_release_alpha_*.*.*' # Match tags in the form js_release_alpha_X.Y.Z
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build-and-release:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
permissions: write-all
|
13 |
+
steps:
|
14 |
+
- name: Check if tag matches the pattern
|
15 |
+
run: |
|
16 |
+
if [[ "${{ github.ref }}" =~ ^refs/tags/js_release_alpha_[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
17 |
+
echo "Tag matches the pattern js_release_alpha_X.Y.Z"
|
18 |
+
echo "NPM_SCRIPT=release_alpha" >> "$GITHUB_ENV"
|
19 |
+
elif [[ "${{ github.ref }}" =~ ^refs/tags/js_release_[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
20 |
+
echo "Tag matches the pattern js_release_X.Y.Z"
|
21 |
+
echo "NPM_SCRIPT=release" >> "$GITHUB_ENV"
|
22 |
+
else
|
23 |
+
echo "Tag does not match the release tag pattern, exiting workflow"
|
24 |
+
exit 1
|
25 |
+
fi
|
26 |
+
- name: Checkout
|
27 |
+
uses: actions/checkout@v3
|
28 |
+
with:
|
29 |
+
fetch-depth: 0
|
30 |
+
- name: Set up JS
|
31 |
+
uses: actions/setup-node@v3
|
32 |
+
with:
|
33 |
+
node-version: '16.x'
|
34 |
+
registry-url: 'https://registry.npmjs.org'
|
35 |
+
- name: Install Client Dev Dependencies
|
36 |
+
run: npm install
|
37 |
+
working-directory: ./clients/js/
|
38 |
+
- name: npm Test & Publish
|
39 |
+
run: npm run db:run && PORT=8001 npm run $NPM_SCRIPT
|
40 |
+
working-directory: ./clients/js/
|
41 |
+
env:
|
42 |
+
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
.github/workflows/chroma-release-python-client.yml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Release Python Client
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- '[0-9]+.[0-9]+.[0-9]+' # Match tags in the form X.Y.Z
|
7 |
+
branches:
|
8 |
+
- main
|
9 |
+
- hammad/thin_client
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
check_tag:
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
outputs:
|
15 |
+
tag_matches: ${{ steps.check-tag.outputs.tag_matches }}
|
16 |
+
steps:
|
17 |
+
- name: Check Tag
|
18 |
+
id: check-tag
|
19 |
+
run: |
|
20 |
+
if [[ ${{ github.event.ref }} =~ ^refs/tags/[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
21 |
+
echo "tag_matches=true" >> $GITHUB_OUTPUT
|
22 |
+
fi
|
23 |
+
build-and-release:
|
24 |
+
runs-on: ubuntu-latest
|
25 |
+
needs: check_tag
|
26 |
+
if: needs.check_tag.outputs.tag_matches == 'true'
|
27 |
+
permissions: write-all
|
28 |
+
steps:
|
29 |
+
- name: Checkout
|
30 |
+
uses: actions/checkout@v3
|
31 |
+
with:
|
32 |
+
fetch-depth: 0
|
33 |
+
- name: Set up Python
|
34 |
+
uses: actions/setup-python@v4
|
35 |
+
with:
|
36 |
+
python-version: '3.10'
|
37 |
+
- name: Install Client Dev Dependencies
|
38 |
+
run: python -m pip install -r ./clients/python/requirements.txt && python -m pip install -r ./clients/python/requirements_dev.txt
|
39 |
+
- name: Build Client
|
40 |
+
run: ./clients/python/build_python_thin_client.sh
|
41 |
+
- name: Install setuptools_scm
|
42 |
+
run: python -m pip install setuptools_scm
|
43 |
+
- name: Get Release Version
|
44 |
+
id: version
|
45 |
+
run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT
|
46 |
+
- name: Get current date
|
47 |
+
id: builddate
|
48 |
+
run: echo "builddate=$(date +'%Y-%m-%dT%H:%M')" >> $GITHUB_OUTPUT
|
49 |
+
- name: Publish to Test PyPI
|
50 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
51 |
+
with:
|
52 |
+
password: ${{ secrets.TEST_PYPI_PYTHON_CLIENT_PUBLISH_KEY }}
|
53 |
+
repository-url: https://test.pypi.org/legacy/
|
54 |
+
- name: Publish to PyPI
|
55 |
+
if: startsWith(github.ref, 'refs/tags')
|
56 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
57 |
+
with:
|
58 |
+
password: ${{ secrets.PYPI_PYTHON_CLIENT_PUBLISH_KEY }}
|
.github/workflows/chroma-release.yml
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- "*"
|
7 |
+
branches:
|
8 |
+
- main
|
9 |
+
|
10 |
+
env:
|
11 |
+
GHCR_IMAGE_NAME: "ghcr.io/chroma-core/chroma"
|
12 |
+
DOCKERHUB_IMAGE_NAME: "chromadb/chroma"
|
13 |
+
PLATFORMS: linux/amd64,linux/arm64 #linux/riscv64, linux/arm/v7
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
check_tag:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
outputs:
|
19 |
+
tag_matches: ${{ steps.check-tag.outputs.tag_matches }}
|
20 |
+
steps:
|
21 |
+
- name: Check Tag
|
22 |
+
id: check-tag
|
23 |
+
run: |
|
24 |
+
if [[ ${{ github.event.ref }} =~ ^refs/tags/[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
25 |
+
echo "tag_matches=true" >> $GITHUB_OUTPUT
|
26 |
+
fi
|
27 |
+
build-and-release:
|
28 |
+
runs-on: ubuntu-latest
|
29 |
+
needs: check_tag
|
30 |
+
permissions: write-all
|
31 |
+
steps:
|
32 |
+
- name: Checkout
|
33 |
+
uses: actions/checkout@v3
|
34 |
+
with:
|
35 |
+
fetch-depth: 0
|
36 |
+
# https://github.com/docker/setup-qemu-action - for multiplatform builds
|
37 |
+
- name: Set up QEMU
|
38 |
+
uses: docker/setup-qemu-action@v2
|
39 |
+
# https://github.com/docker/setup-buildx-action - for multiplatform builds
|
40 |
+
- name: Set up Docker Buildx
|
41 |
+
id: buildx
|
42 |
+
uses: docker/setup-buildx-action@v2
|
43 |
+
- name: Set up Python
|
44 |
+
uses: actions/setup-python@v4
|
45 |
+
with:
|
46 |
+
python-version: "3.10"
|
47 |
+
- name: Install Client Dev Dependencies
|
48 |
+
run: python -m pip install -r requirements_dev.txt
|
49 |
+
- name: Build Client
|
50 |
+
run: python -m build
|
51 |
+
- name: Test Client Package
|
52 |
+
run: bin/test-package.sh dist/*.tar.gz
|
53 |
+
- name: Log in to the Github Container registry
|
54 |
+
uses: docker/[email protected]
|
55 |
+
with:
|
56 |
+
registry: ghcr.io
|
57 |
+
username: ${{ github.actor }}
|
58 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
59 |
+
- name: Login to DockerHub
|
60 |
+
uses: docker/[email protected]
|
61 |
+
with:
|
62 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
63 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
64 |
+
- name: Install setuptools_scm
|
65 |
+
run: python -m pip install setuptools_scm
|
66 |
+
- name: Get Release Version
|
67 |
+
id: version
|
68 |
+
run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT
|
69 |
+
- name: Build and push prerelease Docker image
|
70 |
+
if: "needs.check_tag.outputs.tag_matches != 'true'"
|
71 |
+
uses: docker/[email protected]
|
72 |
+
with:
|
73 |
+
context: .
|
74 |
+
platforms: ${{ env.PLATFORMS }}
|
75 |
+
push: true
|
76 |
+
tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}"
|
77 |
+
- name: Build and push release Docker image
|
78 |
+
if: "needs.check_tag.outputs.tag_matches == 'true'"
|
79 |
+
uses: docker/[email protected]
|
80 |
+
with:
|
81 |
+
context: .
|
82 |
+
platforms: ${{ env.PLATFORMS }}
|
83 |
+
push: true
|
84 |
+
tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.GHCR_IMAGE_NAME }}:latest,${{ env.DOCKERHUB_IMAGE_NAME }}:latest"
|
85 |
+
- name: Get current date
|
86 |
+
id: builddate
|
87 |
+
run: echo "builddate=$(date +'%Y-%m-%dT%H:%M')" >> $GITHUB_OUTPUT
|
88 |
+
- name: Publish to Test PyPI
|
89 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
90 |
+
with:
|
91 |
+
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
92 |
+
repository_url: https://test.pypi.org/legacy/
|
93 |
+
- name: Publish to PyPI
|
94 |
+
if: "needs.check_tag.outputs.tag_matches == 'true'"
|
95 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
96 |
+
with:
|
97 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
98 |
+
- name: Login to AWS
|
99 |
+
uses: aws-actions/configure-aws-credentials@v1
|
100 |
+
with:
|
101 |
+
role-to-assume: arn:aws:iam::369178033109:role/github-action-generate-cf-template
|
102 |
+
aws-region: us-east-1
|
103 |
+
- name: Generate CloudFormation template
|
104 |
+
id: generate-cf
|
105 |
+
if: "needs.check_tag.outputs.tag_matches == 'true'"
|
106 |
+
run: "pip install boto3 && python bin/generate_cloudformation.py"
|
107 |
+
- name: Release Tagged Version
|
108 |
+
uses: ncipollo/[email protected]
|
109 |
+
if: "needs.check_tag.outputs.tag_matches == 'true'"
|
110 |
+
with:
|
111 |
+
body: |
|
112 |
+
Version: `${{steps.version.outputs.version}}`
|
113 |
+
Git ref: `${{github.ref}}`
|
114 |
+
Build Date: `${{steps.builddate.outputs.builddate}}`
|
115 |
+
PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz`
|
116 |
+
Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
|
117 |
+
DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
|
118 |
+
artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz"
|
119 |
+
prerelease: true
|
120 |
+
generateReleaseNotes: true
|
121 |
+
- name: Update Tag
|
122 |
+
uses: richardsimko/[email protected]
|
123 |
+
if: "needs.check_tag.outputs.tag_matches != 'true'"
|
124 |
+
with:
|
125 |
+
tag_name: latest
|
126 |
+
env:
|
127 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
128 |
+
- name: Release Latest
|
129 |
+
uses: ncipollo/[email protected]
|
130 |
+
if: "needs.check_tag.outputs.tag_matches != 'true'"
|
131 |
+
with:
|
132 |
+
tag: "latest"
|
133 |
+
name: "Latest"
|
134 |
+
body: |
|
135 |
+
Version: `${{steps.version.outputs.version}}`
|
136 |
+
Git ref: `${{github.ref}}`
|
137 |
+
Build Date: `${{steps.builddate.outputs.builddate}}`
|
138 |
+
PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz`
|
139 |
+
Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
|
140 |
+
DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
|
141 |
+
artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz"
|
142 |
+
allowUpdates: true
|
143 |
+
prerelease: true
|
144 |
+
- name: Trigger Hosted Chroma FE Release
|
145 |
+
uses: actions/github-script@v6
|
146 |
+
with:
|
147 |
+
github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
|
148 |
+
script: |
|
149 |
+
const result = await github.rest.actions.createWorkflowDispatch({
|
150 |
+
owner: 'chroma-core',
|
151 |
+
repo: 'hosted-chroma',
|
152 |
+
workflow_id: 'build-and-publish-frontend.yaml',
|
153 |
+
ref: 'main'
|
154 |
+
})
|
155 |
+
console.log(result)
|
156 |
+
- name: Trigger Hosted Chroma Coordinator Release
|
157 |
+
uses: actions/github-script@v6
|
158 |
+
with:
|
159 |
+
github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
|
160 |
+
script: |
|
161 |
+
const result = await github.rest.actions.createWorkflowDispatch({
|
162 |
+
owner: 'chroma-core',
|
163 |
+
repo: 'hosted-chroma',
|
164 |
+
workflow_id: 'build-and-deploy-coordinator.yaml',
|
165 |
+
ref: 'main'
|
166 |
+
})
|
167 |
+
console.log(result)
|
168 |
+
- name: Trigger Hosted Worker Release
|
169 |
+
uses: actions/github-script@v6
|
170 |
+
with:
|
171 |
+
github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
|
172 |
+
script: |
|
173 |
+
const result = await github.rest.actions.createWorkflowDispatch({
|
174 |
+
owner: 'chroma-core',
|
175 |
+
repo: 'hosted-chroma',
|
176 |
+
workflow_id: 'build-and-deploy-worker.yaml',
|
177 |
+
ref: 'main'
|
178 |
+
})
|
179 |
+
console.log(result)
|
.github/workflows/chroma-test.yml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- team/hypothesis-tests
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
- '**'
|
12 |
+
workflow_dispatch:
|
13 |
+
|
14 |
+
jobs:
|
15 |
+
test:
|
16 |
+
timeout-minutes: 90
|
17 |
+
strategy:
|
18 |
+
matrix:
|
19 |
+
python: ['3.8', '3.9', '3.10', '3.11']
|
20 |
+
platform: [ubuntu-latest, windows-latest]
|
21 |
+
testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
|
22 |
+
"chromadb/test/auth/test_simple_rbac_authz.py",
|
23 |
+
"chromadb/test/property/test_add.py",
|
24 |
+
"chromadb/test/property/test_collections.py",
|
25 |
+
"chromadb/test/property/test_cross_version_persist.py",
|
26 |
+
"chromadb/test/property/test_embeddings.py",
|
27 |
+
"chromadb/test/property/test_filtering.py",
|
28 |
+
"chromadb/test/property/test_persist.py"]
|
29 |
+
runs-on: ${{ matrix.platform }}
|
30 |
+
steps:
|
31 |
+
- name: Checkout
|
32 |
+
uses: actions/checkout@v3
|
33 |
+
- name: Set up Python ${{ matrix.python }}
|
34 |
+
uses: actions/setup-python@v4
|
35 |
+
with:
|
36 |
+
python-version: ${{ matrix.python }}
|
37 |
+
- name: Install test dependencies
|
38 |
+
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
39 |
+
- name: Upgrade SQLite
|
40 |
+
run: python bin/windows_upgrade_sqlite.py
|
41 |
+
if: runner.os == 'Windows'
|
42 |
+
- name: Test
|
43 |
+
run: python -m pytest ${{ matrix.testfile }}
|
44 |
+
stress-test:
|
45 |
+
timeout-minutes: 90
|
46 |
+
strategy:
|
47 |
+
matrix:
|
48 |
+
python: ['3.8']
|
49 |
+
platform: ['16core-64gb-ubuntu-latest', '16core-64gb-windows-latest']
|
50 |
+
testfile: ["'chromadb/test/stress/'"]
|
51 |
+
runs-on: ${{ matrix.platform }}
|
52 |
+
steps:
|
53 |
+
- name: Checkout
|
54 |
+
uses: actions/checkout@v3
|
55 |
+
- name: Set up Python ${{ matrix.python }}
|
56 |
+
uses: actions/setup-python@v4
|
57 |
+
with:
|
58 |
+
python-version: ${{ matrix.python }}
|
59 |
+
- name: Install test dependencies
|
60 |
+
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
61 |
+
- name: Upgrade SQLite
|
62 |
+
run: python bin/windows_upgrade_sqlite.py
|
63 |
+
if: runner.os == 'Windows'
|
64 |
+
- name: Test
|
65 |
+
run: python -m pytest ${{ matrix.testfile }}
|
.github/workflows/chroma-worker-test.yml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Chroma Worker Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- '**'
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
test:
|
15 |
+
strategy:
|
16 |
+
matrix:
|
17 |
+
platform: [ubuntu-latest]
|
18 |
+
runs-on: ${{ matrix.platform }}
|
19 |
+
steps:
|
20 |
+
- name: Checkout chroma-hnswlib
|
21 |
+
uses: actions/checkout@v3
|
22 |
+
with:
|
23 |
+
repository: chroma-core/hnswlib
|
24 |
+
path: hnswlib
|
25 |
+
- name: Checkout
|
26 |
+
uses: actions/checkout@v3
|
27 |
+
with:
|
28 |
+
path: chroma
|
29 |
+
- name: Install Protoc
|
30 |
+
uses: arduino/setup-protoc@v2
|
31 |
+
- name: Build
|
32 |
+
run: cargo build --verbose
|
33 |
+
working-directory: chroma
|
34 |
+
- name: Test
|
35 |
+
run: cargo test --verbose
|
36 |
+
working-directory: chroma
|
.github/workflows/pr-review-checklist.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PR Review Checklist
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request_target:
|
5 |
+
types:
|
6 |
+
- opened
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
PR-Comment:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- name: PR Comment
|
13 |
+
uses: actions/github-script@v2
|
14 |
+
with:
|
15 |
+
github-token: ${{secrets.GITHUB_TOKEN}}
|
16 |
+
script: |
|
17 |
+
github.issues.createComment({
|
18 |
+
issue_number: ${{ github.event.number }},
|
19 |
+
owner: context.repo.owner,
|
20 |
+
repo: context.repo.repo,
|
21 |
+
body: `# Reviewer Checklist
|
22 |
+
Please leverage this checklist to ensure your code review is thorough before approving
|
23 |
+
## Testing, Bugs, Errors, Logs, Documentation
|
24 |
+
- [ ] Can you think of any use case in which the code does not behave as intended? Have they been tested?
|
25 |
+
- [ ] Can you think of any inputs or external events that could break the code? Is user input validated and safe? Have they been tested?
|
26 |
+
- [ ] If appropriate, are there adequate property based tests?
|
27 |
+
- [ ] If appropriate, are there adequate unit tests?
|
28 |
+
- [ ] Should any logging, debugging, tracing information be added or removed?
|
29 |
+
- [ ] Are error messages user-friendly?
|
30 |
+
- [ ] Have all documentation changes needed been made?
|
31 |
+
- [ ] Have all non-obvious changes been commented?
|
32 |
+
## System Compatibility
|
33 |
+
- [ ] Are there any potential impacts on other parts of the system or backward compatibility?
|
34 |
+
- [ ] Does this change intersect with any items on our roadmap, and if so, is there a plan for fitting them together?
|
35 |
+
## Quality
|
36 |
+
- [ ] Is this code of a unexpectedly high quality (Readability, Modularity, Intuitiveness)`
|
37 |
+
})
|
.github/workflows/python-vuln.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Python Vulnerability Scan
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- '*'
|
6 |
+
- '*/**'
|
7 |
+
paths:
|
8 |
+
- chromadb/**
|
9 |
+
- clients/python/**
|
10 |
+
workflow_dispatch:
|
11 |
+
jobs:
|
12 |
+
bandit-scan:
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
steps:
|
15 |
+
- name: Checkout
|
16 |
+
uses: actions/checkout@v3
|
17 |
+
- uses: ./.github/actions/bandit-scan/
|
18 |
+
with:
|
19 |
+
input-dir: '.'
|
20 |
+
format: 'json'
|
21 |
+
bandit-config: 'bandit.yaml'
|
22 |
+
output-file: 'bandit-report.json'
|
23 |
+
- name: Upload Bandit Report
|
24 |
+
uses: actions/upload-artifact@v3
|
25 |
+
with:
|
26 |
+
name: bandit-artifact
|
27 |
+
path: |
|
28 |
+
bandit-report.json
|
.gitignore
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignore mac created DS_Store files
|
2 |
+
**/.DS_Store
|
3 |
+
|
4 |
+
**/__pycache__
|
5 |
+
|
6 |
+
go/coordinator/bin/
|
7 |
+
go/coordinator/**/testdata/
|
8 |
+
|
9 |
+
*.log
|
10 |
+
|
11 |
+
**/data__nogit
|
12 |
+
|
13 |
+
**/.ipynb_checkpoints
|
14 |
+
|
15 |
+
index_data
|
16 |
+
|
17 |
+
# Default configuration for persist_directory in chromadb/config.py
|
18 |
+
# Currently it's located in "./chroma/"
|
19 |
+
chroma/
|
20 |
+
chroma_test_data/
|
21 |
+
server.htpasswd
|
22 |
+
|
23 |
+
.venv
|
24 |
+
venv
|
25 |
+
.env
|
26 |
+
.chroma
|
27 |
+
*.egg-info
|
28 |
+
dist
|
29 |
+
|
30 |
+
.terraform/
|
31 |
+
.terraform.lock.hcl
|
32 |
+
terraform.tfstate
|
33 |
+
.hypothesis/
|
34 |
+
.idea
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: 'chromadb/proto/(chroma_pb2|coordinator_pb2)\.(py|pyi|py_grpc\.py)' # Generated files
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.5.0
|
5 |
+
hooks:
|
6 |
+
- id: trailing-whitespace
|
7 |
+
- id: mixed-line-ending
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
- id: requirements-txt-fixer
|
10 |
+
- id: check-yaml
|
11 |
+
args: ["--allow-multiple-documents"]
|
12 |
+
- id: check-xml
|
13 |
+
- id: check-merge-conflict
|
14 |
+
- id: check-case-conflict
|
15 |
+
- id: check-docstring-first
|
16 |
+
|
17 |
+
- repo: https://github.com/psf/black
|
18 |
+
# https://github.com/psf/black/issues/2493
|
19 |
+
rev: "refs/tags/23.3.0:refs/tags/23.3.0"
|
20 |
+
hooks:
|
21 |
+
- id: black
|
22 |
+
|
23 |
+
- repo: https://github.com/PyCQA/flake8
|
24 |
+
rev: 6.1.0
|
25 |
+
hooks:
|
26 |
+
- id: flake8
|
27 |
+
args:
|
28 |
+
- "--extend-ignore=E203,E501,E503"
|
29 |
+
- "--max-line-length=88"
|
30 |
+
|
31 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
32 |
+
rev: "v1.2.0"
|
33 |
+
hooks:
|
34 |
+
- id: mypy
|
35 |
+
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract, --config-file=./pyproject.toml]
|
36 |
+
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf", "kubernetes"]
|
.vscode/settings.json
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"git.ignoreLimitWarning": true,
|
3 |
+
"editor.rulers": [
|
4 |
+
88
|
5 |
+
],
|
6 |
+
"editor.formatOnSave": true,
|
7 |
+
"python.formatting.provider": "black",
|
8 |
+
"files.exclude": {
|
9 |
+
"**/__pycache__": true,
|
10 |
+
"**/.ipynb_checkpoints": true,
|
11 |
+
"**/.pytest_cache": true,
|
12 |
+
"**/chroma.egg-info": true
|
13 |
+
},
|
14 |
+
"python.analysis.typeCheckingMode": "basic",
|
15 |
+
"python.linting.flake8Enabled": true,
|
16 |
+
"python.linting.enabled": true,
|
17 |
+
"python.linting.flake8Args": [
|
18 |
+
"--extend-ignore=E203",
|
19 |
+
"--extend-ignore=E501",
|
20 |
+
"--extend-ignore=E503",
|
21 |
+
"--max-line-length=88"
|
22 |
+
],
|
23 |
+
"python.testing.pytestArgs": [
|
24 |
+
"."
|
25 |
+
],
|
26 |
+
"python.testing.unittestEnabled": false,
|
27 |
+
"python.testing.pytestEnabled": true,
|
28 |
+
"editor.formatOnPaste": true,
|
29 |
+
"python.linting.mypyEnabled": true,
|
30 |
+
"python.linting.mypyCategorySeverity.note": "Error",
|
31 |
+
"python.linting.mypyArgs": [
|
32 |
+
"--follow-imports=silent",
|
33 |
+
"--ignore-missing-imports",
|
34 |
+
"--show-column-numbers",
|
35 |
+
"--no-pretty",
|
36 |
+
"--strict",
|
37 |
+
"--disable-error-code=type-abstract"
|
38 |
+
],
|
39 |
+
"protoc": {
|
40 |
+
"options": [
|
41 |
+
"--proto_path=idl/",
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"rust-analyzer.cargo.buildScripts.enable": true,
|
45 |
+
"files.associations": {
|
46 |
+
"fstream": "cpp",
|
47 |
+
"iosfwd": "cpp",
|
48 |
+
"__hash_table": "cpp",
|
49 |
+
"__locale": "cpp",
|
50 |
+
"atomic": "cpp",
|
51 |
+
"deque": "cpp",
|
52 |
+
"filesystem": "cpp",
|
53 |
+
"future": "cpp",
|
54 |
+
"locale": "cpp",
|
55 |
+
"random": "cpp",
|
56 |
+
"regex": "cpp",
|
57 |
+
"string": "cpp",
|
58 |
+
"tuple": "cpp",
|
59 |
+
"type_traits": "cpp",
|
60 |
+
"unordered_map": "cpp",
|
61 |
+
"valarray": "cpp",
|
62 |
+
"variant": "cpp",
|
63 |
+
"vector": "cpp",
|
64 |
+
"__string": "cpp",
|
65 |
+
"istream": "cpp",
|
66 |
+
"memory": "cpp",
|
67 |
+
"optional": "cpp",
|
68 |
+
"string_view": "cpp",
|
69 |
+
"__bit_reference": "cpp",
|
70 |
+
"__bits": "cpp",
|
71 |
+
"__config": "cpp",
|
72 |
+
"__debug": "cpp",
|
73 |
+
"__errc": "cpp",
|
74 |
+
"__mutex_base": "cpp",
|
75 |
+
"__node_handle": "cpp",
|
76 |
+
"__nullptr": "cpp",
|
77 |
+
"__split_buffer": "cpp",
|
78 |
+
"__threading_support": "cpp",
|
79 |
+
"__tree": "cpp",
|
80 |
+
"__tuple": "cpp",
|
81 |
+
"array": "cpp",
|
82 |
+
"bit": "cpp",
|
83 |
+
"bitset": "cpp",
|
84 |
+
"cctype": "cpp",
|
85 |
+
"charconv": "cpp",
|
86 |
+
"chrono": "cpp",
|
87 |
+
"cinttypes": "cpp",
|
88 |
+
"clocale": "cpp",
|
89 |
+
"cmath": "cpp",
|
90 |
+
"compare": "cpp",
|
91 |
+
"complex": "cpp",
|
92 |
+
"concepts": "cpp",
|
93 |
+
"condition_variable": "cpp",
|
94 |
+
"csignal": "cpp",
|
95 |
+
"cstdarg": "cpp",
|
96 |
+
"cstddef": "cpp",
|
97 |
+
"cstdint": "cpp",
|
98 |
+
"cstdio": "cpp",
|
99 |
+
"cstdlib": "cpp",
|
100 |
+
"cstring": "cpp",
|
101 |
+
"ctime": "cpp",
|
102 |
+
"cwchar": "cpp",
|
103 |
+
"cwctype": "cpp",
|
104 |
+
"exception": "cpp",
|
105 |
+
"format": "cpp",
|
106 |
+
"forward_list": "cpp",
|
107 |
+
"initializer_list": "cpp",
|
108 |
+
"iomanip": "cpp",
|
109 |
+
"ios": "cpp",
|
110 |
+
"iostream": "cpp",
|
111 |
+
"limits": "cpp",
|
112 |
+
"list": "cpp",
|
113 |
+
"map": "cpp",
|
114 |
+
"mutex": "cpp",
|
115 |
+
"new": "cpp",
|
116 |
+
"numeric": "cpp",
|
117 |
+
"ostream": "cpp",
|
118 |
+
"queue": "cpp",
|
119 |
+
"ratio": "cpp",
|
120 |
+
"set": "cpp",
|
121 |
+
"sstream": "cpp",
|
122 |
+
"stack": "cpp",
|
123 |
+
"stdexcept": "cpp",
|
124 |
+
"streambuf": "cpp",
|
125 |
+
"system_error": "cpp",
|
126 |
+
"typeindex": "cpp",
|
127 |
+
"typeinfo": "cpp",
|
128 |
+
"unordered_set": "cpp",
|
129 |
+
"algorithm": "cpp"
|
130 |
+
},
|
131 |
+
}
|
Cargo.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Cargo.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[workspace]
|
2 |
+
|
3 |
+
members = [
|
4 |
+
"rust/worker/"
|
5 |
+
]
|
DEVELOP.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Development Instructions
|
2 |
+
|
3 |
+
This project uses the testing, build and release standards specified
|
4 |
+
by the PyPA organization and documented at
|
5 |
+
https://packaging.python.org.
|
6 |
+
|
7 |
+
## Setup
|
8 |
+
|
9 |
+
Because of the dependencies it relies on (like `pytorch`), this project does not support Python version >3.10.0.
|
10 |
+
|
11 |
+
Set up a virtual environment and install the project's requirements
|
12 |
+
and dev requirements:
|
13 |
+
|
14 |
+
```
|
15 |
+
python3 -m venv venv # Only need to do this once
|
16 |
+
source venv/bin/activate # Do this each time you use a new shell for the project
|
17 |
+
pip install -r requirements.txt
|
18 |
+
pip install -r requirements_dev.txt
|
19 |
+
pre-commit install # install the precommit hooks
|
20 |
+
```
|
21 |
+
|
22 |
+
You can also install `chromadb` the `pypi` package locally and in editable mode with `pip install -e .`.
|
23 |
+
|
24 |
+
## Running Chroma
|
25 |
+
|
26 |
+
Chroma can be run via 3 modes:
|
27 |
+
1. Standalone and in-memory:
|
28 |
+
```python
|
29 |
+
import chromadb
|
30 |
+
api = chromadb.Client()
|
31 |
+
print(api.heartbeat())
|
32 |
+
```
|
33 |
+
|
34 |
+
2. Standalone and in-memory with persistence:
|
35 |
+
|
36 |
+
This by default saves your db and your indexes to a `.chroma` directory and can also load from them.
|
37 |
+
```python
|
38 |
+
import chromadb
|
39 |
+
api = chromadb.PersistentClient(path="/path/to/persist/directory")
|
40 |
+
print(api.heartbeat())
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
3. With a persistent backend and a small frontend client
|
45 |
+
|
46 |
+
Run `chroma run --path /chroma_db_path`
|
47 |
+
```python
|
48 |
+
import chromadb
|
49 |
+
api = chromadb.HttpClient(host="localhost", port="8000")
|
50 |
+
|
51 |
+
print(api.heartbeat())
|
52 |
+
```
|
53 |
+
## Local dev setup for distributed chroma
|
54 |
+
We use tilt for providing local dev setup. Tilt is an open source project
|
55 |
+
##### Requirement
|
56 |
+
- Docker
|
57 |
+
- Local Kubernetes cluster (Recommended: [OrbStack](https://orbstack.dev/) for mac, [Kind](https://kind.sigs.k8s.io/) for linux)
|
58 |
+
- [Tilt](https://docs.tilt.dev/)
|
59 |
+
|
60 |
+
For starting the distributed Chroma in the workspace, use `tilt up`. It will create all the required resources and build the necessary Docker image in the current kubectl context.
|
61 |
+
Once done, it will expose Chroma on port 8000. You can also visit the Tilt dashboard UI at http://localhost:10350/. To clean and remove all the resources created by Tilt, use `tilt down`.
|
62 |
+
|
63 |
+
## Testing
|
64 |
+
|
65 |
+
Unit tests are in the `/chromadb/test` directory.
|
66 |
+
|
67 |
+
To run unit tests using your current environment, run `pytest`.
|
68 |
+
|
69 |
+
## Manual Build
|
70 |
+
|
71 |
+
To manually build a distribution, run `python -m build`.
|
72 |
+
|
73 |
+
The project's source and wheel distributions will be placed in the `dist` directory.
|
74 |
+
|
75 |
+
## Manual Release
|
76 |
+
|
77 |
+
Not yet implemented.
|
78 |
+
|
79 |
+
## Versioning
|
80 |
+
|
81 |
+
This project uses PyPA's `setuptools_scm` module to determine the
|
82 |
+
version number for build artifacts, meaning the version number is
|
83 |
+
derived from Git rather than hardcoded in the repository. For full
|
84 |
+
details, see the
|
85 |
+
[documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/).
|
86 |
+
|
87 |
+
In brief, version numbers are generated as follows:
|
88 |
+
|
89 |
+
- If the current git head is tagged, the version number is exactly the
|
90 |
+
tag (e.g, `0.0.1`).
|
91 |
+
- If the the current git head is a clean checkout, but is not tagged,
|
92 |
+
the version number is a patch version increment of the most recent
|
93 |
+
tag, plus `devN` where N is the number of commits since the most
|
94 |
+
recent tag. For example, if there have been 5 commits since the
|
95 |
+
`0.0.1` tag, the generated version will be `0.0.2-dev5`.
|
96 |
+
- If the current head is not a clean checkout, a `+dirty` local
|
97 |
+
version will be appended to the version number. For example,
|
98 |
+
`0.0.2-dev5+dirty`.
|
99 |
+
|
100 |
+
At any point, you can manually run `python -m setuptools_scm` to see
|
101 |
+
what version would be assigned given your current state.
|
102 |
+
|
103 |
+
## Continuous Integration
|
104 |
+
|
105 |
+
This project uses Github Actions to run unit tests automatically upon
|
106 |
+
every commit to the main branch. See the documentation for Github
|
107 |
+
Actions and the flow definitions in `.github/workflows` for details.
|
108 |
+
|
109 |
+
## Continuous Delivery
|
110 |
+
|
111 |
+
Not yet implemented.
|
Dockerfile
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim-bookworm AS builder
|
2 |
+
ARG REBUILD_HNSWLIB
|
3 |
+
RUN apt-get update --fix-missing && apt-get install -y --fix-missing \
|
4 |
+
build-essential \
|
5 |
+
gcc \
|
6 |
+
g++ \
|
7 |
+
cmake \
|
8 |
+
autoconf && \
|
9 |
+
rm -rf /var/lib/apt/lists/* && \
|
10 |
+
mkdir /install
|
11 |
+
|
12 |
+
WORKDIR /install
|
13 |
+
|
14 |
+
COPY ./requirements.txt requirements.txt
|
15 |
+
|
16 |
+
RUN pip install --no-cache-dir --upgrade --prefix="/install" -r requirements.txt
|
17 |
+
RUN if [ "$REBUILD_HNSWLIB" = "true" ]; then pip install --no-binary :all: --force-reinstall --no-cache-dir --prefix="/install" chroma-hnswlib; fi
|
18 |
+
|
19 |
+
FROM python:3.11-slim-bookworm AS final
|
20 |
+
|
21 |
+
RUN mkdir /chroma
|
22 |
+
WORKDIR /chroma
|
23 |
+
|
24 |
+
COPY --from=builder /install /usr/local
|
25 |
+
COPY ./bin/docker_entrypoint.sh /docker_entrypoint.sh
|
26 |
+
COPY ./ /chroma
|
27 |
+
|
28 |
+
RUN chmod +x /docker_entrypoint.sh
|
29 |
+
|
30 |
+
ENV CHROMA_HOST_ADDR "0.0.0.0"
|
31 |
+
ENV CHROMA_HOST_PORT 7860
|
32 |
+
ENV CHROMA_WORKERS 1
|
33 |
+
ENV CHROMA_LOG_CONFIG "chromadb/log_config.yml"
|
34 |
+
ENV CHROMA_TIMEOUT_KEEP_ALIVE 30
|
35 |
+
|
36 |
+
EXPOSE 7860
|
37 |
+
|
38 |
+
ENTRYPOINT ["/docker_entrypoint.sh"]
|
39 |
+
CMD [ "--workers ${CHROMA_WORKERS} --host ${CHROMA_HOST_ADDR} --port ${CHROMA_HOST_PORT} --proxy-headers --log-config ${CHROMA_LOG_CONFIG} --timeout-keep-alive ${CHROMA_TIMEOUT_KEEP_ALIVE}"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,11 +1,106 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<a href="https://trychroma.com"><img src="https://user-images.githubusercontent.com/891664/227103090-6624bf7d-9524-4e05-9d2c-c28d5d451481.png" alt="Chroma logo"></a>
|
3 |
+
</p>
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<b>Chroma - the open-source embedding database</b>. <br />
|
7 |
+
The fastest way to build Python or JavaScript LLM apps with memory!
|
8 |
+
</p>
|
9 |
+
|
10 |
+
<p align="center">
|
11 |
+
<a href="https://discord.gg/MMeYNTmh3x" target="_blank">
|
12 |
+
<img src="https://img.shields.io/discord/1073293645303795742" alt="Discord">
|
13 |
+
</a> |
|
14 |
+
<a href="https://github.com/chroma-core/chroma/blob/master/LICENSE" target="_blank">
|
15 |
+
<img src="https://img.shields.io/static/v1?label=license&message=Apache 2.0&color=white" alt="License">
|
16 |
+
</a> |
|
17 |
+
<a href="https://docs.trychroma.com/" target="_blank">
|
18 |
+
Docs
|
19 |
+
</a> |
|
20 |
+
<a href="https://www.trychroma.com/" target="_blank">
|
21 |
+
Homepage
|
22 |
+
</a>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
|
26 |
+
<p align="center">
|
27 |
+
<a href="https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml" target="_blank">
|
28 |
+
<img src="https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml/badge.svg?branch=main" alt="Integration Tests">
|
29 |
+
</a> |
|
30 |
+
<a href="https://github.com/chroma-core/chroma/actions/workflows/chroma-test.yml" target="_blank">
|
31 |
+
<img src="https://github.com/chroma-core/chroma/actions/workflows/chroma-test.yml/badge.svg?branch=main" alt="Tests">
|
32 |
+
</a>
|
33 |
+
</p>
|
34 |
+
|
35 |
+
```bash
|
36 |
+
pip install chromadb # python client
|
37 |
+
# for javascript, npm install chromadb!
|
38 |
+
# for client-server mode, chroma run --path /chroma_db_path
|
39 |
+
```
|
40 |
+
|
41 |
+
The core API is only 4 functions (run our [💡 Google Colab](https://colab.research.google.com/drive/1QEzFyqnoFxq7LUGyP1vzR4iLt9PpCDXv?usp=sharing) or [Replit template](https://replit.com/@swyx/BasicChromaStarter?v=1)):
|
42 |
+
|
43 |
+
```python
|
44 |
+
import chromadb
|
45 |
+
# setup Chroma in-memory, for easy prototyping. Can add persistence easily!
|
46 |
+
client = chromadb.Client()
|
47 |
+
|
48 |
+
# Create collection. get_collection, get_or_create_collection, delete_collection also available!
|
49 |
+
collection = client.create_collection("all-my-documents")
|
50 |
+
|
51 |
+
# Add docs to the collection. Can also update and delete. Row-based API coming soon!
|
52 |
+
collection.add(
|
53 |
+
documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
|
54 |
+
metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these!
|
55 |
+
ids=["doc1", "doc2"], # unique for each doc
|
56 |
+
)
|
57 |
+
|
58 |
+
# Query/search 2 most similar results. You can also .get by id
|
59 |
+
results = collection.query(
|
60 |
+
query_texts=["This is a query document"],
|
61 |
+
n_results=2,
|
62 |
+
# where={"metadata_field": "is_equal_to_this"}, # optional filter
|
63 |
+
# where_document={"$contains":"search_string"} # optional filter
|
64 |
+
)
|
65 |
+
```
|
66 |
+
|
67 |
+
## Features
|
68 |
+
- __Simple__: Fully-typed, fully-tested, fully-documented == happiness
|
69 |
+
- __Integrations__: [`🦜️🔗 LangChain`](https://blog.langchain.dev/langchain-chroma/) (python and js), [`🦙 LlamaIndex`](https://twitter.com/atroyn/status/1628557389762007040) and more soon
|
70 |
+
- __Dev, Test, Prod__: the same API that runs in your python notebook, scales to your cluster
|
71 |
+
- __Feature-rich__: Queries, filtering, density estimation and more
|
72 |
+
- __Free & Open Source__: Apache 2.0 Licensed
|
73 |
+
|
74 |
+
## Use case: ChatGPT for ______
|
75 |
+
|
76 |
+
For example, the `"Chat your data"` use case:
|
77 |
+
1. Add documents to your database. You can pass in your own embeddings, embedding function, or let Chroma embed them for you.
|
78 |
+
2. Query relevant documents with natural language.
|
79 |
+
3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis.
|
80 |
+
|
81 |
+
## Embeddings?
|
82 |
+
|
83 |
+
What are embeddings?
|
84 |
+
|
85 |
+
- [Read the guide from OpenAI](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
|
86 |
+
- __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model.
|
87 |
+
- __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find.
|
88 |
+
- __Technical__: An embedding is the latent-space position of a document at a layer of a deep neural network. For models trained specifically to embed data, this is the last layer.
|
89 |
+
- __A small example__: If you search your photos for "famous bridge in San Francisco". By embedding this query and comparing it to the embeddings of your photos and their metadata - it should return photos of the Golden Gate Bridge.
|
90 |
+
|
91 |
+
Embeddings databases (also known as **vector databases**) store embeddings and allow you to search by nearest neighbors rather than by substrings like a traditional database. By default, Chroma uses [Sentence Transformers](https://docs.trychroma.com/embeddings#sentence-transformers) to embed for you but you can also use OpenAI embeddings, Cohere (multilingual) embeddings, or your own.
|
92 |
+
|
93 |
+
## Get involved
|
94 |
+
|
95 |
+
Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project.
|
96 |
+
- [Join the conversation on Discord](https://discord.gg/MMeYNTmh3x) - `#contributing` channel
|
97 |
+
- [Review the 🛣️ Roadmap and contribute your ideas](https://docs.trychroma.com/roadmap)
|
98 |
+
- [Grab an issue and open a PR](https://github.com/chroma-core/chroma/issues) - [`Good first issue tag`](https://github.com/chroma-core/chroma/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
|
99 |
+
- [Read our contributing guide](https://docs.trychroma.com/contributing)
|
100 |
+
|
101 |
+
**Release Cadence**
|
102 |
+
We currently release new tagged versions of the `pypi` and `npm` packages on Mondays. Hotfixes go out at any time during the week.
|
103 |
+
|
104 |
+
## License
|
105 |
+
|
106 |
+
[Apache 2.0](./LICENSE)
|
RELEASE_PROCESS.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Release Process
|
2 |
+
|
3 |
+
This guide covers how to release chroma to PyPi
|
4 |
+
|
5 |
+
#### Increase the version number
|
6 |
+
1. Create a new PR for the release that upgrades the version in code. Name it `release/A.B.C` In [this file](https://github.com/chroma-core/chroma/blob/main/chromadb/__init__.py) update the __ version __.
|
7 |
+
```
|
8 |
+
__version__ = "A.B.C"
|
9 |
+
```
|
10 |
+
2. Add the "release" label to this PR
|
11 |
+
3. Once the PR is merged, tag your commit SHA with the release version
|
12 |
+
```
|
13 |
+
git tag A.B.C <SHA>
|
14 |
+
```
|
15 |
+
4. You need to then wait for the github action for main for `chroma release` and `chroma client release` to go green. Not doing this will result in a race condition.
|
16 |
+
|
17 |
+
#### Perform the release
|
18 |
+
1. Push your tag to origin to create the release
|
19 |
+
```
|
20 |
+
git push origin A.B.C
|
21 |
+
```
|
22 |
+
2. This will trigger a Github action which performs the release
|
Tiltfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
docker_build('coordinator',
|
2 |
+
context='.',
|
3 |
+
dockerfile='./go/coordinator/Dockerfile'
|
4 |
+
)
|
5 |
+
|
6 |
+
docker_build('server',
|
7 |
+
context='.',
|
8 |
+
dockerfile='./Dockerfile',
|
9 |
+
)
|
10 |
+
|
11 |
+
docker_build('worker',
|
12 |
+
context='.',
|
13 |
+
dockerfile='./rust/worker/Dockerfile'
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
k8s_yaml(['k8s/dev/setup.yaml'])
|
18 |
+
k8s_resource(
|
19 |
+
objects=['chroma:Namespace', 'memberlist-reader:ClusterRole', 'memberlist-reader:ClusterRoleBinding', 'pod-list-role:Role', 'pod-list-role-binding:RoleBinding', 'memberlists.chroma.cluster:CustomResourceDefinition','worker-memberlist:MemberList'],
|
20 |
+
new_name='k8s_setup',
|
21 |
+
labels=["infrastructure"]
|
22 |
+
)
|
23 |
+
k8s_yaml(['k8s/dev/pulsar.yaml'])
|
24 |
+
k8s_resource('pulsar', resource_deps=['k8s_setup'], labels=["infrastructure"])
|
25 |
+
k8s_yaml(['k8s/dev/server.yaml'])
|
26 |
+
k8s_resource('server', resource_deps=['k8s_setup'],labels=["chroma"], port_forwards=8000 )
|
27 |
+
k8s_yaml(['k8s/dev/coordinator.yaml'])
|
28 |
+
k8s_resource('coordinator', resource_deps=['pulsar', 'server'], labels=["chroma"])
|
29 |
+
k8s_yaml(['k8s/dev/worker.yaml'])
|
30 |
+
k8s_resource('worker', resource_deps=['coordinator'],labels=["chroma"])
|
bandit.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FILE: bandit.yaml
|
2 |
+
exclude_dirs: [ 'chromadb/test', 'bin', 'build', 'build', '.git', '.venv', 'venv', 'env','.github','examples','clients/js','.vscode' ]
|
3 |
+
tests: [ ]
|
4 |
+
skips: [ ]
|
bin/cluster-test.sh
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
function cleanup {
|
6 |
+
# Restore the previous kube context
|
7 |
+
kubectl config use-context $PREV_CHROMA_KUBE_CONTEXT
|
8 |
+
# Kill the tunnel process
|
9 |
+
kill $TUNNEL_PID
|
10 |
+
minikube delete -p chroma-test
|
11 |
+
}
|
12 |
+
|
13 |
+
trap cleanup EXIT
|
14 |
+
|
15 |
+
# Save the current kube context into a variable
|
16 |
+
export PREV_CHROMA_KUBE_CONTEXT=$(kubectl config current-context)
|
17 |
+
|
18 |
+
# Create a new minikube cluster for the test
|
19 |
+
minikube start -p chroma-test
|
20 |
+
|
21 |
+
# Add the ingress addon to the cluster
|
22 |
+
minikube addons enable ingress -p chroma-test
|
23 |
+
minikube addons enable ingress-dns -p chroma-test
|
24 |
+
|
25 |
+
# Setup docker to build inside the minikube cluster and build the image
|
26 |
+
eval $(minikube -p chroma-test docker-env)
|
27 |
+
docker build -t server:latest -f Dockerfile .
|
28 |
+
docker build -t chroma-coordinator:latest -f go/coordinator/Dockerfile .
|
29 |
+
docker build -t worker -f rust/worker/Dockerfile . --build-arg CHROMA_KUBERNETES_INTEGRATION=1
|
30 |
+
|
31 |
+
# Apply the kubernetes manifests
|
32 |
+
kubectl apply -f k8s/deployment
|
33 |
+
kubectl apply -f k8s/crd
|
34 |
+
kubectl apply -f k8s/cr
|
35 |
+
kubectl apply -f k8s/test
|
36 |
+
|
37 |
+
# Wait for the pods in the chroma namespace to be ready
|
38 |
+
kubectl wait --namespace chroma --for=condition=Ready pods --all --timeout=400s
|
39 |
+
|
40 |
+
# Run mini kube tunnel in the background to expose the service
|
41 |
+
minikube tunnel -c true -p chroma-test &
|
42 |
+
TUNNEL_PID=$!
|
43 |
+
|
44 |
+
# Wait for the tunnel to be ready. There isn't an easy way to check if the tunnel is ready. So we just wait for 10 seconds
|
45 |
+
sleep 10
|
46 |
+
|
47 |
+
export CHROMA_CLUSTER_TEST_ONLY=1
|
48 |
+
export CHROMA_SERVER_HOST=$(kubectl get svc server -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
|
49 |
+
export PULSAR_BROKER_URL=$(kubectl get svc pulsar-lb -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
|
50 |
+
export CHROMA_COORDINATOR_HOST=$(kubectl get svc coordinator-lb -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
|
51 |
+
export CHROMA_SERVER_GRPC_PORT="50051"
|
52 |
+
|
53 |
+
echo "Chroma Server is running at port $CHROMA_SERVER_HOST"
|
54 |
+
echo "Pulsar Broker is running at port $PULSAR_BROKER_URL"
|
55 |
+
echo "Chroma Coordinator is running at port $CHROMA_COORDINATOR_HOST"
|
56 |
+
|
57 |
+
echo testing: python -m pytest "$@"
|
58 |
+
python -m pytest "$@"
|
59 |
+
|
60 |
+
export CHROMA_KUBERNETES_INTEGRATION=1
|
61 |
+
cd go/coordinator
|
62 |
+
go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager
|
bin/docker_entrypoint.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
|
4 |
+
export IS_PERSISTENT=1
|
5 |
+
export CHROMA_SERVER_NOFILE=65535
|
6 |
+
args="$@"
|
7 |
+
|
8 |
+
if [[ $args =~ ^uvicorn.* ]]; then
|
9 |
+
echo "Starting server with args: $(eval echo "$args")"
|
10 |
+
echo -e "\033[31mWARNING: Please remove 'uvicorn chromadb.app:app' from your command line arguments. This is now handled by the entrypoint script."
|
11 |
+
exec $(eval echo "$args")
|
12 |
+
else
|
13 |
+
echo "Starting 'uvicorn chromadb.app:app' with args: $(eval echo "$args")"
|
14 |
+
exec uvicorn chromadb.app:app $(eval echo "$args")
|
15 |
+
fi
|
bin/generate_cloudformation.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
import json
|
3 |
+
import subprocess
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
def b64text(txt):
|
9 |
+
"""Generate Base 64 encoded CF json for a multiline string, subbing in values where appropriate"""
|
10 |
+
lines = []
|
11 |
+
for line in txt.splitlines(True):
|
12 |
+
if "${" in line:
|
13 |
+
lines.append({"Fn::Sub": line})
|
14 |
+
else:
|
15 |
+
lines.append(line)
|
16 |
+
return {"Fn::Base64": {"Fn::Join": ["", lines]}}
|
17 |
+
|
18 |
+
|
19 |
+
path = os.path.dirname(os.path.realpath(__file__))
|
20 |
+
version = subprocess.check_output(f"{path}/version").decode("ascii").strip()
|
21 |
+
|
22 |
+
with open(f"{path}/templates/docker-compose.yml") as f:
|
23 |
+
docker_compose_file = str(f.read())
|
24 |
+
|
25 |
+
|
26 |
+
cloud_config_script = """
|
27 |
+
#cloud-config
|
28 |
+
cloud_final_modules:
|
29 |
+
- [scripts-user, always]
|
30 |
+
"""
|
31 |
+
|
32 |
+
cloud_init_script = f"""
|
33 |
+
#!/bin/bash
|
34 |
+
amazon-linux-extras install docker
|
35 |
+
usermod -a -G docker ec2-user
|
36 |
+
curl -L https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/bin/docker-compose
|
37 |
+
chmod +x /usr/local/bin/docker-compose
|
38 |
+
ln -s /usr/local/bin/docker-compose /usr/bin/docker-compose
|
39 |
+
systemctl enable docker
|
40 |
+
systemctl start docker
|
41 |
+
|
42 |
+
cat << EOF > /home/ec2-user/docker-compose.yml
|
43 |
+
{docker_compose_file}
|
44 |
+
EOF
|
45 |
+
|
46 |
+
mkdir /home/ec2-user/config
|
47 |
+
|
48 |
+
docker-compose -f /home/ec2-user/docker-compose.yml up -d
|
49 |
+
"""
|
50 |
+
|
51 |
+
userdata = f"""Content-Type: multipart/mixed; boundary="//"
|
52 |
+
MIME-Version: 1.0
|
53 |
+
|
54 |
+
--//
|
55 |
+
Content-Type: text/cloud-config; charset="us-ascii"
|
56 |
+
MIME-Version: 1.0
|
57 |
+
Content-Transfer-Encoding: 7bit
|
58 |
+
Content-Disposition: attachment; filename="cloud-config.txt"
|
59 |
+
|
60 |
+
{cloud_config_script}
|
61 |
+
|
62 |
+
--//
|
63 |
+
Content-Type: text/x-shellscript; charset="us-ascii"
|
64 |
+
MIME-Version: 1.0
|
65 |
+
Content-Transfer-Encoding: 7bit
|
66 |
+
Content-Disposition: attachment; filename="userdata.txt"
|
67 |
+
|
68 |
+
{cloud_init_script}
|
69 |
+
--//--
|
70 |
+
"""
|
71 |
+
|
72 |
+
cf = {
|
73 |
+
"AWSTemplateFormatVersion": "2010-09-09",
|
74 |
+
"Description": "Create a stack that runs Chroma hosted on a single instance",
|
75 |
+
"Parameters": {
|
76 |
+
"KeyName": {
|
77 |
+
"Description": "Name of an existing EC2 KeyPair to enable SSH access to the instance",
|
78 |
+
"Type": "String",
|
79 |
+
"ConstraintDescription": "If present, must be the name of an existing EC2 KeyPair.",
|
80 |
+
"Default": "",
|
81 |
+
},
|
82 |
+
"InstanceType": {
|
83 |
+
"Description": "EC2 instance type",
|
84 |
+
"Type": "String",
|
85 |
+
"Default": "t3.small",
|
86 |
+
},
|
87 |
+
"ChromaVersion": {
|
88 |
+
"Description": "Chroma version to install",
|
89 |
+
"Type": "String",
|
90 |
+
"Default": version,
|
91 |
+
},
|
92 |
+
},
|
93 |
+
"Conditions": {
|
94 |
+
"HasKeyName": {"Fn::Not": [{"Fn::Equals": [{"Ref": "KeyName"}, ""]}]},
|
95 |
+
},
|
96 |
+
"Resources": {
|
97 |
+
"ChromaInstance": {
|
98 |
+
"Type": "AWS::EC2::Instance",
|
99 |
+
"Properties": {
|
100 |
+
"ImageId": {
|
101 |
+
"Fn::FindInMap": ["Region2AMI", {"Ref": "AWS::Region"}, "AMI"]
|
102 |
+
},
|
103 |
+
"InstanceType": {"Ref": "InstanceType"},
|
104 |
+
"UserData": b64text(userdata),
|
105 |
+
"SecurityGroupIds": [{"Ref": "ChromaInstanceSecurityGroup"}],
|
106 |
+
"KeyName": {
|
107 |
+
"Fn::If": [
|
108 |
+
"HasKeyName",
|
109 |
+
{"Ref": "KeyName"},
|
110 |
+
{"Ref": "AWS::NoValue"},
|
111 |
+
]
|
112 |
+
},
|
113 |
+
"BlockDeviceMappings": [
|
114 |
+
{
|
115 |
+
"DeviceName": {
|
116 |
+
"Fn::FindInMap": [
|
117 |
+
"Region2AMI",
|
118 |
+
{"Ref": "AWS::Region"},
|
119 |
+
"RootDeviceName",
|
120 |
+
]
|
121 |
+
},
|
122 |
+
"Ebs": {"VolumeSize": 24},
|
123 |
+
}
|
124 |
+
],
|
125 |
+
},
|
126 |
+
},
|
127 |
+
"ChromaInstanceSecurityGroup": {
|
128 |
+
"Type": "AWS::EC2::SecurityGroup",
|
129 |
+
"Properties": {
|
130 |
+
"GroupDescription": "Chroma Instance Security Group",
|
131 |
+
"SecurityGroupIngress": [
|
132 |
+
{
|
133 |
+
"IpProtocol": "tcp",
|
134 |
+
"FromPort": "22",
|
135 |
+
"ToPort": "22",
|
136 |
+
"CidrIp": "0.0.0.0/0",
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"IpProtocol": "tcp",
|
140 |
+
"FromPort": "8000",
|
141 |
+
"ToPort": "8000",
|
142 |
+
"CidrIp": "0.0.0.0/0",
|
143 |
+
},
|
144 |
+
],
|
145 |
+
},
|
146 |
+
},
|
147 |
+
},
|
148 |
+
"Outputs": {
|
149 |
+
"ServerIp": {
|
150 |
+
"Description": "IP address of the Chroma server",
|
151 |
+
"Value": {"Fn::GetAtt": ["ChromaInstance", "PublicIp"]},
|
152 |
+
}
|
153 |
+
},
|
154 |
+
"Mappings": {"Region2AMI": {}},
|
155 |
+
}
|
156 |
+
|
157 |
+
# Populate the Region2AMI mappings
|
158 |
+
regions = boto3.client("ec2", region_name="us-east-1").describe_regions()["Regions"]
|
159 |
+
for region in regions:
|
160 |
+
region_name = region["RegionName"]
|
161 |
+
ami_result = boto3.client("ec2", region_name=region_name).describe_images(
|
162 |
+
Owners=["137112412989"],
|
163 |
+
Filters=[
|
164 |
+
{"Name": "name", "Values": ["amzn2-ami-kernel-5.10-hvm-*-x86_64-gp2"]},
|
165 |
+
{"Name": "root-device-type", "Values": ["ebs"]},
|
166 |
+
{"Name": "virtualization-type", "Values": ["hvm"]},
|
167 |
+
],
|
168 |
+
)
|
169 |
+
img = ami_result["Images"][0]
|
170 |
+
ami_id = img["ImageId"]
|
171 |
+
root_device_name = img["BlockDeviceMappings"][0]["DeviceName"]
|
172 |
+
cf["Mappings"]["Region2AMI"][region_name] = {
|
173 |
+
"AMI": ami_id,
|
174 |
+
"RootDeviceName": root_device_name,
|
175 |
+
}
|
176 |
+
|
177 |
+
|
178 |
+
# Write the CF json to a file
|
179 |
+
json.dump(cf, open("/tmp/chroma.cf.json", "w"), indent=4)
|
180 |
+
|
181 |
+
# upload to S3
|
182 |
+
s3 = boto3.client("s3", region_name="us-east-1")
|
183 |
+
s3.upload_file(
|
184 |
+
"/tmp/chroma.cf.json",
|
185 |
+
"public.trychroma.com",
|
186 |
+
f"cloudformation/{version}/chroma.cf.json",
|
187 |
+
)
|
188 |
+
|
189 |
+
# Upload to s3 under /latest version only if this is a release
|
190 |
+
pattern = re.compile(r"^\d+\.\d+\.\d+$")
|
191 |
+
if pattern.match(version):
|
192 |
+
s3.upload_file(
|
193 |
+
"/tmp/chroma.cf.json",
|
194 |
+
"public.trychroma.com",
|
195 |
+
"cloudformation/latest/chroma.cf.json",
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
print(f"Version {version} is not a 3-part semver, not uploading to /latest")
|
bin/integration-test
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
export CHROMA_PORT=8000
|
6 |
+
|
7 |
+
function cleanup {
|
8 |
+
docker compose -f docker-compose.test.yml down --rmi local --volumes
|
9 |
+
rm server.htpasswd .chroma_env
|
10 |
+
}
|
11 |
+
|
12 |
+
function setup_auth {
|
13 |
+
local auth_type="$1"
|
14 |
+
case "$auth_type" in
|
15 |
+
basic)
|
16 |
+
docker run --rm --entrypoint htpasswd httpd:2 -Bbn admin admin > server.htpasswd
|
17 |
+
cat <<EOF > .chroma_env
|
18 |
+
CHROMA_SERVER_AUTH_CREDENTIALS_FILE="/chroma/server.htpasswd"
|
19 |
+
CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider"
|
20 |
+
CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.basic.BasicAuthServerProvider"
|
21 |
+
EOF
|
22 |
+
;;
|
23 |
+
token)
|
24 |
+
cat <<EOF > .chroma_env
|
25 |
+
CHROMA_SERVER_AUTH_CREDENTIALS="test-token"
|
26 |
+
CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="AUTHORIZATION"
|
27 |
+
CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
|
28 |
+
CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider"
|
29 |
+
EOF
|
30 |
+
;;
|
31 |
+
xtoken)
|
32 |
+
cat <<EOF > .chroma_env
|
33 |
+
CHROMA_SERVER_AUTH_CREDENTIALS="test-token"
|
34 |
+
CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="X_CHROMA_TOKEN"
|
35 |
+
CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
|
36 |
+
CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider"
|
37 |
+
EOF
|
38 |
+
;;
|
39 |
+
*)
|
40 |
+
echo "Unknown auth type: $auth_type"
|
41 |
+
exit 1
|
42 |
+
;;
|
43 |
+
esac
|
44 |
+
}
|
45 |
+
|
46 |
+
trap cleanup EXIT
|
47 |
+
|
48 |
+
docker compose -f docker-compose.test.yml up --build -d
|
49 |
+
|
50 |
+
export CHROMA_INTEGRATION_TEST_ONLY=1
|
51 |
+
export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI
|
52 |
+
export CHROMA_SERVER_HOST=localhost
|
53 |
+
export CHROMA_SERVER_HTTP_PORT=8000
|
54 |
+
export CHROMA_SERVER_NOFILE=65535
|
55 |
+
|
56 |
+
echo testing: python -m pytest "$@"
|
57 |
+
python -m pytest "$@"
|
58 |
+
|
59 |
+
cd clients/js
|
60 |
+
|
61 |
+
# moved off of yarn to npm to fix issues with jackspeak/cliui/string-width versions #1314
|
62 |
+
npm install
|
63 |
+
npm run test:run
|
64 |
+
|
65 |
+
docker compose down
|
66 |
+
cd ../..
|
67 |
+
for auth_type in basic token xtoken; do
|
68 |
+
echo "Testing $auth_type auth"
|
69 |
+
setup_auth "$auth_type"
|
70 |
+
cd clients/js
|
71 |
+
docker compose --env-file ../../.chroma_env -f ../../docker-compose.test-auth.yml up --build -d
|
72 |
+
yarn test:run-auth-"$auth_type"
|
73 |
+
cd ../..
|
74 |
+
docker compose down
|
75 |
+
done
|
bin/reset.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
eval $(minikube -p chroma-test docker-env)
|
4 |
+
|
5 |
+
docker build -t chroma-coordinator:latest -f go/coordinator/Dockerfile .
|
6 |
+
|
7 |
+
kubectl delete deployment coordinator -n chroma
|
8 |
+
|
9 |
+
# Apply the kubernetes manifests
|
10 |
+
kubectl apply -f k8s/deployment
|
11 |
+
kubectl apply -f k8s/crd
|
12 |
+
kubectl apply -f k8s/cr
|
13 |
+
kubectl apply -f k8s/test
|
bin/templates/docker-compose.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.9'
|
2 |
+
|
3 |
+
networks:
|
4 |
+
net:
|
5 |
+
driver: bridge
|
6 |
+
|
7 |
+
services:
|
8 |
+
server:
|
9 |
+
image: ghcr.io/chroma-core/chroma:${ChromaVersion}
|
10 |
+
volumes:
|
11 |
+
- index_data:/index_data
|
12 |
+
ports:
|
13 |
+
- 8000:8000
|
14 |
+
networks:
|
15 |
+
- net
|
16 |
+
|
17 |
+
volumes:
|
18 |
+
index_data:
|
19 |
+
driver: local
|
20 |
+
backups:
|
21 |
+
driver: local
|
bin/test-package.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Verify PIP tarball
|
4 |
+
tarball=$(readlink -f $1)
|
5 |
+
if [ -f "$tarball" ]; then
|
6 |
+
echo "Testing PIP package from tarball: $tarball"
|
7 |
+
else
|
8 |
+
echo "Could not find PIP package: $tarball"
|
9 |
+
fi
|
10 |
+
|
11 |
+
# Create temporary project dir
|
12 |
+
dir=$(mktemp -d)
|
13 |
+
|
14 |
+
echo "Building python project dir at $dir ..."
|
15 |
+
|
16 |
+
cd $dir
|
17 |
+
|
18 |
+
python3 -m venv venv
|
19 |
+
|
20 |
+
source venv/bin/activate
|
21 |
+
|
22 |
+
pip install $tarball
|
23 |
+
|
24 |
+
python -c "import chromadb; api = chromadb.Client(); print(api.heartbeat())"
|
bin/test-remote
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
# Assert first argument is present
|
6 |
+
if [ -z "$1" ]; then
|
7 |
+
echo "Usage: bin/test-remote <remote-host>"
|
8 |
+
exit 1
|
9 |
+
fi
|
10 |
+
|
11 |
+
export CHROMA_INTEGRATION_TEST_ONLY=1
|
12 |
+
export CHROMA_SERVER_HOST=$1
|
13 |
+
export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI
|
14 |
+
export CHROMA_SERVER_HTTP_PORT=8000
|
15 |
+
|
16 |
+
python -m pytest
|
bin/test.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sanity check script to ensure that the Chroma client can connect
|
2 |
+
# and is capable of recieving data.
|
3 |
+
import chromadb
|
4 |
+
|
5 |
+
# run in in-memory mode
|
6 |
+
chroma_api = chromadb.Client()
|
7 |
+
print(chroma_api.heartbeat())
|
bin/version
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
export VERSION=`python -m setuptools_scm`
|
3 |
+
|
4 |
+
if [[ -n `git status --porcelain` ]]; then
|
5 |
+
VERSION=$VERSION-dirty
|
6 |
+
fi
|
7 |
+
|
8 |
+
echo $VERSION
|
bin/windows_upgrade_sqlite.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import zipfile
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
# Used by Github Action runners to upgrade sqlite version to 3.42.0
|
9 |
+
DLL_URL = "https://www.sqlite.org/2023/sqlite-dll-win64-x64-3420000.zip"
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
# Download and extract the DLL
|
13 |
+
r = requests.get(DLL_URL)
|
14 |
+
z = zipfile.ZipFile(io.BytesIO(r.content))
|
15 |
+
z.extractall(".")
|
16 |
+
# Print current Python path
|
17 |
+
exec_path = os.path.dirname(sys.executable)
|
18 |
+
dlls_path = os.path.join(exec_path, "DLLs")
|
19 |
+
# Copy the DLL to the Python DLLs folder
|
20 |
+
shutil.copy("sqlite3.dll", dlls_path)
|
chromadb/__init__.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import logging
|
3 |
+
from chromadb.api.client import Client as ClientCreator
|
4 |
+
from chromadb.api.client import AdminClient as AdminClientCreator
|
5 |
+
from chromadb.auth.token import TokenTransportHeader
|
6 |
+
import chromadb.config
|
7 |
+
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
|
8 |
+
from chromadb.api import AdminAPI, ClientAPI
|
9 |
+
from chromadb.api.models.Collection import Collection
|
10 |
+
from chromadb.api.types import (
|
11 |
+
CollectionMetadata,
|
12 |
+
Documents,
|
13 |
+
EmbeddingFunction,
|
14 |
+
Embeddings,
|
15 |
+
IDs,
|
16 |
+
Include,
|
17 |
+
Metadata,
|
18 |
+
Where,
|
19 |
+
QueryResult,
|
20 |
+
GetResult,
|
21 |
+
WhereDocument,
|
22 |
+
UpdateCollectionMetadata,
|
23 |
+
)
|
24 |
+
|
25 |
+
# Re-export types from chromadb.types
|
26 |
+
__all__ = [
|
27 |
+
"Collection",
|
28 |
+
"Metadata",
|
29 |
+
"Where",
|
30 |
+
"WhereDocument",
|
31 |
+
"Documents",
|
32 |
+
"IDs",
|
33 |
+
"Embeddings",
|
34 |
+
"EmbeddingFunction",
|
35 |
+
"Include",
|
36 |
+
"CollectionMetadata",
|
37 |
+
"UpdateCollectionMetadata",
|
38 |
+
"QueryResult",
|
39 |
+
"GetResult",
|
40 |
+
]
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
__settings = Settings()
|
45 |
+
|
46 |
+
__version__ = "0.4.22"
|
47 |
+
|
48 |
+
# Workaround to deal with Colab's old sqlite3 version
|
49 |
+
try:
|
50 |
+
import google.colab # noqa: F401
|
51 |
+
|
52 |
+
IN_COLAB = True
|
53 |
+
except ImportError:
|
54 |
+
IN_COLAB = False
|
55 |
+
|
56 |
+
is_client = False
|
57 |
+
try:
|
58 |
+
from chromadb.is_thin_client import is_thin_client
|
59 |
+
|
60 |
+
is_client = is_thin_client
|
61 |
+
except ImportError:
|
62 |
+
is_client = False
|
63 |
+
|
64 |
+
if not is_client:
|
65 |
+
import sqlite3
|
66 |
+
|
67 |
+
if sqlite3.sqlite_version_info < (3, 35, 0):
|
68 |
+
if IN_COLAB:
|
69 |
+
# In Colab, hotswap to pysqlite-binary if it's too old
|
70 |
+
import subprocess
|
71 |
+
import sys
|
72 |
+
|
73 |
+
subprocess.check_call(
|
74 |
+
[sys.executable, "-m", "pip", "install", "pysqlite3-binary"]
|
75 |
+
)
|
76 |
+
__import__("pysqlite3")
|
77 |
+
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
78 |
+
else:
|
79 |
+
raise RuntimeError(
|
80 |
+
"\033[91mYour system has an unsupported version of sqlite3. Chroma \
|
81 |
+
requires sqlite3 >= 3.35.0.\033[0m\n"
|
82 |
+
"\033[94mPlease visit \
|
83 |
+
https://docs.trychroma.com/troubleshooting#sqlite to learn how \
|
84 |
+
to upgrade.\033[0m"
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def configure(**kwargs) -> None: # type: ignore
|
89 |
+
"""Override Chroma's default settings, environment variables or .env files"""
|
90 |
+
global __settings
|
91 |
+
__settings = chromadb.config.Settings(**kwargs)
|
92 |
+
|
93 |
+
|
94 |
+
def get_settings() -> Settings:
|
95 |
+
return __settings
|
96 |
+
|
97 |
+
|
98 |
+
def EphemeralClient(
|
99 |
+
settings: Optional[Settings] = None,
|
100 |
+
tenant: str = DEFAULT_TENANT,
|
101 |
+
database: str = DEFAULT_DATABASE,
|
102 |
+
) -> ClientAPI:
|
103 |
+
"""
|
104 |
+
Creates an in-memory instance of Chroma. This is useful for testing and
|
105 |
+
development, but not recommended for production use.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
tenant: The tenant to use for this client. Defaults to the default tenant.
|
109 |
+
database: The database to use for this client. Defaults to the default database.
|
110 |
+
"""
|
111 |
+
if settings is None:
|
112 |
+
settings = Settings()
|
113 |
+
settings.is_persistent = False
|
114 |
+
|
115 |
+
return ClientCreator(settings=settings, tenant=tenant, database=database)
|
116 |
+
|
117 |
+
|
118 |
+
def PersistentClient(
|
119 |
+
path: str = "./chroma",
|
120 |
+
settings: Optional[Settings] = None,
|
121 |
+
tenant: str = DEFAULT_TENANT,
|
122 |
+
database: str = DEFAULT_DATABASE,
|
123 |
+
) -> ClientAPI:
|
124 |
+
"""
|
125 |
+
Creates a persistent instance of Chroma that saves to disk. This is useful for
|
126 |
+
testing and development, but not recommended for production use.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
path: The directory to save Chroma's data to. Defaults to "./chroma".
|
130 |
+
tenant: The tenant to use for this client. Defaults to the default tenant.
|
131 |
+
database: The database to use for this client. Defaults to the default database.
|
132 |
+
"""
|
133 |
+
if settings is None:
|
134 |
+
settings = Settings()
|
135 |
+
settings.persist_directory = path
|
136 |
+
settings.is_persistent = True
|
137 |
+
|
138 |
+
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
139 |
+
|
140 |
+
|
141 |
+
def HttpClient(
|
142 |
+
host: str = "localhost",
|
143 |
+
port: str = "8000",
|
144 |
+
ssl: bool = False,
|
145 |
+
headers: Optional[Dict[str, str]] = None,
|
146 |
+
settings: Optional[Settings] = None,
|
147 |
+
tenant: str = DEFAULT_TENANT,
|
148 |
+
database: str = DEFAULT_DATABASE,
|
149 |
+
) -> ClientAPI:
|
150 |
+
"""
|
151 |
+
Creates a client that connects to a remote Chroma server. This supports
|
152 |
+
many clients connecting to the same server, and is the recommended way to
|
153 |
+
use Chroma in production.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
host: The hostname of the Chroma server. Defaults to "localhost".
|
157 |
+
port: The port of the Chroma server. Defaults to "8000".
|
158 |
+
ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
|
159 |
+
headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
|
160 |
+
settings: A dictionary of settings to communicate with the chroma server.
|
161 |
+
tenant: The tenant to use for this client. Defaults to the default tenant.
|
162 |
+
database: The database to use for this client. Defaults to the default database.
|
163 |
+
"""
|
164 |
+
|
165 |
+
if settings is None:
|
166 |
+
settings = Settings()
|
167 |
+
|
168 |
+
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
169 |
+
if settings.chroma_server_host and settings.chroma_server_host != host:
|
170 |
+
raise ValueError(
|
171 |
+
f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
|
172 |
+
)
|
173 |
+
settings.chroma_server_host = host
|
174 |
+
if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
|
175 |
+
raise ValueError(
|
176 |
+
f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
|
177 |
+
)
|
178 |
+
settings.chroma_server_http_port = port
|
179 |
+
settings.chroma_server_ssl_enabled = ssl
|
180 |
+
settings.chroma_server_headers = headers
|
181 |
+
|
182 |
+
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
183 |
+
|
184 |
+
|
185 |
+
def CloudClient(
|
186 |
+
tenant: str,
|
187 |
+
database: str,
|
188 |
+
api_key: Optional[str] = None,
|
189 |
+
settings: Optional[Settings] = None,
|
190 |
+
*, # Following arguments are keyword-only, intended for testing only.
|
191 |
+
cloud_host: str = "api.trychroma.com",
|
192 |
+
cloud_port: str = "8000",
|
193 |
+
enable_ssl: bool = True,
|
194 |
+
) -> ClientAPI:
|
195 |
+
"""
|
196 |
+
Creates a client to connect to a tennant and database on the Chroma cloud.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
tenant: The tenant to use for this client.
|
200 |
+
database: The database to use for this client.
|
201 |
+
api_key: The api key to use for this client.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# If no API key is provided, try to load it from the environment variable
|
205 |
+
if api_key is None:
|
206 |
+
import os
|
207 |
+
|
208 |
+
api_key = os.environ.get("CHROMA_API_KEY")
|
209 |
+
|
210 |
+
# If the API key is still not provided, prompt the user
|
211 |
+
if api_key is None:
|
212 |
+
print(
|
213 |
+
"\033[93mDon't have an API key?\033[0m Get one at https://app.trychroma.com"
|
214 |
+
)
|
215 |
+
api_key = input("Please enter your Chroma API key: ")
|
216 |
+
|
217 |
+
if settings is None:
|
218 |
+
settings = Settings()
|
219 |
+
|
220 |
+
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
221 |
+
settings.chroma_server_host = cloud_host
|
222 |
+
settings.chroma_server_http_port = cloud_port
|
223 |
+
# Always use SSL for cloud
|
224 |
+
settings.chroma_server_ssl_enabled = enable_ssl
|
225 |
+
|
226 |
+
settings.chroma_client_auth_provider = "chromadb.auth.token.TokenAuthClientProvider"
|
227 |
+
settings.chroma_client_auth_credentials = api_key
|
228 |
+
settings.chroma_client_auth_token_transport_header = (
|
229 |
+
TokenTransportHeader.X_CHROMA_TOKEN.name
|
230 |
+
)
|
231 |
+
|
232 |
+
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
233 |
+
|
234 |
+
|
235 |
+
def Client(
|
236 |
+
settings: Settings = __settings,
|
237 |
+
tenant: str = DEFAULT_TENANT,
|
238 |
+
database: str = DEFAULT_DATABASE,
|
239 |
+
) -> ClientAPI:
|
240 |
+
"""
|
241 |
+
Return a running chroma.API instance
|
242 |
+
|
243 |
+
tenant: The tenant to use for this client. Defaults to the default tenant.
|
244 |
+
database: The database to use for this client. Defaults to the default database.
|
245 |
+
|
246 |
+
"""
|
247 |
+
|
248 |
+
return ClientCreator(tenant=tenant, database=database, settings=settings)
|
249 |
+
|
250 |
+
|
251 |
+
def AdminClient(settings: Settings = Settings()) -> AdminAPI:
|
252 |
+
"""
|
253 |
+
|
254 |
+
Creates an admin client that can be used to create tenants and databases.
|
255 |
+
|
256 |
+
"""
|
257 |
+
return AdminClientCreator(settings=settings)
|
chromadb/api/__init__.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Sequence, Optional
|
3 |
+
from uuid import UUID
|
4 |
+
|
5 |
+
from overrides import override
|
6 |
+
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
|
7 |
+
from chromadb.api.models.Collection import Collection
|
8 |
+
from chromadb.api.types import (
|
9 |
+
CollectionMetadata,
|
10 |
+
Documents,
|
11 |
+
Embeddable,
|
12 |
+
EmbeddingFunction,
|
13 |
+
DataLoader,
|
14 |
+
Embeddings,
|
15 |
+
IDs,
|
16 |
+
Include,
|
17 |
+
Loadable,
|
18 |
+
Metadatas,
|
19 |
+
URIs,
|
20 |
+
Where,
|
21 |
+
QueryResult,
|
22 |
+
GetResult,
|
23 |
+
WhereDocument,
|
24 |
+
)
|
25 |
+
from chromadb.config import Component, Settings
|
26 |
+
from chromadb.types import Database, Tenant
|
27 |
+
import chromadb.utils.embedding_functions as ef
|
28 |
+
|
29 |
+
|
30 |
+
class BaseAPI(ABC):
|
31 |
+
@abstractmethod
|
32 |
+
def heartbeat(self) -> int:
|
33 |
+
"""Get the current time in nanoseconds since epoch.
|
34 |
+
Used to check if the server is alive.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
int: The current time in nanoseconds since epoch
|
38 |
+
|
39 |
+
"""
|
40 |
+
pass
|
41 |
+
|
42 |
+
#
|
43 |
+
# COLLECTION METHODS
|
44 |
+
#
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def list_collections(
|
48 |
+
self,
|
49 |
+
limit: Optional[int] = None,
|
50 |
+
offset: Optional[int] = None,
|
51 |
+
) -> Sequence[Collection]:
|
52 |
+
"""List all collections.
|
53 |
+
Args:
|
54 |
+
limit: The maximum number of entries to return. Defaults to None.
|
55 |
+
offset: The number of entries to skip before returning. Defaults to None.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Sequence[Collection]: A list of collections
|
59 |
+
|
60 |
+
Examples:
|
61 |
+
```python
|
62 |
+
client.list_collections()
|
63 |
+
# [collection(name="my_collection", metadata={})]
|
64 |
+
```
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
@abstractmethod
|
69 |
+
def count_collections(self) -> int:
|
70 |
+
"""Count the number of collections.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
int: The number of collections.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
```python
|
77 |
+
client.count_collections()
|
78 |
+
# 1
|
79 |
+
```
|
80 |
+
"""
|
81 |
+
pass
|
82 |
+
|
83 |
+
@abstractmethod
|
84 |
+
def create_collection(
|
85 |
+
self,
|
86 |
+
name: str,
|
87 |
+
metadata: Optional[CollectionMetadata] = None,
|
88 |
+
embedding_function: Optional[
|
89 |
+
EmbeddingFunction[Embeddable]
|
90 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
91 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
92 |
+
get_or_create: bool = False,
|
93 |
+
) -> Collection:
|
94 |
+
"""Create a new collection with the given name and metadata.
|
95 |
+
Args:
|
96 |
+
name: The name of the collection to create.
|
97 |
+
metadata: Optional metadata to associate with the collection.
|
98 |
+
embedding_function: Optional function to use to embed documents.
|
99 |
+
Uses the default embedding function if not provided.
|
100 |
+
get_or_create: If True, return the existing collection if it exists.
|
101 |
+
data_loader: Optional function to use to load records (documents, images, etc.)
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Collection: The newly created collection.
|
105 |
+
|
106 |
+
Raises:
|
107 |
+
ValueError: If the collection already exists and get_or_create is False.
|
108 |
+
ValueError: If the collection name is invalid.
|
109 |
+
|
110 |
+
Examples:
|
111 |
+
```python
|
112 |
+
client.create_collection("my_collection")
|
113 |
+
# collection(name="my_collection", metadata={})
|
114 |
+
|
115 |
+
client.create_collection("my_collection", metadata={"foo": "bar"})
|
116 |
+
# collection(name="my_collection", metadata={"foo": "bar"})
|
117 |
+
```
|
118 |
+
"""
|
119 |
+
pass
|
120 |
+
|
121 |
+
@abstractmethod
|
122 |
+
def get_collection(
|
123 |
+
self,
|
124 |
+
name: str,
|
125 |
+
id: Optional[UUID] = None,
|
126 |
+
embedding_function: Optional[
|
127 |
+
EmbeddingFunction[Embeddable]
|
128 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
129 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
130 |
+
) -> Collection:
|
131 |
+
"""Get a collection with the given name.
|
132 |
+
Args:
|
133 |
+
id: The UUID of the collection to get. Id and Name are simultaneously used for lookup if provided.
|
134 |
+
name: The name of the collection to get
|
135 |
+
embedding_function: Optional function to use to embed documents.
|
136 |
+
Uses the default embedding function if not provided.
|
137 |
+
data_loader: Optional function to use to load records (documents, images, etc.)
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Collection: The collection
|
141 |
+
|
142 |
+
Raises:
|
143 |
+
ValueError: If the collection does not exist
|
144 |
+
|
145 |
+
Examples:
|
146 |
+
```python
|
147 |
+
client.get_collection("my_collection")
|
148 |
+
# collection(name="my_collection", metadata={})
|
149 |
+
```
|
150 |
+
"""
|
151 |
+
pass
|
152 |
+
|
153 |
+
@abstractmethod
|
154 |
+
def get_or_create_collection(
|
155 |
+
self,
|
156 |
+
name: str,
|
157 |
+
metadata: Optional[CollectionMetadata] = None,
|
158 |
+
embedding_function: Optional[
|
159 |
+
EmbeddingFunction[Embeddable]
|
160 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
161 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
162 |
+
) -> Collection:
|
163 |
+
"""Get or create a collection with the given name and metadata.
|
164 |
+
Args:
|
165 |
+
name: The name of the collection to get or create
|
166 |
+
metadata: Optional metadata to associate with the collection. If
|
167 |
+
the collection alredy exists, the metadata will be updated if
|
168 |
+
provided and not None. If the collection does not exist, the
|
169 |
+
new collection will be created with the provided metadata.
|
170 |
+
embedding_function: Optional function to use to embed documents
|
171 |
+
data_loader: Optional function to use to load records (documents, images, etc.)
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
The collection
|
175 |
+
|
176 |
+
Examples:
|
177 |
+
```python
|
178 |
+
client.get_or_create_collection("my_collection")
|
179 |
+
# collection(name="my_collection", metadata={})
|
180 |
+
```
|
181 |
+
"""
|
182 |
+
pass
|
183 |
+
|
184 |
+
def _modify(
|
185 |
+
self,
|
186 |
+
id: UUID,
|
187 |
+
new_name: Optional[str] = None,
|
188 |
+
new_metadata: Optional[CollectionMetadata] = None,
|
189 |
+
) -> None:
|
190 |
+
"""[Internal] Modify a collection by UUID. Can update the name and/or metadata.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
id: The internal UUID of the collection to modify.
|
194 |
+
new_name: The new name of the collection.
|
195 |
+
If None, the existing name will remain. Defaults to None.
|
196 |
+
new_metadata: The new metadata to associate with the collection.
|
197 |
+
Defaults to None.
|
198 |
+
"""
|
199 |
+
pass
|
200 |
+
|
201 |
+
@abstractmethod
|
202 |
+
def delete_collection(
|
203 |
+
self,
|
204 |
+
name: str,
|
205 |
+
) -> None:
|
206 |
+
"""Delete a collection with the given name.
|
207 |
+
Args:
|
208 |
+
name: The name of the collection to delete.
|
209 |
+
|
210 |
+
Raises:
|
211 |
+
ValueError: If the collection does not exist.
|
212 |
+
|
213 |
+
Examples:
|
214 |
+
```python
|
215 |
+
client.delete_collection("my_collection")
|
216 |
+
```
|
217 |
+
"""
|
218 |
+
pass
|
219 |
+
|
220 |
+
#
|
221 |
+
# ITEM METHODS
|
222 |
+
#
|
223 |
+
|
224 |
+
@abstractmethod
|
225 |
+
def _add(
|
226 |
+
self,
|
227 |
+
ids: IDs,
|
228 |
+
collection_id: UUID,
|
229 |
+
embeddings: Embeddings,
|
230 |
+
metadatas: Optional[Metadatas] = None,
|
231 |
+
documents: Optional[Documents] = None,
|
232 |
+
uris: Optional[URIs] = None,
|
233 |
+
) -> bool:
|
234 |
+
"""[Internal] Add embeddings to a collection specified by UUID.
|
235 |
+
If (some) ids already exist, only the new embeddings will be added.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
ids: The ids to associate with the embeddings.
|
239 |
+
collection_id: The UUID of the collection to add the embeddings to.
|
240 |
+
embedding: The sequence of embeddings to add.
|
241 |
+
metadata: The metadata to associate with the embeddings. Defaults to None.
|
242 |
+
documents: The documents to associate with the embeddings. Defaults to None.
|
243 |
+
uris: URIs of data sources for each embedding. Defaults to None.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
True if the embeddings were added successfully.
|
247 |
+
"""
|
248 |
+
pass
|
249 |
+
|
250 |
+
@abstractmethod
|
251 |
+
def _update(
|
252 |
+
self,
|
253 |
+
collection_id: UUID,
|
254 |
+
ids: IDs,
|
255 |
+
embeddings: Optional[Embeddings] = None,
|
256 |
+
metadatas: Optional[Metadatas] = None,
|
257 |
+
documents: Optional[Documents] = None,
|
258 |
+
uris: Optional[URIs] = None,
|
259 |
+
) -> bool:
|
260 |
+
"""[Internal] Update entries in a collection specified by UUID.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
collection_id: The UUID of the collection to update the embeddings in.
|
264 |
+
ids: The IDs of the entries to update.
|
265 |
+
embeddings: The sequence of embeddings to update. Defaults to None.
|
266 |
+
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
267 |
+
documents: The documents to associate with the embeddings. Defaults to None.
|
268 |
+
uris: URIs of data sources for each embedding. Defaults to None.
|
269 |
+
Returns:
|
270 |
+
True if the embeddings were updated successfully.
|
271 |
+
"""
|
272 |
+
pass
|
273 |
+
|
274 |
+
@abstractmethod
|
275 |
+
def _upsert(
|
276 |
+
self,
|
277 |
+
collection_id: UUID,
|
278 |
+
ids: IDs,
|
279 |
+
embeddings: Embeddings,
|
280 |
+
metadatas: Optional[Metadatas] = None,
|
281 |
+
documents: Optional[Documents] = None,
|
282 |
+
uris: Optional[URIs] = None,
|
283 |
+
) -> bool:
|
284 |
+
"""[Internal] Add or update entries in the a collection specified by UUID.
|
285 |
+
If an entry with the same id already exists, it will be updated,
|
286 |
+
otherwise it will be added.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
collection_id: The collection to add the embeddings to
|
290 |
+
ids: The ids to associate with the embeddings. Defaults to None.
|
291 |
+
embeddings: The sequence of embeddings to add
|
292 |
+
metadatas: The metadata to associate with the embeddings. Defaults to None.
|
293 |
+
documents: The documents to associate with the embeddings. Defaults to None.
|
294 |
+
uris: URIs of data sources for each embedding. Defaults to None.
|
295 |
+
"""
|
296 |
+
pass
|
297 |
+
|
298 |
+
@abstractmethod
|
299 |
+
def _count(self, collection_id: UUID) -> int:
|
300 |
+
"""[Internal] Returns the number of entries in a collection specified by UUID.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
collection_id: The UUID of the collection to count the embeddings in.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
int: The number of embeddings in the collection
|
307 |
+
|
308 |
+
"""
|
309 |
+
pass
|
310 |
+
|
311 |
+
@abstractmethod
|
312 |
+
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
313 |
+
"""[Internal] Returns the first n entries in a collection specified by UUID.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
collection_id: The UUID of the collection to peek into.
|
317 |
+
n: The number of entries to peek. Defaults to 10.
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
GetResult: The first n entries in the collection.
|
321 |
+
|
322 |
+
"""
|
323 |
+
|
324 |
+
pass
|
325 |
+
|
326 |
+
@abstractmethod
|
327 |
+
def _get(
|
328 |
+
self,
|
329 |
+
collection_id: UUID,
|
330 |
+
ids: Optional[IDs] = None,
|
331 |
+
where: Optional[Where] = {},
|
332 |
+
sort: Optional[str] = None,
|
333 |
+
limit: Optional[int] = None,
|
334 |
+
offset: Optional[int] = None,
|
335 |
+
page: Optional[int] = None,
|
336 |
+
page_size: Optional[int] = None,
|
337 |
+
where_document: Optional[WhereDocument] = {},
|
338 |
+
include: Include = ["embeddings", "metadatas", "documents"],
|
339 |
+
) -> GetResult:
|
340 |
+
"""[Internal] Returns entries from a collection specified by UUID.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
ids: The IDs of the entries to get. Defaults to None.
|
344 |
+
where: Conditional filtering on metadata. Defaults to {}.
|
345 |
+
sort: The column to sort the entries by. Defaults to None.
|
346 |
+
limit: The maximum number of entries to return. Defaults to None.
|
347 |
+
offset: The number of entries to skip before returning. Defaults to None.
|
348 |
+
page: The page number to return. Defaults to None.
|
349 |
+
page_size: The number of entries to return per page. Defaults to None.
|
350 |
+
where_document: Conditional filtering on documents. Defaults to {}.
|
351 |
+
include: The fields to include in the response.
|
352 |
+
Defaults to ["embeddings", "metadatas", "documents"].
|
353 |
+
Returns:
|
354 |
+
GetResult: The entries in the collection that match the query.
|
355 |
+
|
356 |
+
"""
|
357 |
+
pass
|
358 |
+
|
359 |
+
@abstractmethod
|
360 |
+
def _delete(
|
361 |
+
self,
|
362 |
+
collection_id: UUID,
|
363 |
+
ids: Optional[IDs],
|
364 |
+
where: Optional[Where] = {},
|
365 |
+
where_document: Optional[WhereDocument] = {},
|
366 |
+
) -> IDs:
|
367 |
+
"""[Internal] Deletes entries from a collection specified by UUID.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
collection_id: The UUID of the collection to delete the entries from.
|
371 |
+
ids: The IDs of the entries to delete. Defaults to None.
|
372 |
+
where: Conditional filtering on metadata. Defaults to {}.
|
373 |
+
where_document: Conditional filtering on documents. Defaults to {}.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
IDs: The list of IDs of the entries that were deleted.
|
377 |
+
"""
|
378 |
+
pass
|
379 |
+
|
380 |
+
@abstractmethod
|
381 |
+
def _query(
|
382 |
+
self,
|
383 |
+
collection_id: UUID,
|
384 |
+
query_embeddings: Embeddings,
|
385 |
+
n_results: int = 10,
|
386 |
+
where: Where = {},
|
387 |
+
where_document: WhereDocument = {},
|
388 |
+
include: Include = ["embeddings", "metadatas", "documents", "distances"],
|
389 |
+
) -> QueryResult:
|
390 |
+
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
collection_id: The UUID of the collection to query.
|
394 |
+
query_embeddings: The embeddings to use as the query.
|
395 |
+
n_results: The number of results to return. Defaults to 10.
|
396 |
+
where: Conditional filtering on metadata. Defaults to {}.
|
397 |
+
where_document: Conditional filtering on documents. Defaults to {}.
|
398 |
+
include: The fields to include in the response.
|
399 |
+
Defaults to ["embeddings", "metadatas", "documents", "distances"].
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
QueryResult: The results of the query.
|
403 |
+
"""
|
404 |
+
pass
|
405 |
+
|
406 |
+
@abstractmethod
|
407 |
+
def reset(self) -> bool:
|
408 |
+
"""Resets the database. This will delete all collections and entries.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
bool: True if the database was reset successfully.
|
412 |
+
"""
|
413 |
+
pass
|
414 |
+
|
415 |
+
@abstractmethod
|
416 |
+
def get_version(self) -> str:
|
417 |
+
"""Get the version of Chroma.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
str: The version of Chroma
|
421 |
+
|
422 |
+
"""
|
423 |
+
pass
|
424 |
+
|
425 |
+
@abstractmethod
|
426 |
+
def get_settings(self) -> Settings:
|
427 |
+
"""Get the settings used to initialize.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
Settings: The settings used to initialize.
|
431 |
+
|
432 |
+
"""
|
433 |
+
pass
|
434 |
+
|
435 |
+
@property
|
436 |
+
@abstractmethod
|
437 |
+
def max_batch_size(self) -> int:
|
438 |
+
"""Return the maximum number of records that can be submitted in a single call
|
439 |
+
to submit_embeddings."""
|
440 |
+
pass
|
441 |
+
|
442 |
+
|
443 |
+
class ClientAPI(BaseAPI, ABC):
|
444 |
+
tenant: str
|
445 |
+
database: str
|
446 |
+
|
447 |
+
@abstractmethod
|
448 |
+
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
449 |
+
"""Set the tenant and database for the client. Raises an error if the tenant or
|
450 |
+
database does not exist.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
tenant: The tenant to set.
|
454 |
+
database: The database to set.
|
455 |
+
|
456 |
+
"""
|
457 |
+
pass
|
458 |
+
|
459 |
+
@abstractmethod
|
460 |
+
def set_database(self, database: str) -> None:
|
461 |
+
"""Set the database for the client. Raises an error if the database does not exist.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
database: The database to set.
|
465 |
+
|
466 |
+
"""
|
467 |
+
pass
|
468 |
+
|
469 |
+
@staticmethod
|
470 |
+
@abstractmethod
|
471 |
+
def clear_system_cache() -> None:
|
472 |
+
"""Clear the system cache so that new systems can be created for an existing path.
|
473 |
+
This should only be used for testing purposes."""
|
474 |
+
pass
|
475 |
+
|
476 |
+
|
477 |
+
class AdminAPI(ABC):
|
478 |
+
@abstractmethod
|
479 |
+
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
480 |
+
"""Create a new database. Raises an error if the database already exists.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
database: The name of the database to create.
|
484 |
+
|
485 |
+
"""
|
486 |
+
pass
|
487 |
+
|
488 |
+
@abstractmethod
|
489 |
+
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
490 |
+
"""Get a database. Raises an error if the database does not exist.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
database: The name of the database to get.
|
494 |
+
tenant: The tenant of the database to get.
|
495 |
+
|
496 |
+
"""
|
497 |
+
pass
|
498 |
+
|
499 |
+
@abstractmethod
|
500 |
+
def create_tenant(self, name: str) -> None:
|
501 |
+
"""Create a new tenant. Raises an error if the tenant already exists.
|
502 |
+
|
503 |
+
Args:
|
504 |
+
tenant: The name of the tenant to create.
|
505 |
+
|
506 |
+
"""
|
507 |
+
pass
|
508 |
+
|
509 |
+
@abstractmethod
|
510 |
+
def get_tenant(self, name: str) -> Tenant:
|
511 |
+
"""Get a tenant. Raises an error if the tenant does not exist.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
tenant: The name of the tenant to get.
|
515 |
+
|
516 |
+
"""
|
517 |
+
pass
|
518 |
+
|
519 |
+
|
520 |
+
class ServerAPI(BaseAPI, AdminAPI, Component):
|
521 |
+
"""An API instance that extends the relevant Base API methods by passing
|
522 |
+
in a tenant and database. This is the root component of the Chroma System"""
|
523 |
+
|
524 |
+
@abstractmethod
|
525 |
+
@override
|
526 |
+
def list_collections(
|
527 |
+
self,
|
528 |
+
limit: Optional[int] = None,
|
529 |
+
offset: Optional[int] = None,
|
530 |
+
tenant: str = DEFAULT_TENANT,
|
531 |
+
database: str = DEFAULT_DATABASE,
|
532 |
+
) -> Sequence[Collection]:
|
533 |
+
pass
|
534 |
+
|
535 |
+
@abstractmethod
|
536 |
+
@override
|
537 |
+
def count_collections(
|
538 |
+
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
539 |
+
) -> int:
|
540 |
+
pass
|
541 |
+
|
542 |
+
@abstractmethod
|
543 |
+
@override
|
544 |
+
def create_collection(
|
545 |
+
self,
|
546 |
+
name: str,
|
547 |
+
metadata: Optional[CollectionMetadata] = None,
|
548 |
+
embedding_function: Optional[
|
549 |
+
EmbeddingFunction[Embeddable]
|
550 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
551 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
552 |
+
get_or_create: bool = False,
|
553 |
+
tenant: str = DEFAULT_TENANT,
|
554 |
+
database: str = DEFAULT_DATABASE,
|
555 |
+
) -> Collection:
|
556 |
+
pass
|
557 |
+
|
558 |
+
@abstractmethod
|
559 |
+
@override
|
560 |
+
def get_collection(
|
561 |
+
self,
|
562 |
+
name: str,
|
563 |
+
id: Optional[UUID] = None,
|
564 |
+
embedding_function: Optional[
|
565 |
+
EmbeddingFunction[Embeddable]
|
566 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
567 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
568 |
+
tenant: str = DEFAULT_TENANT,
|
569 |
+
database: str = DEFAULT_DATABASE,
|
570 |
+
) -> Collection:
|
571 |
+
pass
|
572 |
+
|
573 |
+
@abstractmethod
|
574 |
+
@override
|
575 |
+
def get_or_create_collection(
|
576 |
+
self,
|
577 |
+
name: str,
|
578 |
+
metadata: Optional[CollectionMetadata] = None,
|
579 |
+
embedding_function: Optional[
|
580 |
+
EmbeddingFunction[Embeddable]
|
581 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
582 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
583 |
+
tenant: str = DEFAULT_TENANT,
|
584 |
+
database: str = DEFAULT_DATABASE,
|
585 |
+
) -> Collection:
|
586 |
+
pass
|
587 |
+
|
588 |
+
@abstractmethod
|
589 |
+
@override
|
590 |
+
def delete_collection(
|
591 |
+
self,
|
592 |
+
name: str,
|
593 |
+
tenant: str = DEFAULT_TENANT,
|
594 |
+
database: str = DEFAULT_DATABASE,
|
595 |
+
) -> None:
|
596 |
+
pass
|
chromadb/api/client.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import ClassVar, Dict, Optional, Sequence
|
2 |
+
from uuid import UUID
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
from overrides import override
|
6 |
+
import requests
|
7 |
+
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
|
8 |
+
from chromadb.api.types import (
|
9 |
+
CollectionMetadata,
|
10 |
+
DataLoader,
|
11 |
+
Documents,
|
12 |
+
Embeddable,
|
13 |
+
EmbeddingFunction,
|
14 |
+
Embeddings,
|
15 |
+
GetResult,
|
16 |
+
IDs,
|
17 |
+
Include,
|
18 |
+
Loadable,
|
19 |
+
Metadatas,
|
20 |
+
QueryResult,
|
21 |
+
URIs,
|
22 |
+
)
|
23 |
+
from chromadb.config import Settings, System
|
24 |
+
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
|
25 |
+
from chromadb.api.models.Collection import Collection
|
26 |
+
from chromadb.errors import ChromaError
|
27 |
+
from chromadb.telemetry.product import ProductTelemetryClient
|
28 |
+
from chromadb.telemetry.product.events import ClientStartEvent
|
29 |
+
from chromadb.types import Database, Tenant, Where, WhereDocument
|
30 |
+
import chromadb.utils.embedding_functions as ef
|
31 |
+
|
32 |
+
|
33 |
+
class SharedSystemClient:
|
34 |
+
_identifer_to_system: ClassVar[Dict[str, System]] = {}
|
35 |
+
_identifier: str
|
36 |
+
|
37 |
+
# region Initialization
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
settings: Settings = Settings(),
|
41 |
+
) -> None:
|
42 |
+
self._identifier = SharedSystemClient._get_identifier_from_settings(settings)
|
43 |
+
SharedSystemClient._create_system_if_not_exists(self._identifier, settings)
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def _create_system_if_not_exists(
|
47 |
+
cls, identifier: str, settings: Settings
|
48 |
+
) -> System:
|
49 |
+
if identifier not in cls._identifer_to_system:
|
50 |
+
new_system = System(settings)
|
51 |
+
cls._identifer_to_system[identifier] = new_system
|
52 |
+
|
53 |
+
new_system.instance(ProductTelemetryClient)
|
54 |
+
new_system.instance(ServerAPI)
|
55 |
+
|
56 |
+
new_system.start()
|
57 |
+
else:
|
58 |
+
previous_system = cls._identifer_to_system[identifier]
|
59 |
+
|
60 |
+
# For now, the settings must match
|
61 |
+
if previous_system.settings != settings:
|
62 |
+
raise ValueError(
|
63 |
+
f"An instance of Chroma already exists for {identifier} with different settings"
|
64 |
+
)
|
65 |
+
|
66 |
+
return cls._identifer_to_system[identifier]
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def _get_identifier_from_settings(settings: Settings) -> str:
|
70 |
+
identifier = ""
|
71 |
+
api_impl = settings.chroma_api_impl
|
72 |
+
|
73 |
+
if api_impl is None:
|
74 |
+
raise ValueError("Chroma API implementation must be set in settings")
|
75 |
+
elif api_impl == "chromadb.api.segment.SegmentAPI":
|
76 |
+
if settings.is_persistent:
|
77 |
+
identifier = settings.persist_directory
|
78 |
+
else:
|
79 |
+
identifier = (
|
80 |
+
"ephemeral" # TODO: support pathing and multiple ephemeral clients
|
81 |
+
)
|
82 |
+
elif api_impl == "chromadb.api.fastapi.FastAPI":
|
83 |
+
# FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens
|
84 |
+
identifier = str(uuid.uuid4())
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Unsupported Chroma API implementation {api_impl}")
|
87 |
+
|
88 |
+
return identifier
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def _populate_data_from_system(system: System) -> str:
|
92 |
+
identifier = SharedSystemClient._get_identifier_from_settings(system.settings)
|
93 |
+
SharedSystemClient._identifer_to_system[identifier] = system
|
94 |
+
return identifier
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def from_system(cls, system: System) -> "SharedSystemClient":
|
98 |
+
"""Create a client from an existing system. This is useful for testing and debugging."""
|
99 |
+
|
100 |
+
SharedSystemClient._populate_data_from_system(system)
|
101 |
+
instance = cls(system.settings)
|
102 |
+
return instance
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def clear_system_cache() -> None:
|
106 |
+
SharedSystemClient._identifer_to_system = {}
|
107 |
+
|
108 |
+
@property
|
109 |
+
def _system(self) -> System:
|
110 |
+
return SharedSystemClient._identifer_to_system[self._identifier]
|
111 |
+
|
112 |
+
# endregion
|
113 |
+
|
114 |
+
|
115 |
+
class Client(SharedSystemClient, ClientAPI):
|
116 |
+
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
|
117 |
+
A client internally stores its tenant and database and proxies calls to a
|
118 |
+
Server API instance of Chroma. It treats the Server API and corresponding System
|
119 |
+
as a singleton, so multiple clients connecting to the same resource will share the
|
120 |
+
same API instance.
|
121 |
+
|
122 |
+
Client implementations should be implement their own API-caching strategies.
|
123 |
+
"""
|
124 |
+
|
125 |
+
tenant: str = DEFAULT_TENANT
|
126 |
+
database: str = DEFAULT_DATABASE
|
127 |
+
|
128 |
+
_server: ServerAPI
|
129 |
+
# An internal admin client for verifying that databases and tenants exist
|
130 |
+
_admin_client: AdminAPI
|
131 |
+
|
132 |
+
# region Initialization
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
tenant: str = DEFAULT_TENANT,
|
136 |
+
database: str = DEFAULT_DATABASE,
|
137 |
+
settings: Settings = Settings(),
|
138 |
+
) -> None:
|
139 |
+
super().__init__(settings=settings)
|
140 |
+
self.tenant = tenant
|
141 |
+
self.database = database
|
142 |
+
# Create an admin client for verifying that databases and tenants exist
|
143 |
+
self._admin_client = AdminClient.from_system(self._system)
|
144 |
+
self._validate_tenant_database(tenant=tenant, database=database)
|
145 |
+
|
146 |
+
# Get the root system component we want to interact with
|
147 |
+
self._server = self._system.instance(ServerAPI)
|
148 |
+
|
149 |
+
# Submit event for a client start
|
150 |
+
telemetry_client = self._system.instance(ProductTelemetryClient)
|
151 |
+
telemetry_client.capture(ClientStartEvent())
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
@override
|
155 |
+
def from_system(
|
156 |
+
cls,
|
157 |
+
system: System,
|
158 |
+
tenant: str = DEFAULT_TENANT,
|
159 |
+
database: str = DEFAULT_DATABASE,
|
160 |
+
) -> "Client":
|
161 |
+
SharedSystemClient._populate_data_from_system(system)
|
162 |
+
instance = cls(tenant=tenant, database=database, settings=system.settings)
|
163 |
+
return instance
|
164 |
+
|
165 |
+
# endregion
|
166 |
+
|
167 |
+
# region BaseAPI Methods
|
168 |
+
# Note - we could do this in less verbose ways, but they break type checking
|
169 |
+
@override
|
170 |
+
def heartbeat(self) -> int:
|
171 |
+
return self._server.heartbeat()
|
172 |
+
|
173 |
+
@override
|
174 |
+
def list_collections(
|
175 |
+
self, limit: Optional[int] = None, offset: Optional[int] = None
|
176 |
+
) -> Sequence[Collection]:
|
177 |
+
return self._server.list_collections(
|
178 |
+
limit, offset, tenant=self.tenant, database=self.database
|
179 |
+
)
|
180 |
+
|
181 |
+
@override
|
182 |
+
def count_collections(self) -> int:
|
183 |
+
return self._server.count_collections(
|
184 |
+
tenant=self.tenant, database=self.database
|
185 |
+
)
|
186 |
+
|
187 |
+
@override
|
188 |
+
def create_collection(
|
189 |
+
self,
|
190 |
+
name: str,
|
191 |
+
metadata: Optional[CollectionMetadata] = None,
|
192 |
+
embedding_function: Optional[
|
193 |
+
EmbeddingFunction[Embeddable]
|
194 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
195 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
196 |
+
get_or_create: bool = False,
|
197 |
+
) -> Collection:
|
198 |
+
return self._server.create_collection(
|
199 |
+
name=name,
|
200 |
+
metadata=metadata,
|
201 |
+
embedding_function=embedding_function,
|
202 |
+
data_loader=data_loader,
|
203 |
+
tenant=self.tenant,
|
204 |
+
database=self.database,
|
205 |
+
get_or_create=get_or_create,
|
206 |
+
)
|
207 |
+
|
208 |
+
@override
|
209 |
+
def get_collection(
|
210 |
+
self,
|
211 |
+
name: str,
|
212 |
+
id: Optional[UUID] = None,
|
213 |
+
embedding_function: Optional[
|
214 |
+
EmbeddingFunction[Embeddable]
|
215 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
216 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
217 |
+
) -> Collection:
|
218 |
+
return self._server.get_collection(
|
219 |
+
id=id,
|
220 |
+
name=name,
|
221 |
+
embedding_function=embedding_function,
|
222 |
+
data_loader=data_loader,
|
223 |
+
tenant=self.tenant,
|
224 |
+
database=self.database,
|
225 |
+
)
|
226 |
+
|
227 |
+
@override
|
228 |
+
def get_or_create_collection(
|
229 |
+
self,
|
230 |
+
name: str,
|
231 |
+
metadata: Optional[CollectionMetadata] = None,
|
232 |
+
embedding_function: Optional[
|
233 |
+
EmbeddingFunction[Embeddable]
|
234 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
235 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
236 |
+
) -> Collection:
|
237 |
+
return self._server.get_or_create_collection(
|
238 |
+
name=name,
|
239 |
+
metadata=metadata,
|
240 |
+
embedding_function=embedding_function,
|
241 |
+
data_loader=data_loader,
|
242 |
+
tenant=self.tenant,
|
243 |
+
database=self.database,
|
244 |
+
)
|
245 |
+
|
246 |
+
@override
|
247 |
+
def _modify(
|
248 |
+
self,
|
249 |
+
id: UUID,
|
250 |
+
new_name: Optional[str] = None,
|
251 |
+
new_metadata: Optional[CollectionMetadata] = None,
|
252 |
+
) -> None:
|
253 |
+
return self._server._modify(
|
254 |
+
id=id,
|
255 |
+
new_name=new_name,
|
256 |
+
new_metadata=new_metadata,
|
257 |
+
)
|
258 |
+
|
259 |
+
@override
|
260 |
+
def delete_collection(
|
261 |
+
self,
|
262 |
+
name: str,
|
263 |
+
) -> None:
|
264 |
+
return self._server.delete_collection(
|
265 |
+
name=name,
|
266 |
+
tenant=self.tenant,
|
267 |
+
database=self.database,
|
268 |
+
)
|
269 |
+
|
270 |
+
#
|
271 |
+
# ITEM METHODS
|
272 |
+
#
|
273 |
+
|
274 |
+
@override
|
275 |
+
def _add(
|
276 |
+
self,
|
277 |
+
ids: IDs,
|
278 |
+
collection_id: UUID,
|
279 |
+
embeddings: Embeddings,
|
280 |
+
metadatas: Optional[Metadatas] = None,
|
281 |
+
documents: Optional[Documents] = None,
|
282 |
+
uris: Optional[URIs] = None,
|
283 |
+
) -> bool:
|
284 |
+
return self._server._add(
|
285 |
+
ids=ids,
|
286 |
+
collection_id=collection_id,
|
287 |
+
embeddings=embeddings,
|
288 |
+
metadatas=metadatas,
|
289 |
+
documents=documents,
|
290 |
+
uris=uris,
|
291 |
+
)
|
292 |
+
|
293 |
+
@override
|
294 |
+
def _update(
|
295 |
+
self,
|
296 |
+
collection_id: UUID,
|
297 |
+
ids: IDs,
|
298 |
+
embeddings: Optional[Embeddings] = None,
|
299 |
+
metadatas: Optional[Metadatas] = None,
|
300 |
+
documents: Optional[Documents] = None,
|
301 |
+
uris: Optional[URIs] = None,
|
302 |
+
) -> bool:
|
303 |
+
return self._server._update(
|
304 |
+
collection_id=collection_id,
|
305 |
+
ids=ids,
|
306 |
+
embeddings=embeddings,
|
307 |
+
metadatas=metadatas,
|
308 |
+
documents=documents,
|
309 |
+
uris=uris,
|
310 |
+
)
|
311 |
+
|
312 |
+
@override
|
313 |
+
def _upsert(
|
314 |
+
self,
|
315 |
+
collection_id: UUID,
|
316 |
+
ids: IDs,
|
317 |
+
embeddings: Embeddings,
|
318 |
+
metadatas: Optional[Metadatas] = None,
|
319 |
+
documents: Optional[Documents] = None,
|
320 |
+
uris: Optional[URIs] = None,
|
321 |
+
) -> bool:
|
322 |
+
return self._server._upsert(
|
323 |
+
collection_id=collection_id,
|
324 |
+
ids=ids,
|
325 |
+
embeddings=embeddings,
|
326 |
+
metadatas=metadatas,
|
327 |
+
documents=documents,
|
328 |
+
uris=uris,
|
329 |
+
)
|
330 |
+
|
331 |
+
@override
|
332 |
+
def _count(self, collection_id: UUID) -> int:
|
333 |
+
return self._server._count(
|
334 |
+
collection_id=collection_id,
|
335 |
+
)
|
336 |
+
|
337 |
+
@override
|
338 |
+
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
339 |
+
return self._server._peek(
|
340 |
+
collection_id=collection_id,
|
341 |
+
n=n,
|
342 |
+
)
|
343 |
+
|
344 |
+
@override
|
345 |
+
def _get(
|
346 |
+
self,
|
347 |
+
collection_id: UUID,
|
348 |
+
ids: Optional[IDs] = None,
|
349 |
+
where: Optional[Where] = {},
|
350 |
+
sort: Optional[str] = None,
|
351 |
+
limit: Optional[int] = None,
|
352 |
+
offset: Optional[int] = None,
|
353 |
+
page: Optional[int] = None,
|
354 |
+
page_size: Optional[int] = None,
|
355 |
+
where_document: Optional[WhereDocument] = {},
|
356 |
+
include: Include = ["embeddings", "metadatas", "documents"],
|
357 |
+
) -> GetResult:
|
358 |
+
return self._server._get(
|
359 |
+
collection_id=collection_id,
|
360 |
+
ids=ids,
|
361 |
+
where=where,
|
362 |
+
sort=sort,
|
363 |
+
limit=limit,
|
364 |
+
offset=offset,
|
365 |
+
page=page,
|
366 |
+
page_size=page_size,
|
367 |
+
where_document=where_document,
|
368 |
+
include=include,
|
369 |
+
)
|
370 |
+
|
371 |
+
def _delete(
|
372 |
+
self,
|
373 |
+
collection_id: UUID,
|
374 |
+
ids: Optional[IDs],
|
375 |
+
where: Optional[Where] = {},
|
376 |
+
where_document: Optional[WhereDocument] = {},
|
377 |
+
) -> IDs:
|
378 |
+
return self._server._delete(
|
379 |
+
collection_id=collection_id,
|
380 |
+
ids=ids,
|
381 |
+
where=where,
|
382 |
+
where_document=where_document,
|
383 |
+
)
|
384 |
+
|
385 |
+
@override
|
386 |
+
def _query(
|
387 |
+
self,
|
388 |
+
collection_id: UUID,
|
389 |
+
query_embeddings: Embeddings,
|
390 |
+
n_results: int = 10,
|
391 |
+
where: Where = {},
|
392 |
+
where_document: WhereDocument = {},
|
393 |
+
include: Include = ["embeddings", "metadatas", "documents", "distances"],
|
394 |
+
) -> QueryResult:
|
395 |
+
return self._server._query(
|
396 |
+
collection_id=collection_id,
|
397 |
+
query_embeddings=query_embeddings,
|
398 |
+
n_results=n_results,
|
399 |
+
where=where,
|
400 |
+
where_document=where_document,
|
401 |
+
include=include,
|
402 |
+
)
|
403 |
+
|
404 |
+
@override
|
405 |
+
def reset(self) -> bool:
|
406 |
+
return self._server.reset()
|
407 |
+
|
408 |
+
@override
|
409 |
+
def get_version(self) -> str:
|
410 |
+
return self._server.get_version()
|
411 |
+
|
412 |
+
@override
|
413 |
+
def get_settings(self) -> Settings:
|
414 |
+
return self._server.get_settings()
|
415 |
+
|
416 |
+
@property
|
417 |
+
@override
|
418 |
+
def max_batch_size(self) -> int:
|
419 |
+
return self._server.max_batch_size
|
420 |
+
|
421 |
+
# endregion
|
422 |
+
|
423 |
+
# region ClientAPI Methods
|
424 |
+
|
425 |
+
@override
|
426 |
+
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
|
427 |
+
self._validate_tenant_database(tenant=tenant, database=database)
|
428 |
+
self.tenant = tenant
|
429 |
+
self.database = database
|
430 |
+
|
431 |
+
@override
|
432 |
+
def set_database(self, database: str) -> None:
|
433 |
+
self._validate_tenant_database(tenant=self.tenant, database=database)
|
434 |
+
self.database = database
|
435 |
+
|
436 |
+
def _validate_tenant_database(self, tenant: str, database: str) -> None:
|
437 |
+
try:
|
438 |
+
self._admin_client.get_tenant(name=tenant)
|
439 |
+
except requests.exceptions.ConnectionError:
|
440 |
+
raise ValueError(
|
441 |
+
"Could not connect to a Chroma server. Are you sure it is running?"
|
442 |
+
)
|
443 |
+
# Propagate ChromaErrors
|
444 |
+
except ChromaError as e:
|
445 |
+
raise e
|
446 |
+
except Exception:
|
447 |
+
raise ValueError(
|
448 |
+
f"Could not connect to tenant {tenant}. Are you sure it exists?"
|
449 |
+
)
|
450 |
+
|
451 |
+
try:
|
452 |
+
self._admin_client.get_database(name=database, tenant=tenant)
|
453 |
+
except requests.exceptions.ConnectionError:
|
454 |
+
raise ValueError(
|
455 |
+
"Could not connect to a Chroma server. Are you sure it is running?"
|
456 |
+
)
|
457 |
+
except Exception:
|
458 |
+
raise ValueError(
|
459 |
+
f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?"
|
460 |
+
)
|
461 |
+
|
462 |
+
# endregion
|
463 |
+
|
464 |
+
|
465 |
+
class AdminClient(SharedSystemClient, AdminAPI):
|
466 |
+
_server: ServerAPI
|
467 |
+
|
468 |
+
def __init__(self, settings: Settings = Settings()) -> None:
|
469 |
+
super().__init__(settings)
|
470 |
+
self._server = self._system.instance(ServerAPI)
|
471 |
+
|
472 |
+
@override
|
473 |
+
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
474 |
+
return self._server.create_database(name=name, tenant=tenant)
|
475 |
+
|
476 |
+
@override
|
477 |
+
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
|
478 |
+
return self._server.get_database(name=name, tenant=tenant)
|
479 |
+
|
480 |
+
@override
|
481 |
+
def create_tenant(self, name: str) -> None:
|
482 |
+
return self._server.create_tenant(name=name)
|
483 |
+
|
484 |
+
@override
|
485 |
+
def get_tenant(self, name: str) -> Tenant:
|
486 |
+
return self._server.get_tenant(name=name)
|
487 |
+
|
488 |
+
@classmethod
|
489 |
+
@override
|
490 |
+
def from_system(
|
491 |
+
cls,
|
492 |
+
system: System,
|
493 |
+
) -> "AdminClient":
|
494 |
+
SharedSystemClient._populate_data_from_system(system)
|
495 |
+
instance = cls(settings=system.settings)
|
496 |
+
return instance
|
chromadb/api/fastapi.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from typing import Optional, cast, Tuple
|
4 |
+
from typing import Sequence
|
5 |
+
from uuid import UUID
|
6 |
+
|
7 |
+
import requests
|
8 |
+
from overrides import override
|
9 |
+
|
10 |
+
import chromadb.errors as errors
|
11 |
+
from chromadb.types import Database, Tenant
|
12 |
+
import chromadb.utils.embedding_functions as ef
|
13 |
+
from chromadb.api import ServerAPI
|
14 |
+
from chromadb.api.models.Collection import Collection
|
15 |
+
from chromadb.api.types import (
|
16 |
+
DataLoader,
|
17 |
+
Documents,
|
18 |
+
Embeddable,
|
19 |
+
Embeddings,
|
20 |
+
EmbeddingFunction,
|
21 |
+
IDs,
|
22 |
+
Include,
|
23 |
+
Loadable,
|
24 |
+
Metadatas,
|
25 |
+
URIs,
|
26 |
+
Where,
|
27 |
+
WhereDocument,
|
28 |
+
GetResult,
|
29 |
+
QueryResult,
|
30 |
+
CollectionMetadata,
|
31 |
+
validate_batch,
|
32 |
+
)
|
33 |
+
from chromadb.auth import (
|
34 |
+
ClientAuthProvider,
|
35 |
+
)
|
36 |
+
from chromadb.auth.providers import RequestsClientAuthProtocolAdapter
|
37 |
+
from chromadb.auth.registry import resolve_provider
|
38 |
+
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
39 |
+
from chromadb.telemetry.opentelemetry import (
|
40 |
+
OpenTelemetryClient,
|
41 |
+
OpenTelemetryGranularity,
|
42 |
+
trace_method,
|
43 |
+
)
|
44 |
+
from chromadb.telemetry.product import ProductTelemetryClient
|
45 |
+
from urllib.parse import urlparse, urlunparse, quote
|
46 |
+
|
47 |
+
logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
class FastAPI(ServerAPI):
|
51 |
+
_settings: Settings
|
52 |
+
_max_batch_size: int = -1
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def _validate_host(host: str) -> None:
|
56 |
+
parsed = urlparse(host)
|
57 |
+
if "/" in host and parsed.scheme not in {"http", "https"}:
|
58 |
+
raise ValueError(
|
59 |
+
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
|
60 |
+
)
|
61 |
+
if "/" in host and (not host.startswith("http")):
|
62 |
+
raise ValueError(
|
63 |
+
"Invalid URL. "
|
64 |
+
"Seems that you are trying to pass URL as a host but without \
|
65 |
+
specifying the protocol. "
|
66 |
+
"Please add http:// or https:// to the host."
|
67 |
+
)
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def resolve_url(
|
71 |
+
chroma_server_host: str,
|
72 |
+
chroma_server_ssl_enabled: Optional[bool] = False,
|
73 |
+
default_api_path: Optional[str] = "",
|
74 |
+
chroma_server_http_port: Optional[int] = 8000,
|
75 |
+
) -> str:
|
76 |
+
_skip_port = False
|
77 |
+
_chroma_server_host = chroma_server_host
|
78 |
+
FastAPI._validate_host(_chroma_server_host)
|
79 |
+
if _chroma_server_host.startswith("http"):
|
80 |
+
logger.debug("Skipping port as the user is passing a full URL")
|
81 |
+
_skip_port = True
|
82 |
+
parsed = urlparse(_chroma_server_host)
|
83 |
+
|
84 |
+
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
|
85 |
+
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
|
86 |
+
port = (
|
87 |
+
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
|
88 |
+
)
|
89 |
+
path = parsed.path or default_api_path
|
90 |
+
|
91 |
+
if not path or path == net_loc:
|
92 |
+
path = default_api_path if default_api_path else ""
|
93 |
+
if not path.endswith(default_api_path or ""):
|
94 |
+
path = path + default_api_path if default_api_path else ""
|
95 |
+
full_url = urlunparse(
|
96 |
+
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
|
97 |
+
)
|
98 |
+
|
99 |
+
return full_url
|
100 |
+
|
101 |
+
def __init__(self, system: System):
|
102 |
+
super().__init__(system)
|
103 |
+
system.settings.require("chroma_server_host")
|
104 |
+
system.settings.require("chroma_server_http_port")
|
105 |
+
|
106 |
+
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
107 |
+
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
108 |
+
self._settings = system.settings
|
109 |
+
|
110 |
+
self._api_url = FastAPI.resolve_url(
|
111 |
+
chroma_server_host=str(system.settings.chroma_server_host),
|
112 |
+
chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
|
113 |
+
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
|
114 |
+
default_api_path=system.settings.chroma_server_api_default_path,
|
115 |
+
)
|
116 |
+
|
117 |
+
self._header = system.settings.chroma_server_headers
|
118 |
+
if (
|
119 |
+
system.settings.chroma_client_auth_provider
|
120 |
+
and system.settings.chroma_client_auth_protocol_adapter
|
121 |
+
):
|
122 |
+
self._auth_provider = self.require(
|
123 |
+
resolve_provider(
|
124 |
+
system.settings.chroma_client_auth_provider, ClientAuthProvider
|
125 |
+
)
|
126 |
+
)
|
127 |
+
self._adapter = cast(
|
128 |
+
RequestsClientAuthProtocolAdapter,
|
129 |
+
system.require(
|
130 |
+
resolve_provider(
|
131 |
+
system.settings.chroma_client_auth_protocol_adapter,
|
132 |
+
RequestsClientAuthProtocolAdapter,
|
133 |
+
)
|
134 |
+
),
|
135 |
+
)
|
136 |
+
self._session = self._adapter.session
|
137 |
+
else:
|
138 |
+
self._session = requests.Session()
|
139 |
+
if self._header is not None:
|
140 |
+
self._session.headers.update(self._header)
|
141 |
+
if self._settings.chroma_server_ssl_verify is not None:
|
142 |
+
self._session.verify = self._settings.chroma_server_ssl_verify
|
143 |
+
|
144 |
+
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
145 |
+
@override
|
146 |
+
def heartbeat(self) -> int:
|
147 |
+
"""Returns the current server time in nanoseconds to check if the server is alive"""
|
148 |
+
resp = self._session.get(self._api_url)
|
149 |
+
raise_chroma_error(resp)
|
150 |
+
return int(resp.json()["nanosecond heartbeat"])
|
151 |
+
|
152 |
+
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
|
153 |
+
@override
|
154 |
+
def create_database(
|
155 |
+
self,
|
156 |
+
name: str,
|
157 |
+
tenant: str = DEFAULT_TENANT,
|
158 |
+
) -> None:
|
159 |
+
"""Creates a database"""
|
160 |
+
resp = self._session.post(
|
161 |
+
self._api_url + "/databases",
|
162 |
+
data=json.dumps({"name": name}),
|
163 |
+
params={"tenant": tenant},
|
164 |
+
)
|
165 |
+
raise_chroma_error(resp)
|
166 |
+
|
167 |
+
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
|
168 |
+
@override
|
169 |
+
def get_database(
|
170 |
+
self,
|
171 |
+
name: str,
|
172 |
+
tenant: str = DEFAULT_TENANT,
|
173 |
+
) -> Database:
|
174 |
+
"""Returns a database"""
|
175 |
+
resp = self._session.get(
|
176 |
+
self._api_url + "/databases/" + name,
|
177 |
+
params={"tenant": tenant},
|
178 |
+
)
|
179 |
+
raise_chroma_error(resp)
|
180 |
+
resp_json = resp.json()
|
181 |
+
return Database(
|
182 |
+
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
|
183 |
+
)
|
184 |
+
|
185 |
+
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
|
186 |
+
@override
|
187 |
+
def create_tenant(self, name: str) -> None:
|
188 |
+
resp = self._session.post(
|
189 |
+
self._api_url + "/tenants",
|
190 |
+
data=json.dumps({"name": name}),
|
191 |
+
)
|
192 |
+
raise_chroma_error(resp)
|
193 |
+
|
194 |
+
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
|
195 |
+
@override
|
196 |
+
def get_tenant(self, name: str) -> Tenant:
|
197 |
+
resp = self._session.get(
|
198 |
+
self._api_url + "/tenants/" + name,
|
199 |
+
)
|
200 |
+
raise_chroma_error(resp)
|
201 |
+
resp_json = resp.json()
|
202 |
+
return Tenant(name=resp_json["name"])
|
203 |
+
|
204 |
+
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
|
205 |
+
@override
|
206 |
+
def list_collections(
|
207 |
+
self,
|
208 |
+
limit: Optional[int] = None,
|
209 |
+
offset: Optional[int] = None,
|
210 |
+
tenant: str = DEFAULT_TENANT,
|
211 |
+
database: str = DEFAULT_DATABASE,
|
212 |
+
) -> Sequence[Collection]:
|
213 |
+
"""Returns a list of all collections"""
|
214 |
+
resp = self._session.get(
|
215 |
+
self._api_url + "/collections",
|
216 |
+
params={
|
217 |
+
"tenant": tenant,
|
218 |
+
"database": database,
|
219 |
+
"limit": limit,
|
220 |
+
"offset": offset,
|
221 |
+
},
|
222 |
+
)
|
223 |
+
raise_chroma_error(resp)
|
224 |
+
json_collections = resp.json()
|
225 |
+
collections = []
|
226 |
+
for json_collection in json_collections:
|
227 |
+
collections.append(Collection(self, **json_collection))
|
228 |
+
|
229 |
+
return collections
|
230 |
+
|
231 |
+
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
232 |
+
@override
|
233 |
+
def count_collections(
|
234 |
+
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
|
235 |
+
) -> int:
|
236 |
+
"""Returns a count of collections"""
|
237 |
+
resp = self._session.get(
|
238 |
+
self._api_url + "/count_collections",
|
239 |
+
params={"tenant": tenant, "database": database},
|
240 |
+
)
|
241 |
+
raise_chroma_error(resp)
|
242 |
+
return cast(int, resp.json())
|
243 |
+
|
244 |
+
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
245 |
+
@override
|
246 |
+
def create_collection(
|
247 |
+
self,
|
248 |
+
name: str,
|
249 |
+
metadata: Optional[CollectionMetadata] = None,
|
250 |
+
embedding_function: Optional[
|
251 |
+
EmbeddingFunction[Embeddable]
|
252 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
253 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
254 |
+
get_or_create: bool = False,
|
255 |
+
tenant: str = DEFAULT_TENANT,
|
256 |
+
database: str = DEFAULT_DATABASE,
|
257 |
+
) -> Collection:
|
258 |
+
"""Creates a collection"""
|
259 |
+
resp = self._session.post(
|
260 |
+
self._api_url + "/collections",
|
261 |
+
data=json.dumps(
|
262 |
+
{
|
263 |
+
"name": name,
|
264 |
+
"metadata": metadata,
|
265 |
+
"get_or_create": get_or_create,
|
266 |
+
}
|
267 |
+
),
|
268 |
+
params={"tenant": tenant, "database": database},
|
269 |
+
)
|
270 |
+
raise_chroma_error(resp)
|
271 |
+
resp_json = resp.json()
|
272 |
+
return Collection(
|
273 |
+
client=self,
|
274 |
+
id=resp_json["id"],
|
275 |
+
name=resp_json["name"],
|
276 |
+
embedding_function=embedding_function,
|
277 |
+
data_loader=data_loader,
|
278 |
+
metadata=resp_json["metadata"],
|
279 |
+
)
|
280 |
+
|
281 |
+
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
282 |
+
@override
|
283 |
+
def get_collection(
|
284 |
+
self,
|
285 |
+
name: str,
|
286 |
+
id: Optional[UUID] = None,
|
287 |
+
embedding_function: Optional[
|
288 |
+
EmbeddingFunction[Embeddable]
|
289 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
290 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
291 |
+
tenant: str = DEFAULT_TENANT,
|
292 |
+
database: str = DEFAULT_DATABASE,
|
293 |
+
) -> Collection:
|
294 |
+
"""Returns a collection"""
|
295 |
+
if (name is None and id is None) or (name is not None and id is not None):
|
296 |
+
raise ValueError("Name or id must be specified, but not both")
|
297 |
+
|
298 |
+
_params = {"tenant": tenant, "database": database}
|
299 |
+
if id is not None:
|
300 |
+
_params["type"] = str(id)
|
301 |
+
resp = self._session.get(
|
302 |
+
self._api_url + "/collections/" + name if name else str(id), params=_params
|
303 |
+
)
|
304 |
+
raise_chroma_error(resp)
|
305 |
+
resp_json = resp.json()
|
306 |
+
return Collection(
|
307 |
+
client=self,
|
308 |
+
name=resp_json["name"],
|
309 |
+
id=resp_json["id"],
|
310 |
+
embedding_function=embedding_function,
|
311 |
+
data_loader=data_loader,
|
312 |
+
metadata=resp_json["metadata"],
|
313 |
+
)
|
314 |
+
|
315 |
+
@trace_method(
|
316 |
+
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
317 |
+
)
|
318 |
+
@override
|
319 |
+
def get_or_create_collection(
|
320 |
+
self,
|
321 |
+
name: str,
|
322 |
+
metadata: Optional[CollectionMetadata] = None,
|
323 |
+
embedding_function: Optional[
|
324 |
+
EmbeddingFunction[Embeddable]
|
325 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
326 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
327 |
+
tenant: str = DEFAULT_TENANT,
|
328 |
+
database: str = DEFAULT_DATABASE,
|
329 |
+
) -> Collection:
|
330 |
+
return cast(
|
331 |
+
Collection,
|
332 |
+
self.create_collection(
|
333 |
+
name=name,
|
334 |
+
metadata=metadata,
|
335 |
+
embedding_function=embedding_function,
|
336 |
+
data_loader=data_loader,
|
337 |
+
get_or_create=True,
|
338 |
+
tenant=tenant,
|
339 |
+
database=database,
|
340 |
+
),
|
341 |
+
)
|
342 |
+
|
343 |
+
@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
|
344 |
+
@override
|
345 |
+
def _modify(
|
346 |
+
self,
|
347 |
+
id: UUID,
|
348 |
+
new_name: Optional[str] = None,
|
349 |
+
new_metadata: Optional[CollectionMetadata] = None,
|
350 |
+
) -> None:
|
351 |
+
"""Updates a collection"""
|
352 |
+
resp = self._session.put(
|
353 |
+
self._api_url + "/collections/" + str(id),
|
354 |
+
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
|
355 |
+
)
|
356 |
+
raise_chroma_error(resp)
|
357 |
+
|
358 |
+
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
359 |
+
@override
|
360 |
+
def delete_collection(
|
361 |
+
self,
|
362 |
+
name: str,
|
363 |
+
tenant: str = DEFAULT_TENANT,
|
364 |
+
database: str = DEFAULT_DATABASE,
|
365 |
+
) -> None:
|
366 |
+
"""Deletes a collection"""
|
367 |
+
resp = self._session.delete(
|
368 |
+
self._api_url + "/collections/" + name,
|
369 |
+
params={"tenant": tenant, "database": database},
|
370 |
+
)
|
371 |
+
raise_chroma_error(resp)
|
372 |
+
|
373 |
+
@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
|
374 |
+
@override
|
375 |
+
def _count(
|
376 |
+
self,
|
377 |
+
collection_id: UUID,
|
378 |
+
) -> int:
|
379 |
+
"""Returns the number of embeddings in the database"""
|
380 |
+
resp = self._session.get(
|
381 |
+
self._api_url + "/collections/" + str(collection_id) + "/count"
|
382 |
+
)
|
383 |
+
raise_chroma_error(resp)
|
384 |
+
return cast(int, resp.json())
|
385 |
+
|
386 |
+
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
|
387 |
+
@override
|
388 |
+
def _peek(
|
389 |
+
self,
|
390 |
+
collection_id: UUID,
|
391 |
+
n: int = 10,
|
392 |
+
) -> GetResult:
|
393 |
+
return cast(
|
394 |
+
GetResult,
|
395 |
+
self._get(
|
396 |
+
collection_id,
|
397 |
+
limit=n,
|
398 |
+
include=["embeddings", "documents", "metadatas"],
|
399 |
+
),
|
400 |
+
)
|
401 |
+
|
402 |
+
@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
|
403 |
+
@override
|
404 |
+
def _get(
|
405 |
+
self,
|
406 |
+
collection_id: UUID,
|
407 |
+
ids: Optional[IDs] = None,
|
408 |
+
where: Optional[Where] = {},
|
409 |
+
sort: Optional[str] = None,
|
410 |
+
limit: Optional[int] = None,
|
411 |
+
offset: Optional[int] = None,
|
412 |
+
page: Optional[int] = None,
|
413 |
+
page_size: Optional[int] = None,
|
414 |
+
where_document: Optional[WhereDocument] = {},
|
415 |
+
include: Include = ["metadatas", "documents"],
|
416 |
+
) -> GetResult:
|
417 |
+
if page and page_size:
|
418 |
+
offset = (page - 1) * page_size
|
419 |
+
limit = page_size
|
420 |
+
|
421 |
+
resp = self._session.post(
|
422 |
+
self._api_url + "/collections/" + str(collection_id) + "/get",
|
423 |
+
data=json.dumps(
|
424 |
+
{
|
425 |
+
"ids": ids,
|
426 |
+
"where": where,
|
427 |
+
"sort": sort,
|
428 |
+
"limit": limit,
|
429 |
+
"offset": offset,
|
430 |
+
"where_document": where_document,
|
431 |
+
"include": include,
|
432 |
+
}
|
433 |
+
),
|
434 |
+
)
|
435 |
+
|
436 |
+
raise_chroma_error(resp)
|
437 |
+
body = resp.json()
|
438 |
+
return GetResult(
|
439 |
+
ids=body["ids"],
|
440 |
+
embeddings=body.get("embeddings", None),
|
441 |
+
metadatas=body.get("metadatas", None),
|
442 |
+
documents=body.get("documents", None),
|
443 |
+
data=None,
|
444 |
+
uris=body.get("uris", None),
|
445 |
+
)
|
446 |
+
|
447 |
+
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
|
448 |
+
@override
|
449 |
+
def _delete(
|
450 |
+
self,
|
451 |
+
collection_id: UUID,
|
452 |
+
ids: Optional[IDs] = None,
|
453 |
+
where: Optional[Where] = {},
|
454 |
+
where_document: Optional[WhereDocument] = {},
|
455 |
+
) -> IDs:
|
456 |
+
"""Deletes embeddings from the database"""
|
457 |
+
resp = self._session.post(
|
458 |
+
self._api_url + "/collections/" + str(collection_id) + "/delete",
|
459 |
+
data=json.dumps(
|
460 |
+
{"where": where, "ids": ids, "where_document": where_document}
|
461 |
+
),
|
462 |
+
)
|
463 |
+
|
464 |
+
raise_chroma_error(resp)
|
465 |
+
return cast(IDs, resp.json())
|
466 |
+
|
467 |
+
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
|
468 |
+
def _submit_batch(
|
469 |
+
self,
|
470 |
+
batch: Tuple[
|
471 |
+
IDs,
|
472 |
+
Optional[Embeddings],
|
473 |
+
Optional[Metadatas],
|
474 |
+
Optional[Documents],
|
475 |
+
Optional[URIs],
|
476 |
+
],
|
477 |
+
url: str,
|
478 |
+
) -> requests.Response:
|
479 |
+
"""
|
480 |
+
Submits a batch of embeddings to the database
|
481 |
+
"""
|
482 |
+
resp = self._session.post(
|
483 |
+
self._api_url + url,
|
484 |
+
data=json.dumps(
|
485 |
+
{
|
486 |
+
"ids": batch[0],
|
487 |
+
"embeddings": batch[1],
|
488 |
+
"metadatas": batch[2],
|
489 |
+
"documents": batch[3],
|
490 |
+
"uris": batch[4],
|
491 |
+
}
|
492 |
+
),
|
493 |
+
)
|
494 |
+
return resp
|
495 |
+
|
496 |
+
@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
|
497 |
+
@override
|
498 |
+
def _add(
|
499 |
+
self,
|
500 |
+
ids: IDs,
|
501 |
+
collection_id: UUID,
|
502 |
+
embeddings: Embeddings,
|
503 |
+
metadatas: Optional[Metadatas] = None,
|
504 |
+
documents: Optional[Documents] = None,
|
505 |
+
uris: Optional[URIs] = None,
|
506 |
+
) -> bool:
|
507 |
+
"""
|
508 |
+
Adds a batch of embeddings to the database
|
509 |
+
- pass in column oriented data lists
|
510 |
+
"""
|
511 |
+
batch = (ids, embeddings, metadatas, documents, uris)
|
512 |
+
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
513 |
+
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
|
514 |
+
raise_chroma_error(resp)
|
515 |
+
return True
|
516 |
+
|
517 |
+
@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
|
518 |
+
@override
|
519 |
+
def _update(
|
520 |
+
self,
|
521 |
+
collection_id: UUID,
|
522 |
+
ids: IDs,
|
523 |
+
embeddings: Optional[Embeddings] = None,
|
524 |
+
metadatas: Optional[Metadatas] = None,
|
525 |
+
documents: Optional[Documents] = None,
|
526 |
+
uris: Optional[URIs] = None,
|
527 |
+
) -> bool:
|
528 |
+
"""
|
529 |
+
Updates a batch of embeddings in the database
|
530 |
+
- pass in column oriented data lists
|
531 |
+
"""
|
532 |
+
batch = (ids, embeddings, metadatas, documents, uris)
|
533 |
+
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
534 |
+
resp = self._submit_batch(
|
535 |
+
batch, "/collections/" + str(collection_id) + "/update"
|
536 |
+
)
|
537 |
+
raise_chroma_error(resp)
|
538 |
+
return True
|
539 |
+
|
540 |
+
@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
|
541 |
+
@override
|
542 |
+
def _upsert(
|
543 |
+
self,
|
544 |
+
collection_id: UUID,
|
545 |
+
ids: IDs,
|
546 |
+
embeddings: Embeddings,
|
547 |
+
metadatas: Optional[Metadatas] = None,
|
548 |
+
documents: Optional[Documents] = None,
|
549 |
+
uris: Optional[URIs] = None,
|
550 |
+
) -> bool:
|
551 |
+
"""
|
552 |
+
Upserts a batch of embeddings in the database
|
553 |
+
- pass in column oriented data lists
|
554 |
+
"""
|
555 |
+
batch = (ids, embeddings, metadatas, documents, uris)
|
556 |
+
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
557 |
+
resp = self._submit_batch(
|
558 |
+
batch, "/collections/" + str(collection_id) + "/upsert"
|
559 |
+
)
|
560 |
+
raise_chroma_error(resp)
|
561 |
+
return True
|
562 |
+
|
563 |
+
@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
|
564 |
+
@override
|
565 |
+
def _query(
|
566 |
+
self,
|
567 |
+
collection_id: UUID,
|
568 |
+
query_embeddings: Embeddings,
|
569 |
+
n_results: int = 10,
|
570 |
+
where: Optional[Where] = {},
|
571 |
+
where_document: Optional[WhereDocument] = {},
|
572 |
+
include: Include = ["metadatas", "documents", "distances"],
|
573 |
+
) -> QueryResult:
|
574 |
+
"""Gets the nearest neighbors of a single embedding"""
|
575 |
+
resp = self._session.post(
|
576 |
+
self._api_url + "/collections/" + str(collection_id) + "/query",
|
577 |
+
data=json.dumps(
|
578 |
+
{
|
579 |
+
"query_embeddings": query_embeddings,
|
580 |
+
"n_results": n_results,
|
581 |
+
"where": where,
|
582 |
+
"where_document": where_document,
|
583 |
+
"include": include,
|
584 |
+
}
|
585 |
+
),
|
586 |
+
)
|
587 |
+
|
588 |
+
raise_chroma_error(resp)
|
589 |
+
body = resp.json()
|
590 |
+
|
591 |
+
return QueryResult(
|
592 |
+
ids=body["ids"],
|
593 |
+
distances=body.get("distances", None),
|
594 |
+
embeddings=body.get("embeddings", None),
|
595 |
+
metadatas=body.get("metadatas", None),
|
596 |
+
documents=body.get("documents", None),
|
597 |
+
uris=body.get("uris", None),
|
598 |
+
data=None,
|
599 |
+
)
|
600 |
+
|
601 |
+
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
|
602 |
+
@override
|
603 |
+
def reset(self) -> bool:
|
604 |
+
"""Resets the database"""
|
605 |
+
resp = self._session.post(self._api_url + "/reset")
|
606 |
+
raise_chroma_error(resp)
|
607 |
+
return cast(bool, resp.json())
|
608 |
+
|
609 |
+
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
|
610 |
+
@override
|
611 |
+
def get_version(self) -> str:
|
612 |
+
"""Returns the version of the server"""
|
613 |
+
resp = self._session.get(self._api_url + "/version")
|
614 |
+
raise_chroma_error(resp)
|
615 |
+
return cast(str, resp.json())
|
616 |
+
|
617 |
+
@override
|
618 |
+
def get_settings(self) -> Settings:
|
619 |
+
"""Returns the settings of the client"""
|
620 |
+
return self._settings
|
621 |
+
|
622 |
+
@property
|
623 |
+
@trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION)
|
624 |
+
@override
|
625 |
+
def max_batch_size(self) -> int:
|
626 |
+
if self._max_batch_size == -1:
|
627 |
+
resp = self._session.get(self._api_url + "/pre-flight-checks")
|
628 |
+
raise_chroma_error(resp)
|
629 |
+
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
|
630 |
+
return self._max_batch_size
|
631 |
+
|
632 |
+
|
633 |
+
def raise_chroma_error(resp: requests.Response) -> None:
|
634 |
+
"""Raises an error if the response is not ok, using a ChromaError if possible"""
|
635 |
+
if resp.ok:
|
636 |
+
return
|
637 |
+
|
638 |
+
chroma_error = None
|
639 |
+
try:
|
640 |
+
body = resp.json()
|
641 |
+
if "error" in body:
|
642 |
+
if body["error"] in errors.error_types:
|
643 |
+
chroma_error = errors.error_types[body["error"]](body["message"])
|
644 |
+
|
645 |
+
except BaseException:
|
646 |
+
pass
|
647 |
+
|
648 |
+
if chroma_error:
|
649 |
+
raise chroma_error
|
650 |
+
|
651 |
+
try:
|
652 |
+
resp.raise_for_status()
|
653 |
+
except requests.HTTPError:
|
654 |
+
raise (Exception(resp.text))
|
chromadb/api/models/Collection.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from pydantic import BaseModel, PrivateAttr
|
5 |
+
|
6 |
+
from uuid import UUID
|
7 |
+
import chromadb.utils.embedding_functions as ef
|
8 |
+
|
9 |
+
from chromadb.api.types import (
|
10 |
+
URI,
|
11 |
+
CollectionMetadata,
|
12 |
+
DataLoader,
|
13 |
+
Embedding,
|
14 |
+
Embeddings,
|
15 |
+
Embeddable,
|
16 |
+
Include,
|
17 |
+
Loadable,
|
18 |
+
Metadata,
|
19 |
+
Metadatas,
|
20 |
+
Document,
|
21 |
+
Documents,
|
22 |
+
Image,
|
23 |
+
Images,
|
24 |
+
URIs,
|
25 |
+
Where,
|
26 |
+
IDs,
|
27 |
+
EmbeddingFunction,
|
28 |
+
GetResult,
|
29 |
+
QueryResult,
|
30 |
+
ID,
|
31 |
+
OneOrMany,
|
32 |
+
WhereDocument,
|
33 |
+
maybe_cast_one_to_many_ids,
|
34 |
+
maybe_cast_one_to_many_embedding,
|
35 |
+
maybe_cast_one_to_many_metadata,
|
36 |
+
maybe_cast_one_to_many_document,
|
37 |
+
maybe_cast_one_to_many_image,
|
38 |
+
maybe_cast_one_to_many_uri,
|
39 |
+
validate_ids,
|
40 |
+
validate_include,
|
41 |
+
validate_metadata,
|
42 |
+
validate_metadatas,
|
43 |
+
validate_where,
|
44 |
+
validate_where_document,
|
45 |
+
validate_n_results,
|
46 |
+
validate_embeddings,
|
47 |
+
validate_embedding_function,
|
48 |
+
)
|
49 |
+
import logging
|
50 |
+
|
51 |
+
logger = logging.getLogger(__name__)
|
52 |
+
|
53 |
+
if TYPE_CHECKING:
|
54 |
+
from chromadb.api import ServerAPI
|
55 |
+
|
56 |
+
|
57 |
+
class Collection(BaseModel):
|
58 |
+
name: str
|
59 |
+
id: UUID
|
60 |
+
metadata: Optional[CollectionMetadata] = None
|
61 |
+
tenant: Optional[str] = None
|
62 |
+
database: Optional[str] = None
|
63 |
+
_client: "ServerAPI" = PrivateAttr()
|
64 |
+
_embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr()
|
65 |
+
_data_loader: Optional[DataLoader[Loadable]] = PrivateAttr()
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
client: "ServerAPI",
|
70 |
+
name: str,
|
71 |
+
id: UUID,
|
72 |
+
embedding_function: Optional[
|
73 |
+
EmbeddingFunction[Embeddable]
|
74 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
75 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
76 |
+
tenant: Optional[str] = None,
|
77 |
+
database: Optional[str] = None,
|
78 |
+
metadata: Optional[CollectionMetadata] = None,
|
79 |
+
):
|
80 |
+
super().__init__(
|
81 |
+
name=name, metadata=metadata, id=id, tenant=tenant, database=database
|
82 |
+
)
|
83 |
+
self._client = client
|
84 |
+
|
85 |
+
# Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
|
86 |
+
if embedding_function is not None:
|
87 |
+
validate_embedding_function(embedding_function)
|
88 |
+
|
89 |
+
self._embedding_function = embedding_function
|
90 |
+
self._data_loader = data_loader
|
91 |
+
|
92 |
+
def __repr__(self) -> str:
|
93 |
+
return f"Collection(name={self.name})"
|
94 |
+
|
95 |
+
def count(self) -> int:
|
96 |
+
"""The total number of embeddings added to the database
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
int: The total number of embeddings added to the database
|
100 |
+
|
101 |
+
"""
|
102 |
+
return self._client._count(collection_id=self.id)
|
103 |
+
|
104 |
+
def add(
|
105 |
+
self,
|
106 |
+
ids: OneOrMany[ID],
|
107 |
+
embeddings: Optional[
|
108 |
+
Union[
|
109 |
+
OneOrMany[Embedding],
|
110 |
+
OneOrMany[np.ndarray],
|
111 |
+
]
|
112 |
+
] = None,
|
113 |
+
metadatas: Optional[OneOrMany[Metadata]] = None,
|
114 |
+
documents: Optional[OneOrMany[Document]] = None,
|
115 |
+
images: Optional[OneOrMany[Image]] = None,
|
116 |
+
uris: Optional[OneOrMany[URI]] = None,
|
117 |
+
) -> None:
|
118 |
+
"""Add embeddings to the data store.
|
119 |
+
Args:
|
120 |
+
ids: The ids of the embeddings you wish to add
|
121 |
+
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
122 |
+
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
123 |
+
documents: The documents to associate with the embeddings. Optional.
|
124 |
+
images: The images to associate with the embeddings. Optional.
|
125 |
+
uris: The uris of the images to associate with the embeddings. Optional.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
None
|
129 |
+
|
130 |
+
Raises:
|
131 |
+
ValueError: If you don't provide either embeddings or documents
|
132 |
+
ValueError: If the length of ids, embeddings, metadatas, or documents don't match
|
133 |
+
ValueError: If you don't provide an embedding function and don't provide embeddings
|
134 |
+
ValueError: If you provide both embeddings and documents
|
135 |
+
ValueError: If you provide an id that already exists
|
136 |
+
|
137 |
+
"""
|
138 |
+
|
139 |
+
(
|
140 |
+
ids,
|
141 |
+
embeddings,
|
142 |
+
metadatas,
|
143 |
+
documents,
|
144 |
+
images,
|
145 |
+
uris,
|
146 |
+
) = self._validate_embedding_set(
|
147 |
+
ids, embeddings, metadatas, documents, images, uris
|
148 |
+
)
|
149 |
+
|
150 |
+
# We need to compute the embeddings if they're not provided
|
151 |
+
if embeddings is None:
|
152 |
+
# At this point, we know that one of documents or images are provided from the validation above
|
153 |
+
if documents is not None:
|
154 |
+
embeddings = self._embed(input=documents)
|
155 |
+
elif images is not None:
|
156 |
+
embeddings = self._embed(input=images)
|
157 |
+
else:
|
158 |
+
if uris is None:
|
159 |
+
raise ValueError(
|
160 |
+
"You must provide either embeddings, documents, images, or uris."
|
161 |
+
)
|
162 |
+
if self._data_loader is None:
|
163 |
+
raise ValueError(
|
164 |
+
"You must set a data loader on the collection if loading from URIs."
|
165 |
+
)
|
166 |
+
embeddings = self._embed(self._data_loader(uris))
|
167 |
+
|
168 |
+
self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
|
169 |
+
|
170 |
+
def get(
|
171 |
+
self,
|
172 |
+
ids: Optional[OneOrMany[ID]] = None,
|
173 |
+
where: Optional[Where] = None,
|
174 |
+
limit: Optional[int] = None,
|
175 |
+
offset: Optional[int] = None,
|
176 |
+
where_document: Optional[WhereDocument] = None,
|
177 |
+
include: Include = ["metadatas", "documents"],
|
178 |
+
) -> GetResult:
|
179 |
+
"""Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
|
180 |
+
all embeddings up to limit starting at offset.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
ids: The ids of the embeddings to get. Optional.
|
184 |
+
where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
|
185 |
+
limit: The number of documents to return. Optional.
|
186 |
+
offset: The offset to start returning results from. Useful for paging results with limit. Optional.
|
187 |
+
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
|
188 |
+
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
GetResult: A GetResult object containing the results.
|
192 |
+
|
193 |
+
"""
|
194 |
+
|
195 |
+
valid_where = validate_where(where) if where else None
|
196 |
+
valid_where_document = (
|
197 |
+
validate_where_document(where_document) if where_document else None
|
198 |
+
)
|
199 |
+
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
|
200 |
+
valid_include = validate_include(include, allow_distances=False)
|
201 |
+
|
202 |
+
if "data" in include and self._data_loader is None:
|
203 |
+
raise ValueError(
|
204 |
+
"You must set a data loader on the collection if loading from URIs."
|
205 |
+
)
|
206 |
+
|
207 |
+
# We need to include uris in the result from the API to load datas
|
208 |
+
if "data" in include and "uris" not in include:
|
209 |
+
valid_include.append("uris")
|
210 |
+
|
211 |
+
get_results = self._client._get(
|
212 |
+
self.id,
|
213 |
+
valid_ids,
|
214 |
+
valid_where,
|
215 |
+
None,
|
216 |
+
limit,
|
217 |
+
offset,
|
218 |
+
where_document=valid_where_document,
|
219 |
+
include=valid_include,
|
220 |
+
)
|
221 |
+
|
222 |
+
if (
|
223 |
+
"data" in include
|
224 |
+
and self._data_loader is not None
|
225 |
+
and get_results["uris"] is not None
|
226 |
+
):
|
227 |
+
get_results["data"] = self._data_loader(get_results["uris"])
|
228 |
+
|
229 |
+
# Remove URIs from the result if they weren't requested
|
230 |
+
if "uris" not in include:
|
231 |
+
get_results["uris"] = None
|
232 |
+
|
233 |
+
return get_results
|
234 |
+
|
235 |
+
def peek(self, limit: int = 10) -> GetResult:
|
236 |
+
"""Get the first few results in the database up to limit
|
237 |
+
|
238 |
+
Args:
|
239 |
+
limit: The number of results to return.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
GetResult: A GetResult object containing the results.
|
243 |
+
"""
|
244 |
+
return self._client._peek(self.id, limit)
|
245 |
+
|
246 |
+
def query(
|
247 |
+
self,
|
248 |
+
query_embeddings: Optional[
|
249 |
+
Union[
|
250 |
+
OneOrMany[Embedding],
|
251 |
+
OneOrMany[np.ndarray],
|
252 |
+
]
|
253 |
+
] = None,
|
254 |
+
query_texts: Optional[OneOrMany[Document]] = None,
|
255 |
+
query_images: Optional[OneOrMany[Image]] = None,
|
256 |
+
query_uris: Optional[OneOrMany[URI]] = None,
|
257 |
+
n_results: int = 10,
|
258 |
+
where: Optional[Where] = None,
|
259 |
+
where_document: Optional[WhereDocument] = None,
|
260 |
+
include: Include = ["metadatas", "documents", "distances"],
|
261 |
+
) -> QueryResult:
|
262 |
+
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
query_embeddings: The embeddings to get the closes neighbors of. Optional.
|
266 |
+
query_texts: The document texts to get the closes neighbors of. Optional.
|
267 |
+
query_images: The images to get the closes neighbors of. Optional.
|
268 |
+
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
|
269 |
+
where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
|
270 |
+
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
|
271 |
+
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
QueryResult: A QueryResult object containing the results.
|
275 |
+
|
276 |
+
Raises:
|
277 |
+
ValueError: If you don't provide either query_embeddings, query_texts, or query_images
|
278 |
+
ValueError: If you provide both query_embeddings and query_texts
|
279 |
+
ValueError: If you provide both query_embeddings and query_images
|
280 |
+
ValueError: If you provide both query_texts and query_images
|
281 |
+
|
282 |
+
"""
|
283 |
+
|
284 |
+
# Users must provide only one of query_embeddings, query_texts, query_images, or query_uris
|
285 |
+
if not (
|
286 |
+
(query_embeddings is not None)
|
287 |
+
^ (query_texts is not None)
|
288 |
+
^ (query_images is not None)
|
289 |
+
^ (query_uris is not None)
|
290 |
+
):
|
291 |
+
raise ValueError(
|
292 |
+
"You must provide one of query_embeddings, query_texts, query_images, or query_uris."
|
293 |
+
)
|
294 |
+
|
295 |
+
valid_where = validate_where(where) if where else {}
|
296 |
+
valid_where_document = (
|
297 |
+
validate_where_document(where_document) if where_document else {}
|
298 |
+
)
|
299 |
+
valid_query_embeddings = (
|
300 |
+
validate_embeddings(
|
301 |
+
self._normalize_embeddings(
|
302 |
+
maybe_cast_one_to_many_embedding(query_embeddings)
|
303 |
+
)
|
304 |
+
)
|
305 |
+
if query_embeddings is not None
|
306 |
+
else None
|
307 |
+
)
|
308 |
+
valid_query_texts = (
|
309 |
+
maybe_cast_one_to_many_document(query_texts)
|
310 |
+
if query_texts is not None
|
311 |
+
else None
|
312 |
+
)
|
313 |
+
valid_query_images = (
|
314 |
+
maybe_cast_one_to_many_image(query_images)
|
315 |
+
if query_images is not None
|
316 |
+
else None
|
317 |
+
)
|
318 |
+
valid_query_uris = (
|
319 |
+
maybe_cast_one_to_many_uri(query_uris) if query_uris is not None else None
|
320 |
+
)
|
321 |
+
valid_include = validate_include(include, allow_distances=True)
|
322 |
+
valid_n_results = validate_n_results(n_results)
|
323 |
+
|
324 |
+
# If query_embeddings are not provided, we need to compute them from the inputs
|
325 |
+
if valid_query_embeddings is None:
|
326 |
+
if query_texts is not None:
|
327 |
+
valid_query_embeddings = self._embed(input=valid_query_texts)
|
328 |
+
elif query_images is not None:
|
329 |
+
valid_query_embeddings = self._embed(input=valid_query_images)
|
330 |
+
else:
|
331 |
+
if valid_query_uris is None:
|
332 |
+
raise ValueError(
|
333 |
+
"You must provide either query_embeddings, query_texts, query_images, or query_uris."
|
334 |
+
)
|
335 |
+
if self._data_loader is None:
|
336 |
+
raise ValueError(
|
337 |
+
"You must set a data loader on the collection if loading from URIs."
|
338 |
+
)
|
339 |
+
valid_query_embeddings = self._embed(
|
340 |
+
self._data_loader(valid_query_uris)
|
341 |
+
)
|
342 |
+
|
343 |
+
if "data" in include and "uris" not in include:
|
344 |
+
valid_include.append("uris")
|
345 |
+
query_results = self._client._query(
|
346 |
+
collection_id=self.id,
|
347 |
+
query_embeddings=valid_query_embeddings,
|
348 |
+
n_results=valid_n_results,
|
349 |
+
where=valid_where,
|
350 |
+
where_document=valid_where_document,
|
351 |
+
include=include,
|
352 |
+
)
|
353 |
+
|
354 |
+
if (
|
355 |
+
"data" in include
|
356 |
+
and self._data_loader is not None
|
357 |
+
and query_results["uris"] is not None
|
358 |
+
):
|
359 |
+
query_results["data"] = [
|
360 |
+
self._data_loader(uris) for uris in query_results["uris"]
|
361 |
+
]
|
362 |
+
|
363 |
+
# Remove URIs from the result if they weren't requested
|
364 |
+
if "uris" not in include:
|
365 |
+
query_results["uris"] = None
|
366 |
+
|
367 |
+
return query_results
|
368 |
+
|
369 |
+
def modify(
|
370 |
+
self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None
|
371 |
+
) -> None:
|
372 |
+
"""Modify the collection name or metadata
|
373 |
+
|
374 |
+
Args:
|
375 |
+
name: The updated name for the collection. Optional.
|
376 |
+
metadata: The updated metadata for the collection. Optional.
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
None
|
380 |
+
"""
|
381 |
+
if metadata is not None:
|
382 |
+
validate_metadata(metadata)
|
383 |
+
if "hnsw:space" in metadata:
|
384 |
+
raise ValueError(
|
385 |
+
"Changing the distance function of a collection once it is created is not supported currently.")
|
386 |
+
|
387 |
+
self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
|
388 |
+
if name:
|
389 |
+
self.name = name
|
390 |
+
if metadata:
|
391 |
+
self.metadata = metadata
|
392 |
+
|
393 |
+
def update(
|
394 |
+
self,
|
395 |
+
ids: OneOrMany[ID],
|
396 |
+
embeddings: Optional[
|
397 |
+
Union[
|
398 |
+
OneOrMany[Embedding],
|
399 |
+
OneOrMany[np.ndarray],
|
400 |
+
]
|
401 |
+
] = None,
|
402 |
+
metadatas: Optional[OneOrMany[Metadata]] = None,
|
403 |
+
documents: Optional[OneOrMany[Document]] = None,
|
404 |
+
images: Optional[OneOrMany[Image]] = None,
|
405 |
+
uris: Optional[OneOrMany[URI]] = None,
|
406 |
+
) -> None:
|
407 |
+
"""Update the embeddings, metadatas or documents for provided ids.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
ids: The ids of the embeddings to update
|
411 |
+
embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
|
412 |
+
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
413 |
+
documents: The documents to associate with the embeddings. Optional.
|
414 |
+
images: The images to associate with the embeddings. Optional.
|
415 |
+
Returns:
|
416 |
+
None
|
417 |
+
"""
|
418 |
+
|
419 |
+
(
|
420 |
+
ids,
|
421 |
+
embeddings,
|
422 |
+
metadatas,
|
423 |
+
documents,
|
424 |
+
images,
|
425 |
+
uris,
|
426 |
+
) = self._validate_embedding_set(
|
427 |
+
ids,
|
428 |
+
embeddings,
|
429 |
+
metadatas,
|
430 |
+
documents,
|
431 |
+
images,
|
432 |
+
uris,
|
433 |
+
require_embeddings_or_data=False,
|
434 |
+
)
|
435 |
+
|
436 |
+
if embeddings is None:
|
437 |
+
if documents is not None:
|
438 |
+
embeddings = self._embed(input=documents)
|
439 |
+
elif images is not None:
|
440 |
+
embeddings = self._embed(input=images)
|
441 |
+
|
442 |
+
self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
|
443 |
+
|
444 |
+
def upsert(
|
445 |
+
self,
|
446 |
+
ids: OneOrMany[ID],
|
447 |
+
embeddings: Optional[
|
448 |
+
Union[
|
449 |
+
OneOrMany[Embedding],
|
450 |
+
OneOrMany[np.ndarray],
|
451 |
+
]
|
452 |
+
] = None,
|
453 |
+
metadatas: Optional[OneOrMany[Metadata]] = None,
|
454 |
+
documents: Optional[OneOrMany[Document]] = None,
|
455 |
+
images: Optional[OneOrMany[Image]] = None,
|
456 |
+
uris: Optional[OneOrMany[URI]] = None,
|
457 |
+
) -> None:
|
458 |
+
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
ids: The ids of the embeddings to update
|
462 |
+
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
|
463 |
+
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
|
464 |
+
documents: The documents to associate with the embeddings. Optional.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
None
|
468 |
+
"""
|
469 |
+
|
470 |
+
(
|
471 |
+
ids,
|
472 |
+
embeddings,
|
473 |
+
metadatas,
|
474 |
+
documents,
|
475 |
+
images,
|
476 |
+
uris,
|
477 |
+
) = self._validate_embedding_set(
|
478 |
+
ids, embeddings, metadatas, documents, images, uris
|
479 |
+
)
|
480 |
+
|
481 |
+
if embeddings is None:
|
482 |
+
if documents is not None:
|
483 |
+
embeddings = self._embed(input=documents)
|
484 |
+
else:
|
485 |
+
embeddings = self._embed(input=images)
|
486 |
+
|
487 |
+
self._client._upsert(
|
488 |
+
collection_id=self.id,
|
489 |
+
ids=ids,
|
490 |
+
embeddings=embeddings,
|
491 |
+
metadatas=metadatas,
|
492 |
+
documents=documents,
|
493 |
+
uris=uris,
|
494 |
+
)
|
495 |
+
|
496 |
+
def delete(
|
497 |
+
self,
|
498 |
+
ids: Optional[IDs] = None,
|
499 |
+
where: Optional[Where] = None,
|
500 |
+
where_document: Optional[WhereDocument] = None,
|
501 |
+
) -> None:
|
502 |
+
"""Delete the embeddings based on ids and/or a where filter
|
503 |
+
|
504 |
+
Args:
|
505 |
+
ids: The ids of the embeddings to delete
|
506 |
+
where: A Where type dict used to filter the delection by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
|
507 |
+
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{$contains: {"text": "hello"}}`. Optional.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
None
|
511 |
+
|
512 |
+
Raises:
|
513 |
+
ValueError: If you don't provide either ids, where, or where_document
|
514 |
+
"""
|
515 |
+
ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
|
516 |
+
where = validate_where(where) if where else None
|
517 |
+
where_document = (
|
518 |
+
validate_where_document(where_document) if where_document else None
|
519 |
+
)
|
520 |
+
|
521 |
+
self._client._delete(self.id, ids, where, where_document)
|
522 |
+
|
523 |
+
def _validate_embedding_set(
|
524 |
+
self,
|
525 |
+
ids: OneOrMany[ID],
|
526 |
+
embeddings: Optional[
|
527 |
+
Union[
|
528 |
+
OneOrMany[Embedding],
|
529 |
+
OneOrMany[np.ndarray],
|
530 |
+
]
|
531 |
+
],
|
532 |
+
metadatas: Optional[OneOrMany[Metadata]],
|
533 |
+
documents: Optional[OneOrMany[Document]],
|
534 |
+
images: Optional[OneOrMany[Image]] = None,
|
535 |
+
uris: Optional[OneOrMany[URI]] = None,
|
536 |
+
require_embeddings_or_data: bool = True,
|
537 |
+
) -> Tuple[
|
538 |
+
IDs,
|
539 |
+
Optional[Embeddings],
|
540 |
+
Optional[Metadatas],
|
541 |
+
Optional[Documents],
|
542 |
+
Optional[Images],
|
543 |
+
Optional[URIs],
|
544 |
+
]:
|
545 |
+
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids))
|
546 |
+
valid_embeddings = (
|
547 |
+
validate_embeddings(
|
548 |
+
self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings))
|
549 |
+
)
|
550 |
+
if embeddings is not None
|
551 |
+
else None
|
552 |
+
)
|
553 |
+
valid_metadatas = (
|
554 |
+
validate_metadatas(maybe_cast_one_to_many_metadata(metadatas))
|
555 |
+
if metadatas is not None
|
556 |
+
else None
|
557 |
+
)
|
558 |
+
valid_documents = (
|
559 |
+
maybe_cast_one_to_many_document(documents)
|
560 |
+
if documents is not None
|
561 |
+
else None
|
562 |
+
)
|
563 |
+
valid_images = (
|
564 |
+
maybe_cast_one_to_many_image(images) if images is not None else None
|
565 |
+
)
|
566 |
+
|
567 |
+
valid_uris = maybe_cast_one_to_many_uri(uris) if uris is not None else None
|
568 |
+
|
569 |
+
# Check that one of embeddings or ducuments or images is provided
|
570 |
+
if require_embeddings_or_data:
|
571 |
+
if (
|
572 |
+
valid_embeddings is None
|
573 |
+
and valid_documents is None
|
574 |
+
and valid_images is None
|
575 |
+
and valid_uris is None
|
576 |
+
):
|
577 |
+
raise ValueError(
|
578 |
+
"You must provide embeddings, documents, images, or uris."
|
579 |
+
)
|
580 |
+
|
581 |
+
# Only one of documents or images can be provided
|
582 |
+
if valid_documents is not None and valid_images is not None:
|
583 |
+
raise ValueError("You can only provide documents or images, not both.")
|
584 |
+
|
585 |
+
# Check that, if they're provided, the lengths of the arrays match the length of ids
|
586 |
+
if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
|
587 |
+
raise ValueError(
|
588 |
+
f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
|
589 |
+
)
|
590 |
+
if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
|
591 |
+
raise ValueError(
|
592 |
+
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
|
593 |
+
)
|
594 |
+
if valid_documents is not None and len(valid_documents) != len(valid_ids):
|
595 |
+
raise ValueError(
|
596 |
+
f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
|
597 |
+
)
|
598 |
+
if valid_images is not None and len(valid_images) != len(valid_ids):
|
599 |
+
raise ValueError(
|
600 |
+
f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
|
601 |
+
)
|
602 |
+
if valid_uris is not None and len(valid_uris) != len(valid_ids):
|
603 |
+
raise ValueError(
|
604 |
+
f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
|
605 |
+
)
|
606 |
+
|
607 |
+
return (
|
608 |
+
valid_ids,
|
609 |
+
valid_embeddings,
|
610 |
+
valid_metadatas,
|
611 |
+
valid_documents,
|
612 |
+
valid_images,
|
613 |
+
valid_uris,
|
614 |
+
)
|
615 |
+
|
616 |
+
@staticmethod
|
617 |
+
def _normalize_embeddings(
|
618 |
+
embeddings: Union[
|
619 |
+
OneOrMany[Embedding],
|
620 |
+
OneOrMany[np.ndarray],
|
621 |
+
]
|
622 |
+
) -> Embeddings:
|
623 |
+
if isinstance(embeddings, np.ndarray):
|
624 |
+
return embeddings.tolist()
|
625 |
+
return embeddings
|
626 |
+
|
627 |
+
def _embed(self, input: Any) -> Embeddings:
|
628 |
+
if self._embedding_function is None:
|
629 |
+
raise ValueError(
|
630 |
+
"You must provide an embedding function to compute embeddings."
|
631 |
+
"https://docs.trychroma.com/embeddings"
|
632 |
+
)
|
633 |
+
return self._embedding_function(input=input)
|
chromadb/api/segment.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chromadb.api import ServerAPI
|
2 |
+
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
|
3 |
+
from chromadb.db.system import SysDB
|
4 |
+
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
|
5 |
+
from chromadb.telemetry.opentelemetry import (
|
6 |
+
add_attributes_to_current_span,
|
7 |
+
OpenTelemetryClient,
|
8 |
+
OpenTelemetryGranularity,
|
9 |
+
trace_method,
|
10 |
+
)
|
11 |
+
from chromadb.telemetry.product import ProductTelemetryClient
|
12 |
+
from chromadb.ingest import Producer
|
13 |
+
from chromadb.api.models.Collection import Collection
|
14 |
+
from chromadb import __version__
|
15 |
+
from chromadb.errors import InvalidDimensionException, InvalidCollectionException
|
16 |
+
import chromadb.utils.embedding_functions as ef
|
17 |
+
|
18 |
+
from chromadb.api.types import (
|
19 |
+
URI,
|
20 |
+
CollectionMetadata,
|
21 |
+
Embeddable,
|
22 |
+
Document,
|
23 |
+
EmbeddingFunction,
|
24 |
+
DataLoader,
|
25 |
+
IDs,
|
26 |
+
Embeddings,
|
27 |
+
Embedding,
|
28 |
+
Loadable,
|
29 |
+
Metadatas,
|
30 |
+
Documents,
|
31 |
+
URIs,
|
32 |
+
Where,
|
33 |
+
WhereDocument,
|
34 |
+
Include,
|
35 |
+
GetResult,
|
36 |
+
QueryResult,
|
37 |
+
validate_metadata,
|
38 |
+
validate_update_metadata,
|
39 |
+
validate_where,
|
40 |
+
validate_where_document,
|
41 |
+
validate_batch,
|
42 |
+
)
|
43 |
+
from chromadb.telemetry.product.events import (
|
44 |
+
CollectionAddEvent,
|
45 |
+
CollectionDeleteEvent,
|
46 |
+
CollectionGetEvent,
|
47 |
+
CollectionUpdateEvent,
|
48 |
+
CollectionQueryEvent,
|
49 |
+
ClientCreateCollectionEvent,
|
50 |
+
)
|
51 |
+
|
52 |
+
import chromadb.types as t
|
53 |
+
|
54 |
+
from typing import Any, Optional, Sequence, Generator, List, cast, Set, Dict
|
55 |
+
from overrides import override
|
56 |
+
from uuid import UUID, uuid4
|
57 |
+
import time
|
58 |
+
import logging
|
59 |
+
import re
|
60 |
+
|
61 |
+
|
62 |
+
logger = logging.getLogger(__name__)
|
63 |
+
|
64 |
+
|
65 |
+
# mimics s3 bucket requirements for naming
|
66 |
+
def check_index_name(index_name: str) -> None:
|
67 |
+
msg = (
|
68 |
+
"Expected collection name that "
|
69 |
+
"(1) contains 3-63 characters, "
|
70 |
+
"(2) starts and ends with an alphanumeric character, "
|
71 |
+
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
|
72 |
+
"(4) contains no two consecutive periods (..) and "
|
73 |
+
"(5) is not a valid IPv4 address, "
|
74 |
+
f"got {index_name}"
|
75 |
+
)
|
76 |
+
if len(index_name) < 3 or len(index_name) > 63:
|
77 |
+
raise ValueError(msg)
|
78 |
+
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name):
|
79 |
+
raise ValueError(msg)
|
80 |
+
if ".." in index_name:
|
81 |
+
raise ValueError(msg)
|
82 |
+
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
|
83 |
+
raise ValueError(msg)
|
84 |
+
|
85 |
+
|
86 |
+
class SegmentAPI(ServerAPI):
|
87 |
+
"""API implementation utilizing the new segment-based internal architecture"""
|
88 |
+
|
89 |
+
_settings: Settings
|
90 |
+
_sysdb: SysDB
|
91 |
+
_manager: SegmentManager
|
92 |
+
_producer: Producer
|
93 |
+
_product_telemetry_client: ProductTelemetryClient
|
94 |
+
_opentelemetry_client: OpenTelemetryClient
|
95 |
+
_tenant_id: str
|
96 |
+
_topic_ns: str
|
97 |
+
_collection_cache: Dict[UUID, t.Collection]
|
98 |
+
|
99 |
+
def __init__(self, system: System):
|
100 |
+
super().__init__(system)
|
101 |
+
self._settings = system.settings
|
102 |
+
self._sysdb = self.require(SysDB)
|
103 |
+
self._manager = self.require(SegmentManager)
|
104 |
+
self._product_telemetry_client = self.require(ProductTelemetryClient)
|
105 |
+
self._opentelemetry_client = self.require(OpenTelemetryClient)
|
106 |
+
self._producer = self.require(Producer)
|
107 |
+
self._collection_cache = {}
|
108 |
+
|
109 |
+
@override
|
110 |
+
def heartbeat(self) -> int:
|
111 |
+
return int(time.time_ns())
|
112 |
+
|
113 |
+
@override
|
114 |
+
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
|
115 |
+
if len(name) < 3:
|
116 |
+
raise ValueError("Database name must be at least 3 characters long")
|
117 |
+
|
118 |
+
self._sysdb.create_database(
|
119 |
+
id=uuid4(),
|
120 |
+
name=name,
|
121 |
+
tenant=tenant,
|
122 |
+
)
|
123 |
+
|
124 |
+
@override
|
125 |
+
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
|
126 |
+
return self._sysdb.get_database(name=name, tenant=tenant)
|
127 |
+
|
128 |
+
@override
|
129 |
+
def create_tenant(self, name: str) -> None:
|
130 |
+
if len(name) < 3:
|
131 |
+
raise ValueError("Tenant name must be at least 3 characters long")
|
132 |
+
|
133 |
+
self._sysdb.create_tenant(
|
134 |
+
name=name,
|
135 |
+
)
|
136 |
+
|
137 |
+
@override
|
138 |
+
def get_tenant(self, name: str) -> t.Tenant:
|
139 |
+
return self._sysdb.get_tenant(name=name)
|
140 |
+
|
141 |
+
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
|
142 |
+
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
|
143 |
+
# causes the system to somehow convert all values to strings.
|
144 |
+
@trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION)
|
145 |
+
@override
|
146 |
+
def create_collection(
|
147 |
+
self,
|
148 |
+
name: str,
|
149 |
+
metadata: Optional[CollectionMetadata] = None,
|
150 |
+
embedding_function: Optional[
|
151 |
+
EmbeddingFunction[Any]
|
152 |
+
] = ef.DefaultEmbeddingFunction(),
|
153 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
154 |
+
get_or_create: bool = False,
|
155 |
+
tenant: str = DEFAULT_TENANT,
|
156 |
+
database: str = DEFAULT_DATABASE,
|
157 |
+
) -> Collection:
|
158 |
+
if metadata is not None:
|
159 |
+
validate_metadata(metadata)
|
160 |
+
|
161 |
+
# TODO: remove backwards compatibility in naming requirements
|
162 |
+
check_index_name(name)
|
163 |
+
|
164 |
+
id = uuid4()
|
165 |
+
|
166 |
+
coll, created = self._sysdb.create_collection(
|
167 |
+
id=id,
|
168 |
+
name=name,
|
169 |
+
metadata=metadata,
|
170 |
+
dimension=None,
|
171 |
+
get_or_create=get_or_create,
|
172 |
+
tenant=tenant,
|
173 |
+
database=database,
|
174 |
+
)
|
175 |
+
|
176 |
+
if created:
|
177 |
+
segments = self._manager.create_segments(coll)
|
178 |
+
for segment in segments:
|
179 |
+
self._sysdb.create_segment(segment)
|
180 |
+
|
181 |
+
# TODO: This event doesn't capture the get_or_create case appropriately
|
182 |
+
self._product_telemetry_client.capture(
|
183 |
+
ClientCreateCollectionEvent(
|
184 |
+
collection_uuid=str(id),
|
185 |
+
embedding_function=embedding_function.__class__.__name__,
|
186 |
+
)
|
187 |
+
)
|
188 |
+
add_attributes_to_current_span({"collection_uuid": str(id)})
|
189 |
+
|
190 |
+
return Collection(
|
191 |
+
client=self,
|
192 |
+
id=coll["id"],
|
193 |
+
name=name,
|
194 |
+
metadata=coll["metadata"], # type: ignore
|
195 |
+
embedding_function=embedding_function,
|
196 |
+
data_loader=data_loader,
|
197 |
+
tenant=tenant,
|
198 |
+
database=database,
|
199 |
+
)
|
200 |
+
|
201 |
+
@trace_method(
|
202 |
+
"SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
|
203 |
+
)
|
204 |
+
@override
|
205 |
+
def get_or_create_collection(
|
206 |
+
self,
|
207 |
+
name: str,
|
208 |
+
metadata: Optional[CollectionMetadata] = None,
|
209 |
+
embedding_function: Optional[
|
210 |
+
EmbeddingFunction[Embeddable]
|
211 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
212 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
213 |
+
tenant: str = DEFAULT_TENANT,
|
214 |
+
database: str = DEFAULT_DATABASE,
|
215 |
+
) -> Collection:
|
216 |
+
return self.create_collection( # type: ignore
|
217 |
+
name=name,
|
218 |
+
metadata=metadata,
|
219 |
+
embedding_function=embedding_function,
|
220 |
+
data_loader=data_loader,
|
221 |
+
get_or_create=True,
|
222 |
+
tenant=tenant,
|
223 |
+
database=database,
|
224 |
+
)
|
225 |
+
|
226 |
+
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
|
227 |
+
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
|
228 |
+
# causes the system to somehow convert all values to strings
|
229 |
+
@trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION)
|
230 |
+
@override
|
231 |
+
def get_collection(
|
232 |
+
self,
|
233 |
+
name: Optional[str] = None,
|
234 |
+
id: Optional[UUID] = None,
|
235 |
+
embedding_function: Optional[
|
236 |
+
EmbeddingFunction[Embeddable]
|
237 |
+
] = ef.DefaultEmbeddingFunction(), # type: ignore
|
238 |
+
data_loader: Optional[DataLoader[Loadable]] = None,
|
239 |
+
tenant: str = DEFAULT_TENANT,
|
240 |
+
database: str = DEFAULT_DATABASE,
|
241 |
+
) -> Collection:
|
242 |
+
if id is None and name is None or (id is not None and name is not None):
|
243 |
+
raise ValueError("Name or id must be specified, but not both")
|
244 |
+
existing = self._sysdb.get_collections(
|
245 |
+
id=id, name=name, tenant=tenant, database=database
|
246 |
+
)
|
247 |
+
|
248 |
+
if existing:
|
249 |
+
return Collection(
|
250 |
+
client=self,
|
251 |
+
id=existing[0]["id"],
|
252 |
+
name=existing[0]["name"],
|
253 |
+
metadata=existing[0]["metadata"], # type: ignore
|
254 |
+
embedding_function=embedding_function,
|
255 |
+
data_loader=data_loader,
|
256 |
+
tenant=existing[0]["tenant"],
|
257 |
+
database=existing[0]["database"],
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
raise ValueError(f"Collection {name} does not exist.")
|
261 |
+
|
262 |
+
@trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION)
|
263 |
+
@override
|
264 |
+
def list_collections(
|
265 |
+
self,
|
266 |
+
limit: Optional[int] = None,
|
267 |
+
offset: Optional[int] = None,
|
268 |
+
tenant: str = DEFAULT_TENANT,
|
269 |
+
database: str = DEFAULT_DATABASE,
|
270 |
+
) -> Sequence[Collection]:
|
271 |
+
collections = []
|
272 |
+
db_collections = self._sysdb.get_collections(
|
273 |
+
limit=limit, offset=offset, tenant=tenant, database=database
|
274 |
+
)
|
275 |
+
for db_collection in db_collections:
|
276 |
+
collections.append(
|
277 |
+
Collection(
|
278 |
+
client=self,
|
279 |
+
id=db_collection["id"],
|
280 |
+
name=db_collection["name"],
|
281 |
+
metadata=db_collection["metadata"], # type: ignore
|
282 |
+
tenant=db_collection["tenant"],
|
283 |
+
database=db_collection["database"],
|
284 |
+
)
|
285 |
+
)
|
286 |
+
return collections
|
287 |
+
|
288 |
+
@trace_method("SegmentAPI.count_collections", OpenTelemetryGranularity.OPERATION)
|
289 |
+
@override
|
290 |
+
def count_collections(
|
291 |
+
self,
|
292 |
+
tenant: str = DEFAULT_TENANT,
|
293 |
+
database: str = DEFAULT_DATABASE,
|
294 |
+
) -> int:
|
295 |
+
collection_count = len(
|
296 |
+
self._sysdb.get_collections(tenant=tenant, database=database)
|
297 |
+
)
|
298 |
+
|
299 |
+
return collection_count
|
300 |
+
|
301 |
+
@trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION)
|
302 |
+
@override
|
303 |
+
def _modify(
|
304 |
+
self,
|
305 |
+
id: UUID,
|
306 |
+
new_name: Optional[str] = None,
|
307 |
+
new_metadata: Optional[CollectionMetadata] = None,
|
308 |
+
) -> None:
|
309 |
+
if new_name:
|
310 |
+
# backwards compatibility in naming requirements (for now)
|
311 |
+
check_index_name(new_name)
|
312 |
+
|
313 |
+
if new_metadata:
|
314 |
+
validate_update_metadata(new_metadata)
|
315 |
+
|
316 |
+
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
|
317 |
+
# signature of `_modify` but not changing the API right now.
|
318 |
+
if new_name and new_metadata:
|
319 |
+
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
|
320 |
+
elif new_name:
|
321 |
+
self._sysdb.update_collection(id, name=new_name)
|
322 |
+
elif new_metadata:
|
323 |
+
self._sysdb.update_collection(id, metadata=new_metadata)
|
324 |
+
|
325 |
+
@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
326 |
+
@override
|
327 |
+
def delete_collection(
|
328 |
+
self,
|
329 |
+
name: str,
|
330 |
+
tenant: str = DEFAULT_TENANT,
|
331 |
+
database: str = DEFAULT_DATABASE,
|
332 |
+
) -> None:
|
333 |
+
existing = self._sysdb.get_collections(
|
334 |
+
name=name, tenant=tenant, database=database
|
335 |
+
)
|
336 |
+
|
337 |
+
if existing:
|
338 |
+
self._sysdb.delete_collection(
|
339 |
+
existing[0]["id"], tenant=tenant, database=database
|
340 |
+
)
|
341 |
+
for s in self._manager.delete_segments(existing[0]["id"]):
|
342 |
+
self._sysdb.delete_segment(s)
|
343 |
+
if existing and existing[0]["id"] in self._collection_cache:
|
344 |
+
del self._collection_cache[existing[0]["id"]]
|
345 |
+
else:
|
346 |
+
raise ValueError(f"Collection {name} does not exist.")
|
347 |
+
|
348 |
+
@trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION)
|
349 |
+
@override
|
350 |
+
def _add(
|
351 |
+
self,
|
352 |
+
ids: IDs,
|
353 |
+
collection_id: UUID,
|
354 |
+
embeddings: Embeddings,
|
355 |
+
metadatas: Optional[Metadatas] = None,
|
356 |
+
documents: Optional[Documents] = None,
|
357 |
+
uris: Optional[URIs] = None,
|
358 |
+
) -> bool:
|
359 |
+
coll = self._get_collection(collection_id)
|
360 |
+
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
|
361 |
+
validate_batch(
|
362 |
+
(ids, embeddings, metadatas, documents, uris),
|
363 |
+
{"max_batch_size": self.max_batch_size},
|
364 |
+
)
|
365 |
+
records_to_submit = []
|
366 |
+
for r in _records(
|
367 |
+
t.Operation.ADD,
|
368 |
+
ids=ids,
|
369 |
+
collection_id=collection_id,
|
370 |
+
embeddings=embeddings,
|
371 |
+
metadatas=metadatas,
|
372 |
+
documents=documents,
|
373 |
+
uris=uris,
|
374 |
+
):
|
375 |
+
self._validate_embedding_record(coll, r)
|
376 |
+
records_to_submit.append(r)
|
377 |
+
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
378 |
+
|
379 |
+
self._product_telemetry_client.capture(
|
380 |
+
CollectionAddEvent(
|
381 |
+
collection_uuid=str(collection_id),
|
382 |
+
add_amount=len(ids),
|
383 |
+
with_metadata=len(ids) if metadatas is not None else 0,
|
384 |
+
with_documents=len(ids) if documents is not None else 0,
|
385 |
+
with_uris=len(ids) if uris is not None else 0,
|
386 |
+
)
|
387 |
+
)
|
388 |
+
return True
|
389 |
+
|
390 |
+
@trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION)
|
391 |
+
@override
|
392 |
+
def _update(
|
393 |
+
self,
|
394 |
+
collection_id: UUID,
|
395 |
+
ids: IDs,
|
396 |
+
embeddings: Optional[Embeddings] = None,
|
397 |
+
metadatas: Optional[Metadatas] = None,
|
398 |
+
documents: Optional[Documents] = None,
|
399 |
+
uris: Optional[URIs] = None,
|
400 |
+
) -> bool:
|
401 |
+
coll = self._get_collection(collection_id)
|
402 |
+
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
|
403 |
+
validate_batch(
|
404 |
+
(ids, embeddings, metadatas, documents, uris),
|
405 |
+
{"max_batch_size": self.max_batch_size},
|
406 |
+
)
|
407 |
+
records_to_submit = []
|
408 |
+
for r in _records(
|
409 |
+
t.Operation.UPDATE,
|
410 |
+
ids=ids,
|
411 |
+
collection_id=collection_id,
|
412 |
+
embeddings=embeddings,
|
413 |
+
metadatas=metadatas,
|
414 |
+
documents=documents,
|
415 |
+
uris=uris,
|
416 |
+
):
|
417 |
+
self._validate_embedding_record(coll, r)
|
418 |
+
records_to_submit.append(r)
|
419 |
+
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
420 |
+
|
421 |
+
self._product_telemetry_client.capture(
|
422 |
+
CollectionUpdateEvent(
|
423 |
+
collection_uuid=str(collection_id),
|
424 |
+
update_amount=len(ids),
|
425 |
+
with_embeddings=len(embeddings) if embeddings else 0,
|
426 |
+
with_metadata=len(metadatas) if metadatas else 0,
|
427 |
+
with_documents=len(documents) if documents else 0,
|
428 |
+
with_uris=len(uris) if uris else 0,
|
429 |
+
)
|
430 |
+
)
|
431 |
+
|
432 |
+
return True
|
433 |
+
|
434 |
+
@trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION)
|
435 |
+
@override
|
436 |
+
def _upsert(
|
437 |
+
self,
|
438 |
+
collection_id: UUID,
|
439 |
+
ids: IDs,
|
440 |
+
embeddings: Embeddings,
|
441 |
+
metadatas: Optional[Metadatas] = None,
|
442 |
+
documents: Optional[Documents] = None,
|
443 |
+
uris: Optional[URIs] = None,
|
444 |
+
) -> bool:
|
445 |
+
coll = self._get_collection(collection_id)
|
446 |
+
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
|
447 |
+
validate_batch(
|
448 |
+
(ids, embeddings, metadatas, documents, uris),
|
449 |
+
{"max_batch_size": self.max_batch_size},
|
450 |
+
)
|
451 |
+
records_to_submit = []
|
452 |
+
for r in _records(
|
453 |
+
t.Operation.UPSERT,
|
454 |
+
ids=ids,
|
455 |
+
collection_id=collection_id,
|
456 |
+
embeddings=embeddings,
|
457 |
+
metadatas=metadatas,
|
458 |
+
documents=documents,
|
459 |
+
uris=uris,
|
460 |
+
):
|
461 |
+
self._validate_embedding_record(coll, r)
|
462 |
+
records_to_submit.append(r)
|
463 |
+
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
464 |
+
|
465 |
+
return True
|
466 |
+
|
467 |
+
@trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION)
|
468 |
+
@override
|
469 |
+
def _get(
|
470 |
+
self,
|
471 |
+
collection_id: UUID,
|
472 |
+
ids: Optional[IDs] = None,
|
473 |
+
where: Optional[Where] = {},
|
474 |
+
sort: Optional[str] = None,
|
475 |
+
limit: Optional[int] = None,
|
476 |
+
offset: Optional[int] = None,
|
477 |
+
page: Optional[int] = None,
|
478 |
+
page_size: Optional[int] = None,
|
479 |
+
where_document: Optional[WhereDocument] = {},
|
480 |
+
include: Include = ["embeddings", "metadatas", "documents"],
|
481 |
+
) -> GetResult:
|
482 |
+
add_attributes_to_current_span(
|
483 |
+
{
|
484 |
+
"collection_id": str(collection_id),
|
485 |
+
"ids_count": len(ids) if ids else 0,
|
486 |
+
}
|
487 |
+
)
|
488 |
+
|
489 |
+
where = validate_where(where) if where is not None and len(where) > 0 else None
|
490 |
+
where_document = (
|
491 |
+
validate_where_document(where_document)
|
492 |
+
if where_document is not None and len(where_document) > 0
|
493 |
+
else None
|
494 |
+
)
|
495 |
+
|
496 |
+
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
|
497 |
+
|
498 |
+
if sort is not None:
|
499 |
+
raise NotImplementedError("Sorting is not yet supported")
|
500 |
+
|
501 |
+
if page and page_size:
|
502 |
+
offset = (page - 1) * page_size
|
503 |
+
limit = page_size
|
504 |
+
|
505 |
+
records = metadata_segment.get_metadata(
|
506 |
+
where=where,
|
507 |
+
where_document=where_document,
|
508 |
+
ids=ids,
|
509 |
+
limit=limit,
|
510 |
+
offset=offset,
|
511 |
+
)
|
512 |
+
|
513 |
+
if len(records) == 0:
|
514 |
+
# Nothing to return if there are no records
|
515 |
+
return GetResult(
|
516 |
+
ids=[],
|
517 |
+
embeddings=[] if "embeddings" in include else None,
|
518 |
+
metadatas=[] if "metadatas" in include else None,
|
519 |
+
documents=[] if "documents" in include else None,
|
520 |
+
uris=[] if "uris" in include else None,
|
521 |
+
data=[] if "data" in include else None,
|
522 |
+
)
|
523 |
+
|
524 |
+
vectors: Sequence[t.VectorEmbeddingRecord] = []
|
525 |
+
if "embeddings" in include:
|
526 |
+
vector_ids = [r["id"] for r in records]
|
527 |
+
vector_segment = self._manager.get_segment(collection_id, VectorReader)
|
528 |
+
vectors = vector_segment.get_vectors(ids=vector_ids)
|
529 |
+
|
530 |
+
# TODO: Fix type so we don't need to ignore
|
531 |
+
# It is possible to have a set of records, some with metadata and some without
|
532 |
+
# Same with documents
|
533 |
+
|
534 |
+
metadatas = [r["metadata"] for r in records]
|
535 |
+
|
536 |
+
if "documents" in include:
|
537 |
+
documents = [_doc(m) for m in metadatas]
|
538 |
+
|
539 |
+
if "uris" in include:
|
540 |
+
uris = [_uri(m) for m in metadatas]
|
541 |
+
|
542 |
+
ids_amount = len(ids) if ids else 0
|
543 |
+
self._product_telemetry_client.capture(
|
544 |
+
CollectionGetEvent(
|
545 |
+
collection_uuid=str(collection_id),
|
546 |
+
ids_count=ids_amount,
|
547 |
+
limit=limit if limit else 0,
|
548 |
+
include_metadata=ids_amount if "metadatas" in include else 0,
|
549 |
+
include_documents=ids_amount if "documents" in include else 0,
|
550 |
+
include_uris=ids_amount if "uris" in include else 0,
|
551 |
+
)
|
552 |
+
)
|
553 |
+
|
554 |
+
return GetResult(
|
555 |
+
ids=[r["id"] for r in records],
|
556 |
+
embeddings=[r["embedding"] for r in vectors]
|
557 |
+
if "embeddings" in include
|
558 |
+
else None,
|
559 |
+
metadatas=_clean_metadatas(metadatas)
|
560 |
+
if "metadatas" in include
|
561 |
+
else None, # type: ignore
|
562 |
+
documents=documents if "documents" in include else None, # type: ignore
|
563 |
+
uris=uris if "uris" in include else None, # type: ignore
|
564 |
+
data=None,
|
565 |
+
)
|
566 |
+
|
567 |
+
@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
|
568 |
+
@override
|
569 |
+
def _delete(
|
570 |
+
self,
|
571 |
+
collection_id: UUID,
|
572 |
+
ids: Optional[IDs] = None,
|
573 |
+
where: Optional[Where] = None,
|
574 |
+
where_document: Optional[WhereDocument] = None,
|
575 |
+
) -> IDs:
|
576 |
+
add_attributes_to_current_span(
|
577 |
+
{
|
578 |
+
"collection_id": str(collection_id),
|
579 |
+
"ids_count": len(ids) if ids else 0,
|
580 |
+
}
|
581 |
+
)
|
582 |
+
|
583 |
+
where = validate_where(where) if where is not None and len(where) > 0 else None
|
584 |
+
where_document = (
|
585 |
+
validate_where_document(where_document)
|
586 |
+
if where_document is not None and len(where_document) > 0
|
587 |
+
else None
|
588 |
+
)
|
589 |
+
|
590 |
+
# You must have at least one of non-empty ids, where, or where_document.
|
591 |
+
if (
|
592 |
+
(ids is None or (ids is not None and len(ids) == 0))
|
593 |
+
and (where is None or (where is not None and len(where) == 0))
|
594 |
+
and (
|
595 |
+
where_document is None
|
596 |
+
or (where_document is not None and len(where_document) == 0)
|
597 |
+
)
|
598 |
+
):
|
599 |
+
raise ValueError(
|
600 |
+
"""
|
601 |
+
You must provide either ids, where, or where_document to delete. If
|
602 |
+
you want to delete all data in a collection you can delete the
|
603 |
+
collection itself using the delete_collection method. Or alternatively,
|
604 |
+
you can get() all the relevant ids and then delete them.
|
605 |
+
"""
|
606 |
+
)
|
607 |
+
|
608 |
+
coll = self._get_collection(collection_id)
|
609 |
+
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)
|
610 |
+
|
611 |
+
if (where or where_document) or not ids:
|
612 |
+
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
|
613 |
+
records = metadata_segment.get_metadata(
|
614 |
+
where=where, where_document=where_document, ids=ids
|
615 |
+
)
|
616 |
+
ids_to_delete = [r["id"] for r in records]
|
617 |
+
else:
|
618 |
+
ids_to_delete = ids
|
619 |
+
|
620 |
+
if len(ids_to_delete) == 0:
|
621 |
+
return []
|
622 |
+
|
623 |
+
records_to_submit = []
|
624 |
+
for r in _records(
|
625 |
+
operation=t.Operation.DELETE, ids=ids_to_delete, collection_id=collection_id
|
626 |
+
):
|
627 |
+
self._validate_embedding_record(coll, r)
|
628 |
+
records_to_submit.append(r)
|
629 |
+
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
630 |
+
|
631 |
+
self._product_telemetry_client.capture(
|
632 |
+
CollectionDeleteEvent(
|
633 |
+
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
|
634 |
+
)
|
635 |
+
)
|
636 |
+
return ids_to_delete
|
637 |
+
|
638 |
+
@trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION)
|
639 |
+
@override
|
640 |
+
def _count(self, collection_id: UUID) -> int:
|
641 |
+
add_attributes_to_current_span({"collection_id": str(collection_id)})
|
642 |
+
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
|
643 |
+
return metadata_segment.count()
|
644 |
+
|
645 |
+
@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
|
646 |
+
@override
|
647 |
+
def _query(
|
648 |
+
self,
|
649 |
+
collection_id: UUID,
|
650 |
+
query_embeddings: Embeddings,
|
651 |
+
n_results: int = 10,
|
652 |
+
where: Where = {},
|
653 |
+
where_document: WhereDocument = {},
|
654 |
+
include: Include = ["documents", "metadatas", "distances"],
|
655 |
+
) -> QueryResult:
|
656 |
+
add_attributes_to_current_span(
|
657 |
+
{
|
658 |
+
"collection_id": str(collection_id),
|
659 |
+
"n_results": n_results,
|
660 |
+
"where": str(where),
|
661 |
+
}
|
662 |
+
)
|
663 |
+
where = validate_where(where) if where is not None and len(where) > 0 else where
|
664 |
+
where_document = (
|
665 |
+
validate_where_document(where_document)
|
666 |
+
if where_document is not None and len(where_document) > 0
|
667 |
+
else where_document
|
668 |
+
)
|
669 |
+
|
670 |
+
allowed_ids = None
|
671 |
+
|
672 |
+
coll = self._get_collection(collection_id)
|
673 |
+
for embedding in query_embeddings:
|
674 |
+
self._validate_dimension(coll, len(embedding), update=False)
|
675 |
+
|
676 |
+
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
|
677 |
+
|
678 |
+
if where or where_document:
|
679 |
+
records = metadata_reader.get_metadata(
|
680 |
+
where=where, where_document=where_document
|
681 |
+
)
|
682 |
+
allowed_ids = [r["id"] for r in records]
|
683 |
+
|
684 |
+
query = t.VectorQuery(
|
685 |
+
vectors=query_embeddings,
|
686 |
+
k=n_results,
|
687 |
+
allowed_ids=allowed_ids,
|
688 |
+
include_embeddings="embeddings" in include,
|
689 |
+
options=None,
|
690 |
+
)
|
691 |
+
|
692 |
+
vector_reader = self._manager.get_segment(collection_id, VectorReader)
|
693 |
+
results = vector_reader.query_vectors(query)
|
694 |
+
|
695 |
+
ids: List[List[str]] = []
|
696 |
+
distances: List[List[float]] = []
|
697 |
+
embeddings: List[List[Embedding]] = []
|
698 |
+
documents: List[List[Document]] = []
|
699 |
+
uris: List[List[URI]] = []
|
700 |
+
metadatas: List[List[t.Metadata]] = []
|
701 |
+
|
702 |
+
for result in results:
|
703 |
+
ids.append([r["id"] for r in result])
|
704 |
+
if "distances" in include:
|
705 |
+
distances.append([r["distance"] for r in result])
|
706 |
+
if "embeddings" in include:
|
707 |
+
embeddings.append([cast(Embedding, r["embedding"]) for r in result])
|
708 |
+
|
709 |
+
if "documents" in include or "metadatas" in include or "uris" in include:
|
710 |
+
all_ids: Set[str] = set()
|
711 |
+
for id_list in ids:
|
712 |
+
all_ids.update(id_list)
|
713 |
+
records = metadata_reader.get_metadata(ids=list(all_ids))
|
714 |
+
metadata_by_id = {r["id"]: r["metadata"] for r in records}
|
715 |
+
for id_list in ids:
|
716 |
+
# In the segment based architecture, it is possible for one segment
|
717 |
+
# to have a record that another segment does not have. This results in
|
718 |
+
# data inconsistency. For the case of the local segments and the
|
719 |
+
# local segment manager, there is a case where a thread writes
|
720 |
+
# a record to the vector segment but not the metadata segment.
|
721 |
+
# Then a query'ing thread reads from the vector segment and
|
722 |
+
# queries the metadata segment. The metadata segment does not have
|
723 |
+
# the record. In this case we choose to return potentially
|
724 |
+
# incorrect data in the form of None.
|
725 |
+
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
|
726 |
+
if "metadatas" in include:
|
727 |
+
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
|
728 |
+
if "documents" in include:
|
729 |
+
doc_list = [_doc(m) for m in metadata_list]
|
730 |
+
documents.append(doc_list) # type: ignore
|
731 |
+
if "uris" in include:
|
732 |
+
uri_list = [_uri(m) for m in metadata_list]
|
733 |
+
uris.append(uri_list) # type: ignore
|
734 |
+
|
735 |
+
query_amount = len(query_embeddings)
|
736 |
+
self._product_telemetry_client.capture(
|
737 |
+
CollectionQueryEvent(
|
738 |
+
collection_uuid=str(collection_id),
|
739 |
+
query_amount=query_amount,
|
740 |
+
n_results=n_results,
|
741 |
+
with_metadata_filter=query_amount if where is not None else 0,
|
742 |
+
with_document_filter=query_amount if where_document is not None else 0,
|
743 |
+
include_metadatas=query_amount if "metadatas" in include else 0,
|
744 |
+
include_documents=query_amount if "documents" in include else 0,
|
745 |
+
include_uris=query_amount if "uris" in include else 0,
|
746 |
+
include_distances=query_amount if "distances" in include else 0,
|
747 |
+
)
|
748 |
+
)
|
749 |
+
|
750 |
+
return QueryResult(
|
751 |
+
ids=ids,
|
752 |
+
distances=distances if distances else None,
|
753 |
+
metadatas=metadatas if metadatas else None,
|
754 |
+
embeddings=embeddings if embeddings else None,
|
755 |
+
documents=documents if documents else None,
|
756 |
+
uris=uris if uris else None,
|
757 |
+
data=None,
|
758 |
+
)
|
759 |
+
|
760 |
+
@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
|
761 |
+
@override
|
762 |
+
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
763 |
+
add_attributes_to_current_span({"collection_id": str(collection_id)})
|
764 |
+
return self._get(collection_id, limit=n) # type: ignore
|
765 |
+
|
766 |
+
@override
|
767 |
+
def get_version(self) -> str:
|
768 |
+
return __version__
|
769 |
+
|
770 |
+
@override
|
771 |
+
def reset_state(self) -> None:
|
772 |
+
self._collection_cache = {}
|
773 |
+
|
774 |
+
@override
|
775 |
+
def reset(self) -> bool:
|
776 |
+
self._system.reset_state()
|
777 |
+
return True
|
778 |
+
|
779 |
+
@override
|
780 |
+
def get_settings(self) -> Settings:
|
781 |
+
return self._settings
|
782 |
+
|
783 |
+
@property
|
784 |
+
@override
|
785 |
+
def max_batch_size(self) -> int:
|
786 |
+
return self._producer.max_batch_size
|
787 |
+
|
788 |
+
# TODO: This could potentially cause race conditions in a distributed version of the
|
789 |
+
# system, since the cache is only local.
|
790 |
+
# TODO: promote collection -> topic to a base class method so that it can be
|
791 |
+
# used for channel assignment in the distributed version of the system.
|
792 |
+
@trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL)
|
793 |
+
def _validate_embedding_record(
|
794 |
+
self, collection: t.Collection, record: t.SubmitEmbeddingRecord
|
795 |
+
) -> None:
|
796 |
+
"""Validate the dimension of an embedding record before submitting it to the system."""
|
797 |
+
add_attributes_to_current_span({"collection_id": str(collection["id"])})
|
798 |
+
if record["embedding"]:
|
799 |
+
self._validate_dimension(collection, len(record["embedding"]), update=True)
|
800 |
+
|
801 |
+
@trace_method("SegmentAPI._validate_dimension", OpenTelemetryGranularity.ALL)
|
802 |
+
def _validate_dimension(
|
803 |
+
self, collection: t.Collection, dim: int, update: bool
|
804 |
+
) -> None:
|
805 |
+
"""Validate that a collection supports records of the given dimension. If update
|
806 |
+
is true, update the collection if the collection doesn't already have a
|
807 |
+
dimension."""
|
808 |
+
if collection["dimension"] is None:
|
809 |
+
if update:
|
810 |
+
id = collection["id"]
|
811 |
+
self._sysdb.update_collection(id=id, dimension=dim)
|
812 |
+
self._collection_cache[id]["dimension"] = dim
|
813 |
+
elif collection["dimension"] != dim:
|
814 |
+
raise InvalidDimensionException(
|
815 |
+
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}"
|
816 |
+
)
|
817 |
+
else:
|
818 |
+
return # all is well
|
819 |
+
|
820 |
+
@trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL)
|
821 |
+
def _get_collection(self, collection_id: UUID) -> t.Collection:
|
822 |
+
"""Read-through cache for collection data"""
|
823 |
+
if collection_id not in self._collection_cache:
|
824 |
+
collections = self._sysdb.get_collections(id=collection_id)
|
825 |
+
if not collections:
|
826 |
+
raise InvalidCollectionException(
|
827 |
+
f"Collection {collection_id} does not exist."
|
828 |
+
)
|
829 |
+
self._collection_cache[collection_id] = collections[0]
|
830 |
+
return self._collection_cache[collection_id]
|
831 |
+
|
832 |
+
|
833 |
+
def _records(
|
834 |
+
operation: t.Operation,
|
835 |
+
ids: IDs,
|
836 |
+
collection_id: UUID,
|
837 |
+
embeddings: Optional[Embeddings] = None,
|
838 |
+
metadatas: Optional[Metadatas] = None,
|
839 |
+
documents: Optional[Documents] = None,
|
840 |
+
uris: Optional[URIs] = None,
|
841 |
+
) -> Generator[t.SubmitEmbeddingRecord, None, None]:
|
842 |
+
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of
|
843 |
+
SubmitEmbeddingRecords"""
|
844 |
+
|
845 |
+
# Presumes that callers were invoked via Collection model, which means
|
846 |
+
# that we know that the embeddings, metadatas and documents have already been
|
847 |
+
# normalized and are guaranteed to be consistently named lists.
|
848 |
+
|
849 |
+
for i, id in enumerate(ids):
|
850 |
+
metadata = None
|
851 |
+
if metadatas:
|
852 |
+
metadata = metadatas[i]
|
853 |
+
|
854 |
+
if documents:
|
855 |
+
document = documents[i]
|
856 |
+
if metadata:
|
857 |
+
metadata = {**metadata, "chroma:document": document}
|
858 |
+
else:
|
859 |
+
metadata = {"chroma:document": document}
|
860 |
+
|
861 |
+
if uris:
|
862 |
+
uri = uris[i]
|
863 |
+
if metadata:
|
864 |
+
metadata = {**metadata, "chroma:uri": uri}
|
865 |
+
else:
|
866 |
+
metadata = {"chroma:uri": uri}
|
867 |
+
|
868 |
+
record = t.SubmitEmbeddingRecord(
|
869 |
+
id=id,
|
870 |
+
embedding=embeddings[i] if embeddings else None,
|
871 |
+
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
|
872 |
+
metadata=metadata,
|
873 |
+
operation=operation,
|
874 |
+
collection_id=collection_id,
|
875 |
+
)
|
876 |
+
yield record
|
877 |
+
|
878 |
+
|
879 |
+
def _doc(metadata: Optional[t.Metadata]) -> Optional[str]:
|
880 |
+
"""Retrieve the document (if any) from a Metadata map"""
|
881 |
+
|
882 |
+
if metadata and "chroma:document" in metadata:
|
883 |
+
return str(metadata["chroma:document"])
|
884 |
+
return None
|
885 |
+
|
886 |
+
|
887 |
+
def _uri(metadata: Optional[t.Metadata]) -> Optional[str]:
|
888 |
+
"""Retrieve the uri (if any) from a Metadata map"""
|
889 |
+
|
890 |
+
if metadata and "chroma:uri" in metadata:
|
891 |
+
return str(metadata["chroma:uri"])
|
892 |
+
return None
|
893 |
+
|
894 |
+
|
895 |
+
def _clean_metadatas(
|
896 |
+
metadata: List[Optional[t.Metadata]],
|
897 |
+
) -> List[Optional[t.Metadata]]:
|
898 |
+
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
|
899 |
+
list of metadata maps."""
|
900 |
+
return [_clean_metadata(m) for m in metadata]
|
901 |
+
|
902 |
+
|
903 |
+
def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]:
|
904 |
+
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
|
905 |
+
metadata map."""
|
906 |
+
if not metadata:
|
907 |
+
return None
|
908 |
+
result = {}
|
909 |
+
for k, v in metadata.items():
|
910 |
+
if not k.startswith("chroma:"):
|
911 |
+
result[k] = v
|
912 |
+
if len(result) == 0:
|
913 |
+
return None
|
914 |
+
return result
|
chromadb/api/types.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
|
2 |
+
from numpy.typing import NDArray
|
3 |
+
import numpy as np
|
4 |
+
from typing_extensions import Literal, TypedDict, Protocol
|
5 |
+
import chromadb.errors as errors
|
6 |
+
from chromadb.types import (
|
7 |
+
Metadata,
|
8 |
+
UpdateMetadata,
|
9 |
+
Vector,
|
10 |
+
LiteralValue,
|
11 |
+
LogicalOperator,
|
12 |
+
WhereOperator,
|
13 |
+
OperatorExpression,
|
14 |
+
Where,
|
15 |
+
WhereDocumentOperator,
|
16 |
+
WhereDocument,
|
17 |
+
)
|
18 |
+
from inspect import signature
|
19 |
+
from tenacity import retry
|
20 |
+
|
21 |
+
# Re-export types from chromadb.types
|
22 |
+
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
|
23 |
+
|
24 |
+
T = TypeVar("T")
|
25 |
+
OneOrMany = Union[T, List[T]]
|
26 |
+
|
27 |
+
# URIs
|
28 |
+
URI = str
|
29 |
+
URIs = List[URI]
|
30 |
+
|
31 |
+
|
32 |
+
def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
|
33 |
+
if isinstance(target, str):
|
34 |
+
# One URI
|
35 |
+
return cast(URIs, [target])
|
36 |
+
# Already a sequence
|
37 |
+
return cast(URIs, target)
|
38 |
+
|
39 |
+
|
40 |
+
# IDs
|
41 |
+
ID = str
|
42 |
+
IDs = List[ID]
|
43 |
+
|
44 |
+
|
45 |
+
def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
|
46 |
+
if isinstance(target, str):
|
47 |
+
# One ID
|
48 |
+
return cast(IDs, [target])
|
49 |
+
# Already a sequence
|
50 |
+
return cast(IDs, target)
|
51 |
+
|
52 |
+
|
53 |
+
# Embeddings
|
54 |
+
Embedding = Vector
|
55 |
+
Embeddings = List[Embedding]
|
56 |
+
|
57 |
+
|
58 |
+
def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
|
59 |
+
if isinstance(target, List):
|
60 |
+
# One Embedding
|
61 |
+
if isinstance(target[0], (int, float)):
|
62 |
+
return cast(Embeddings, [target])
|
63 |
+
# Already a sequence
|
64 |
+
return cast(Embeddings, target)
|
65 |
+
|
66 |
+
|
67 |
+
# Metadatas
|
68 |
+
Metadatas = List[Metadata]
|
69 |
+
|
70 |
+
|
71 |
+
def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
|
72 |
+
# One Metadata dict
|
73 |
+
if isinstance(target, dict):
|
74 |
+
return cast(Metadatas, [target])
|
75 |
+
# Already a sequence
|
76 |
+
return cast(Metadatas, target)
|
77 |
+
|
78 |
+
|
79 |
+
CollectionMetadata = Dict[str, Any]
|
80 |
+
UpdateCollectionMetadata = UpdateMetadata
|
81 |
+
|
82 |
+
# Documents
|
83 |
+
Document = str
|
84 |
+
Documents = List[Document]
|
85 |
+
|
86 |
+
|
87 |
+
def is_document(target: Any) -> bool:
|
88 |
+
if not isinstance(target, str):
|
89 |
+
return False
|
90 |
+
return True
|
91 |
+
|
92 |
+
|
93 |
+
def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
|
94 |
+
# One Document
|
95 |
+
if is_document(target):
|
96 |
+
return cast(Documents, [target])
|
97 |
+
# Already a sequence
|
98 |
+
return cast(Documents, target)
|
99 |
+
|
100 |
+
|
101 |
+
# Images
|
102 |
+
ImageDType = Union[np.uint, np.int_, np.float_]
|
103 |
+
Image = NDArray[ImageDType]
|
104 |
+
Images = List[Image]
|
105 |
+
|
106 |
+
|
107 |
+
def is_image(target: Any) -> bool:
|
108 |
+
if not isinstance(target, np.ndarray):
|
109 |
+
return False
|
110 |
+
if len(target.shape) < 2:
|
111 |
+
return False
|
112 |
+
return True
|
113 |
+
|
114 |
+
|
115 |
+
def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
|
116 |
+
if is_image(target):
|
117 |
+
return cast(Images, [target])
|
118 |
+
# Already a sequence
|
119 |
+
return cast(Images, target)
|
120 |
+
|
121 |
+
|
122 |
+
Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
|
123 |
+
|
124 |
+
# This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
|
125 |
+
# However, this provokes an incompatibility with the Overrides library and Python 3.7
|
126 |
+
Include = List[
|
127 |
+
Union[
|
128 |
+
Literal["documents"],
|
129 |
+
Literal["embeddings"],
|
130 |
+
Literal["metadatas"],
|
131 |
+
Literal["distances"],
|
132 |
+
Literal["uris"],
|
133 |
+
Literal["data"],
|
134 |
+
]
|
135 |
+
]
|
136 |
+
|
137 |
+
# Re-export types from chromadb.types
|
138 |
+
LiteralValue = LiteralValue
|
139 |
+
LogicalOperator = LogicalOperator
|
140 |
+
WhereOperator = WhereOperator
|
141 |
+
OperatorExpression = OperatorExpression
|
142 |
+
Where = Where
|
143 |
+
WhereDocumentOperator = WhereDocumentOperator
|
144 |
+
|
145 |
+
Embeddable = Union[Documents, Images]
|
146 |
+
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
147 |
+
|
148 |
+
|
149 |
+
Loadable = List[Optional[Image]]
|
150 |
+
L = TypeVar("L", covariant=True, bound=Loadable)
|
151 |
+
|
152 |
+
|
153 |
+
class GetResult(TypedDict):
|
154 |
+
ids: List[ID]
|
155 |
+
embeddings: Optional[List[Embedding]]
|
156 |
+
documents: Optional[List[Document]]
|
157 |
+
uris: Optional[URIs]
|
158 |
+
data: Optional[Loadable]
|
159 |
+
metadatas: Optional[List[Metadata]]
|
160 |
+
|
161 |
+
|
162 |
+
class QueryResult(TypedDict):
|
163 |
+
ids: List[IDs]
|
164 |
+
embeddings: Optional[List[List[Embedding]]]
|
165 |
+
documents: Optional[List[List[Document]]]
|
166 |
+
uris: Optional[List[List[URI]]]
|
167 |
+
data: Optional[List[Loadable]]
|
168 |
+
metadatas: Optional[List[List[Metadata]]]
|
169 |
+
distances: Optional[List[List[float]]]
|
170 |
+
|
171 |
+
|
172 |
+
class IndexMetadata(TypedDict):
|
173 |
+
dimensionality: int
|
174 |
+
# The current number of elements in the index (total = additions - deletes)
|
175 |
+
curr_elements: int
|
176 |
+
# The auto-incrementing ID of the last inserted element, never decreases so
|
177 |
+
# can be used as a count of total historical size. Should increase by 1 every add.
|
178 |
+
# Assume cannot overflow
|
179 |
+
total_elements_added: int
|
180 |
+
time_created: float
|
181 |
+
|
182 |
+
|
183 |
+
class EmbeddingFunction(Protocol[D]):
|
184 |
+
def __call__(self, input: D) -> Embeddings:
|
185 |
+
...
|
186 |
+
|
187 |
+
def __init_subclass__(cls) -> None:
|
188 |
+
super().__init_subclass__()
|
189 |
+
# Raise an exception if __call__ is not defined since it is expected to be defined
|
190 |
+
call = getattr(cls, "__call__")
|
191 |
+
|
192 |
+
def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
193 |
+
result = call(self, input)
|
194 |
+
return validate_embeddings(maybe_cast_one_to_many_embedding(result))
|
195 |
+
|
196 |
+
setattr(cls, "__call__", __call__)
|
197 |
+
|
198 |
+
def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
|
199 |
+
return retry(**retry_kwargs)(self.__call__)(input)
|
200 |
+
|
201 |
+
|
202 |
+
def validate_embedding_function(
|
203 |
+
embedding_function: EmbeddingFunction[Embeddable],
|
204 |
+
) -> None:
|
205 |
+
function_signature = signature(
|
206 |
+
embedding_function.__class__.__call__
|
207 |
+
).parameters.keys()
|
208 |
+
protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
|
209 |
+
|
210 |
+
if not function_signature == protocol_signature:
|
211 |
+
raise ValueError(
|
212 |
+
f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
|
213 |
+
"Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
|
214 |
+
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
|
215 |
+
)
|
216 |
+
|
217 |
+
|
218 |
+
class DataLoader(Protocol[L]):
|
219 |
+
def __call__(self, uris: URIs) -> L:
|
220 |
+
...
|
221 |
+
|
222 |
+
|
223 |
+
def validate_ids(ids: IDs) -> IDs:
|
224 |
+
"""Validates ids to ensure it is a list of strings"""
|
225 |
+
if not isinstance(ids, list):
|
226 |
+
raise ValueError(f"Expected IDs to be a list, got {ids}")
|
227 |
+
if len(ids) == 0:
|
228 |
+
raise ValueError(f"Expected IDs to be a non-empty list, got {ids}")
|
229 |
+
seen = set()
|
230 |
+
dups = set()
|
231 |
+
for id_ in ids:
|
232 |
+
if not isinstance(id_, str):
|
233 |
+
raise ValueError(f"Expected ID to be a str, got {id_}")
|
234 |
+
if id_ in seen:
|
235 |
+
dups.add(id_)
|
236 |
+
else:
|
237 |
+
seen.add(id_)
|
238 |
+
if dups:
|
239 |
+
n_dups = len(dups)
|
240 |
+
if n_dups < 10:
|
241 |
+
example_string = ", ".join(dups)
|
242 |
+
message = (
|
243 |
+
f"Expected IDs to be unique, found duplicates of: {example_string}"
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
examples = []
|
247 |
+
for idx, dup in enumerate(dups):
|
248 |
+
examples.append(dup)
|
249 |
+
if idx == 10:
|
250 |
+
break
|
251 |
+
example_string = (
|
252 |
+
f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
|
253 |
+
)
|
254 |
+
message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
|
255 |
+
raise errors.DuplicateIDError(message)
|
256 |
+
return ids
|
257 |
+
|
258 |
+
|
259 |
+
def validate_metadata(metadata: Metadata) -> Metadata:
|
260 |
+
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
|
261 |
+
if not isinstance(metadata, dict) and metadata is not None:
|
262 |
+
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
|
263 |
+
if metadata is None:
|
264 |
+
return metadata
|
265 |
+
if len(metadata) == 0:
|
266 |
+
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
|
267 |
+
for key, value in metadata.items():
|
268 |
+
if not isinstance(key, str):
|
269 |
+
raise TypeError(
|
270 |
+
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
|
271 |
+
)
|
272 |
+
# isinstance(True, int) evaluates to True, so we need to check for bools separately
|
273 |
+
if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
|
274 |
+
raise ValueError(
|
275 |
+
f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}"
|
276 |
+
)
|
277 |
+
return metadata
|
278 |
+
|
279 |
+
|
280 |
+
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
|
281 |
+
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
|
282 |
+
if not isinstance(metadata, dict) and metadata is not None:
|
283 |
+
raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
|
284 |
+
if metadata is None:
|
285 |
+
return metadata
|
286 |
+
if len(metadata) == 0:
|
287 |
+
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
|
288 |
+
for key, value in metadata.items():
|
289 |
+
if not isinstance(key, str):
|
290 |
+
raise ValueError(f"Expected metadata key to be a str, got {key}")
|
291 |
+
# isinstance(True, int) evaluates to True, so we need to check for bools separately
|
292 |
+
if not isinstance(value, bool) and not isinstance(
|
293 |
+
value, (str, int, float, type(None))
|
294 |
+
):
|
295 |
+
raise ValueError(
|
296 |
+
f"Expected metadata value to be a str, int, or float, got {value}"
|
297 |
+
)
|
298 |
+
return metadata
|
299 |
+
|
300 |
+
|
301 |
+
def validate_metadatas(metadatas: Metadatas) -> Metadatas:
|
302 |
+
"""Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
|
303 |
+
if not isinstance(metadatas, list):
|
304 |
+
raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
|
305 |
+
for metadata in metadatas:
|
306 |
+
validate_metadata(metadata)
|
307 |
+
return metadatas
|
308 |
+
|
309 |
+
|
310 |
+
def validate_where(where: Where) -> Where:
|
311 |
+
"""
|
312 |
+
Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
|
313 |
+
or in the case of $and and $or, a list of where expressions
|
314 |
+
"""
|
315 |
+
if not isinstance(where, dict):
|
316 |
+
raise ValueError(f"Expected where to be a dict, got {where}")
|
317 |
+
if len(where) != 1:
|
318 |
+
raise ValueError(f"Expected where to have exactly one operator, got {where}")
|
319 |
+
for key, value in where.items():
|
320 |
+
if not isinstance(key, str):
|
321 |
+
raise ValueError(f"Expected where key to be a str, got {key}")
|
322 |
+
if (
|
323 |
+
key != "$and"
|
324 |
+
and key != "$or"
|
325 |
+
and key != "$in"
|
326 |
+
and key != "$nin"
|
327 |
+
and not isinstance(value, (str, int, float, dict))
|
328 |
+
):
|
329 |
+
raise ValueError(
|
330 |
+
f"Expected where value to be a str, int, float, or operator expression, got {value}"
|
331 |
+
)
|
332 |
+
if key == "$and" or key == "$or":
|
333 |
+
if not isinstance(value, list):
|
334 |
+
raise ValueError(
|
335 |
+
f"Expected where value for $and or $or to be a list of where expressions, got {value}"
|
336 |
+
)
|
337 |
+
if len(value) <= 1:
|
338 |
+
raise ValueError(
|
339 |
+
f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
|
340 |
+
)
|
341 |
+
for where_expression in value:
|
342 |
+
validate_where(where_expression)
|
343 |
+
# Value is a operator expression
|
344 |
+
if isinstance(value, dict):
|
345 |
+
# Ensure there is only one operator
|
346 |
+
if len(value) != 1:
|
347 |
+
raise ValueError(
|
348 |
+
f"Expected operator expression to have exactly one operator, got {value}"
|
349 |
+
)
|
350 |
+
|
351 |
+
for operator, operand in value.items():
|
352 |
+
# Only numbers can be compared with gt, gte, lt, lte
|
353 |
+
if operator in ["$gt", "$gte", "$lt", "$lte"]:
|
354 |
+
if not isinstance(operand, (int, float)):
|
355 |
+
raise ValueError(
|
356 |
+
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
|
357 |
+
)
|
358 |
+
if operator in ["$in", "$nin"]:
|
359 |
+
if not isinstance(operand, list):
|
360 |
+
raise ValueError(
|
361 |
+
f"Expected operand value to be an list for operator {operator}, got {operand}"
|
362 |
+
)
|
363 |
+
if operator not in [
|
364 |
+
"$gt",
|
365 |
+
"$gte",
|
366 |
+
"$lt",
|
367 |
+
"$lte",
|
368 |
+
"$ne",
|
369 |
+
"$eq",
|
370 |
+
"$in",
|
371 |
+
"$nin",
|
372 |
+
]:
|
373 |
+
raise ValueError(
|
374 |
+
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
|
375 |
+
f"got {operator}"
|
376 |
+
)
|
377 |
+
|
378 |
+
if not isinstance(operand, (str, int, float, list)):
|
379 |
+
raise ValueError(
|
380 |
+
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
|
381 |
+
)
|
382 |
+
if isinstance(operand, list) and (
|
383 |
+
len(operand) == 0
|
384 |
+
or not all(isinstance(x, type(operand[0])) for x in operand)
|
385 |
+
):
|
386 |
+
raise ValueError(
|
387 |
+
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
|
388 |
+
f"got {operand}"
|
389 |
+
)
|
390 |
+
return where
|
391 |
+
|
392 |
+
|
393 |
+
def validate_where_document(where_document: WhereDocument) -> WhereDocument:
|
394 |
+
"""
|
395 |
+
Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
|
396 |
+
a list of where_document expressions
|
397 |
+
"""
|
398 |
+
if not isinstance(where_document, dict):
|
399 |
+
raise ValueError(
|
400 |
+
f"Expected where document to be a dictionary, got {where_document}"
|
401 |
+
)
|
402 |
+
if len(where_document) != 1:
|
403 |
+
raise ValueError(
|
404 |
+
f"Expected where document to have exactly one operator, got {where_document}"
|
405 |
+
)
|
406 |
+
for operator, operand in where_document.items():
|
407 |
+
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
|
408 |
+
raise ValueError(
|
409 |
+
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
|
410 |
+
)
|
411 |
+
if operator == "$and" or operator == "$or":
|
412 |
+
if not isinstance(operand, list):
|
413 |
+
raise ValueError(
|
414 |
+
f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
|
415 |
+
)
|
416 |
+
if len(operand) <= 1:
|
417 |
+
raise ValueError(
|
418 |
+
f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
|
419 |
+
)
|
420 |
+
for where_document_expression in operand:
|
421 |
+
validate_where_document(where_document_expression)
|
422 |
+
# Value is a $contains operator
|
423 |
+
elif not isinstance(operand, str):
|
424 |
+
raise ValueError(
|
425 |
+
f"Expected where document operand value for operator $contains to be a str, got {operand}"
|
426 |
+
)
|
427 |
+
elif len(operand) == 0:
|
428 |
+
raise ValueError(
|
429 |
+
"Expected where document operand value for operator $contains to be a non-empty str"
|
430 |
+
)
|
431 |
+
return where_document
|
432 |
+
|
433 |
+
|
434 |
+
def validate_include(include: Include, allow_distances: bool) -> Include:
|
435 |
+
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
|
436 |
+
to control if distances is allowed"""
|
437 |
+
|
438 |
+
if not isinstance(include, list):
|
439 |
+
raise ValueError(f"Expected include to be a list, got {include}")
|
440 |
+
for item in include:
|
441 |
+
if not isinstance(item, str):
|
442 |
+
raise ValueError(f"Expected include item to be a str, got {item}")
|
443 |
+
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
|
444 |
+
if allow_distances:
|
445 |
+
allowed_values.append("distances")
|
446 |
+
if item not in allowed_values:
|
447 |
+
raise ValueError(
|
448 |
+
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
|
449 |
+
)
|
450 |
+
return include
|
451 |
+
|
452 |
+
|
453 |
+
def validate_n_results(n_results: int) -> int:
|
454 |
+
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
|
455 |
+
# Check Number of requested results
|
456 |
+
if not isinstance(n_results, int):
|
457 |
+
raise ValueError(
|
458 |
+
f"Expected requested number of results to be a int, got {n_results}"
|
459 |
+
)
|
460 |
+
if n_results <= 0:
|
461 |
+
raise TypeError(
|
462 |
+
f"Number of requested results {n_results}, cannot be negative, or zero."
|
463 |
+
)
|
464 |
+
return n_results
|
465 |
+
|
466 |
+
|
467 |
+
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
468 |
+
"""Validates embeddings to ensure it is a list of list of ints, or floats"""
|
469 |
+
if not isinstance(embeddings, list):
|
470 |
+
raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
|
471 |
+
if len(embeddings) == 0:
|
472 |
+
raise ValueError(
|
473 |
+
f"Expected embeddings to be a list with at least one item, got {embeddings}"
|
474 |
+
)
|
475 |
+
if not all([isinstance(e, list) for e in embeddings]):
|
476 |
+
raise ValueError(
|
477 |
+
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
|
478 |
+
)
|
479 |
+
for i,embedding in enumerate(embeddings):
|
480 |
+
if len(embedding) == 0:
|
481 |
+
raise ValueError(
|
482 |
+
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
|
483 |
+
)
|
484 |
+
if not all(
|
485 |
+
[
|
486 |
+
isinstance(value, (int, float)) and not isinstance(value, bool)
|
487 |
+
for value in embedding
|
488 |
+
]
|
489 |
+
):
|
490 |
+
raise ValueError(
|
491 |
+
f"Expected each value in the embedding to be a int or float, got {embeddings}"
|
492 |
+
)
|
493 |
+
return embeddings
|
494 |
+
|
495 |
+
|
496 |
+
def validate_batch(
|
497 |
+
batch: Tuple[
|
498 |
+
IDs,
|
499 |
+
Optional[Embeddings],
|
500 |
+
Optional[Metadatas],
|
501 |
+
Optional[Documents],
|
502 |
+
Optional[URIs],
|
503 |
+
],
|
504 |
+
limits: Dict[str, Any],
|
505 |
+
) -> None:
|
506 |
+
if len(batch[0]) > limits["max_batch_size"]:
|
507 |
+
raise ValueError(
|
508 |
+
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
|
509 |
+
)
|