badalsahani commited on
Commit
287a0bc
·
1 Parent(s): fbbc97b

feat: chroma initial deploy

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +10 -0
  2. .gitattributes +1 -35
  3. .github/ISSUE_TEMPLATE/bug_report.yaml +43 -0
  4. .github/ISSUE_TEMPLATE/config.yml +5 -0
  5. .github/ISSUE_TEMPLATE/feature_request.yaml +46 -0
  6. .github/ISSUE_TEMPLATE/installation_trouble.yaml +41 -0
  7. .github/actions/bandit-scan/Dockerfile +7 -0
  8. .github/actions/bandit-scan/action.yaml +26 -0
  9. .github/actions/bandit-scan/entrypoint.sh +13 -0
  10. .github/workflows/chroma-client-integration-test.yml +31 -0
  11. .github/workflows/chroma-cluster-test.yml +42 -0
  12. .github/workflows/chroma-coordinator-test.yaml +23 -0
  13. .github/workflows/chroma-integration-test.yml +40 -0
  14. .github/workflows/chroma-js-release.yml +42 -0
  15. .github/workflows/chroma-release-python-client.yml +58 -0
  16. .github/workflows/chroma-release.yml +179 -0
  17. .github/workflows/chroma-test.yml +65 -0
  18. .github/workflows/chroma-worker-test.yml +36 -0
  19. .github/workflows/pr-review-checklist.yml +37 -0
  20. .github/workflows/python-vuln.yaml +28 -0
  21. .gitignore +34 -0
  22. .pre-commit-config.yaml +36 -0
  23. .vscode/settings.json +131 -0
  24. Cargo.lock +0 -0
  25. Cargo.toml +5 -0
  26. DEVELOP.md +111 -0
  27. Dockerfile +39 -0
  28. LICENSE +201 -0
  29. README.md +106 -11
  30. RELEASE_PROCESS.md +22 -0
  31. Tiltfile +30 -0
  32. bandit.yaml +4 -0
  33. bin/cluster-test.sh +62 -0
  34. bin/docker_entrypoint.sh +15 -0
  35. bin/generate_cloudformation.py +198 -0
  36. bin/integration-test +75 -0
  37. bin/reset.sh +13 -0
  38. bin/templates/docker-compose.yml +21 -0
  39. bin/test-package.sh +24 -0
  40. bin/test-remote +16 -0
  41. bin/test.py +7 -0
  42. bin/version +8 -0
  43. bin/windows_upgrade_sqlite.py +20 -0
  44. chromadb/__init__.py +257 -0
  45. chromadb/api/__init__.py +596 -0
  46. chromadb/api/client.py +496 -0
  47. chromadb/api/fastapi.py +654 -0
  48. chromadb/api/models/Collection.py +633 -0
  49. chromadb/api/segment.py +914 -0
  50. chromadb/api/types.py +509 -0
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ venv
2
+ .conda
3
+ .git
4
+ examples
5
+ clients
6
+ .hypothesis
7
+ __pycache__
8
+ .vscode
9
+ *.egg-info
10
+ .pytest_cache
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *_pb2.py* linguist-generated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐛 Bug Report
2
+ description: File a bug report to help us improve Chroma
3
+ title: "[Bug]: "
4
+ labels: ["bug", "triage"]
5
+ # assignees:
6
+ # - octocat
7
+ body:
8
+ - type: markdown
9
+ attributes:
10
+ value: |
11
+ Thanks for taking the time to fill out this bug report!
12
+ - type: textarea
13
+ id: what-happened
14
+ attributes:
15
+ label: What happened?
16
+ description: Also tell us, what did you expect to happen?
17
+ placeholder: Tell us what you see!
18
+ # value: "A bug happened!"
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ id: versions
23
+ attributes:
24
+ label: Versions
25
+ description: Your Chroma, Python, and OS versions, as well as whatever else you think relevant. Check that you have [the latest Chroma](https://github.com/chroma-core/chroma/pkgs/container/chroma) as we are a fast moving pre v1.0 project.
26
+ placeholder: Chroma v0.3.22, Python 3.9.6, MacOS 12.5
27
+ # value: "A bug happened!"
28
+ validations:
29
+ required: true
30
+ - type: textarea
31
+ id: logs
32
+ attributes:
33
+ label: Relevant log output
34
+ description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
35
+ render: shell
36
+ # - type: checkboxes
37
+ # id: terms
38
+ # attributes:
39
+ # label: Code of Conduct
40
+ # description: By submitting this issue, you agree to follow our [Code of Conduct](https://example.com)
41
+ # options:
42
+ # - label: I agree to follow this project's Code of Conduct
43
+ # required: true
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ blank_issues_enabled: true
2
+ contact_links:
3
+ - name: 🤷🏻‍♀️ Questions
4
+ url: https://discord.com/invite/MMeYNTmh3x
5
+ about: Interact with the Chroma community here by asking for help, discussing and more!
.github/ISSUE_TEMPLATE/feature_request.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🚀 Feature request
2
+ description: Suggest an idea for Chroma
3
+ title: "[Feature Request]: "
4
+ labels: ["enhancement"]
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: |
9
+ Thanks for taking the time to request this feature!
10
+ - type: textarea
11
+ id: problem
12
+ attributes:
13
+ label: Describe the problem
14
+ description: Please provide a clear and concise description the problem this feature would solve. The more information you can provide here, the better.
15
+ placeholder: I prefer if...
16
+ validations:
17
+ required: true
18
+ - type: textarea
19
+ id: solution
20
+ attributes:
21
+ label: Describe the proposed solution
22
+ description: Please provide a clear and concise description of what you would like to happen.
23
+ placeholder: I would like to see...
24
+ validations:
25
+ required: true
26
+ - type: textarea
27
+ id: alternatives
28
+ attributes:
29
+ label: Alternatives considered
30
+ description: "Please provide a clear and concise description of any alternative solutions or features you've considered."
31
+ - type: dropdown
32
+ id: importance
33
+ attributes:
34
+ label: Importance
35
+ description: How important is this feature to you?
36
+ options:
37
+ - nice to have
38
+ - would make my life easier
39
+ - i cannot use Chroma without it
40
+ validations:
41
+ required: true
42
+ - type: textarea
43
+ id: additional-context
44
+ attributes:
45
+ label: Additional Information
46
+ description: Add any other context or screenshots about the feature request here.
.github/ISSUE_TEMPLATE/installation_trouble.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Installation Issue
2
+ description: Request for install help with Chroma
3
+ title: "[Install issue]: "
4
+ labels: ["installation trouble"]
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: |
9
+ Thanks for taking the time to fill out this issue report!
10
+ - type: textarea
11
+ id: what-happened
12
+ attributes:
13
+ label: What happened?
14
+ description: Also tell us, what did you expect to happen?
15
+ placeholder: Tell us what you see!
16
+ # value: "A bug happened!"
17
+ validations:
18
+ required: true
19
+ - type: textarea
20
+ id: versions
21
+ attributes:
22
+ label: Versions
23
+ description: We need your Chroma, Python, and OS versions, as well as whatever else you think relevant.
24
+ placeholder: Chroma v0.3.14, Python 3.9.6, MacOS 12.5
25
+ # value: "A bug happened!"
26
+ validations:
27
+ required: true
28
+ - type: textarea
29
+ id: logs
30
+ attributes:
31
+ label: Relevant log output
32
+ description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
33
+ render: shell
34
+ # - type: checkboxes
35
+ # id: terms
36
+ # attributes:
37
+ # label: Code of Conduct
38
+ # description: By submitting this issue, you agree to follow our [Code of Conduct](https://example.com)
39
+ # options:
40
+ # - label: I agree to follow this project's Code of Conduct
41
+ # required: true
.github/actions/bandit-scan/Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-alpine AS base-action
2
+
3
+ RUN pip3 install -U setuptools pip bandit
4
+
5
+ COPY entrypoint.sh /entrypoint.sh
6
+ RUN chmod +x /entrypoint.sh
7
+ ENTRYPOINT ["sh","/entrypoint.sh"]
.github/actions/bandit-scan/action.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 'Bandit Scan'
2
+ description: 'This action performs a security vulnerability scan of python code using bandit library.'
3
+ inputs:
4
+ bandit-config:
5
+ description: 'Bandit configuration file'
6
+ required: false
7
+ input-dir:
8
+ description: 'Directory to scan'
9
+ required: false
10
+ default: '.'
11
+ format:
12
+ description: 'Output format (txt, csv, json, xml, yaml). Default: json'
13
+ required: false
14
+ default: 'json'
15
+ output-file:
16
+ description: "The report file to produce. Make sure to align your format with the file extension to avoid confusion."
17
+ required: false
18
+ default: "bandit-scan.json"
19
+ runs:
20
+ using: 'docker'
21
+ image: 'Dockerfile'
22
+ args:
23
+ - ${{ inputs.format }}
24
+ - ${{ inputs.bandit-config }}
25
+ - ${{ inputs.input-dir }}
26
+ - ${{ inputs.output-file }}
.github/actions/bandit-scan/entrypoint.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ CFG="-c $2"
3
+ if [ -z "$1" ]; then
4
+ echo "No path to scan provided"
5
+ exit 1
6
+ fi
7
+
8
+ if [ -z "$2" ]; then
9
+ CFG = ""
10
+ fi
11
+
12
+ bandit -f "$1" ${CFG} -r "$3" -o "$4"
13
+ exit 0 #we want to ignore the exit code of bandit (for now)
.github/workflows/chroma-client-integration-test.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Client Integration Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - '**'
11
+ workflow_dispatch:
12
+
13
+ jobs:
14
+ test:
15
+ timeout-minutes: 90
16
+ strategy:
17
+ matrix:
18
+ python: ['3.8', '3.9', '3.10', '3.11']
19
+ platform: [ubuntu-latest, windows-latest]
20
+ runs-on: ${{ matrix.platform }}
21
+ steps:
22
+ - name: Checkout
23
+ uses: actions/checkout@v3
24
+ - name: Set up Python ${{ matrix.python }}
25
+ uses: actions/setup-python@v4
26
+ with:
27
+ python-version: ${{ matrix.python }}
28
+ - name: Install test dependencies
29
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
30
+ - name: Test
31
+ run: clients/python/integration-test.sh
.github/workflows/chroma-cluster-test.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Cluster Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - '**'
11
+ workflow_dispatch:
12
+
13
+ jobs:
14
+ test:
15
+ strategy:
16
+ matrix:
17
+ python: ['3.8']
18
+ platform: ['16core-64gb-ubuntu-latest']
19
+ testfile: ["chromadb/test/ingest/test_producer_consumer.py",
20
+ "chromadb/test/db/test_system.py",
21
+ "chromadb/test/segment/distributed/test_memberlist_provider.py",]
22
+ runs-on: ${{ matrix.platform }}
23
+ steps:
24
+ - name: Checkout
25
+ uses: actions/checkout@v3
26
+ - name: Set up Python ${{ matrix.python }}
27
+ uses: actions/setup-python@v4
28
+ with:
29
+ python-version: ${{ matrix.python }}
30
+ - name: Install test dependencies
31
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
32
+ - name: Start minikube
33
+ id: minikube
34
+ uses: medyagh/setup-minikube@latest
35
+ with:
36
+ minikube-version: latest
37
+ kubernetes-version: latest
38
+ driver: docker
39
+ addons: ingress, ingress-dns
40
+ start-args: '--profile chroma-test'
41
+ - name: Integration Test
42
+ run: bin/cluster-test.sh ${{ matrix.testfile }}
.github/workflows/chroma-coordinator-test.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Coordinator Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - '**'
11
+ workflow_dispatch:
12
+
13
+ jobs:
14
+ test:
15
+ strategy:
16
+ matrix:
17
+ platform: [ubuntu-latest]
18
+ runs-on: ${{ matrix.platform }}
19
+ steps:
20
+ - name: Checkout
21
+ uses: actions/checkout@v3
22
+ - name: Build and test
23
+ run: cd go/coordinator && make test
.github/workflows/chroma-integration-test.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Integration Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - team/hypothesis-tests
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - '**'
12
+ workflow_dispatch:
13
+
14
+ jobs:
15
+ test:
16
+ strategy:
17
+ matrix:
18
+ python: ['3.8']
19
+ platform: [ubuntu-latest, windows-latest]
20
+ testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore='chromadb/test/test_cli.py' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
21
+ "chromadb/test/property/test_add.py",
22
+ "chromadb/test/test_cli.py",
23
+ "chromadb/test/auth/test_simple_rbac_authz.py",
24
+ "chromadb/test/property/test_collections.py",
25
+ "chromadb/test/property/test_cross_version_persist.py",
26
+ "chromadb/test/property/test_embeddings.py",
27
+ "chromadb/test/property/test_filtering.py",
28
+ "chromadb/test/property/test_persist.py"]
29
+ runs-on: ${{ matrix.platform }}
30
+ steps:
31
+ - name: Checkout
32
+ uses: actions/checkout@v3
33
+ - name: Set up Python ${{ matrix.python }}
34
+ uses: actions/setup-python@v4
35
+ with:
36
+ python-version: ${{ matrix.python }}
37
+ - name: Install test dependencies
38
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
39
+ - name: Integration Test
40
+ run: bin/integration-test ${{ matrix.testfile }}
.github/workflows/chroma-js-release.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Release JS Client
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'js_release_*.*.*' # Match tags in the form js_release_X.Y.Z
7
+ - 'js_release_alpha_*.*.*' # Match tags in the form js_release_alpha_X.Y.Z
8
+
9
+ jobs:
10
+ build-and-release:
11
+ runs-on: ubuntu-latest
12
+ permissions: write-all
13
+ steps:
14
+ - name: Check if tag matches the pattern
15
+ run: |
16
+ if [[ "${{ github.ref }}" =~ ^refs/tags/js_release_alpha_[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
17
+ echo "Tag matches the pattern js_release_alpha_X.Y.Z"
18
+ echo "NPM_SCRIPT=release_alpha" >> "$GITHUB_ENV"
19
+ elif [[ "${{ github.ref }}" =~ ^refs/tags/js_release_[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
20
+ echo "Tag matches the pattern js_release_X.Y.Z"
21
+ echo "NPM_SCRIPT=release" >> "$GITHUB_ENV"
22
+ else
23
+ echo "Tag does not match the release tag pattern, exiting workflow"
24
+ exit 1
25
+ fi
26
+ - name: Checkout
27
+ uses: actions/checkout@v3
28
+ with:
29
+ fetch-depth: 0
30
+ - name: Set up JS
31
+ uses: actions/setup-node@v3
32
+ with:
33
+ node-version: '16.x'
34
+ registry-url: 'https://registry.npmjs.org'
35
+ - name: Install Client Dev Dependencies
36
+ run: npm install
37
+ working-directory: ./clients/js/
38
+ - name: npm Test & Publish
39
+ run: npm run db:run && PORT=8001 npm run $NPM_SCRIPT
40
+ working-directory: ./clients/js/
41
+ env:
42
+ NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
.github/workflows/chroma-release-python-client.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Release Python Client
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - '[0-9]+.[0-9]+.[0-9]+' # Match tags in the form X.Y.Z
7
+ branches:
8
+ - main
9
+ - hammad/thin_client
10
+
11
+ jobs:
12
+ check_tag:
13
+ runs-on: ubuntu-latest
14
+ outputs:
15
+ tag_matches: ${{ steps.check-tag.outputs.tag_matches }}
16
+ steps:
17
+ - name: Check Tag
18
+ id: check-tag
19
+ run: |
20
+ if [[ ${{ github.event.ref }} =~ ^refs/tags/[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
21
+ echo "tag_matches=true" >> $GITHUB_OUTPUT
22
+ fi
23
+ build-and-release:
24
+ runs-on: ubuntu-latest
25
+ needs: check_tag
26
+ if: needs.check_tag.outputs.tag_matches == 'true'
27
+ permissions: write-all
28
+ steps:
29
+ - name: Checkout
30
+ uses: actions/checkout@v3
31
+ with:
32
+ fetch-depth: 0
33
+ - name: Set up Python
34
+ uses: actions/setup-python@v4
35
+ with:
36
+ python-version: '3.10'
37
+ - name: Install Client Dev Dependencies
38
+ run: python -m pip install -r ./clients/python/requirements.txt && python -m pip install -r ./clients/python/requirements_dev.txt
39
+ - name: Build Client
40
+ run: ./clients/python/build_python_thin_client.sh
41
+ - name: Install setuptools_scm
42
+ run: python -m pip install setuptools_scm
43
+ - name: Get Release Version
44
+ id: version
45
+ run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT
46
+ - name: Get current date
47
+ id: builddate
48
+ run: echo "builddate=$(date +'%Y-%m-%dT%H:%M')" >> $GITHUB_OUTPUT
49
+ - name: Publish to Test PyPI
50
+ uses: pypa/gh-action-pypi-publish@release/v1
51
+ with:
52
+ password: ${{ secrets.TEST_PYPI_PYTHON_CLIENT_PUBLISH_KEY }}
53
+ repository-url: https://test.pypi.org/legacy/
54
+ - name: Publish to PyPI
55
+ if: startsWith(github.ref, 'refs/tags')
56
+ uses: pypa/gh-action-pypi-publish@release/v1
57
+ with:
58
+ password: ${{ secrets.PYPI_PYTHON_CLIENT_PUBLISH_KEY }}
.github/workflows/chroma-release.yml ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "*"
7
+ branches:
8
+ - main
9
+
10
+ env:
11
+ GHCR_IMAGE_NAME: "ghcr.io/chroma-core/chroma"
12
+ DOCKERHUB_IMAGE_NAME: "chromadb/chroma"
13
+ PLATFORMS: linux/amd64,linux/arm64 #linux/riscv64, linux/arm/v7
14
+
15
+ jobs:
16
+ check_tag:
17
+ runs-on: ubuntu-latest
18
+ outputs:
19
+ tag_matches: ${{ steps.check-tag.outputs.tag_matches }}
20
+ steps:
21
+ - name: Check Tag
22
+ id: check-tag
23
+ run: |
24
+ if [[ ${{ github.event.ref }} =~ ^refs/tags/[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
25
+ echo "tag_matches=true" >> $GITHUB_OUTPUT
26
+ fi
27
+ build-and-release:
28
+ runs-on: ubuntu-latest
29
+ needs: check_tag
30
+ permissions: write-all
31
+ steps:
32
+ - name: Checkout
33
+ uses: actions/checkout@v3
34
+ with:
35
+ fetch-depth: 0
36
+ # https://github.com/docker/setup-qemu-action - for multiplatform builds
37
+ - name: Set up QEMU
38
+ uses: docker/setup-qemu-action@v2
39
+ # https://github.com/docker/setup-buildx-action - for multiplatform builds
40
+ - name: Set up Docker Buildx
41
+ id: buildx
42
+ uses: docker/setup-buildx-action@v2
43
+ - name: Set up Python
44
+ uses: actions/setup-python@v4
45
+ with:
46
+ python-version: "3.10"
47
+ - name: Install Client Dev Dependencies
48
+ run: python -m pip install -r requirements_dev.txt
49
+ - name: Build Client
50
+ run: python -m build
51
+ - name: Test Client Package
52
+ run: bin/test-package.sh dist/*.tar.gz
53
+ - name: Log in to the Github Container registry
54
+ uses: docker/[email protected]
55
+ with:
56
+ registry: ghcr.io
57
+ username: ${{ github.actor }}
58
+ password: ${{ secrets.GITHUB_TOKEN }}
59
+ - name: Login to DockerHub
60
+ uses: docker/[email protected]
61
+ with:
62
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
63
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
64
+ - name: Install setuptools_scm
65
+ run: python -m pip install setuptools_scm
66
+ - name: Get Release Version
67
+ id: version
68
+ run: echo "version=$(python -m setuptools_scm)" >> $GITHUB_OUTPUT
69
+ - name: Build and push prerelease Docker image
70
+ if: "needs.check_tag.outputs.tag_matches != 'true'"
71
+ uses: docker/[email protected]
72
+ with:
73
+ context: .
74
+ platforms: ${{ env.PLATFORMS }}
75
+ push: true
76
+ tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}"
77
+ - name: Build and push release Docker image
78
+ if: "needs.check_tag.outputs.tag_matches == 'true'"
79
+ uses: docker/[email protected]
80
+ with:
81
+ context: .
82
+ platforms: ${{ env.PLATFORMS }}
83
+ push: true
84
+ tags: "${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }},${{ env.GHCR_IMAGE_NAME }}:latest,${{ env.DOCKERHUB_IMAGE_NAME }}:latest"
85
+ - name: Get current date
86
+ id: builddate
87
+ run: echo "builddate=$(date +'%Y-%m-%dT%H:%M')" >> $GITHUB_OUTPUT
88
+ - name: Publish to Test PyPI
89
+ uses: pypa/gh-action-pypi-publish@release/v1
90
+ with:
91
+ password: ${{ secrets.TEST_PYPI_API_TOKEN }}
92
+ repository_url: https://test.pypi.org/legacy/
93
+ - name: Publish to PyPI
94
+ if: "needs.check_tag.outputs.tag_matches == 'true'"
95
+ uses: pypa/gh-action-pypi-publish@release/v1
96
+ with:
97
+ password: ${{ secrets.PYPI_API_TOKEN }}
98
+ - name: Login to AWS
99
+ uses: aws-actions/configure-aws-credentials@v1
100
+ with:
101
+ role-to-assume: arn:aws:iam::369178033109:role/github-action-generate-cf-template
102
+ aws-region: us-east-1
103
+ - name: Generate CloudFormation template
104
+ id: generate-cf
105
+ if: "needs.check_tag.outputs.tag_matches == 'true'"
106
+ run: "pip install boto3 && python bin/generate_cloudformation.py"
107
+ - name: Release Tagged Version
108
+ uses: ncipollo/[email protected]
109
+ if: "needs.check_tag.outputs.tag_matches == 'true'"
110
+ with:
111
+ body: |
112
+ Version: `${{steps.version.outputs.version}}`
113
+ Git ref: `${{github.ref}}`
114
+ Build Date: `${{steps.builddate.outputs.builddate}}`
115
+ PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz`
116
+ Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
117
+ DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
118
+ artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz"
119
+ prerelease: true
120
+ generateReleaseNotes: true
121
+ - name: Update Tag
122
+ uses: richardsimko/[email protected]
123
+ if: "needs.check_tag.outputs.tag_matches != 'true'"
124
+ with:
125
+ tag_name: latest
126
+ env:
127
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
128
+ - name: Release Latest
129
+ uses: ncipollo/[email protected]
130
+ if: "needs.check_tag.outputs.tag_matches != 'true'"
131
+ with:
132
+ tag: "latest"
133
+ name: "Latest"
134
+ body: |
135
+ Version: `${{steps.version.outputs.version}}`
136
+ Git ref: `${{github.ref}}`
137
+ Build Date: `${{steps.builddate.outputs.builddate}}`
138
+ PIP Package: `chroma-${{steps.version.outputs.version}}.tar.gz`
139
+ Github Container Registry Image: `${{ env.GHCR_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
140
+ DockerHub Image: `${{ env.DOCKERHUB_IMAGE_NAME }}:${{ steps.version.outputs.version }}`
141
+ artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz"
142
+ allowUpdates: true
143
+ prerelease: true
144
+ - name: Trigger Hosted Chroma FE Release
145
+ uses: actions/github-script@v6
146
+ with:
147
+ github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
148
+ script: |
149
+ const result = await github.rest.actions.createWorkflowDispatch({
150
+ owner: 'chroma-core',
151
+ repo: 'hosted-chroma',
152
+ workflow_id: 'build-and-publish-frontend.yaml',
153
+ ref: 'main'
154
+ })
155
+ console.log(result)
156
+ - name: Trigger Hosted Chroma Coordinator Release
157
+ uses: actions/github-script@v6
158
+ with:
159
+ github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
160
+ script: |
161
+ const result = await github.rest.actions.createWorkflowDispatch({
162
+ owner: 'chroma-core',
163
+ repo: 'hosted-chroma',
164
+ workflow_id: 'build-and-deploy-coordinator.yaml',
165
+ ref: 'main'
166
+ })
167
+ console.log(result)
168
+ - name: Trigger Hosted Worker Release
169
+ uses: actions/github-script@v6
170
+ with:
171
+ github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
172
+ script: |
173
+ const result = await github.rest.actions.createWorkflowDispatch({
174
+ owner: 'chroma-core',
175
+ repo: 'hosted-chroma',
176
+ workflow_id: 'build-and-deploy-worker.yaml',
177
+ ref: 'main'
178
+ })
179
+ console.log(result)
.github/workflows/chroma-test.yml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - team/hypothesis-tests
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - '**'
12
+ workflow_dispatch:
13
+
14
+ jobs:
15
+ test:
16
+ timeout-minutes: 90
17
+ strategy:
18
+ matrix:
19
+ python: ['3.8', '3.9', '3.10', '3.11']
20
+ platform: [ubuntu-latest, windows-latest]
21
+ testfile: ["--ignore-glob 'chromadb/test/property/*' --ignore-glob 'chromadb/test/stress/*' --ignore='chromadb/test/auth/test_simple_rbac_authz.py'",
22
+ "chromadb/test/auth/test_simple_rbac_authz.py",
23
+ "chromadb/test/property/test_add.py",
24
+ "chromadb/test/property/test_collections.py",
25
+ "chromadb/test/property/test_cross_version_persist.py",
26
+ "chromadb/test/property/test_embeddings.py",
27
+ "chromadb/test/property/test_filtering.py",
28
+ "chromadb/test/property/test_persist.py"]
29
+ runs-on: ${{ matrix.platform }}
30
+ steps:
31
+ - name: Checkout
32
+ uses: actions/checkout@v3
33
+ - name: Set up Python ${{ matrix.python }}
34
+ uses: actions/setup-python@v4
35
+ with:
36
+ python-version: ${{ matrix.python }}
37
+ - name: Install test dependencies
38
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
39
+ - name: Upgrade SQLite
40
+ run: python bin/windows_upgrade_sqlite.py
41
+ if: runner.os == 'Windows'
42
+ - name: Test
43
+ run: python -m pytest ${{ matrix.testfile }}
44
+ stress-test:
45
+ timeout-minutes: 90
46
+ strategy:
47
+ matrix:
48
+ python: ['3.8']
49
+ platform: ['16core-64gb-ubuntu-latest', '16core-64gb-windows-latest']
50
+ testfile: ["'chromadb/test/stress/'"]
51
+ runs-on: ${{ matrix.platform }}
52
+ steps:
53
+ - name: Checkout
54
+ uses: actions/checkout@v3
55
+ - name: Set up Python ${{ matrix.python }}
56
+ uses: actions/setup-python@v4
57
+ with:
58
+ python-version: ${{ matrix.python }}
59
+ - name: Install test dependencies
60
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
61
+ - name: Upgrade SQLite
62
+ run: python bin/windows_upgrade_sqlite.py
63
+ if: runner.os == 'Windows'
64
+ - name: Test
65
+ run: python -m pytest ${{ matrix.testfile }}
.github/workflows/chroma-worker-test.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Chroma Worker Tests
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - '**'
11
+ workflow_dispatch:
12
+
13
+ jobs:
14
+ test:
15
+ strategy:
16
+ matrix:
17
+ platform: [ubuntu-latest]
18
+ runs-on: ${{ matrix.platform }}
19
+ steps:
20
+ - name: Checkout chroma-hnswlib
21
+ uses: actions/checkout@v3
22
+ with:
23
+ repository: chroma-core/hnswlib
24
+ path: hnswlib
25
+ - name: Checkout
26
+ uses: actions/checkout@v3
27
+ with:
28
+ path: chroma
29
+ - name: Install Protoc
30
+ uses: arduino/setup-protoc@v2
31
+ - name: Build
32
+ run: cargo build --verbose
33
+ working-directory: chroma
34
+ - name: Test
35
+ run: cargo test --verbose
36
+ working-directory: chroma
.github/workflows/pr-review-checklist.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PR Review Checklist
2
+
3
+ on:
4
+ pull_request_target:
5
+ types:
6
+ - opened
7
+
8
+ jobs:
9
+ PR-Comment:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: PR Comment
13
+ uses: actions/github-script@v2
14
+ with:
15
+ github-token: ${{secrets.GITHUB_TOKEN}}
16
+ script: |
17
+ github.issues.createComment({
18
+ issue_number: ${{ github.event.number }},
19
+ owner: context.repo.owner,
20
+ repo: context.repo.repo,
21
+ body: `# Reviewer Checklist
22
+ Please leverage this checklist to ensure your code review is thorough before approving
23
+ ## Testing, Bugs, Errors, Logs, Documentation
24
+ - [ ] Can you think of any use case in which the code does not behave as intended? Have they been tested?
25
+ - [ ] Can you think of any inputs or external events that could break the code? Is user input validated and safe? Have they been tested?
26
+ - [ ] If appropriate, are there adequate property based tests?
27
+ - [ ] If appropriate, are there adequate unit tests?
28
+ - [ ] Should any logging, debugging, tracing information be added or removed?
29
+ - [ ] Are error messages user-friendly?
30
+ - [ ] Have all documentation changes needed been made?
31
+ - [ ] Have all non-obvious changes been commented?
32
+ ## System Compatibility
33
+ - [ ] Are there any potential impacts on other parts of the system or backward compatibility?
34
+ - [ ] Does this change intersect with any items on our roadmap, and if so, is there a plan for fitting them together?
35
+ ## Quality
36
+ - [ ] Is this code of a unexpectedly high quality (Readability, Modularity, Intuitiveness)`
37
+ })
.github/workflows/python-vuln.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python Vulnerability Scan
2
+ on:
3
+ push:
4
+ branches:
5
+ - '*'
6
+ - '*/**'
7
+ paths:
8
+ - chromadb/**
9
+ - clients/python/**
10
+ workflow_dispatch:
11
+ jobs:
12
+ bandit-scan:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - name: Checkout
16
+ uses: actions/checkout@v3
17
+ - uses: ./.github/actions/bandit-scan/
18
+ with:
19
+ input-dir: '.'
20
+ format: 'json'
21
+ bandit-config: 'bandit.yaml'
22
+ output-file: 'bandit-report.json'
23
+ - name: Upload Bandit Report
24
+ uses: actions/upload-artifact@v3
25
+ with:
26
+ name: bandit-artifact
27
+ path: |
28
+ bandit-report.json
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore mac created DS_Store files
2
+ **/.DS_Store
3
+
4
+ **/__pycache__
5
+
6
+ go/coordinator/bin/
7
+ go/coordinator/**/testdata/
8
+
9
+ *.log
10
+
11
+ **/data__nogit
12
+
13
+ **/.ipynb_checkpoints
14
+
15
+ index_data
16
+
17
+ # Default configuration for persist_directory in chromadb/config.py
18
+ # Currently it's located in "./chroma/"
19
+ chroma/
20
+ chroma_test_data/
21
+ server.htpasswd
22
+
23
+ .venv
24
+ venv
25
+ .env
26
+ .chroma
27
+ *.egg-info
28
+ dist
29
+
30
+ .terraform/
31
+ .terraform.lock.hcl
32
+ terraform.tfstate
33
+ .hypothesis/
34
+ .idea
.pre-commit-config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: 'chromadb/proto/(chroma_pb2|coordinator_pb2)\.(py|pyi|py_grpc\.py)' # Generated files
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.5.0
5
+ hooks:
6
+ - id: trailing-whitespace
7
+ - id: mixed-line-ending
8
+ - id: end-of-file-fixer
9
+ - id: requirements-txt-fixer
10
+ - id: check-yaml
11
+ args: ["--allow-multiple-documents"]
12
+ - id: check-xml
13
+ - id: check-merge-conflict
14
+ - id: check-case-conflict
15
+ - id: check-docstring-first
16
+
17
+ - repo: https://github.com/psf/black
18
+ # https://github.com/psf/black/issues/2493
19
+ rev: "refs/tags/23.3.0:refs/tags/23.3.0"
20
+ hooks:
21
+ - id: black
22
+
23
+ - repo: https://github.com/PyCQA/flake8
24
+ rev: 6.1.0
25
+ hooks:
26
+ - id: flake8
27
+ args:
28
+ - "--extend-ignore=E203,E501,E503"
29
+ - "--max-line-length=88"
30
+
31
+ - repo: https://github.com/pre-commit/mirrors-mypy
32
+ rev: "v1.2.0"
33
+ hooks:
34
+ - id: mypy
35
+ args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract, --config-file=./pyproject.toml]
36
+ additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf", "kubernetes"]
.vscode/settings.json ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "git.ignoreLimitWarning": true,
3
+ "editor.rulers": [
4
+ 88
5
+ ],
6
+ "editor.formatOnSave": true,
7
+ "python.formatting.provider": "black",
8
+ "files.exclude": {
9
+ "**/__pycache__": true,
10
+ "**/.ipynb_checkpoints": true,
11
+ "**/.pytest_cache": true,
12
+ "**/chroma.egg-info": true
13
+ },
14
+ "python.analysis.typeCheckingMode": "basic",
15
+ "python.linting.flake8Enabled": true,
16
+ "python.linting.enabled": true,
17
+ "python.linting.flake8Args": [
18
+ "--extend-ignore=E203",
19
+ "--extend-ignore=E501",
20
+ "--extend-ignore=E503",
21
+ "--max-line-length=88"
22
+ ],
23
+ "python.testing.pytestArgs": [
24
+ "."
25
+ ],
26
+ "python.testing.unittestEnabled": false,
27
+ "python.testing.pytestEnabled": true,
28
+ "editor.formatOnPaste": true,
29
+ "python.linting.mypyEnabled": true,
30
+ "python.linting.mypyCategorySeverity.note": "Error",
31
+ "python.linting.mypyArgs": [
32
+ "--follow-imports=silent",
33
+ "--ignore-missing-imports",
34
+ "--show-column-numbers",
35
+ "--no-pretty",
36
+ "--strict",
37
+ "--disable-error-code=type-abstract"
38
+ ],
39
+ "protoc": {
40
+ "options": [
41
+ "--proto_path=idl/",
42
+ ]
43
+ },
44
+ "rust-analyzer.cargo.buildScripts.enable": true,
45
+ "files.associations": {
46
+ "fstream": "cpp",
47
+ "iosfwd": "cpp",
48
+ "__hash_table": "cpp",
49
+ "__locale": "cpp",
50
+ "atomic": "cpp",
51
+ "deque": "cpp",
52
+ "filesystem": "cpp",
53
+ "future": "cpp",
54
+ "locale": "cpp",
55
+ "random": "cpp",
56
+ "regex": "cpp",
57
+ "string": "cpp",
58
+ "tuple": "cpp",
59
+ "type_traits": "cpp",
60
+ "unordered_map": "cpp",
61
+ "valarray": "cpp",
62
+ "variant": "cpp",
63
+ "vector": "cpp",
64
+ "__string": "cpp",
65
+ "istream": "cpp",
66
+ "memory": "cpp",
67
+ "optional": "cpp",
68
+ "string_view": "cpp",
69
+ "__bit_reference": "cpp",
70
+ "__bits": "cpp",
71
+ "__config": "cpp",
72
+ "__debug": "cpp",
73
+ "__errc": "cpp",
74
+ "__mutex_base": "cpp",
75
+ "__node_handle": "cpp",
76
+ "__nullptr": "cpp",
77
+ "__split_buffer": "cpp",
78
+ "__threading_support": "cpp",
79
+ "__tree": "cpp",
80
+ "__tuple": "cpp",
81
+ "array": "cpp",
82
+ "bit": "cpp",
83
+ "bitset": "cpp",
84
+ "cctype": "cpp",
85
+ "charconv": "cpp",
86
+ "chrono": "cpp",
87
+ "cinttypes": "cpp",
88
+ "clocale": "cpp",
89
+ "cmath": "cpp",
90
+ "compare": "cpp",
91
+ "complex": "cpp",
92
+ "concepts": "cpp",
93
+ "condition_variable": "cpp",
94
+ "csignal": "cpp",
95
+ "cstdarg": "cpp",
96
+ "cstddef": "cpp",
97
+ "cstdint": "cpp",
98
+ "cstdio": "cpp",
99
+ "cstdlib": "cpp",
100
+ "cstring": "cpp",
101
+ "ctime": "cpp",
102
+ "cwchar": "cpp",
103
+ "cwctype": "cpp",
104
+ "exception": "cpp",
105
+ "format": "cpp",
106
+ "forward_list": "cpp",
107
+ "initializer_list": "cpp",
108
+ "iomanip": "cpp",
109
+ "ios": "cpp",
110
+ "iostream": "cpp",
111
+ "limits": "cpp",
112
+ "list": "cpp",
113
+ "map": "cpp",
114
+ "mutex": "cpp",
115
+ "new": "cpp",
116
+ "numeric": "cpp",
117
+ "ostream": "cpp",
118
+ "queue": "cpp",
119
+ "ratio": "cpp",
120
+ "set": "cpp",
121
+ "sstream": "cpp",
122
+ "stack": "cpp",
123
+ "stdexcept": "cpp",
124
+ "streambuf": "cpp",
125
+ "system_error": "cpp",
126
+ "typeindex": "cpp",
127
+ "typeinfo": "cpp",
128
+ "unordered_set": "cpp",
129
+ "algorithm": "cpp"
130
+ },
131
+ }
Cargo.lock ADDED
The diff for this file is too large to render. See raw diff
 
Cargo.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [workspace]
2
+
3
+ members = [
4
+ "rust/worker/"
5
+ ]
DEVELOP.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Development Instructions
2
+
3
+ This project uses the testing, build and release standards specified
4
+ by the PyPA organization and documented at
5
+ https://packaging.python.org.
6
+
7
+ ## Setup
8
+
9
+ Because of the dependencies it relies on (like `pytorch`), this project does not support Python version >3.10.0.
10
+
11
+ Set up a virtual environment and install the project's requirements
12
+ and dev requirements:
13
+
14
+ ```
15
+ python3 -m venv venv # Only need to do this once
16
+ source venv/bin/activate # Do this each time you use a new shell for the project
17
+ pip install -r requirements.txt
18
+ pip install -r requirements_dev.txt
19
+ pre-commit install # install the precommit hooks
20
+ ```
21
+
22
+ You can also install `chromadb` the `pypi` package locally and in editable mode with `pip install -e .`.
23
+
24
+ ## Running Chroma
25
+
26
+ Chroma can be run via 3 modes:
27
+ 1. Standalone and in-memory:
28
+ ```python
29
+ import chromadb
30
+ api = chromadb.Client()
31
+ print(api.heartbeat())
32
+ ```
33
+
34
+ 2. Standalone and in-memory with persistence:
35
+
36
+ This by default saves your db and your indexes to a `.chroma` directory and can also load from them.
37
+ ```python
38
+ import chromadb
39
+ api = chromadb.PersistentClient(path="/path/to/persist/directory")
40
+ print(api.heartbeat())
41
+ ```
42
+
43
+
44
+ 3. With a persistent backend and a small frontend client
45
+
46
+ Run `chroma run --path /chroma_db_path`
47
+ ```python
48
+ import chromadb
49
+ api = chromadb.HttpClient(host="localhost", port="8000")
50
+
51
+ print(api.heartbeat())
52
+ ```
53
+ ## Local dev setup for distributed chroma
54
+ We use tilt for providing local dev setup. Tilt is an open source project
55
+ ##### Requirement
56
+ - Docker
57
+ - Local Kubernetes cluster (Recommended: [OrbStack](https://orbstack.dev/) for mac, [Kind](https://kind.sigs.k8s.io/) for linux)
58
+ - [Tilt](https://docs.tilt.dev/)
59
+
60
+ For starting the distributed Chroma in the workspace, use `tilt up`. It will create all the required resources and build the necessary Docker image in the current kubectl context.
61
+ Once done, it will expose Chroma on port 8000. You can also visit the Tilt dashboard UI at http://localhost:10350/. To clean and remove all the resources created by Tilt, use `tilt down`.
62
+
63
+ ## Testing
64
+
65
+ Unit tests are in the `/chromadb/test` directory.
66
+
67
+ To run unit tests using your current environment, run `pytest`.
68
+
69
+ ## Manual Build
70
+
71
+ To manually build a distribution, run `python -m build`.
72
+
73
+ The project's source and wheel distributions will be placed in the `dist` directory.
74
+
75
+ ## Manual Release
76
+
77
+ Not yet implemented.
78
+
79
+ ## Versioning
80
+
81
+ This project uses PyPA's `setuptools_scm` module to determine the
82
+ version number for build artifacts, meaning the version number is
83
+ derived from Git rather than hardcoded in the repository. For full
84
+ details, see the
85
+ [documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/).
86
+
87
+ In brief, version numbers are generated as follows:
88
+
89
+ - If the current git head is tagged, the version number is exactly the
90
+ tag (e.g, `0.0.1`).
91
+ - If the the current git head is a clean checkout, but is not tagged,
92
+ the version number is a patch version increment of the most recent
93
+ tag, plus `devN` where N is the number of commits since the most
94
+ recent tag. For example, if there have been 5 commits since the
95
+ `0.0.1` tag, the generated version will be `0.0.2-dev5`.
96
+ - If the current head is not a clean checkout, a `+dirty` local
97
+ version will be appended to the version number. For example,
98
+ `0.0.2-dev5+dirty`.
99
+
100
+ At any point, you can manually run `python -m setuptools_scm` to see
101
+ what version would be assigned given your current state.
102
+
103
+ ## Continuous Integration
104
+
105
+ This project uses Github Actions to run unit tests automatically upon
106
+ every commit to the main branch. See the documentation for Github
107
+ Actions and the flow definitions in `.github/workflows` for details.
108
+
109
+ ## Continuous Delivery
110
+
111
+ Not yet implemented.
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim-bookworm AS builder
2
+ ARG REBUILD_HNSWLIB
3
+ RUN apt-get update --fix-missing && apt-get install -y --fix-missing \
4
+ build-essential \
5
+ gcc \
6
+ g++ \
7
+ cmake \
8
+ autoconf && \
9
+ rm -rf /var/lib/apt/lists/* && \
10
+ mkdir /install
11
+
12
+ WORKDIR /install
13
+
14
+ COPY ./requirements.txt requirements.txt
15
+
16
+ RUN pip install --no-cache-dir --upgrade --prefix="/install" -r requirements.txt
17
+ RUN if [ "$REBUILD_HNSWLIB" = "true" ]; then pip install --no-binary :all: --force-reinstall --no-cache-dir --prefix="/install" chroma-hnswlib; fi
18
+
19
+ FROM python:3.11-slim-bookworm AS final
20
+
21
+ RUN mkdir /chroma
22
+ WORKDIR /chroma
23
+
24
+ COPY --from=builder /install /usr/local
25
+ COPY ./bin/docker_entrypoint.sh /docker_entrypoint.sh
26
+ COPY ./ /chroma
27
+
28
+ RUN chmod +x /docker_entrypoint.sh
29
+
30
+ ENV CHROMA_HOST_ADDR "0.0.0.0"
31
+ ENV CHROMA_HOST_PORT 7860
32
+ ENV CHROMA_WORKERS 1
33
+ ENV CHROMA_LOG_CONFIG "chromadb/log_config.yml"
34
+ ENV CHROMA_TIMEOUT_KEEP_ALIVE 30
35
+
36
+ EXPOSE 7860
37
+
38
+ ENTRYPOINT ["/docker_entrypoint.sh"]
39
+ CMD [ "--workers ${CHROMA_WORKERS} --host ${CHROMA_HOST_ADDR} --port ${CHROMA_HOST_PORT} --proxy-headers --log-config ${CHROMA_LOG_CONFIG} --timeout-keep-alive ${CHROMA_TIMEOUT_KEEP_ALIVE}"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,11 +1,106 @@
1
- ---
2
- title: Chroma
3
- emoji: 🏢
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <a href="https://trychroma.com"><img src="https://user-images.githubusercontent.com/891664/227103090-6624bf7d-9524-4e05-9d2c-c28d5d451481.png" alt="Chroma logo"></a>
3
+ </p>
4
+
5
+ <p align="center">
6
+ <b>Chroma - the open-source embedding database</b>. <br />
7
+ The fastest way to build Python or JavaScript LLM apps with memory!
8
+ </p>
9
+
10
+ <p align="center">
11
+ <a href="https://discord.gg/MMeYNTmh3x" target="_blank">
12
+ <img src="https://img.shields.io/discord/1073293645303795742" alt="Discord">
13
+ </a> |
14
+ <a href="https://github.com/chroma-core/chroma/blob/master/LICENSE" target="_blank">
15
+ <img src="https://img.shields.io/static/v1?label=license&message=Apache 2.0&color=white" alt="License">
16
+ </a> |
17
+ <a href="https://docs.trychroma.com/" target="_blank">
18
+ Docs
19
+ </a> |
20
+ <a href="https://www.trychroma.com/" target="_blank">
21
+ Homepage
22
+ </a>
23
+ </p>
24
+
25
+
26
+ <p align="center">
27
+ <a href="https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml" target="_blank">
28
+ <img src="https://github.com/chroma-core/chroma/actions/workflows/chroma-integration-test.yml/badge.svg?branch=main" alt="Integration Tests">
29
+ </a> |
30
+ <a href="https://github.com/chroma-core/chroma/actions/workflows/chroma-test.yml" target="_blank">
31
+ <img src="https://github.com/chroma-core/chroma/actions/workflows/chroma-test.yml/badge.svg?branch=main" alt="Tests">
32
+ </a>
33
+ </p>
34
+
35
+ ```bash
36
+ pip install chromadb # python client
37
+ # for javascript, npm install chromadb!
38
+ # for client-server mode, chroma run --path /chroma_db_path
39
+ ```
40
+
41
+ The core API is only 4 functions (run our [💡 Google Colab](https://colab.research.google.com/drive/1QEzFyqnoFxq7LUGyP1vzR4iLt9PpCDXv?usp=sharing) or [Replit template](https://replit.com/@swyx/BasicChromaStarter?v=1)):
42
+
43
+ ```python
44
+ import chromadb
45
+ # setup Chroma in-memory, for easy prototyping. Can add persistence easily!
46
+ client = chromadb.Client()
47
+
48
+ # Create collection. get_collection, get_or_create_collection, delete_collection also available!
49
+ collection = client.create_collection("all-my-documents")
50
+
51
+ # Add docs to the collection. Can also update and delete. Row-based API coming soon!
52
+ collection.add(
53
+ documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
54
+ metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these!
55
+ ids=["doc1", "doc2"], # unique for each doc
56
+ )
57
+
58
+ # Query/search 2 most similar results. You can also .get by id
59
+ results = collection.query(
60
+ query_texts=["This is a query document"],
61
+ n_results=2,
62
+ # where={"metadata_field": "is_equal_to_this"}, # optional filter
63
+ # where_document={"$contains":"search_string"} # optional filter
64
+ )
65
+ ```
66
+
67
+ ## Features
68
+ - __Simple__: Fully-typed, fully-tested, fully-documented == happiness
69
+ - __Integrations__: [`🦜️🔗 LangChain`](https://blog.langchain.dev/langchain-chroma/) (python and js), [`🦙 LlamaIndex`](https://twitter.com/atroyn/status/1628557389762007040) and more soon
70
+ - __Dev, Test, Prod__: the same API that runs in your python notebook, scales to your cluster
71
+ - __Feature-rich__: Queries, filtering, density estimation and more
72
+ - __Free & Open Source__: Apache 2.0 Licensed
73
+
74
+ ## Use case: ChatGPT for ______
75
+
76
+ For example, the `"Chat your data"` use case:
77
+ 1. Add documents to your database. You can pass in your own embeddings, embedding function, or let Chroma embed them for you.
78
+ 2. Query relevant documents with natural language.
79
+ 3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis.
80
+
81
+ ## Embeddings?
82
+
83
+ What are embeddings?
84
+
85
+ - [Read the guide from OpenAI](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
86
+ - __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model.
87
+ - __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find.
88
+ - __Technical__: An embedding is the latent-space position of a document at a layer of a deep neural network. For models trained specifically to embed data, this is the last layer.
89
+ - __A small example__: If you search your photos for "famous bridge in San Francisco". By embedding this query and comparing it to the embeddings of your photos and their metadata - it should return photos of the Golden Gate Bridge.
90
+
91
+ Embeddings databases (also known as **vector databases**) store embeddings and allow you to search by nearest neighbors rather than by substrings like a traditional database. By default, Chroma uses [Sentence Transformers](https://docs.trychroma.com/embeddings#sentence-transformers) to embed for you but you can also use OpenAI embeddings, Cohere (multilingual) embeddings, or your own.
92
+
93
+ ## Get involved
94
+
95
+ Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project.
96
+ - [Join the conversation on Discord](https://discord.gg/MMeYNTmh3x) - `#contributing` channel
97
+ - [Review the 🛣️ Roadmap and contribute your ideas](https://docs.trychroma.com/roadmap)
98
+ - [Grab an issue and open a PR](https://github.com/chroma-core/chroma/issues) - [`Good first issue tag`](https://github.com/chroma-core/chroma/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
99
+ - [Read our contributing guide](https://docs.trychroma.com/contributing)
100
+
101
+ **Release Cadence**
102
+ We currently release new tagged versions of the `pypi` and `npm` packages on Mondays. Hotfixes go out at any time during the week.
103
+
104
+ ## License
105
+
106
+ [Apache 2.0](./LICENSE)
RELEASE_PROCESS.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Release Process
2
+
3
+ This guide covers how to release chroma to PyPi
4
+
5
+ #### Increase the version number
6
+ 1. Create a new PR for the release that upgrades the version in code. Name it `release/A.B.C` In [this file](https://github.com/chroma-core/chroma/blob/main/chromadb/__init__.py) update the __ version __.
7
+ ```
8
+ __version__ = "A.B.C"
9
+ ```
10
+ 2. Add the "release" label to this PR
11
+ 3. Once the PR is merged, tag your commit SHA with the release version
12
+ ```
13
+ git tag A.B.C <SHA>
14
+ ```
15
+ 4. You need to then wait for the github action for main for `chroma release` and `chroma client release` to go green. Not doing this will result in a race condition.
16
+
17
+ #### Perform the release
18
+ 1. Push your tag to origin to create the release
19
+ ```
20
+ git push origin A.B.C
21
+ ```
22
+ 2. This will trigger a Github action which performs the release
Tiltfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ docker_build('coordinator',
2
+ context='.',
3
+ dockerfile='./go/coordinator/Dockerfile'
4
+ )
5
+
6
+ docker_build('server',
7
+ context='.',
8
+ dockerfile='./Dockerfile',
9
+ )
10
+
11
+ docker_build('worker',
12
+ context='.',
13
+ dockerfile='./rust/worker/Dockerfile'
14
+ )
15
+
16
+
17
+ k8s_yaml(['k8s/dev/setup.yaml'])
18
+ k8s_resource(
19
+ objects=['chroma:Namespace', 'memberlist-reader:ClusterRole', 'memberlist-reader:ClusterRoleBinding', 'pod-list-role:Role', 'pod-list-role-binding:RoleBinding', 'memberlists.chroma.cluster:CustomResourceDefinition','worker-memberlist:MemberList'],
20
+ new_name='k8s_setup',
21
+ labels=["infrastructure"]
22
+ )
23
+ k8s_yaml(['k8s/dev/pulsar.yaml'])
24
+ k8s_resource('pulsar', resource_deps=['k8s_setup'], labels=["infrastructure"])
25
+ k8s_yaml(['k8s/dev/server.yaml'])
26
+ k8s_resource('server', resource_deps=['k8s_setup'],labels=["chroma"], port_forwards=8000 )
27
+ k8s_yaml(['k8s/dev/coordinator.yaml'])
28
+ k8s_resource('coordinator', resource_deps=['pulsar', 'server'], labels=["chroma"])
29
+ k8s_yaml(['k8s/dev/worker.yaml'])
30
+ k8s_resource('worker', resource_deps=['coordinator'],labels=["chroma"])
bandit.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # FILE: bandit.yaml
2
+ exclude_dirs: [ 'chromadb/test', 'bin', 'build', 'build', '.git', '.venv', 'venv', 'env','.github','examples','clients/js','.vscode' ]
3
+ tests: [ ]
4
+ skips: [ ]
bin/cluster-test.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -e
4
+
5
+ function cleanup {
6
+ # Restore the previous kube context
7
+ kubectl config use-context $PREV_CHROMA_KUBE_CONTEXT
8
+ # Kill the tunnel process
9
+ kill $TUNNEL_PID
10
+ minikube delete -p chroma-test
11
+ }
12
+
13
+ trap cleanup EXIT
14
+
15
+ # Save the current kube context into a variable
16
+ export PREV_CHROMA_KUBE_CONTEXT=$(kubectl config current-context)
17
+
18
+ # Create a new minikube cluster for the test
19
+ minikube start -p chroma-test
20
+
21
+ # Add the ingress addon to the cluster
22
+ minikube addons enable ingress -p chroma-test
23
+ minikube addons enable ingress-dns -p chroma-test
24
+
25
+ # Setup docker to build inside the minikube cluster and build the image
26
+ eval $(minikube -p chroma-test docker-env)
27
+ docker build -t server:latest -f Dockerfile .
28
+ docker build -t chroma-coordinator:latest -f go/coordinator/Dockerfile .
29
+ docker build -t worker -f rust/worker/Dockerfile . --build-arg CHROMA_KUBERNETES_INTEGRATION=1
30
+
31
+ # Apply the kubernetes manifests
32
+ kubectl apply -f k8s/deployment
33
+ kubectl apply -f k8s/crd
34
+ kubectl apply -f k8s/cr
35
+ kubectl apply -f k8s/test
36
+
37
+ # Wait for the pods in the chroma namespace to be ready
38
+ kubectl wait --namespace chroma --for=condition=Ready pods --all --timeout=400s
39
+
40
+ # Run mini kube tunnel in the background to expose the service
41
+ minikube tunnel -c true -p chroma-test &
42
+ TUNNEL_PID=$!
43
+
44
+ # Wait for the tunnel to be ready. There isn't an easy way to check if the tunnel is ready. So we just wait for 10 seconds
45
+ sleep 10
46
+
47
+ export CHROMA_CLUSTER_TEST_ONLY=1
48
+ export CHROMA_SERVER_HOST=$(kubectl get svc server -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
49
+ export PULSAR_BROKER_URL=$(kubectl get svc pulsar-lb -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
50
+ export CHROMA_COORDINATOR_HOST=$(kubectl get svc coordinator-lb -n chroma -o=jsonpath='{.status.loadBalancer.ingress[0].ip}')
51
+ export CHROMA_SERVER_GRPC_PORT="50051"
52
+
53
+ echo "Chroma Server is running at port $CHROMA_SERVER_HOST"
54
+ echo "Pulsar Broker is running at port $PULSAR_BROKER_URL"
55
+ echo "Chroma Coordinator is running at port $CHROMA_COORDINATOR_HOST"
56
+
57
+ echo testing: python -m pytest "$@"
58
+ python -m pytest "$@"
59
+
60
+ export CHROMA_KUBERNETES_INTEGRATION=1
61
+ cd go/coordinator
62
+ go test -timeout 30s -run ^TestNodeWatcher$ github.com/chroma/chroma-coordinator/internal/memberlist_manager
bin/docker_entrypoint.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ export IS_PERSISTENT=1
5
+ export CHROMA_SERVER_NOFILE=65535
6
+ args="$@"
7
+
8
+ if [[ $args =~ ^uvicorn.* ]]; then
9
+ echo "Starting server with args: $(eval echo "$args")"
10
+ echo -e "\033[31mWARNING: Please remove 'uvicorn chromadb.app:app' from your command line arguments. This is now handled by the entrypoint script."
11
+ exec $(eval echo "$args")
12
+ else
13
+ echo "Starting 'uvicorn chromadb.app:app' with args: $(eval echo "$args")"
14
+ exec uvicorn chromadb.app:app $(eval echo "$args")
15
+ fi
bin/generate_cloudformation.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import json
3
+ import subprocess
4
+ import os
5
+ import re
6
+
7
+
8
+ def b64text(txt):
9
+ """Generate Base 64 encoded CF json for a multiline string, subbing in values where appropriate"""
10
+ lines = []
11
+ for line in txt.splitlines(True):
12
+ if "${" in line:
13
+ lines.append({"Fn::Sub": line})
14
+ else:
15
+ lines.append(line)
16
+ return {"Fn::Base64": {"Fn::Join": ["", lines]}}
17
+
18
+
19
+ path = os.path.dirname(os.path.realpath(__file__))
20
+ version = subprocess.check_output(f"{path}/version").decode("ascii").strip()
21
+
22
+ with open(f"{path}/templates/docker-compose.yml") as f:
23
+ docker_compose_file = str(f.read())
24
+
25
+
26
+ cloud_config_script = """
27
+ #cloud-config
28
+ cloud_final_modules:
29
+ - [scripts-user, always]
30
+ """
31
+
32
+ cloud_init_script = f"""
33
+ #!/bin/bash
34
+ amazon-linux-extras install docker
35
+ usermod -a -G docker ec2-user
36
+ curl -L https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/bin/docker-compose
37
+ chmod +x /usr/local/bin/docker-compose
38
+ ln -s /usr/local/bin/docker-compose /usr/bin/docker-compose
39
+ systemctl enable docker
40
+ systemctl start docker
41
+
42
+ cat << EOF > /home/ec2-user/docker-compose.yml
43
+ {docker_compose_file}
44
+ EOF
45
+
46
+ mkdir /home/ec2-user/config
47
+
48
+ docker-compose -f /home/ec2-user/docker-compose.yml up -d
49
+ """
50
+
51
+ userdata = f"""Content-Type: multipart/mixed; boundary="//"
52
+ MIME-Version: 1.0
53
+
54
+ --//
55
+ Content-Type: text/cloud-config; charset="us-ascii"
56
+ MIME-Version: 1.0
57
+ Content-Transfer-Encoding: 7bit
58
+ Content-Disposition: attachment; filename="cloud-config.txt"
59
+
60
+ {cloud_config_script}
61
+
62
+ --//
63
+ Content-Type: text/x-shellscript; charset="us-ascii"
64
+ MIME-Version: 1.0
65
+ Content-Transfer-Encoding: 7bit
66
+ Content-Disposition: attachment; filename="userdata.txt"
67
+
68
+ {cloud_init_script}
69
+ --//--
70
+ """
71
+
72
+ cf = {
73
+ "AWSTemplateFormatVersion": "2010-09-09",
74
+ "Description": "Create a stack that runs Chroma hosted on a single instance",
75
+ "Parameters": {
76
+ "KeyName": {
77
+ "Description": "Name of an existing EC2 KeyPair to enable SSH access to the instance",
78
+ "Type": "String",
79
+ "ConstraintDescription": "If present, must be the name of an existing EC2 KeyPair.",
80
+ "Default": "",
81
+ },
82
+ "InstanceType": {
83
+ "Description": "EC2 instance type",
84
+ "Type": "String",
85
+ "Default": "t3.small",
86
+ },
87
+ "ChromaVersion": {
88
+ "Description": "Chroma version to install",
89
+ "Type": "String",
90
+ "Default": version,
91
+ },
92
+ },
93
+ "Conditions": {
94
+ "HasKeyName": {"Fn::Not": [{"Fn::Equals": [{"Ref": "KeyName"}, ""]}]},
95
+ },
96
+ "Resources": {
97
+ "ChromaInstance": {
98
+ "Type": "AWS::EC2::Instance",
99
+ "Properties": {
100
+ "ImageId": {
101
+ "Fn::FindInMap": ["Region2AMI", {"Ref": "AWS::Region"}, "AMI"]
102
+ },
103
+ "InstanceType": {"Ref": "InstanceType"},
104
+ "UserData": b64text(userdata),
105
+ "SecurityGroupIds": [{"Ref": "ChromaInstanceSecurityGroup"}],
106
+ "KeyName": {
107
+ "Fn::If": [
108
+ "HasKeyName",
109
+ {"Ref": "KeyName"},
110
+ {"Ref": "AWS::NoValue"},
111
+ ]
112
+ },
113
+ "BlockDeviceMappings": [
114
+ {
115
+ "DeviceName": {
116
+ "Fn::FindInMap": [
117
+ "Region2AMI",
118
+ {"Ref": "AWS::Region"},
119
+ "RootDeviceName",
120
+ ]
121
+ },
122
+ "Ebs": {"VolumeSize": 24},
123
+ }
124
+ ],
125
+ },
126
+ },
127
+ "ChromaInstanceSecurityGroup": {
128
+ "Type": "AWS::EC2::SecurityGroup",
129
+ "Properties": {
130
+ "GroupDescription": "Chroma Instance Security Group",
131
+ "SecurityGroupIngress": [
132
+ {
133
+ "IpProtocol": "tcp",
134
+ "FromPort": "22",
135
+ "ToPort": "22",
136
+ "CidrIp": "0.0.0.0/0",
137
+ },
138
+ {
139
+ "IpProtocol": "tcp",
140
+ "FromPort": "8000",
141
+ "ToPort": "8000",
142
+ "CidrIp": "0.0.0.0/0",
143
+ },
144
+ ],
145
+ },
146
+ },
147
+ },
148
+ "Outputs": {
149
+ "ServerIp": {
150
+ "Description": "IP address of the Chroma server",
151
+ "Value": {"Fn::GetAtt": ["ChromaInstance", "PublicIp"]},
152
+ }
153
+ },
154
+ "Mappings": {"Region2AMI": {}},
155
+ }
156
+
157
+ # Populate the Region2AMI mappings
158
+ regions = boto3.client("ec2", region_name="us-east-1").describe_regions()["Regions"]
159
+ for region in regions:
160
+ region_name = region["RegionName"]
161
+ ami_result = boto3.client("ec2", region_name=region_name).describe_images(
162
+ Owners=["137112412989"],
163
+ Filters=[
164
+ {"Name": "name", "Values": ["amzn2-ami-kernel-5.10-hvm-*-x86_64-gp2"]},
165
+ {"Name": "root-device-type", "Values": ["ebs"]},
166
+ {"Name": "virtualization-type", "Values": ["hvm"]},
167
+ ],
168
+ )
169
+ img = ami_result["Images"][0]
170
+ ami_id = img["ImageId"]
171
+ root_device_name = img["BlockDeviceMappings"][0]["DeviceName"]
172
+ cf["Mappings"]["Region2AMI"][region_name] = {
173
+ "AMI": ami_id,
174
+ "RootDeviceName": root_device_name,
175
+ }
176
+
177
+
178
+ # Write the CF json to a file
179
+ json.dump(cf, open("/tmp/chroma.cf.json", "w"), indent=4)
180
+
181
+ # upload to S3
182
+ s3 = boto3.client("s3", region_name="us-east-1")
183
+ s3.upload_file(
184
+ "/tmp/chroma.cf.json",
185
+ "public.trychroma.com",
186
+ f"cloudformation/{version}/chroma.cf.json",
187
+ )
188
+
189
+ # Upload to s3 under /latest version only if this is a release
190
+ pattern = re.compile(r"^\d+\.\d+\.\d+$")
191
+ if pattern.match(version):
192
+ s3.upload_file(
193
+ "/tmp/chroma.cf.json",
194
+ "public.trychroma.com",
195
+ "cloudformation/latest/chroma.cf.json",
196
+ )
197
+ else:
198
+ print(f"Version {version} is not a 3-part semver, not uploading to /latest")
bin/integration-test ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -e
4
+
5
+ export CHROMA_PORT=8000
6
+
7
+ function cleanup {
8
+ docker compose -f docker-compose.test.yml down --rmi local --volumes
9
+ rm server.htpasswd .chroma_env
10
+ }
11
+
12
+ function setup_auth {
13
+ local auth_type="$1"
14
+ case "$auth_type" in
15
+ basic)
16
+ docker run --rm --entrypoint htpasswd httpd:2 -Bbn admin admin > server.htpasswd
17
+ cat <<EOF > .chroma_env
18
+ CHROMA_SERVER_AUTH_CREDENTIALS_FILE="/chroma/server.htpasswd"
19
+ CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider"
20
+ CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.basic.BasicAuthServerProvider"
21
+ EOF
22
+ ;;
23
+ token)
24
+ cat <<EOF > .chroma_env
25
+ CHROMA_SERVER_AUTH_CREDENTIALS="test-token"
26
+ CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="AUTHORIZATION"
27
+ CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
28
+ CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider"
29
+ EOF
30
+ ;;
31
+ xtoken)
32
+ cat <<EOF > .chroma_env
33
+ CHROMA_SERVER_AUTH_CREDENTIALS="test-token"
34
+ CHROMA_SERVER_AUTH_TOKEN_TRANSPORT_HEADER="X_CHROMA_TOKEN"
35
+ CHROMA_SERVER_AUTH_CREDENTIALS_PROVIDER="chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"
36
+ CHROMA_SERVER_AUTH_PROVIDER="chromadb.auth.token.TokenAuthServerProvider"
37
+ EOF
38
+ ;;
39
+ *)
40
+ echo "Unknown auth type: $auth_type"
41
+ exit 1
42
+ ;;
43
+ esac
44
+ }
45
+
46
+ trap cleanup EXIT
47
+
48
+ docker compose -f docker-compose.test.yml up --build -d
49
+
50
+ export CHROMA_INTEGRATION_TEST_ONLY=1
51
+ export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI
52
+ export CHROMA_SERVER_HOST=localhost
53
+ export CHROMA_SERVER_HTTP_PORT=8000
54
+ export CHROMA_SERVER_NOFILE=65535
55
+
56
+ echo testing: python -m pytest "$@"
57
+ python -m pytest "$@"
58
+
59
+ cd clients/js
60
+
61
+ # moved off of yarn to npm to fix issues with jackspeak/cliui/string-width versions #1314
62
+ npm install
63
+ npm run test:run
64
+
65
+ docker compose down
66
+ cd ../..
67
+ for auth_type in basic token xtoken; do
68
+ echo "Testing $auth_type auth"
69
+ setup_auth "$auth_type"
70
+ cd clients/js
71
+ docker compose --env-file ../../.chroma_env -f ../../docker-compose.test-auth.yml up --build -d
72
+ yarn test:run-auth-"$auth_type"
73
+ cd ../..
74
+ docker compose down
75
+ done
bin/reset.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ eval $(minikube -p chroma-test docker-env)
4
+
5
+ docker build -t chroma-coordinator:latest -f go/coordinator/Dockerfile .
6
+
7
+ kubectl delete deployment coordinator -n chroma
8
+
9
+ # Apply the kubernetes manifests
10
+ kubectl apply -f k8s/deployment
11
+ kubectl apply -f k8s/crd
12
+ kubectl apply -f k8s/cr
13
+ kubectl apply -f k8s/test
bin/templates/docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.9'
2
+
3
+ networks:
4
+ net:
5
+ driver: bridge
6
+
7
+ services:
8
+ server:
9
+ image: ghcr.io/chroma-core/chroma:${ChromaVersion}
10
+ volumes:
11
+ - index_data:/index_data
12
+ ports:
13
+ - 8000:8000
14
+ networks:
15
+ - net
16
+
17
+ volumes:
18
+ index_data:
19
+ driver: local
20
+ backups:
21
+ driver: local
bin/test-package.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Verify PIP tarball
4
+ tarball=$(readlink -f $1)
5
+ if [ -f "$tarball" ]; then
6
+ echo "Testing PIP package from tarball: $tarball"
7
+ else
8
+ echo "Could not find PIP package: $tarball"
9
+ fi
10
+
11
+ # Create temporary project dir
12
+ dir=$(mktemp -d)
13
+
14
+ echo "Building python project dir at $dir ..."
15
+
16
+ cd $dir
17
+
18
+ python3 -m venv venv
19
+
20
+ source venv/bin/activate
21
+
22
+ pip install $tarball
23
+
24
+ python -c "import chromadb; api = chromadb.Client(); print(api.heartbeat())"
bin/test-remote ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -e
4
+
5
+ # Assert first argument is present
6
+ if [ -z "$1" ]; then
7
+ echo "Usage: bin/test-remote <remote-host>"
8
+ exit 1
9
+ fi
10
+
11
+ export CHROMA_INTEGRATION_TEST_ONLY=1
12
+ export CHROMA_SERVER_HOST=$1
13
+ export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI
14
+ export CHROMA_SERVER_HTTP_PORT=8000
15
+
16
+ python -m pytest
bin/test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Sanity check script to ensure that the Chroma client can connect
2
+ # and is capable of recieving data.
3
+ import chromadb
4
+
5
+ # run in in-memory mode
6
+ chroma_api = chromadb.Client()
7
+ print(chroma_api.heartbeat())
bin/version ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ export VERSION=`python -m setuptools_scm`
3
+
4
+ if [[ -n `git status --porcelain` ]]; then
5
+ VERSION=$VERSION-dirty
6
+ fi
7
+
8
+ echo $VERSION
bin/windows_upgrade_sqlite.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import zipfile
3
+ import io
4
+ import os
5
+ import sys
6
+ import shutil
7
+
8
+ # Used by Github Action runners to upgrade sqlite version to 3.42.0
9
+ DLL_URL = "https://www.sqlite.org/2023/sqlite-dll-win64-x64-3420000.zip"
10
+
11
+ if __name__ == "__main__":
12
+ # Download and extract the DLL
13
+ r = requests.get(DLL_URL)
14
+ z = zipfile.ZipFile(io.BytesIO(r.content))
15
+ z.extractall(".")
16
+ # Print current Python path
17
+ exec_path = os.path.dirname(sys.executable)
18
+ dlls_path = os.path.join(exec_path, "DLLs")
19
+ # Copy the DLL to the Python DLLs folder
20
+ shutil.copy("sqlite3.dll", dlls_path)
chromadb/__init__.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import logging
3
+ from chromadb.api.client import Client as ClientCreator
4
+ from chromadb.api.client import AdminClient as AdminClientCreator
5
+ from chromadb.auth.token import TokenTransportHeader
6
+ import chromadb.config
7
+ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
8
+ from chromadb.api import AdminAPI, ClientAPI
9
+ from chromadb.api.models.Collection import Collection
10
+ from chromadb.api.types import (
11
+ CollectionMetadata,
12
+ Documents,
13
+ EmbeddingFunction,
14
+ Embeddings,
15
+ IDs,
16
+ Include,
17
+ Metadata,
18
+ Where,
19
+ QueryResult,
20
+ GetResult,
21
+ WhereDocument,
22
+ UpdateCollectionMetadata,
23
+ )
24
+
25
+ # Re-export types from chromadb.types
26
+ __all__ = [
27
+ "Collection",
28
+ "Metadata",
29
+ "Where",
30
+ "WhereDocument",
31
+ "Documents",
32
+ "IDs",
33
+ "Embeddings",
34
+ "EmbeddingFunction",
35
+ "Include",
36
+ "CollectionMetadata",
37
+ "UpdateCollectionMetadata",
38
+ "QueryResult",
39
+ "GetResult",
40
+ ]
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ __settings = Settings()
45
+
46
+ __version__ = "0.4.22"
47
+
48
+ # Workaround to deal with Colab's old sqlite3 version
49
+ try:
50
+ import google.colab # noqa: F401
51
+
52
+ IN_COLAB = True
53
+ except ImportError:
54
+ IN_COLAB = False
55
+
56
+ is_client = False
57
+ try:
58
+ from chromadb.is_thin_client import is_thin_client
59
+
60
+ is_client = is_thin_client
61
+ except ImportError:
62
+ is_client = False
63
+
64
+ if not is_client:
65
+ import sqlite3
66
+
67
+ if sqlite3.sqlite_version_info < (3, 35, 0):
68
+ if IN_COLAB:
69
+ # In Colab, hotswap to pysqlite-binary if it's too old
70
+ import subprocess
71
+ import sys
72
+
73
+ subprocess.check_call(
74
+ [sys.executable, "-m", "pip", "install", "pysqlite3-binary"]
75
+ )
76
+ __import__("pysqlite3")
77
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
78
+ else:
79
+ raise RuntimeError(
80
+ "\033[91mYour system has an unsupported version of sqlite3. Chroma \
81
+ requires sqlite3 >= 3.35.0.\033[0m\n"
82
+ "\033[94mPlease visit \
83
+ https://docs.trychroma.com/troubleshooting#sqlite to learn how \
84
+ to upgrade.\033[0m"
85
+ )
86
+
87
+
88
+ def configure(**kwargs) -> None: # type: ignore
89
+ """Override Chroma's default settings, environment variables or .env files"""
90
+ global __settings
91
+ __settings = chromadb.config.Settings(**kwargs)
92
+
93
+
94
+ def get_settings() -> Settings:
95
+ return __settings
96
+
97
+
98
+ def EphemeralClient(
99
+ settings: Optional[Settings] = None,
100
+ tenant: str = DEFAULT_TENANT,
101
+ database: str = DEFAULT_DATABASE,
102
+ ) -> ClientAPI:
103
+ """
104
+ Creates an in-memory instance of Chroma. This is useful for testing and
105
+ development, but not recommended for production use.
106
+
107
+ Args:
108
+ tenant: The tenant to use for this client. Defaults to the default tenant.
109
+ database: The database to use for this client. Defaults to the default database.
110
+ """
111
+ if settings is None:
112
+ settings = Settings()
113
+ settings.is_persistent = False
114
+
115
+ return ClientCreator(settings=settings, tenant=tenant, database=database)
116
+
117
+
118
+ def PersistentClient(
119
+ path: str = "./chroma",
120
+ settings: Optional[Settings] = None,
121
+ tenant: str = DEFAULT_TENANT,
122
+ database: str = DEFAULT_DATABASE,
123
+ ) -> ClientAPI:
124
+ """
125
+ Creates a persistent instance of Chroma that saves to disk. This is useful for
126
+ testing and development, but not recommended for production use.
127
+
128
+ Args:
129
+ path: The directory to save Chroma's data to. Defaults to "./chroma".
130
+ tenant: The tenant to use for this client. Defaults to the default tenant.
131
+ database: The database to use for this client. Defaults to the default database.
132
+ """
133
+ if settings is None:
134
+ settings = Settings()
135
+ settings.persist_directory = path
136
+ settings.is_persistent = True
137
+
138
+ return ClientCreator(tenant=tenant, database=database, settings=settings)
139
+
140
+
141
+ def HttpClient(
142
+ host: str = "localhost",
143
+ port: str = "8000",
144
+ ssl: bool = False,
145
+ headers: Optional[Dict[str, str]] = None,
146
+ settings: Optional[Settings] = None,
147
+ tenant: str = DEFAULT_TENANT,
148
+ database: str = DEFAULT_DATABASE,
149
+ ) -> ClientAPI:
150
+ """
151
+ Creates a client that connects to a remote Chroma server. This supports
152
+ many clients connecting to the same server, and is the recommended way to
153
+ use Chroma in production.
154
+
155
+ Args:
156
+ host: The hostname of the Chroma server. Defaults to "localhost".
157
+ port: The port of the Chroma server. Defaults to "8000".
158
+ ssl: Whether to use SSL to connect to the Chroma server. Defaults to False.
159
+ headers: A dictionary of headers to send to the Chroma server. Defaults to {}.
160
+ settings: A dictionary of settings to communicate with the chroma server.
161
+ tenant: The tenant to use for this client. Defaults to the default tenant.
162
+ database: The database to use for this client. Defaults to the default database.
163
+ """
164
+
165
+ if settings is None:
166
+ settings = Settings()
167
+
168
+ settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
169
+ if settings.chroma_server_host and settings.chroma_server_host != host:
170
+ raise ValueError(
171
+ f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]"
172
+ )
173
+ settings.chroma_server_host = host
174
+ if settings.chroma_server_http_port and settings.chroma_server_http_port != port:
175
+ raise ValueError(
176
+ f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]"
177
+ )
178
+ settings.chroma_server_http_port = port
179
+ settings.chroma_server_ssl_enabled = ssl
180
+ settings.chroma_server_headers = headers
181
+
182
+ return ClientCreator(tenant=tenant, database=database, settings=settings)
183
+
184
+
185
+ def CloudClient(
186
+ tenant: str,
187
+ database: str,
188
+ api_key: Optional[str] = None,
189
+ settings: Optional[Settings] = None,
190
+ *, # Following arguments are keyword-only, intended for testing only.
191
+ cloud_host: str = "api.trychroma.com",
192
+ cloud_port: str = "8000",
193
+ enable_ssl: bool = True,
194
+ ) -> ClientAPI:
195
+ """
196
+ Creates a client to connect to a tennant and database on the Chroma cloud.
197
+
198
+ Args:
199
+ tenant: The tenant to use for this client.
200
+ database: The database to use for this client.
201
+ api_key: The api key to use for this client.
202
+ """
203
+
204
+ # If no API key is provided, try to load it from the environment variable
205
+ if api_key is None:
206
+ import os
207
+
208
+ api_key = os.environ.get("CHROMA_API_KEY")
209
+
210
+ # If the API key is still not provided, prompt the user
211
+ if api_key is None:
212
+ print(
213
+ "\033[93mDon't have an API key?\033[0m Get one at https://app.trychroma.com"
214
+ )
215
+ api_key = input("Please enter your Chroma API key: ")
216
+
217
+ if settings is None:
218
+ settings = Settings()
219
+
220
+ settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
221
+ settings.chroma_server_host = cloud_host
222
+ settings.chroma_server_http_port = cloud_port
223
+ # Always use SSL for cloud
224
+ settings.chroma_server_ssl_enabled = enable_ssl
225
+
226
+ settings.chroma_client_auth_provider = "chromadb.auth.token.TokenAuthClientProvider"
227
+ settings.chroma_client_auth_credentials = api_key
228
+ settings.chroma_client_auth_token_transport_header = (
229
+ TokenTransportHeader.X_CHROMA_TOKEN.name
230
+ )
231
+
232
+ return ClientCreator(tenant=tenant, database=database, settings=settings)
233
+
234
+
235
+ def Client(
236
+ settings: Settings = __settings,
237
+ tenant: str = DEFAULT_TENANT,
238
+ database: str = DEFAULT_DATABASE,
239
+ ) -> ClientAPI:
240
+ """
241
+ Return a running chroma.API instance
242
+
243
+ tenant: The tenant to use for this client. Defaults to the default tenant.
244
+ database: The database to use for this client. Defaults to the default database.
245
+
246
+ """
247
+
248
+ return ClientCreator(tenant=tenant, database=database, settings=settings)
249
+
250
+
251
+ def AdminClient(settings: Settings = Settings()) -> AdminAPI:
252
+ """
253
+
254
+ Creates an admin client that can be used to create tenants and databases.
255
+
256
+ """
257
+ return AdminClientCreator(settings=settings)
chromadb/api/__init__.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Sequence, Optional
3
+ from uuid import UUID
4
+
5
+ from overrides import override
6
+ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
7
+ from chromadb.api.models.Collection import Collection
8
+ from chromadb.api.types import (
9
+ CollectionMetadata,
10
+ Documents,
11
+ Embeddable,
12
+ EmbeddingFunction,
13
+ DataLoader,
14
+ Embeddings,
15
+ IDs,
16
+ Include,
17
+ Loadable,
18
+ Metadatas,
19
+ URIs,
20
+ Where,
21
+ QueryResult,
22
+ GetResult,
23
+ WhereDocument,
24
+ )
25
+ from chromadb.config import Component, Settings
26
+ from chromadb.types import Database, Tenant
27
+ import chromadb.utils.embedding_functions as ef
28
+
29
+
30
+ class BaseAPI(ABC):
31
+ @abstractmethod
32
+ def heartbeat(self) -> int:
33
+ """Get the current time in nanoseconds since epoch.
34
+ Used to check if the server is alive.
35
+
36
+ Returns:
37
+ int: The current time in nanoseconds since epoch
38
+
39
+ """
40
+ pass
41
+
42
+ #
43
+ # COLLECTION METHODS
44
+ #
45
+
46
+ @abstractmethod
47
+ def list_collections(
48
+ self,
49
+ limit: Optional[int] = None,
50
+ offset: Optional[int] = None,
51
+ ) -> Sequence[Collection]:
52
+ """List all collections.
53
+ Args:
54
+ limit: The maximum number of entries to return. Defaults to None.
55
+ offset: The number of entries to skip before returning. Defaults to None.
56
+
57
+ Returns:
58
+ Sequence[Collection]: A list of collections
59
+
60
+ Examples:
61
+ ```python
62
+ client.list_collections()
63
+ # [collection(name="my_collection", metadata={})]
64
+ ```
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def count_collections(self) -> int:
70
+ """Count the number of collections.
71
+
72
+ Returns:
73
+ int: The number of collections.
74
+
75
+ Examples:
76
+ ```python
77
+ client.count_collections()
78
+ # 1
79
+ ```
80
+ """
81
+ pass
82
+
83
+ @abstractmethod
84
+ def create_collection(
85
+ self,
86
+ name: str,
87
+ metadata: Optional[CollectionMetadata] = None,
88
+ embedding_function: Optional[
89
+ EmbeddingFunction[Embeddable]
90
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
91
+ data_loader: Optional[DataLoader[Loadable]] = None,
92
+ get_or_create: bool = False,
93
+ ) -> Collection:
94
+ """Create a new collection with the given name and metadata.
95
+ Args:
96
+ name: The name of the collection to create.
97
+ metadata: Optional metadata to associate with the collection.
98
+ embedding_function: Optional function to use to embed documents.
99
+ Uses the default embedding function if not provided.
100
+ get_or_create: If True, return the existing collection if it exists.
101
+ data_loader: Optional function to use to load records (documents, images, etc.)
102
+
103
+ Returns:
104
+ Collection: The newly created collection.
105
+
106
+ Raises:
107
+ ValueError: If the collection already exists and get_or_create is False.
108
+ ValueError: If the collection name is invalid.
109
+
110
+ Examples:
111
+ ```python
112
+ client.create_collection("my_collection")
113
+ # collection(name="my_collection", metadata={})
114
+
115
+ client.create_collection("my_collection", metadata={"foo": "bar"})
116
+ # collection(name="my_collection", metadata={"foo": "bar"})
117
+ ```
118
+ """
119
+ pass
120
+
121
+ @abstractmethod
122
+ def get_collection(
123
+ self,
124
+ name: str,
125
+ id: Optional[UUID] = None,
126
+ embedding_function: Optional[
127
+ EmbeddingFunction[Embeddable]
128
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
129
+ data_loader: Optional[DataLoader[Loadable]] = None,
130
+ ) -> Collection:
131
+ """Get a collection with the given name.
132
+ Args:
133
+ id: The UUID of the collection to get. Id and Name are simultaneously used for lookup if provided.
134
+ name: The name of the collection to get
135
+ embedding_function: Optional function to use to embed documents.
136
+ Uses the default embedding function if not provided.
137
+ data_loader: Optional function to use to load records (documents, images, etc.)
138
+
139
+ Returns:
140
+ Collection: The collection
141
+
142
+ Raises:
143
+ ValueError: If the collection does not exist
144
+
145
+ Examples:
146
+ ```python
147
+ client.get_collection("my_collection")
148
+ # collection(name="my_collection", metadata={})
149
+ ```
150
+ """
151
+ pass
152
+
153
+ @abstractmethod
154
+ def get_or_create_collection(
155
+ self,
156
+ name: str,
157
+ metadata: Optional[CollectionMetadata] = None,
158
+ embedding_function: Optional[
159
+ EmbeddingFunction[Embeddable]
160
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
161
+ data_loader: Optional[DataLoader[Loadable]] = None,
162
+ ) -> Collection:
163
+ """Get or create a collection with the given name and metadata.
164
+ Args:
165
+ name: The name of the collection to get or create
166
+ metadata: Optional metadata to associate with the collection. If
167
+ the collection alredy exists, the metadata will be updated if
168
+ provided and not None. If the collection does not exist, the
169
+ new collection will be created with the provided metadata.
170
+ embedding_function: Optional function to use to embed documents
171
+ data_loader: Optional function to use to load records (documents, images, etc.)
172
+
173
+ Returns:
174
+ The collection
175
+
176
+ Examples:
177
+ ```python
178
+ client.get_or_create_collection("my_collection")
179
+ # collection(name="my_collection", metadata={})
180
+ ```
181
+ """
182
+ pass
183
+
184
+ def _modify(
185
+ self,
186
+ id: UUID,
187
+ new_name: Optional[str] = None,
188
+ new_metadata: Optional[CollectionMetadata] = None,
189
+ ) -> None:
190
+ """[Internal] Modify a collection by UUID. Can update the name and/or metadata.
191
+
192
+ Args:
193
+ id: The internal UUID of the collection to modify.
194
+ new_name: The new name of the collection.
195
+ If None, the existing name will remain. Defaults to None.
196
+ new_metadata: The new metadata to associate with the collection.
197
+ Defaults to None.
198
+ """
199
+ pass
200
+
201
+ @abstractmethod
202
+ def delete_collection(
203
+ self,
204
+ name: str,
205
+ ) -> None:
206
+ """Delete a collection with the given name.
207
+ Args:
208
+ name: The name of the collection to delete.
209
+
210
+ Raises:
211
+ ValueError: If the collection does not exist.
212
+
213
+ Examples:
214
+ ```python
215
+ client.delete_collection("my_collection")
216
+ ```
217
+ """
218
+ pass
219
+
220
+ #
221
+ # ITEM METHODS
222
+ #
223
+
224
+ @abstractmethod
225
+ def _add(
226
+ self,
227
+ ids: IDs,
228
+ collection_id: UUID,
229
+ embeddings: Embeddings,
230
+ metadatas: Optional[Metadatas] = None,
231
+ documents: Optional[Documents] = None,
232
+ uris: Optional[URIs] = None,
233
+ ) -> bool:
234
+ """[Internal] Add embeddings to a collection specified by UUID.
235
+ If (some) ids already exist, only the new embeddings will be added.
236
+
237
+ Args:
238
+ ids: The ids to associate with the embeddings.
239
+ collection_id: The UUID of the collection to add the embeddings to.
240
+ embedding: The sequence of embeddings to add.
241
+ metadata: The metadata to associate with the embeddings. Defaults to None.
242
+ documents: The documents to associate with the embeddings. Defaults to None.
243
+ uris: URIs of data sources for each embedding. Defaults to None.
244
+
245
+ Returns:
246
+ True if the embeddings were added successfully.
247
+ """
248
+ pass
249
+
250
+ @abstractmethod
251
+ def _update(
252
+ self,
253
+ collection_id: UUID,
254
+ ids: IDs,
255
+ embeddings: Optional[Embeddings] = None,
256
+ metadatas: Optional[Metadatas] = None,
257
+ documents: Optional[Documents] = None,
258
+ uris: Optional[URIs] = None,
259
+ ) -> bool:
260
+ """[Internal] Update entries in a collection specified by UUID.
261
+
262
+ Args:
263
+ collection_id: The UUID of the collection to update the embeddings in.
264
+ ids: The IDs of the entries to update.
265
+ embeddings: The sequence of embeddings to update. Defaults to None.
266
+ metadatas: The metadata to associate with the embeddings. Defaults to None.
267
+ documents: The documents to associate with the embeddings. Defaults to None.
268
+ uris: URIs of data sources for each embedding. Defaults to None.
269
+ Returns:
270
+ True if the embeddings were updated successfully.
271
+ """
272
+ pass
273
+
274
+ @abstractmethod
275
+ def _upsert(
276
+ self,
277
+ collection_id: UUID,
278
+ ids: IDs,
279
+ embeddings: Embeddings,
280
+ metadatas: Optional[Metadatas] = None,
281
+ documents: Optional[Documents] = None,
282
+ uris: Optional[URIs] = None,
283
+ ) -> bool:
284
+ """[Internal] Add or update entries in the a collection specified by UUID.
285
+ If an entry with the same id already exists, it will be updated,
286
+ otherwise it will be added.
287
+
288
+ Args:
289
+ collection_id: The collection to add the embeddings to
290
+ ids: The ids to associate with the embeddings. Defaults to None.
291
+ embeddings: The sequence of embeddings to add
292
+ metadatas: The metadata to associate with the embeddings. Defaults to None.
293
+ documents: The documents to associate with the embeddings. Defaults to None.
294
+ uris: URIs of data sources for each embedding. Defaults to None.
295
+ """
296
+ pass
297
+
298
+ @abstractmethod
299
+ def _count(self, collection_id: UUID) -> int:
300
+ """[Internal] Returns the number of entries in a collection specified by UUID.
301
+
302
+ Args:
303
+ collection_id: The UUID of the collection to count the embeddings in.
304
+
305
+ Returns:
306
+ int: The number of embeddings in the collection
307
+
308
+ """
309
+ pass
310
+
311
+ @abstractmethod
312
+ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
313
+ """[Internal] Returns the first n entries in a collection specified by UUID.
314
+
315
+ Args:
316
+ collection_id: The UUID of the collection to peek into.
317
+ n: The number of entries to peek. Defaults to 10.
318
+
319
+ Returns:
320
+ GetResult: The first n entries in the collection.
321
+
322
+ """
323
+
324
+ pass
325
+
326
+ @abstractmethod
327
+ def _get(
328
+ self,
329
+ collection_id: UUID,
330
+ ids: Optional[IDs] = None,
331
+ where: Optional[Where] = {},
332
+ sort: Optional[str] = None,
333
+ limit: Optional[int] = None,
334
+ offset: Optional[int] = None,
335
+ page: Optional[int] = None,
336
+ page_size: Optional[int] = None,
337
+ where_document: Optional[WhereDocument] = {},
338
+ include: Include = ["embeddings", "metadatas", "documents"],
339
+ ) -> GetResult:
340
+ """[Internal] Returns entries from a collection specified by UUID.
341
+
342
+ Args:
343
+ ids: The IDs of the entries to get. Defaults to None.
344
+ where: Conditional filtering on metadata. Defaults to {}.
345
+ sort: The column to sort the entries by. Defaults to None.
346
+ limit: The maximum number of entries to return. Defaults to None.
347
+ offset: The number of entries to skip before returning. Defaults to None.
348
+ page: The page number to return. Defaults to None.
349
+ page_size: The number of entries to return per page. Defaults to None.
350
+ where_document: Conditional filtering on documents. Defaults to {}.
351
+ include: The fields to include in the response.
352
+ Defaults to ["embeddings", "metadatas", "documents"].
353
+ Returns:
354
+ GetResult: The entries in the collection that match the query.
355
+
356
+ """
357
+ pass
358
+
359
+ @abstractmethod
360
+ def _delete(
361
+ self,
362
+ collection_id: UUID,
363
+ ids: Optional[IDs],
364
+ where: Optional[Where] = {},
365
+ where_document: Optional[WhereDocument] = {},
366
+ ) -> IDs:
367
+ """[Internal] Deletes entries from a collection specified by UUID.
368
+
369
+ Args:
370
+ collection_id: The UUID of the collection to delete the entries from.
371
+ ids: The IDs of the entries to delete. Defaults to None.
372
+ where: Conditional filtering on metadata. Defaults to {}.
373
+ where_document: Conditional filtering on documents. Defaults to {}.
374
+
375
+ Returns:
376
+ IDs: The list of IDs of the entries that were deleted.
377
+ """
378
+ pass
379
+
380
+ @abstractmethod
381
+ def _query(
382
+ self,
383
+ collection_id: UUID,
384
+ query_embeddings: Embeddings,
385
+ n_results: int = 10,
386
+ where: Where = {},
387
+ where_document: WhereDocument = {},
388
+ include: Include = ["embeddings", "metadatas", "documents", "distances"],
389
+ ) -> QueryResult:
390
+ """[Internal] Performs a nearest neighbors query on a collection specified by UUID.
391
+
392
+ Args:
393
+ collection_id: The UUID of the collection to query.
394
+ query_embeddings: The embeddings to use as the query.
395
+ n_results: The number of results to return. Defaults to 10.
396
+ where: Conditional filtering on metadata. Defaults to {}.
397
+ where_document: Conditional filtering on documents. Defaults to {}.
398
+ include: The fields to include in the response.
399
+ Defaults to ["embeddings", "metadatas", "documents", "distances"].
400
+
401
+ Returns:
402
+ QueryResult: The results of the query.
403
+ """
404
+ pass
405
+
406
+ @abstractmethod
407
+ def reset(self) -> bool:
408
+ """Resets the database. This will delete all collections and entries.
409
+
410
+ Returns:
411
+ bool: True if the database was reset successfully.
412
+ """
413
+ pass
414
+
415
+ @abstractmethod
416
+ def get_version(self) -> str:
417
+ """Get the version of Chroma.
418
+
419
+ Returns:
420
+ str: The version of Chroma
421
+
422
+ """
423
+ pass
424
+
425
+ @abstractmethod
426
+ def get_settings(self) -> Settings:
427
+ """Get the settings used to initialize.
428
+
429
+ Returns:
430
+ Settings: The settings used to initialize.
431
+
432
+ """
433
+ pass
434
+
435
+ @property
436
+ @abstractmethod
437
+ def max_batch_size(self) -> int:
438
+ """Return the maximum number of records that can be submitted in a single call
439
+ to submit_embeddings."""
440
+ pass
441
+
442
+
443
+ class ClientAPI(BaseAPI, ABC):
444
+ tenant: str
445
+ database: str
446
+
447
+ @abstractmethod
448
+ def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
449
+ """Set the tenant and database for the client. Raises an error if the tenant or
450
+ database does not exist.
451
+
452
+ Args:
453
+ tenant: The tenant to set.
454
+ database: The database to set.
455
+
456
+ """
457
+ pass
458
+
459
+ @abstractmethod
460
+ def set_database(self, database: str) -> None:
461
+ """Set the database for the client. Raises an error if the database does not exist.
462
+
463
+ Args:
464
+ database: The database to set.
465
+
466
+ """
467
+ pass
468
+
469
+ @staticmethod
470
+ @abstractmethod
471
+ def clear_system_cache() -> None:
472
+ """Clear the system cache so that new systems can be created for an existing path.
473
+ This should only be used for testing purposes."""
474
+ pass
475
+
476
+
477
+ class AdminAPI(ABC):
478
+ @abstractmethod
479
+ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
480
+ """Create a new database. Raises an error if the database already exists.
481
+
482
+ Args:
483
+ database: The name of the database to create.
484
+
485
+ """
486
+ pass
487
+
488
+ @abstractmethod
489
+ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
490
+ """Get a database. Raises an error if the database does not exist.
491
+
492
+ Args:
493
+ database: The name of the database to get.
494
+ tenant: The tenant of the database to get.
495
+
496
+ """
497
+ pass
498
+
499
+ @abstractmethod
500
+ def create_tenant(self, name: str) -> None:
501
+ """Create a new tenant. Raises an error if the tenant already exists.
502
+
503
+ Args:
504
+ tenant: The name of the tenant to create.
505
+
506
+ """
507
+ pass
508
+
509
+ @abstractmethod
510
+ def get_tenant(self, name: str) -> Tenant:
511
+ """Get a tenant. Raises an error if the tenant does not exist.
512
+
513
+ Args:
514
+ tenant: The name of the tenant to get.
515
+
516
+ """
517
+ pass
518
+
519
+
520
+ class ServerAPI(BaseAPI, AdminAPI, Component):
521
+ """An API instance that extends the relevant Base API methods by passing
522
+ in a tenant and database. This is the root component of the Chroma System"""
523
+
524
+ @abstractmethod
525
+ @override
526
+ def list_collections(
527
+ self,
528
+ limit: Optional[int] = None,
529
+ offset: Optional[int] = None,
530
+ tenant: str = DEFAULT_TENANT,
531
+ database: str = DEFAULT_DATABASE,
532
+ ) -> Sequence[Collection]:
533
+ pass
534
+
535
+ @abstractmethod
536
+ @override
537
+ def count_collections(
538
+ self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
539
+ ) -> int:
540
+ pass
541
+
542
+ @abstractmethod
543
+ @override
544
+ def create_collection(
545
+ self,
546
+ name: str,
547
+ metadata: Optional[CollectionMetadata] = None,
548
+ embedding_function: Optional[
549
+ EmbeddingFunction[Embeddable]
550
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
551
+ data_loader: Optional[DataLoader[Loadable]] = None,
552
+ get_or_create: bool = False,
553
+ tenant: str = DEFAULT_TENANT,
554
+ database: str = DEFAULT_DATABASE,
555
+ ) -> Collection:
556
+ pass
557
+
558
+ @abstractmethod
559
+ @override
560
+ def get_collection(
561
+ self,
562
+ name: str,
563
+ id: Optional[UUID] = None,
564
+ embedding_function: Optional[
565
+ EmbeddingFunction[Embeddable]
566
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
567
+ data_loader: Optional[DataLoader[Loadable]] = None,
568
+ tenant: str = DEFAULT_TENANT,
569
+ database: str = DEFAULT_DATABASE,
570
+ ) -> Collection:
571
+ pass
572
+
573
+ @abstractmethod
574
+ @override
575
+ def get_or_create_collection(
576
+ self,
577
+ name: str,
578
+ metadata: Optional[CollectionMetadata] = None,
579
+ embedding_function: Optional[
580
+ EmbeddingFunction[Embeddable]
581
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
582
+ data_loader: Optional[DataLoader[Loadable]] = None,
583
+ tenant: str = DEFAULT_TENANT,
584
+ database: str = DEFAULT_DATABASE,
585
+ ) -> Collection:
586
+ pass
587
+
588
+ @abstractmethod
589
+ @override
590
+ def delete_collection(
591
+ self,
592
+ name: str,
593
+ tenant: str = DEFAULT_TENANT,
594
+ database: str = DEFAULT_DATABASE,
595
+ ) -> None:
596
+ pass
chromadb/api/client.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import ClassVar, Dict, Optional, Sequence
2
+ from uuid import UUID
3
+ import uuid
4
+
5
+ from overrides import override
6
+ import requests
7
+ from chromadb.api import AdminAPI, ClientAPI, ServerAPI
8
+ from chromadb.api.types import (
9
+ CollectionMetadata,
10
+ DataLoader,
11
+ Documents,
12
+ Embeddable,
13
+ EmbeddingFunction,
14
+ Embeddings,
15
+ GetResult,
16
+ IDs,
17
+ Include,
18
+ Loadable,
19
+ Metadatas,
20
+ QueryResult,
21
+ URIs,
22
+ )
23
+ from chromadb.config import Settings, System
24
+ from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
25
+ from chromadb.api.models.Collection import Collection
26
+ from chromadb.errors import ChromaError
27
+ from chromadb.telemetry.product import ProductTelemetryClient
28
+ from chromadb.telemetry.product.events import ClientStartEvent
29
+ from chromadb.types import Database, Tenant, Where, WhereDocument
30
+ import chromadb.utils.embedding_functions as ef
31
+
32
+
33
+ class SharedSystemClient:
34
+ _identifer_to_system: ClassVar[Dict[str, System]] = {}
35
+ _identifier: str
36
+
37
+ # region Initialization
38
+ def __init__(
39
+ self,
40
+ settings: Settings = Settings(),
41
+ ) -> None:
42
+ self._identifier = SharedSystemClient._get_identifier_from_settings(settings)
43
+ SharedSystemClient._create_system_if_not_exists(self._identifier, settings)
44
+
45
+ @classmethod
46
+ def _create_system_if_not_exists(
47
+ cls, identifier: str, settings: Settings
48
+ ) -> System:
49
+ if identifier not in cls._identifer_to_system:
50
+ new_system = System(settings)
51
+ cls._identifer_to_system[identifier] = new_system
52
+
53
+ new_system.instance(ProductTelemetryClient)
54
+ new_system.instance(ServerAPI)
55
+
56
+ new_system.start()
57
+ else:
58
+ previous_system = cls._identifer_to_system[identifier]
59
+
60
+ # For now, the settings must match
61
+ if previous_system.settings != settings:
62
+ raise ValueError(
63
+ f"An instance of Chroma already exists for {identifier} with different settings"
64
+ )
65
+
66
+ return cls._identifer_to_system[identifier]
67
+
68
+ @staticmethod
69
+ def _get_identifier_from_settings(settings: Settings) -> str:
70
+ identifier = ""
71
+ api_impl = settings.chroma_api_impl
72
+
73
+ if api_impl is None:
74
+ raise ValueError("Chroma API implementation must be set in settings")
75
+ elif api_impl == "chromadb.api.segment.SegmentAPI":
76
+ if settings.is_persistent:
77
+ identifier = settings.persist_directory
78
+ else:
79
+ identifier = (
80
+ "ephemeral" # TODO: support pathing and multiple ephemeral clients
81
+ )
82
+ elif api_impl == "chromadb.api.fastapi.FastAPI":
83
+ # FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens
84
+ identifier = str(uuid.uuid4())
85
+ else:
86
+ raise ValueError(f"Unsupported Chroma API implementation {api_impl}")
87
+
88
+ return identifier
89
+
90
+ @staticmethod
91
+ def _populate_data_from_system(system: System) -> str:
92
+ identifier = SharedSystemClient._get_identifier_from_settings(system.settings)
93
+ SharedSystemClient._identifer_to_system[identifier] = system
94
+ return identifier
95
+
96
+ @classmethod
97
+ def from_system(cls, system: System) -> "SharedSystemClient":
98
+ """Create a client from an existing system. This is useful for testing and debugging."""
99
+
100
+ SharedSystemClient._populate_data_from_system(system)
101
+ instance = cls(system.settings)
102
+ return instance
103
+
104
+ @staticmethod
105
+ def clear_system_cache() -> None:
106
+ SharedSystemClient._identifer_to_system = {}
107
+
108
+ @property
109
+ def _system(self) -> System:
110
+ return SharedSystemClient._identifer_to_system[self._identifier]
111
+
112
+ # endregion
113
+
114
+
115
+ class Client(SharedSystemClient, ClientAPI):
116
+ """A client for Chroma. This is the main entrypoint for interacting with Chroma.
117
+ A client internally stores its tenant and database and proxies calls to a
118
+ Server API instance of Chroma. It treats the Server API and corresponding System
119
+ as a singleton, so multiple clients connecting to the same resource will share the
120
+ same API instance.
121
+
122
+ Client implementations should be implement their own API-caching strategies.
123
+ """
124
+
125
+ tenant: str = DEFAULT_TENANT
126
+ database: str = DEFAULT_DATABASE
127
+
128
+ _server: ServerAPI
129
+ # An internal admin client for verifying that databases and tenants exist
130
+ _admin_client: AdminAPI
131
+
132
+ # region Initialization
133
+ def __init__(
134
+ self,
135
+ tenant: str = DEFAULT_TENANT,
136
+ database: str = DEFAULT_DATABASE,
137
+ settings: Settings = Settings(),
138
+ ) -> None:
139
+ super().__init__(settings=settings)
140
+ self.tenant = tenant
141
+ self.database = database
142
+ # Create an admin client for verifying that databases and tenants exist
143
+ self._admin_client = AdminClient.from_system(self._system)
144
+ self._validate_tenant_database(tenant=tenant, database=database)
145
+
146
+ # Get the root system component we want to interact with
147
+ self._server = self._system.instance(ServerAPI)
148
+
149
+ # Submit event for a client start
150
+ telemetry_client = self._system.instance(ProductTelemetryClient)
151
+ telemetry_client.capture(ClientStartEvent())
152
+
153
+ @classmethod
154
+ @override
155
+ def from_system(
156
+ cls,
157
+ system: System,
158
+ tenant: str = DEFAULT_TENANT,
159
+ database: str = DEFAULT_DATABASE,
160
+ ) -> "Client":
161
+ SharedSystemClient._populate_data_from_system(system)
162
+ instance = cls(tenant=tenant, database=database, settings=system.settings)
163
+ return instance
164
+
165
+ # endregion
166
+
167
+ # region BaseAPI Methods
168
+ # Note - we could do this in less verbose ways, but they break type checking
169
+ @override
170
+ def heartbeat(self) -> int:
171
+ return self._server.heartbeat()
172
+
173
+ @override
174
+ def list_collections(
175
+ self, limit: Optional[int] = None, offset: Optional[int] = None
176
+ ) -> Sequence[Collection]:
177
+ return self._server.list_collections(
178
+ limit, offset, tenant=self.tenant, database=self.database
179
+ )
180
+
181
+ @override
182
+ def count_collections(self) -> int:
183
+ return self._server.count_collections(
184
+ tenant=self.tenant, database=self.database
185
+ )
186
+
187
+ @override
188
+ def create_collection(
189
+ self,
190
+ name: str,
191
+ metadata: Optional[CollectionMetadata] = None,
192
+ embedding_function: Optional[
193
+ EmbeddingFunction[Embeddable]
194
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
195
+ data_loader: Optional[DataLoader[Loadable]] = None,
196
+ get_or_create: bool = False,
197
+ ) -> Collection:
198
+ return self._server.create_collection(
199
+ name=name,
200
+ metadata=metadata,
201
+ embedding_function=embedding_function,
202
+ data_loader=data_loader,
203
+ tenant=self.tenant,
204
+ database=self.database,
205
+ get_or_create=get_or_create,
206
+ )
207
+
208
+ @override
209
+ def get_collection(
210
+ self,
211
+ name: str,
212
+ id: Optional[UUID] = None,
213
+ embedding_function: Optional[
214
+ EmbeddingFunction[Embeddable]
215
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
216
+ data_loader: Optional[DataLoader[Loadable]] = None,
217
+ ) -> Collection:
218
+ return self._server.get_collection(
219
+ id=id,
220
+ name=name,
221
+ embedding_function=embedding_function,
222
+ data_loader=data_loader,
223
+ tenant=self.tenant,
224
+ database=self.database,
225
+ )
226
+
227
+ @override
228
+ def get_or_create_collection(
229
+ self,
230
+ name: str,
231
+ metadata: Optional[CollectionMetadata] = None,
232
+ embedding_function: Optional[
233
+ EmbeddingFunction[Embeddable]
234
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
235
+ data_loader: Optional[DataLoader[Loadable]] = None,
236
+ ) -> Collection:
237
+ return self._server.get_or_create_collection(
238
+ name=name,
239
+ metadata=metadata,
240
+ embedding_function=embedding_function,
241
+ data_loader=data_loader,
242
+ tenant=self.tenant,
243
+ database=self.database,
244
+ )
245
+
246
+ @override
247
+ def _modify(
248
+ self,
249
+ id: UUID,
250
+ new_name: Optional[str] = None,
251
+ new_metadata: Optional[CollectionMetadata] = None,
252
+ ) -> None:
253
+ return self._server._modify(
254
+ id=id,
255
+ new_name=new_name,
256
+ new_metadata=new_metadata,
257
+ )
258
+
259
+ @override
260
+ def delete_collection(
261
+ self,
262
+ name: str,
263
+ ) -> None:
264
+ return self._server.delete_collection(
265
+ name=name,
266
+ tenant=self.tenant,
267
+ database=self.database,
268
+ )
269
+
270
+ #
271
+ # ITEM METHODS
272
+ #
273
+
274
+ @override
275
+ def _add(
276
+ self,
277
+ ids: IDs,
278
+ collection_id: UUID,
279
+ embeddings: Embeddings,
280
+ metadatas: Optional[Metadatas] = None,
281
+ documents: Optional[Documents] = None,
282
+ uris: Optional[URIs] = None,
283
+ ) -> bool:
284
+ return self._server._add(
285
+ ids=ids,
286
+ collection_id=collection_id,
287
+ embeddings=embeddings,
288
+ metadatas=metadatas,
289
+ documents=documents,
290
+ uris=uris,
291
+ )
292
+
293
+ @override
294
+ def _update(
295
+ self,
296
+ collection_id: UUID,
297
+ ids: IDs,
298
+ embeddings: Optional[Embeddings] = None,
299
+ metadatas: Optional[Metadatas] = None,
300
+ documents: Optional[Documents] = None,
301
+ uris: Optional[URIs] = None,
302
+ ) -> bool:
303
+ return self._server._update(
304
+ collection_id=collection_id,
305
+ ids=ids,
306
+ embeddings=embeddings,
307
+ metadatas=metadatas,
308
+ documents=documents,
309
+ uris=uris,
310
+ )
311
+
312
+ @override
313
+ def _upsert(
314
+ self,
315
+ collection_id: UUID,
316
+ ids: IDs,
317
+ embeddings: Embeddings,
318
+ metadatas: Optional[Metadatas] = None,
319
+ documents: Optional[Documents] = None,
320
+ uris: Optional[URIs] = None,
321
+ ) -> bool:
322
+ return self._server._upsert(
323
+ collection_id=collection_id,
324
+ ids=ids,
325
+ embeddings=embeddings,
326
+ metadatas=metadatas,
327
+ documents=documents,
328
+ uris=uris,
329
+ )
330
+
331
+ @override
332
+ def _count(self, collection_id: UUID) -> int:
333
+ return self._server._count(
334
+ collection_id=collection_id,
335
+ )
336
+
337
+ @override
338
+ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
339
+ return self._server._peek(
340
+ collection_id=collection_id,
341
+ n=n,
342
+ )
343
+
344
+ @override
345
+ def _get(
346
+ self,
347
+ collection_id: UUID,
348
+ ids: Optional[IDs] = None,
349
+ where: Optional[Where] = {},
350
+ sort: Optional[str] = None,
351
+ limit: Optional[int] = None,
352
+ offset: Optional[int] = None,
353
+ page: Optional[int] = None,
354
+ page_size: Optional[int] = None,
355
+ where_document: Optional[WhereDocument] = {},
356
+ include: Include = ["embeddings", "metadatas", "documents"],
357
+ ) -> GetResult:
358
+ return self._server._get(
359
+ collection_id=collection_id,
360
+ ids=ids,
361
+ where=where,
362
+ sort=sort,
363
+ limit=limit,
364
+ offset=offset,
365
+ page=page,
366
+ page_size=page_size,
367
+ where_document=where_document,
368
+ include=include,
369
+ )
370
+
371
+ def _delete(
372
+ self,
373
+ collection_id: UUID,
374
+ ids: Optional[IDs],
375
+ where: Optional[Where] = {},
376
+ where_document: Optional[WhereDocument] = {},
377
+ ) -> IDs:
378
+ return self._server._delete(
379
+ collection_id=collection_id,
380
+ ids=ids,
381
+ where=where,
382
+ where_document=where_document,
383
+ )
384
+
385
+ @override
386
+ def _query(
387
+ self,
388
+ collection_id: UUID,
389
+ query_embeddings: Embeddings,
390
+ n_results: int = 10,
391
+ where: Where = {},
392
+ where_document: WhereDocument = {},
393
+ include: Include = ["embeddings", "metadatas", "documents", "distances"],
394
+ ) -> QueryResult:
395
+ return self._server._query(
396
+ collection_id=collection_id,
397
+ query_embeddings=query_embeddings,
398
+ n_results=n_results,
399
+ where=where,
400
+ where_document=where_document,
401
+ include=include,
402
+ )
403
+
404
+ @override
405
+ def reset(self) -> bool:
406
+ return self._server.reset()
407
+
408
+ @override
409
+ def get_version(self) -> str:
410
+ return self._server.get_version()
411
+
412
+ @override
413
+ def get_settings(self) -> Settings:
414
+ return self._server.get_settings()
415
+
416
+ @property
417
+ @override
418
+ def max_batch_size(self) -> int:
419
+ return self._server.max_batch_size
420
+
421
+ # endregion
422
+
423
+ # region ClientAPI Methods
424
+
425
+ @override
426
+ def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
427
+ self._validate_tenant_database(tenant=tenant, database=database)
428
+ self.tenant = tenant
429
+ self.database = database
430
+
431
+ @override
432
+ def set_database(self, database: str) -> None:
433
+ self._validate_tenant_database(tenant=self.tenant, database=database)
434
+ self.database = database
435
+
436
+ def _validate_tenant_database(self, tenant: str, database: str) -> None:
437
+ try:
438
+ self._admin_client.get_tenant(name=tenant)
439
+ except requests.exceptions.ConnectionError:
440
+ raise ValueError(
441
+ "Could not connect to a Chroma server. Are you sure it is running?"
442
+ )
443
+ # Propagate ChromaErrors
444
+ except ChromaError as e:
445
+ raise e
446
+ except Exception:
447
+ raise ValueError(
448
+ f"Could not connect to tenant {tenant}. Are you sure it exists?"
449
+ )
450
+
451
+ try:
452
+ self._admin_client.get_database(name=database, tenant=tenant)
453
+ except requests.exceptions.ConnectionError:
454
+ raise ValueError(
455
+ "Could not connect to a Chroma server. Are you sure it is running?"
456
+ )
457
+ except Exception:
458
+ raise ValueError(
459
+ f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?"
460
+ )
461
+
462
+ # endregion
463
+
464
+
465
+ class AdminClient(SharedSystemClient, AdminAPI):
466
+ _server: ServerAPI
467
+
468
+ def __init__(self, settings: Settings = Settings()) -> None:
469
+ super().__init__(settings)
470
+ self._server = self._system.instance(ServerAPI)
471
+
472
+ @override
473
+ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
474
+ return self._server.create_database(name=name, tenant=tenant)
475
+
476
+ @override
477
+ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
478
+ return self._server.get_database(name=name, tenant=tenant)
479
+
480
+ @override
481
+ def create_tenant(self, name: str) -> None:
482
+ return self._server.create_tenant(name=name)
483
+
484
+ @override
485
+ def get_tenant(self, name: str) -> Tenant:
486
+ return self._server.get_tenant(name=name)
487
+
488
+ @classmethod
489
+ @override
490
+ def from_system(
491
+ cls,
492
+ system: System,
493
+ ) -> "AdminClient":
494
+ SharedSystemClient._populate_data_from_system(system)
495
+ instance = cls(settings=system.settings)
496
+ return instance
chromadb/api/fastapi.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Optional, cast, Tuple
4
+ from typing import Sequence
5
+ from uuid import UUID
6
+
7
+ import requests
8
+ from overrides import override
9
+
10
+ import chromadb.errors as errors
11
+ from chromadb.types import Database, Tenant
12
+ import chromadb.utils.embedding_functions as ef
13
+ from chromadb.api import ServerAPI
14
+ from chromadb.api.models.Collection import Collection
15
+ from chromadb.api.types import (
16
+ DataLoader,
17
+ Documents,
18
+ Embeddable,
19
+ Embeddings,
20
+ EmbeddingFunction,
21
+ IDs,
22
+ Include,
23
+ Loadable,
24
+ Metadatas,
25
+ URIs,
26
+ Where,
27
+ WhereDocument,
28
+ GetResult,
29
+ QueryResult,
30
+ CollectionMetadata,
31
+ validate_batch,
32
+ )
33
+ from chromadb.auth import (
34
+ ClientAuthProvider,
35
+ )
36
+ from chromadb.auth.providers import RequestsClientAuthProtocolAdapter
37
+ from chromadb.auth.registry import resolve_provider
38
+ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
39
+ from chromadb.telemetry.opentelemetry import (
40
+ OpenTelemetryClient,
41
+ OpenTelemetryGranularity,
42
+ trace_method,
43
+ )
44
+ from chromadb.telemetry.product import ProductTelemetryClient
45
+ from urllib.parse import urlparse, urlunparse, quote
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class FastAPI(ServerAPI):
51
+ _settings: Settings
52
+ _max_batch_size: int = -1
53
+
54
+ @staticmethod
55
+ def _validate_host(host: str) -> None:
56
+ parsed = urlparse(host)
57
+ if "/" in host and parsed.scheme not in {"http", "https"}:
58
+ raise ValueError(
59
+ "Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
60
+ )
61
+ if "/" in host and (not host.startswith("http")):
62
+ raise ValueError(
63
+ "Invalid URL. "
64
+ "Seems that you are trying to pass URL as a host but without \
65
+ specifying the protocol. "
66
+ "Please add http:// or https:// to the host."
67
+ )
68
+
69
+ @staticmethod
70
+ def resolve_url(
71
+ chroma_server_host: str,
72
+ chroma_server_ssl_enabled: Optional[bool] = False,
73
+ default_api_path: Optional[str] = "",
74
+ chroma_server_http_port: Optional[int] = 8000,
75
+ ) -> str:
76
+ _skip_port = False
77
+ _chroma_server_host = chroma_server_host
78
+ FastAPI._validate_host(_chroma_server_host)
79
+ if _chroma_server_host.startswith("http"):
80
+ logger.debug("Skipping port as the user is passing a full URL")
81
+ _skip_port = True
82
+ parsed = urlparse(_chroma_server_host)
83
+
84
+ scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
85
+ net_loc = parsed.netloc or parsed.hostname or chroma_server_host
86
+ port = (
87
+ ":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
88
+ )
89
+ path = parsed.path or default_api_path
90
+
91
+ if not path or path == net_loc:
92
+ path = default_api_path if default_api_path else ""
93
+ if not path.endswith(default_api_path or ""):
94
+ path = path + default_api_path if default_api_path else ""
95
+ full_url = urlunparse(
96
+ (scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
97
+ )
98
+
99
+ return full_url
100
+
101
+ def __init__(self, system: System):
102
+ super().__init__(system)
103
+ system.settings.require("chroma_server_host")
104
+ system.settings.require("chroma_server_http_port")
105
+
106
+ self._opentelemetry_client = self.require(OpenTelemetryClient)
107
+ self._product_telemetry_client = self.require(ProductTelemetryClient)
108
+ self._settings = system.settings
109
+
110
+ self._api_url = FastAPI.resolve_url(
111
+ chroma_server_host=str(system.settings.chroma_server_host),
112
+ chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
113
+ chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
114
+ default_api_path=system.settings.chroma_server_api_default_path,
115
+ )
116
+
117
+ self._header = system.settings.chroma_server_headers
118
+ if (
119
+ system.settings.chroma_client_auth_provider
120
+ and system.settings.chroma_client_auth_protocol_adapter
121
+ ):
122
+ self._auth_provider = self.require(
123
+ resolve_provider(
124
+ system.settings.chroma_client_auth_provider, ClientAuthProvider
125
+ )
126
+ )
127
+ self._adapter = cast(
128
+ RequestsClientAuthProtocolAdapter,
129
+ system.require(
130
+ resolve_provider(
131
+ system.settings.chroma_client_auth_protocol_adapter,
132
+ RequestsClientAuthProtocolAdapter,
133
+ )
134
+ ),
135
+ )
136
+ self._session = self._adapter.session
137
+ else:
138
+ self._session = requests.Session()
139
+ if self._header is not None:
140
+ self._session.headers.update(self._header)
141
+ if self._settings.chroma_server_ssl_verify is not None:
142
+ self._session.verify = self._settings.chroma_server_ssl_verify
143
+
144
+ @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
145
+ @override
146
+ def heartbeat(self) -> int:
147
+ """Returns the current server time in nanoseconds to check if the server is alive"""
148
+ resp = self._session.get(self._api_url)
149
+ raise_chroma_error(resp)
150
+ return int(resp.json()["nanosecond heartbeat"])
151
+
152
+ @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
153
+ @override
154
+ def create_database(
155
+ self,
156
+ name: str,
157
+ tenant: str = DEFAULT_TENANT,
158
+ ) -> None:
159
+ """Creates a database"""
160
+ resp = self._session.post(
161
+ self._api_url + "/databases",
162
+ data=json.dumps({"name": name}),
163
+ params={"tenant": tenant},
164
+ )
165
+ raise_chroma_error(resp)
166
+
167
+ @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
168
+ @override
169
+ def get_database(
170
+ self,
171
+ name: str,
172
+ tenant: str = DEFAULT_TENANT,
173
+ ) -> Database:
174
+ """Returns a database"""
175
+ resp = self._session.get(
176
+ self._api_url + "/databases/" + name,
177
+ params={"tenant": tenant},
178
+ )
179
+ raise_chroma_error(resp)
180
+ resp_json = resp.json()
181
+ return Database(
182
+ id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
183
+ )
184
+
185
+ @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
186
+ @override
187
+ def create_tenant(self, name: str) -> None:
188
+ resp = self._session.post(
189
+ self._api_url + "/tenants",
190
+ data=json.dumps({"name": name}),
191
+ )
192
+ raise_chroma_error(resp)
193
+
194
+ @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
195
+ @override
196
+ def get_tenant(self, name: str) -> Tenant:
197
+ resp = self._session.get(
198
+ self._api_url + "/tenants/" + name,
199
+ )
200
+ raise_chroma_error(resp)
201
+ resp_json = resp.json()
202
+ return Tenant(name=resp_json["name"])
203
+
204
+ @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
205
+ @override
206
+ def list_collections(
207
+ self,
208
+ limit: Optional[int] = None,
209
+ offset: Optional[int] = None,
210
+ tenant: str = DEFAULT_TENANT,
211
+ database: str = DEFAULT_DATABASE,
212
+ ) -> Sequence[Collection]:
213
+ """Returns a list of all collections"""
214
+ resp = self._session.get(
215
+ self._api_url + "/collections",
216
+ params={
217
+ "tenant": tenant,
218
+ "database": database,
219
+ "limit": limit,
220
+ "offset": offset,
221
+ },
222
+ )
223
+ raise_chroma_error(resp)
224
+ json_collections = resp.json()
225
+ collections = []
226
+ for json_collection in json_collections:
227
+ collections.append(Collection(self, **json_collection))
228
+
229
+ return collections
230
+
231
+ @trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
232
+ @override
233
+ def count_collections(
234
+ self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
235
+ ) -> int:
236
+ """Returns a count of collections"""
237
+ resp = self._session.get(
238
+ self._api_url + "/count_collections",
239
+ params={"tenant": tenant, "database": database},
240
+ )
241
+ raise_chroma_error(resp)
242
+ return cast(int, resp.json())
243
+
244
+ @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
245
+ @override
246
+ def create_collection(
247
+ self,
248
+ name: str,
249
+ metadata: Optional[CollectionMetadata] = None,
250
+ embedding_function: Optional[
251
+ EmbeddingFunction[Embeddable]
252
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
253
+ data_loader: Optional[DataLoader[Loadable]] = None,
254
+ get_or_create: bool = False,
255
+ tenant: str = DEFAULT_TENANT,
256
+ database: str = DEFAULT_DATABASE,
257
+ ) -> Collection:
258
+ """Creates a collection"""
259
+ resp = self._session.post(
260
+ self._api_url + "/collections",
261
+ data=json.dumps(
262
+ {
263
+ "name": name,
264
+ "metadata": metadata,
265
+ "get_or_create": get_or_create,
266
+ }
267
+ ),
268
+ params={"tenant": tenant, "database": database},
269
+ )
270
+ raise_chroma_error(resp)
271
+ resp_json = resp.json()
272
+ return Collection(
273
+ client=self,
274
+ id=resp_json["id"],
275
+ name=resp_json["name"],
276
+ embedding_function=embedding_function,
277
+ data_loader=data_loader,
278
+ metadata=resp_json["metadata"],
279
+ )
280
+
281
+ @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
282
+ @override
283
+ def get_collection(
284
+ self,
285
+ name: str,
286
+ id: Optional[UUID] = None,
287
+ embedding_function: Optional[
288
+ EmbeddingFunction[Embeddable]
289
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
290
+ data_loader: Optional[DataLoader[Loadable]] = None,
291
+ tenant: str = DEFAULT_TENANT,
292
+ database: str = DEFAULT_DATABASE,
293
+ ) -> Collection:
294
+ """Returns a collection"""
295
+ if (name is None and id is None) or (name is not None and id is not None):
296
+ raise ValueError("Name or id must be specified, but not both")
297
+
298
+ _params = {"tenant": tenant, "database": database}
299
+ if id is not None:
300
+ _params["type"] = str(id)
301
+ resp = self._session.get(
302
+ self._api_url + "/collections/" + name if name else str(id), params=_params
303
+ )
304
+ raise_chroma_error(resp)
305
+ resp_json = resp.json()
306
+ return Collection(
307
+ client=self,
308
+ name=resp_json["name"],
309
+ id=resp_json["id"],
310
+ embedding_function=embedding_function,
311
+ data_loader=data_loader,
312
+ metadata=resp_json["metadata"],
313
+ )
314
+
315
+ @trace_method(
316
+ "FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
317
+ )
318
+ @override
319
+ def get_or_create_collection(
320
+ self,
321
+ name: str,
322
+ metadata: Optional[CollectionMetadata] = None,
323
+ embedding_function: Optional[
324
+ EmbeddingFunction[Embeddable]
325
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
326
+ data_loader: Optional[DataLoader[Loadable]] = None,
327
+ tenant: str = DEFAULT_TENANT,
328
+ database: str = DEFAULT_DATABASE,
329
+ ) -> Collection:
330
+ return cast(
331
+ Collection,
332
+ self.create_collection(
333
+ name=name,
334
+ metadata=metadata,
335
+ embedding_function=embedding_function,
336
+ data_loader=data_loader,
337
+ get_or_create=True,
338
+ tenant=tenant,
339
+ database=database,
340
+ ),
341
+ )
342
+
343
+ @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
344
+ @override
345
+ def _modify(
346
+ self,
347
+ id: UUID,
348
+ new_name: Optional[str] = None,
349
+ new_metadata: Optional[CollectionMetadata] = None,
350
+ ) -> None:
351
+ """Updates a collection"""
352
+ resp = self._session.put(
353
+ self._api_url + "/collections/" + str(id),
354
+ data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
355
+ )
356
+ raise_chroma_error(resp)
357
+
358
+ @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
359
+ @override
360
+ def delete_collection(
361
+ self,
362
+ name: str,
363
+ tenant: str = DEFAULT_TENANT,
364
+ database: str = DEFAULT_DATABASE,
365
+ ) -> None:
366
+ """Deletes a collection"""
367
+ resp = self._session.delete(
368
+ self._api_url + "/collections/" + name,
369
+ params={"tenant": tenant, "database": database},
370
+ )
371
+ raise_chroma_error(resp)
372
+
373
+ @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
374
+ @override
375
+ def _count(
376
+ self,
377
+ collection_id: UUID,
378
+ ) -> int:
379
+ """Returns the number of embeddings in the database"""
380
+ resp = self._session.get(
381
+ self._api_url + "/collections/" + str(collection_id) + "/count"
382
+ )
383
+ raise_chroma_error(resp)
384
+ return cast(int, resp.json())
385
+
386
+ @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
387
+ @override
388
+ def _peek(
389
+ self,
390
+ collection_id: UUID,
391
+ n: int = 10,
392
+ ) -> GetResult:
393
+ return cast(
394
+ GetResult,
395
+ self._get(
396
+ collection_id,
397
+ limit=n,
398
+ include=["embeddings", "documents", "metadatas"],
399
+ ),
400
+ )
401
+
402
+ @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
403
+ @override
404
+ def _get(
405
+ self,
406
+ collection_id: UUID,
407
+ ids: Optional[IDs] = None,
408
+ where: Optional[Where] = {},
409
+ sort: Optional[str] = None,
410
+ limit: Optional[int] = None,
411
+ offset: Optional[int] = None,
412
+ page: Optional[int] = None,
413
+ page_size: Optional[int] = None,
414
+ where_document: Optional[WhereDocument] = {},
415
+ include: Include = ["metadatas", "documents"],
416
+ ) -> GetResult:
417
+ if page and page_size:
418
+ offset = (page - 1) * page_size
419
+ limit = page_size
420
+
421
+ resp = self._session.post(
422
+ self._api_url + "/collections/" + str(collection_id) + "/get",
423
+ data=json.dumps(
424
+ {
425
+ "ids": ids,
426
+ "where": where,
427
+ "sort": sort,
428
+ "limit": limit,
429
+ "offset": offset,
430
+ "where_document": where_document,
431
+ "include": include,
432
+ }
433
+ ),
434
+ )
435
+
436
+ raise_chroma_error(resp)
437
+ body = resp.json()
438
+ return GetResult(
439
+ ids=body["ids"],
440
+ embeddings=body.get("embeddings", None),
441
+ metadatas=body.get("metadatas", None),
442
+ documents=body.get("documents", None),
443
+ data=None,
444
+ uris=body.get("uris", None),
445
+ )
446
+
447
+ @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
448
+ @override
449
+ def _delete(
450
+ self,
451
+ collection_id: UUID,
452
+ ids: Optional[IDs] = None,
453
+ where: Optional[Where] = {},
454
+ where_document: Optional[WhereDocument] = {},
455
+ ) -> IDs:
456
+ """Deletes embeddings from the database"""
457
+ resp = self._session.post(
458
+ self._api_url + "/collections/" + str(collection_id) + "/delete",
459
+ data=json.dumps(
460
+ {"where": where, "ids": ids, "where_document": where_document}
461
+ ),
462
+ )
463
+
464
+ raise_chroma_error(resp)
465
+ return cast(IDs, resp.json())
466
+
467
+ @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
468
+ def _submit_batch(
469
+ self,
470
+ batch: Tuple[
471
+ IDs,
472
+ Optional[Embeddings],
473
+ Optional[Metadatas],
474
+ Optional[Documents],
475
+ Optional[URIs],
476
+ ],
477
+ url: str,
478
+ ) -> requests.Response:
479
+ """
480
+ Submits a batch of embeddings to the database
481
+ """
482
+ resp = self._session.post(
483
+ self._api_url + url,
484
+ data=json.dumps(
485
+ {
486
+ "ids": batch[0],
487
+ "embeddings": batch[1],
488
+ "metadatas": batch[2],
489
+ "documents": batch[3],
490
+ "uris": batch[4],
491
+ }
492
+ ),
493
+ )
494
+ return resp
495
+
496
+ @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
497
+ @override
498
+ def _add(
499
+ self,
500
+ ids: IDs,
501
+ collection_id: UUID,
502
+ embeddings: Embeddings,
503
+ metadatas: Optional[Metadatas] = None,
504
+ documents: Optional[Documents] = None,
505
+ uris: Optional[URIs] = None,
506
+ ) -> bool:
507
+ """
508
+ Adds a batch of embeddings to the database
509
+ - pass in column oriented data lists
510
+ """
511
+ batch = (ids, embeddings, metadatas, documents, uris)
512
+ validate_batch(batch, {"max_batch_size": self.max_batch_size})
513
+ resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
514
+ raise_chroma_error(resp)
515
+ return True
516
+
517
+ @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
518
+ @override
519
+ def _update(
520
+ self,
521
+ collection_id: UUID,
522
+ ids: IDs,
523
+ embeddings: Optional[Embeddings] = None,
524
+ metadatas: Optional[Metadatas] = None,
525
+ documents: Optional[Documents] = None,
526
+ uris: Optional[URIs] = None,
527
+ ) -> bool:
528
+ """
529
+ Updates a batch of embeddings in the database
530
+ - pass in column oriented data lists
531
+ """
532
+ batch = (ids, embeddings, metadatas, documents, uris)
533
+ validate_batch(batch, {"max_batch_size": self.max_batch_size})
534
+ resp = self._submit_batch(
535
+ batch, "/collections/" + str(collection_id) + "/update"
536
+ )
537
+ raise_chroma_error(resp)
538
+ return True
539
+
540
+ @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
541
+ @override
542
+ def _upsert(
543
+ self,
544
+ collection_id: UUID,
545
+ ids: IDs,
546
+ embeddings: Embeddings,
547
+ metadatas: Optional[Metadatas] = None,
548
+ documents: Optional[Documents] = None,
549
+ uris: Optional[URIs] = None,
550
+ ) -> bool:
551
+ """
552
+ Upserts a batch of embeddings in the database
553
+ - pass in column oriented data lists
554
+ """
555
+ batch = (ids, embeddings, metadatas, documents, uris)
556
+ validate_batch(batch, {"max_batch_size": self.max_batch_size})
557
+ resp = self._submit_batch(
558
+ batch, "/collections/" + str(collection_id) + "/upsert"
559
+ )
560
+ raise_chroma_error(resp)
561
+ return True
562
+
563
+ @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
564
+ @override
565
+ def _query(
566
+ self,
567
+ collection_id: UUID,
568
+ query_embeddings: Embeddings,
569
+ n_results: int = 10,
570
+ where: Optional[Where] = {},
571
+ where_document: Optional[WhereDocument] = {},
572
+ include: Include = ["metadatas", "documents", "distances"],
573
+ ) -> QueryResult:
574
+ """Gets the nearest neighbors of a single embedding"""
575
+ resp = self._session.post(
576
+ self._api_url + "/collections/" + str(collection_id) + "/query",
577
+ data=json.dumps(
578
+ {
579
+ "query_embeddings": query_embeddings,
580
+ "n_results": n_results,
581
+ "where": where,
582
+ "where_document": where_document,
583
+ "include": include,
584
+ }
585
+ ),
586
+ )
587
+
588
+ raise_chroma_error(resp)
589
+ body = resp.json()
590
+
591
+ return QueryResult(
592
+ ids=body["ids"],
593
+ distances=body.get("distances", None),
594
+ embeddings=body.get("embeddings", None),
595
+ metadatas=body.get("metadatas", None),
596
+ documents=body.get("documents", None),
597
+ uris=body.get("uris", None),
598
+ data=None,
599
+ )
600
+
601
+ @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
602
+ @override
603
+ def reset(self) -> bool:
604
+ """Resets the database"""
605
+ resp = self._session.post(self._api_url + "/reset")
606
+ raise_chroma_error(resp)
607
+ return cast(bool, resp.json())
608
+
609
+ @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
610
+ @override
611
+ def get_version(self) -> str:
612
+ """Returns the version of the server"""
613
+ resp = self._session.get(self._api_url + "/version")
614
+ raise_chroma_error(resp)
615
+ return cast(str, resp.json())
616
+
617
+ @override
618
+ def get_settings(self) -> Settings:
619
+ """Returns the settings of the client"""
620
+ return self._settings
621
+
622
+ @property
623
+ @trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION)
624
+ @override
625
+ def max_batch_size(self) -> int:
626
+ if self._max_batch_size == -1:
627
+ resp = self._session.get(self._api_url + "/pre-flight-checks")
628
+ raise_chroma_error(resp)
629
+ self._max_batch_size = cast(int, resp.json()["max_batch_size"])
630
+ return self._max_batch_size
631
+
632
+
633
+ def raise_chroma_error(resp: requests.Response) -> None:
634
+ """Raises an error if the response is not ok, using a ChromaError if possible"""
635
+ if resp.ok:
636
+ return
637
+
638
+ chroma_error = None
639
+ try:
640
+ body = resp.json()
641
+ if "error" in body:
642
+ if body["error"] in errors.error_types:
643
+ chroma_error = errors.error_types[body["error"]](body["message"])
644
+
645
+ except BaseException:
646
+ pass
647
+
648
+ if chroma_error:
649
+ raise chroma_error
650
+
651
+ try:
652
+ resp.raise_for_status()
653
+ except requests.HTTPError:
654
+ raise (Exception(resp.text))
chromadb/api/models/Collection.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional, Tuple, Any, Union
2
+
3
+ import numpy as np
4
+ from pydantic import BaseModel, PrivateAttr
5
+
6
+ from uuid import UUID
7
+ import chromadb.utils.embedding_functions as ef
8
+
9
+ from chromadb.api.types import (
10
+ URI,
11
+ CollectionMetadata,
12
+ DataLoader,
13
+ Embedding,
14
+ Embeddings,
15
+ Embeddable,
16
+ Include,
17
+ Loadable,
18
+ Metadata,
19
+ Metadatas,
20
+ Document,
21
+ Documents,
22
+ Image,
23
+ Images,
24
+ URIs,
25
+ Where,
26
+ IDs,
27
+ EmbeddingFunction,
28
+ GetResult,
29
+ QueryResult,
30
+ ID,
31
+ OneOrMany,
32
+ WhereDocument,
33
+ maybe_cast_one_to_many_ids,
34
+ maybe_cast_one_to_many_embedding,
35
+ maybe_cast_one_to_many_metadata,
36
+ maybe_cast_one_to_many_document,
37
+ maybe_cast_one_to_many_image,
38
+ maybe_cast_one_to_many_uri,
39
+ validate_ids,
40
+ validate_include,
41
+ validate_metadata,
42
+ validate_metadatas,
43
+ validate_where,
44
+ validate_where_document,
45
+ validate_n_results,
46
+ validate_embeddings,
47
+ validate_embedding_function,
48
+ )
49
+ import logging
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ if TYPE_CHECKING:
54
+ from chromadb.api import ServerAPI
55
+
56
+
57
+ class Collection(BaseModel):
58
+ name: str
59
+ id: UUID
60
+ metadata: Optional[CollectionMetadata] = None
61
+ tenant: Optional[str] = None
62
+ database: Optional[str] = None
63
+ _client: "ServerAPI" = PrivateAttr()
64
+ _embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr()
65
+ _data_loader: Optional[DataLoader[Loadable]] = PrivateAttr()
66
+
67
+ def __init__(
68
+ self,
69
+ client: "ServerAPI",
70
+ name: str,
71
+ id: UUID,
72
+ embedding_function: Optional[
73
+ EmbeddingFunction[Embeddable]
74
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
75
+ data_loader: Optional[DataLoader[Loadable]] = None,
76
+ tenant: Optional[str] = None,
77
+ database: Optional[str] = None,
78
+ metadata: Optional[CollectionMetadata] = None,
79
+ ):
80
+ super().__init__(
81
+ name=name, metadata=metadata, id=id, tenant=tenant, database=database
82
+ )
83
+ self._client = client
84
+
85
+ # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
86
+ if embedding_function is not None:
87
+ validate_embedding_function(embedding_function)
88
+
89
+ self._embedding_function = embedding_function
90
+ self._data_loader = data_loader
91
+
92
+ def __repr__(self) -> str:
93
+ return f"Collection(name={self.name})"
94
+
95
+ def count(self) -> int:
96
+ """The total number of embeddings added to the database
97
+
98
+ Returns:
99
+ int: The total number of embeddings added to the database
100
+
101
+ """
102
+ return self._client._count(collection_id=self.id)
103
+
104
+ def add(
105
+ self,
106
+ ids: OneOrMany[ID],
107
+ embeddings: Optional[
108
+ Union[
109
+ OneOrMany[Embedding],
110
+ OneOrMany[np.ndarray],
111
+ ]
112
+ ] = None,
113
+ metadatas: Optional[OneOrMany[Metadata]] = None,
114
+ documents: Optional[OneOrMany[Document]] = None,
115
+ images: Optional[OneOrMany[Image]] = None,
116
+ uris: Optional[OneOrMany[URI]] = None,
117
+ ) -> None:
118
+ """Add embeddings to the data store.
119
+ Args:
120
+ ids: The ids of the embeddings you wish to add
121
+ embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
122
+ metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
123
+ documents: The documents to associate with the embeddings. Optional.
124
+ images: The images to associate with the embeddings. Optional.
125
+ uris: The uris of the images to associate with the embeddings. Optional.
126
+
127
+ Returns:
128
+ None
129
+
130
+ Raises:
131
+ ValueError: If you don't provide either embeddings or documents
132
+ ValueError: If the length of ids, embeddings, metadatas, or documents don't match
133
+ ValueError: If you don't provide an embedding function and don't provide embeddings
134
+ ValueError: If you provide both embeddings and documents
135
+ ValueError: If you provide an id that already exists
136
+
137
+ """
138
+
139
+ (
140
+ ids,
141
+ embeddings,
142
+ metadatas,
143
+ documents,
144
+ images,
145
+ uris,
146
+ ) = self._validate_embedding_set(
147
+ ids, embeddings, metadatas, documents, images, uris
148
+ )
149
+
150
+ # We need to compute the embeddings if they're not provided
151
+ if embeddings is None:
152
+ # At this point, we know that one of documents or images are provided from the validation above
153
+ if documents is not None:
154
+ embeddings = self._embed(input=documents)
155
+ elif images is not None:
156
+ embeddings = self._embed(input=images)
157
+ else:
158
+ if uris is None:
159
+ raise ValueError(
160
+ "You must provide either embeddings, documents, images, or uris."
161
+ )
162
+ if self._data_loader is None:
163
+ raise ValueError(
164
+ "You must set a data loader on the collection if loading from URIs."
165
+ )
166
+ embeddings = self._embed(self._data_loader(uris))
167
+
168
+ self._client._add(ids, self.id, embeddings, metadatas, documents, uris)
169
+
170
+ def get(
171
+ self,
172
+ ids: Optional[OneOrMany[ID]] = None,
173
+ where: Optional[Where] = None,
174
+ limit: Optional[int] = None,
175
+ offset: Optional[int] = None,
176
+ where_document: Optional[WhereDocument] = None,
177
+ include: Include = ["metadatas", "documents"],
178
+ ) -> GetResult:
179
+ """Get embeddings and their associate data from the data store. If no ids or where filter is provided returns
180
+ all embeddings up to limit starting at offset.
181
+
182
+ Args:
183
+ ids: The ids of the embeddings to get. Optional.
184
+ where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
185
+ limit: The number of documents to return. Optional.
186
+ offset: The offset to start returning results from. Useful for paging results with limit. Optional.
187
+ where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
188
+ include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional.
189
+
190
+ Returns:
191
+ GetResult: A GetResult object containing the results.
192
+
193
+ """
194
+
195
+ valid_where = validate_where(where) if where else None
196
+ valid_where_document = (
197
+ validate_where_document(where_document) if where_document else None
198
+ )
199
+ valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
200
+ valid_include = validate_include(include, allow_distances=False)
201
+
202
+ if "data" in include and self._data_loader is None:
203
+ raise ValueError(
204
+ "You must set a data loader on the collection if loading from URIs."
205
+ )
206
+
207
+ # We need to include uris in the result from the API to load datas
208
+ if "data" in include and "uris" not in include:
209
+ valid_include.append("uris")
210
+
211
+ get_results = self._client._get(
212
+ self.id,
213
+ valid_ids,
214
+ valid_where,
215
+ None,
216
+ limit,
217
+ offset,
218
+ where_document=valid_where_document,
219
+ include=valid_include,
220
+ )
221
+
222
+ if (
223
+ "data" in include
224
+ and self._data_loader is not None
225
+ and get_results["uris"] is not None
226
+ ):
227
+ get_results["data"] = self._data_loader(get_results["uris"])
228
+
229
+ # Remove URIs from the result if they weren't requested
230
+ if "uris" not in include:
231
+ get_results["uris"] = None
232
+
233
+ return get_results
234
+
235
+ def peek(self, limit: int = 10) -> GetResult:
236
+ """Get the first few results in the database up to limit
237
+
238
+ Args:
239
+ limit: The number of results to return.
240
+
241
+ Returns:
242
+ GetResult: A GetResult object containing the results.
243
+ """
244
+ return self._client._peek(self.id, limit)
245
+
246
+ def query(
247
+ self,
248
+ query_embeddings: Optional[
249
+ Union[
250
+ OneOrMany[Embedding],
251
+ OneOrMany[np.ndarray],
252
+ ]
253
+ ] = None,
254
+ query_texts: Optional[OneOrMany[Document]] = None,
255
+ query_images: Optional[OneOrMany[Image]] = None,
256
+ query_uris: Optional[OneOrMany[URI]] = None,
257
+ n_results: int = 10,
258
+ where: Optional[Where] = None,
259
+ where_document: Optional[WhereDocument] = None,
260
+ include: Include = ["metadatas", "documents", "distances"],
261
+ ) -> QueryResult:
262
+ """Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.
263
+
264
+ Args:
265
+ query_embeddings: The embeddings to get the closes neighbors of. Optional.
266
+ query_texts: The document texts to get the closes neighbors of. Optional.
267
+ query_images: The images to get the closes neighbors of. Optional.
268
+ n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
269
+ where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
270
+ where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
271
+ include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
272
+
273
+ Returns:
274
+ QueryResult: A QueryResult object containing the results.
275
+
276
+ Raises:
277
+ ValueError: If you don't provide either query_embeddings, query_texts, or query_images
278
+ ValueError: If you provide both query_embeddings and query_texts
279
+ ValueError: If you provide both query_embeddings and query_images
280
+ ValueError: If you provide both query_texts and query_images
281
+
282
+ """
283
+
284
+ # Users must provide only one of query_embeddings, query_texts, query_images, or query_uris
285
+ if not (
286
+ (query_embeddings is not None)
287
+ ^ (query_texts is not None)
288
+ ^ (query_images is not None)
289
+ ^ (query_uris is not None)
290
+ ):
291
+ raise ValueError(
292
+ "You must provide one of query_embeddings, query_texts, query_images, or query_uris."
293
+ )
294
+
295
+ valid_where = validate_where(where) if where else {}
296
+ valid_where_document = (
297
+ validate_where_document(where_document) if where_document else {}
298
+ )
299
+ valid_query_embeddings = (
300
+ validate_embeddings(
301
+ self._normalize_embeddings(
302
+ maybe_cast_one_to_many_embedding(query_embeddings)
303
+ )
304
+ )
305
+ if query_embeddings is not None
306
+ else None
307
+ )
308
+ valid_query_texts = (
309
+ maybe_cast_one_to_many_document(query_texts)
310
+ if query_texts is not None
311
+ else None
312
+ )
313
+ valid_query_images = (
314
+ maybe_cast_one_to_many_image(query_images)
315
+ if query_images is not None
316
+ else None
317
+ )
318
+ valid_query_uris = (
319
+ maybe_cast_one_to_many_uri(query_uris) if query_uris is not None else None
320
+ )
321
+ valid_include = validate_include(include, allow_distances=True)
322
+ valid_n_results = validate_n_results(n_results)
323
+
324
+ # If query_embeddings are not provided, we need to compute them from the inputs
325
+ if valid_query_embeddings is None:
326
+ if query_texts is not None:
327
+ valid_query_embeddings = self._embed(input=valid_query_texts)
328
+ elif query_images is not None:
329
+ valid_query_embeddings = self._embed(input=valid_query_images)
330
+ else:
331
+ if valid_query_uris is None:
332
+ raise ValueError(
333
+ "You must provide either query_embeddings, query_texts, query_images, or query_uris."
334
+ )
335
+ if self._data_loader is None:
336
+ raise ValueError(
337
+ "You must set a data loader on the collection if loading from URIs."
338
+ )
339
+ valid_query_embeddings = self._embed(
340
+ self._data_loader(valid_query_uris)
341
+ )
342
+
343
+ if "data" in include and "uris" not in include:
344
+ valid_include.append("uris")
345
+ query_results = self._client._query(
346
+ collection_id=self.id,
347
+ query_embeddings=valid_query_embeddings,
348
+ n_results=valid_n_results,
349
+ where=valid_where,
350
+ where_document=valid_where_document,
351
+ include=include,
352
+ )
353
+
354
+ if (
355
+ "data" in include
356
+ and self._data_loader is not None
357
+ and query_results["uris"] is not None
358
+ ):
359
+ query_results["data"] = [
360
+ self._data_loader(uris) for uris in query_results["uris"]
361
+ ]
362
+
363
+ # Remove URIs from the result if they weren't requested
364
+ if "uris" not in include:
365
+ query_results["uris"] = None
366
+
367
+ return query_results
368
+
369
+ def modify(
370
+ self, name: Optional[str] = None, metadata: Optional[CollectionMetadata] = None
371
+ ) -> None:
372
+ """Modify the collection name or metadata
373
+
374
+ Args:
375
+ name: The updated name for the collection. Optional.
376
+ metadata: The updated metadata for the collection. Optional.
377
+
378
+ Returns:
379
+ None
380
+ """
381
+ if metadata is not None:
382
+ validate_metadata(metadata)
383
+ if "hnsw:space" in metadata:
384
+ raise ValueError(
385
+ "Changing the distance function of a collection once it is created is not supported currently.")
386
+
387
+ self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
388
+ if name:
389
+ self.name = name
390
+ if metadata:
391
+ self.metadata = metadata
392
+
393
+ def update(
394
+ self,
395
+ ids: OneOrMany[ID],
396
+ embeddings: Optional[
397
+ Union[
398
+ OneOrMany[Embedding],
399
+ OneOrMany[np.ndarray],
400
+ ]
401
+ ] = None,
402
+ metadatas: Optional[OneOrMany[Metadata]] = None,
403
+ documents: Optional[OneOrMany[Document]] = None,
404
+ images: Optional[OneOrMany[Image]] = None,
405
+ uris: Optional[OneOrMany[URI]] = None,
406
+ ) -> None:
407
+ """Update the embeddings, metadatas or documents for provided ids.
408
+
409
+ Args:
410
+ ids: The ids of the embeddings to update
411
+ embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional.
412
+ metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
413
+ documents: The documents to associate with the embeddings. Optional.
414
+ images: The images to associate with the embeddings. Optional.
415
+ Returns:
416
+ None
417
+ """
418
+
419
+ (
420
+ ids,
421
+ embeddings,
422
+ metadatas,
423
+ documents,
424
+ images,
425
+ uris,
426
+ ) = self._validate_embedding_set(
427
+ ids,
428
+ embeddings,
429
+ metadatas,
430
+ documents,
431
+ images,
432
+ uris,
433
+ require_embeddings_or_data=False,
434
+ )
435
+
436
+ if embeddings is None:
437
+ if documents is not None:
438
+ embeddings = self._embed(input=documents)
439
+ elif images is not None:
440
+ embeddings = self._embed(input=images)
441
+
442
+ self._client._update(self.id, ids, embeddings, metadatas, documents, uris)
443
+
444
+ def upsert(
445
+ self,
446
+ ids: OneOrMany[ID],
447
+ embeddings: Optional[
448
+ Union[
449
+ OneOrMany[Embedding],
450
+ OneOrMany[np.ndarray],
451
+ ]
452
+ ] = None,
453
+ metadatas: Optional[OneOrMany[Metadata]] = None,
454
+ documents: Optional[OneOrMany[Document]] = None,
455
+ images: Optional[OneOrMany[Image]] = None,
456
+ uris: Optional[OneOrMany[URI]] = None,
457
+ ) -> None:
458
+ """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
459
+
460
+ Args:
461
+ ids: The ids of the embeddings to update
462
+ embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
463
+ metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
464
+ documents: The documents to associate with the embeddings. Optional.
465
+
466
+ Returns:
467
+ None
468
+ """
469
+
470
+ (
471
+ ids,
472
+ embeddings,
473
+ metadatas,
474
+ documents,
475
+ images,
476
+ uris,
477
+ ) = self._validate_embedding_set(
478
+ ids, embeddings, metadatas, documents, images, uris
479
+ )
480
+
481
+ if embeddings is None:
482
+ if documents is not None:
483
+ embeddings = self._embed(input=documents)
484
+ else:
485
+ embeddings = self._embed(input=images)
486
+
487
+ self._client._upsert(
488
+ collection_id=self.id,
489
+ ids=ids,
490
+ embeddings=embeddings,
491
+ metadatas=metadatas,
492
+ documents=documents,
493
+ uris=uris,
494
+ )
495
+
496
+ def delete(
497
+ self,
498
+ ids: Optional[IDs] = None,
499
+ where: Optional[Where] = None,
500
+ where_document: Optional[WhereDocument] = None,
501
+ ) -> None:
502
+ """Delete the embeddings based on ids and/or a where filter
503
+
504
+ Args:
505
+ ids: The ids of the embeddings to delete
506
+ where: A Where type dict used to filter the delection by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional.
507
+ where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. `{$contains: {"text": "hello"}}`. Optional.
508
+
509
+ Returns:
510
+ None
511
+
512
+ Raises:
513
+ ValueError: If you don't provide either ids, where, or where_document
514
+ """
515
+ ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None
516
+ where = validate_where(where) if where else None
517
+ where_document = (
518
+ validate_where_document(where_document) if where_document else None
519
+ )
520
+
521
+ self._client._delete(self.id, ids, where, where_document)
522
+
523
+ def _validate_embedding_set(
524
+ self,
525
+ ids: OneOrMany[ID],
526
+ embeddings: Optional[
527
+ Union[
528
+ OneOrMany[Embedding],
529
+ OneOrMany[np.ndarray],
530
+ ]
531
+ ],
532
+ metadatas: Optional[OneOrMany[Metadata]],
533
+ documents: Optional[OneOrMany[Document]],
534
+ images: Optional[OneOrMany[Image]] = None,
535
+ uris: Optional[OneOrMany[URI]] = None,
536
+ require_embeddings_or_data: bool = True,
537
+ ) -> Tuple[
538
+ IDs,
539
+ Optional[Embeddings],
540
+ Optional[Metadatas],
541
+ Optional[Documents],
542
+ Optional[Images],
543
+ Optional[URIs],
544
+ ]:
545
+ valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids))
546
+ valid_embeddings = (
547
+ validate_embeddings(
548
+ self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings))
549
+ )
550
+ if embeddings is not None
551
+ else None
552
+ )
553
+ valid_metadatas = (
554
+ validate_metadatas(maybe_cast_one_to_many_metadata(metadatas))
555
+ if metadatas is not None
556
+ else None
557
+ )
558
+ valid_documents = (
559
+ maybe_cast_one_to_many_document(documents)
560
+ if documents is not None
561
+ else None
562
+ )
563
+ valid_images = (
564
+ maybe_cast_one_to_many_image(images) if images is not None else None
565
+ )
566
+
567
+ valid_uris = maybe_cast_one_to_many_uri(uris) if uris is not None else None
568
+
569
+ # Check that one of embeddings or ducuments or images is provided
570
+ if require_embeddings_or_data:
571
+ if (
572
+ valid_embeddings is None
573
+ and valid_documents is None
574
+ and valid_images is None
575
+ and valid_uris is None
576
+ ):
577
+ raise ValueError(
578
+ "You must provide embeddings, documents, images, or uris."
579
+ )
580
+
581
+ # Only one of documents or images can be provided
582
+ if valid_documents is not None and valid_images is not None:
583
+ raise ValueError("You can only provide documents or images, not both.")
584
+
585
+ # Check that, if they're provided, the lengths of the arrays match the length of ids
586
+ if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids):
587
+ raise ValueError(
588
+ f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}"
589
+ )
590
+ if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids):
591
+ raise ValueError(
592
+ f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
593
+ )
594
+ if valid_documents is not None and len(valid_documents) != len(valid_ids):
595
+ raise ValueError(
596
+ f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
597
+ )
598
+ if valid_images is not None and len(valid_images) != len(valid_ids):
599
+ raise ValueError(
600
+ f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
601
+ )
602
+ if valid_uris is not None and len(valid_uris) != len(valid_ids):
603
+ raise ValueError(
604
+ f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
605
+ )
606
+
607
+ return (
608
+ valid_ids,
609
+ valid_embeddings,
610
+ valid_metadatas,
611
+ valid_documents,
612
+ valid_images,
613
+ valid_uris,
614
+ )
615
+
616
+ @staticmethod
617
+ def _normalize_embeddings(
618
+ embeddings: Union[
619
+ OneOrMany[Embedding],
620
+ OneOrMany[np.ndarray],
621
+ ]
622
+ ) -> Embeddings:
623
+ if isinstance(embeddings, np.ndarray):
624
+ return embeddings.tolist()
625
+ return embeddings
626
+
627
+ def _embed(self, input: Any) -> Embeddings:
628
+ if self._embedding_function is None:
629
+ raise ValueError(
630
+ "You must provide an embedding function to compute embeddings."
631
+ "https://docs.trychroma.com/embeddings"
632
+ )
633
+ return self._embedding_function(input=input)
chromadb/api/segment.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chromadb.api import ServerAPI
2
+ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
3
+ from chromadb.db.system import SysDB
4
+ from chromadb.segment import SegmentManager, MetadataReader, VectorReader
5
+ from chromadb.telemetry.opentelemetry import (
6
+ add_attributes_to_current_span,
7
+ OpenTelemetryClient,
8
+ OpenTelemetryGranularity,
9
+ trace_method,
10
+ )
11
+ from chromadb.telemetry.product import ProductTelemetryClient
12
+ from chromadb.ingest import Producer
13
+ from chromadb.api.models.Collection import Collection
14
+ from chromadb import __version__
15
+ from chromadb.errors import InvalidDimensionException, InvalidCollectionException
16
+ import chromadb.utils.embedding_functions as ef
17
+
18
+ from chromadb.api.types import (
19
+ URI,
20
+ CollectionMetadata,
21
+ Embeddable,
22
+ Document,
23
+ EmbeddingFunction,
24
+ DataLoader,
25
+ IDs,
26
+ Embeddings,
27
+ Embedding,
28
+ Loadable,
29
+ Metadatas,
30
+ Documents,
31
+ URIs,
32
+ Where,
33
+ WhereDocument,
34
+ Include,
35
+ GetResult,
36
+ QueryResult,
37
+ validate_metadata,
38
+ validate_update_metadata,
39
+ validate_where,
40
+ validate_where_document,
41
+ validate_batch,
42
+ )
43
+ from chromadb.telemetry.product.events import (
44
+ CollectionAddEvent,
45
+ CollectionDeleteEvent,
46
+ CollectionGetEvent,
47
+ CollectionUpdateEvent,
48
+ CollectionQueryEvent,
49
+ ClientCreateCollectionEvent,
50
+ )
51
+
52
+ import chromadb.types as t
53
+
54
+ from typing import Any, Optional, Sequence, Generator, List, cast, Set, Dict
55
+ from overrides import override
56
+ from uuid import UUID, uuid4
57
+ import time
58
+ import logging
59
+ import re
60
+
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ # mimics s3 bucket requirements for naming
66
+ def check_index_name(index_name: str) -> None:
67
+ msg = (
68
+ "Expected collection name that "
69
+ "(1) contains 3-63 characters, "
70
+ "(2) starts and ends with an alphanumeric character, "
71
+ "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
72
+ "(4) contains no two consecutive periods (..) and "
73
+ "(5) is not a valid IPv4 address, "
74
+ f"got {index_name}"
75
+ )
76
+ if len(index_name) < 3 or len(index_name) > 63:
77
+ raise ValueError(msg)
78
+ if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name):
79
+ raise ValueError(msg)
80
+ if ".." in index_name:
81
+ raise ValueError(msg)
82
+ if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
83
+ raise ValueError(msg)
84
+
85
+
86
+ class SegmentAPI(ServerAPI):
87
+ """API implementation utilizing the new segment-based internal architecture"""
88
+
89
+ _settings: Settings
90
+ _sysdb: SysDB
91
+ _manager: SegmentManager
92
+ _producer: Producer
93
+ _product_telemetry_client: ProductTelemetryClient
94
+ _opentelemetry_client: OpenTelemetryClient
95
+ _tenant_id: str
96
+ _topic_ns: str
97
+ _collection_cache: Dict[UUID, t.Collection]
98
+
99
+ def __init__(self, system: System):
100
+ super().__init__(system)
101
+ self._settings = system.settings
102
+ self._sysdb = self.require(SysDB)
103
+ self._manager = self.require(SegmentManager)
104
+ self._product_telemetry_client = self.require(ProductTelemetryClient)
105
+ self._opentelemetry_client = self.require(OpenTelemetryClient)
106
+ self._producer = self.require(Producer)
107
+ self._collection_cache = {}
108
+
109
+ @override
110
+ def heartbeat(self) -> int:
111
+ return int(time.time_ns())
112
+
113
+ @override
114
+ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
115
+ if len(name) < 3:
116
+ raise ValueError("Database name must be at least 3 characters long")
117
+
118
+ self._sysdb.create_database(
119
+ id=uuid4(),
120
+ name=name,
121
+ tenant=tenant,
122
+ )
123
+
124
+ @override
125
+ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
126
+ return self._sysdb.get_database(name=name, tenant=tenant)
127
+
128
+ @override
129
+ def create_tenant(self, name: str) -> None:
130
+ if len(name) < 3:
131
+ raise ValueError("Tenant name must be at least 3 characters long")
132
+
133
+ self._sysdb.create_tenant(
134
+ name=name,
135
+ )
136
+
137
+ @override
138
+ def get_tenant(self, name: str) -> t.Tenant:
139
+ return self._sysdb.get_tenant(name=name)
140
+
141
+ # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
142
+ # necessary because changing the value type from `Any` to`` `Union[str, int, float]`
143
+ # causes the system to somehow convert all values to strings.
144
+ @trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION)
145
+ @override
146
+ def create_collection(
147
+ self,
148
+ name: str,
149
+ metadata: Optional[CollectionMetadata] = None,
150
+ embedding_function: Optional[
151
+ EmbeddingFunction[Any]
152
+ ] = ef.DefaultEmbeddingFunction(),
153
+ data_loader: Optional[DataLoader[Loadable]] = None,
154
+ get_or_create: bool = False,
155
+ tenant: str = DEFAULT_TENANT,
156
+ database: str = DEFAULT_DATABASE,
157
+ ) -> Collection:
158
+ if metadata is not None:
159
+ validate_metadata(metadata)
160
+
161
+ # TODO: remove backwards compatibility in naming requirements
162
+ check_index_name(name)
163
+
164
+ id = uuid4()
165
+
166
+ coll, created = self._sysdb.create_collection(
167
+ id=id,
168
+ name=name,
169
+ metadata=metadata,
170
+ dimension=None,
171
+ get_or_create=get_or_create,
172
+ tenant=tenant,
173
+ database=database,
174
+ )
175
+
176
+ if created:
177
+ segments = self._manager.create_segments(coll)
178
+ for segment in segments:
179
+ self._sysdb.create_segment(segment)
180
+
181
+ # TODO: This event doesn't capture the get_or_create case appropriately
182
+ self._product_telemetry_client.capture(
183
+ ClientCreateCollectionEvent(
184
+ collection_uuid=str(id),
185
+ embedding_function=embedding_function.__class__.__name__,
186
+ )
187
+ )
188
+ add_attributes_to_current_span({"collection_uuid": str(id)})
189
+
190
+ return Collection(
191
+ client=self,
192
+ id=coll["id"],
193
+ name=name,
194
+ metadata=coll["metadata"], # type: ignore
195
+ embedding_function=embedding_function,
196
+ data_loader=data_loader,
197
+ tenant=tenant,
198
+ database=database,
199
+ )
200
+
201
+ @trace_method(
202
+ "SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
203
+ )
204
+ @override
205
+ def get_or_create_collection(
206
+ self,
207
+ name: str,
208
+ metadata: Optional[CollectionMetadata] = None,
209
+ embedding_function: Optional[
210
+ EmbeddingFunction[Embeddable]
211
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
212
+ data_loader: Optional[DataLoader[Loadable]] = None,
213
+ tenant: str = DEFAULT_TENANT,
214
+ database: str = DEFAULT_DATABASE,
215
+ ) -> Collection:
216
+ return self.create_collection( # type: ignore
217
+ name=name,
218
+ metadata=metadata,
219
+ embedding_function=embedding_function,
220
+ data_loader=data_loader,
221
+ get_or_create=True,
222
+ tenant=tenant,
223
+ database=database,
224
+ )
225
+
226
+ # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
227
+ # necessary because changing the value type from `Any` to`` `Union[str, int, float]`
228
+ # causes the system to somehow convert all values to strings
229
+ @trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION)
230
+ @override
231
+ def get_collection(
232
+ self,
233
+ name: Optional[str] = None,
234
+ id: Optional[UUID] = None,
235
+ embedding_function: Optional[
236
+ EmbeddingFunction[Embeddable]
237
+ ] = ef.DefaultEmbeddingFunction(), # type: ignore
238
+ data_loader: Optional[DataLoader[Loadable]] = None,
239
+ tenant: str = DEFAULT_TENANT,
240
+ database: str = DEFAULT_DATABASE,
241
+ ) -> Collection:
242
+ if id is None and name is None or (id is not None and name is not None):
243
+ raise ValueError("Name or id must be specified, but not both")
244
+ existing = self._sysdb.get_collections(
245
+ id=id, name=name, tenant=tenant, database=database
246
+ )
247
+
248
+ if existing:
249
+ return Collection(
250
+ client=self,
251
+ id=existing[0]["id"],
252
+ name=existing[0]["name"],
253
+ metadata=existing[0]["metadata"], # type: ignore
254
+ embedding_function=embedding_function,
255
+ data_loader=data_loader,
256
+ tenant=existing[0]["tenant"],
257
+ database=existing[0]["database"],
258
+ )
259
+ else:
260
+ raise ValueError(f"Collection {name} does not exist.")
261
+
262
+ @trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION)
263
+ @override
264
+ def list_collections(
265
+ self,
266
+ limit: Optional[int] = None,
267
+ offset: Optional[int] = None,
268
+ tenant: str = DEFAULT_TENANT,
269
+ database: str = DEFAULT_DATABASE,
270
+ ) -> Sequence[Collection]:
271
+ collections = []
272
+ db_collections = self._sysdb.get_collections(
273
+ limit=limit, offset=offset, tenant=tenant, database=database
274
+ )
275
+ for db_collection in db_collections:
276
+ collections.append(
277
+ Collection(
278
+ client=self,
279
+ id=db_collection["id"],
280
+ name=db_collection["name"],
281
+ metadata=db_collection["metadata"], # type: ignore
282
+ tenant=db_collection["tenant"],
283
+ database=db_collection["database"],
284
+ )
285
+ )
286
+ return collections
287
+
288
+ @trace_method("SegmentAPI.count_collections", OpenTelemetryGranularity.OPERATION)
289
+ @override
290
+ def count_collections(
291
+ self,
292
+ tenant: str = DEFAULT_TENANT,
293
+ database: str = DEFAULT_DATABASE,
294
+ ) -> int:
295
+ collection_count = len(
296
+ self._sysdb.get_collections(tenant=tenant, database=database)
297
+ )
298
+
299
+ return collection_count
300
+
301
+ @trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION)
302
+ @override
303
+ def _modify(
304
+ self,
305
+ id: UUID,
306
+ new_name: Optional[str] = None,
307
+ new_metadata: Optional[CollectionMetadata] = None,
308
+ ) -> None:
309
+ if new_name:
310
+ # backwards compatibility in naming requirements (for now)
311
+ check_index_name(new_name)
312
+
313
+ if new_metadata:
314
+ validate_update_metadata(new_metadata)
315
+
316
+ # TODO eventually we'll want to use OptionalArgument and Unspecified in the
317
+ # signature of `_modify` but not changing the API right now.
318
+ if new_name and new_metadata:
319
+ self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
320
+ elif new_name:
321
+ self._sysdb.update_collection(id, name=new_name)
322
+ elif new_metadata:
323
+ self._sysdb.update_collection(id, metadata=new_metadata)
324
+
325
+ @trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
326
+ @override
327
+ def delete_collection(
328
+ self,
329
+ name: str,
330
+ tenant: str = DEFAULT_TENANT,
331
+ database: str = DEFAULT_DATABASE,
332
+ ) -> None:
333
+ existing = self._sysdb.get_collections(
334
+ name=name, tenant=tenant, database=database
335
+ )
336
+
337
+ if existing:
338
+ self._sysdb.delete_collection(
339
+ existing[0]["id"], tenant=tenant, database=database
340
+ )
341
+ for s in self._manager.delete_segments(existing[0]["id"]):
342
+ self._sysdb.delete_segment(s)
343
+ if existing and existing[0]["id"] in self._collection_cache:
344
+ del self._collection_cache[existing[0]["id"]]
345
+ else:
346
+ raise ValueError(f"Collection {name} does not exist.")
347
+
348
+ @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION)
349
+ @override
350
+ def _add(
351
+ self,
352
+ ids: IDs,
353
+ collection_id: UUID,
354
+ embeddings: Embeddings,
355
+ metadatas: Optional[Metadatas] = None,
356
+ documents: Optional[Documents] = None,
357
+ uris: Optional[URIs] = None,
358
+ ) -> bool:
359
+ coll = self._get_collection(collection_id)
360
+ self._manager.hint_use_collection(collection_id, t.Operation.ADD)
361
+ validate_batch(
362
+ (ids, embeddings, metadatas, documents, uris),
363
+ {"max_batch_size": self.max_batch_size},
364
+ )
365
+ records_to_submit = []
366
+ for r in _records(
367
+ t.Operation.ADD,
368
+ ids=ids,
369
+ collection_id=collection_id,
370
+ embeddings=embeddings,
371
+ metadatas=metadatas,
372
+ documents=documents,
373
+ uris=uris,
374
+ ):
375
+ self._validate_embedding_record(coll, r)
376
+ records_to_submit.append(r)
377
+ self._producer.submit_embeddings(coll["topic"], records_to_submit)
378
+
379
+ self._product_telemetry_client.capture(
380
+ CollectionAddEvent(
381
+ collection_uuid=str(collection_id),
382
+ add_amount=len(ids),
383
+ with_metadata=len(ids) if metadatas is not None else 0,
384
+ with_documents=len(ids) if documents is not None else 0,
385
+ with_uris=len(ids) if uris is not None else 0,
386
+ )
387
+ )
388
+ return True
389
+
390
+ @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION)
391
+ @override
392
+ def _update(
393
+ self,
394
+ collection_id: UUID,
395
+ ids: IDs,
396
+ embeddings: Optional[Embeddings] = None,
397
+ metadatas: Optional[Metadatas] = None,
398
+ documents: Optional[Documents] = None,
399
+ uris: Optional[URIs] = None,
400
+ ) -> bool:
401
+ coll = self._get_collection(collection_id)
402
+ self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
403
+ validate_batch(
404
+ (ids, embeddings, metadatas, documents, uris),
405
+ {"max_batch_size": self.max_batch_size},
406
+ )
407
+ records_to_submit = []
408
+ for r in _records(
409
+ t.Operation.UPDATE,
410
+ ids=ids,
411
+ collection_id=collection_id,
412
+ embeddings=embeddings,
413
+ metadatas=metadatas,
414
+ documents=documents,
415
+ uris=uris,
416
+ ):
417
+ self._validate_embedding_record(coll, r)
418
+ records_to_submit.append(r)
419
+ self._producer.submit_embeddings(coll["topic"], records_to_submit)
420
+
421
+ self._product_telemetry_client.capture(
422
+ CollectionUpdateEvent(
423
+ collection_uuid=str(collection_id),
424
+ update_amount=len(ids),
425
+ with_embeddings=len(embeddings) if embeddings else 0,
426
+ with_metadata=len(metadatas) if metadatas else 0,
427
+ with_documents=len(documents) if documents else 0,
428
+ with_uris=len(uris) if uris else 0,
429
+ )
430
+ )
431
+
432
+ return True
433
+
434
+ @trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION)
435
+ @override
436
+ def _upsert(
437
+ self,
438
+ collection_id: UUID,
439
+ ids: IDs,
440
+ embeddings: Embeddings,
441
+ metadatas: Optional[Metadatas] = None,
442
+ documents: Optional[Documents] = None,
443
+ uris: Optional[URIs] = None,
444
+ ) -> bool:
445
+ coll = self._get_collection(collection_id)
446
+ self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
447
+ validate_batch(
448
+ (ids, embeddings, metadatas, documents, uris),
449
+ {"max_batch_size": self.max_batch_size},
450
+ )
451
+ records_to_submit = []
452
+ for r in _records(
453
+ t.Operation.UPSERT,
454
+ ids=ids,
455
+ collection_id=collection_id,
456
+ embeddings=embeddings,
457
+ metadatas=metadatas,
458
+ documents=documents,
459
+ uris=uris,
460
+ ):
461
+ self._validate_embedding_record(coll, r)
462
+ records_to_submit.append(r)
463
+ self._producer.submit_embeddings(coll["topic"], records_to_submit)
464
+
465
+ return True
466
+
467
+ @trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION)
468
+ @override
469
+ def _get(
470
+ self,
471
+ collection_id: UUID,
472
+ ids: Optional[IDs] = None,
473
+ where: Optional[Where] = {},
474
+ sort: Optional[str] = None,
475
+ limit: Optional[int] = None,
476
+ offset: Optional[int] = None,
477
+ page: Optional[int] = None,
478
+ page_size: Optional[int] = None,
479
+ where_document: Optional[WhereDocument] = {},
480
+ include: Include = ["embeddings", "metadatas", "documents"],
481
+ ) -> GetResult:
482
+ add_attributes_to_current_span(
483
+ {
484
+ "collection_id": str(collection_id),
485
+ "ids_count": len(ids) if ids else 0,
486
+ }
487
+ )
488
+
489
+ where = validate_where(where) if where is not None and len(where) > 0 else None
490
+ where_document = (
491
+ validate_where_document(where_document)
492
+ if where_document is not None and len(where_document) > 0
493
+ else None
494
+ )
495
+
496
+ metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
497
+
498
+ if sort is not None:
499
+ raise NotImplementedError("Sorting is not yet supported")
500
+
501
+ if page and page_size:
502
+ offset = (page - 1) * page_size
503
+ limit = page_size
504
+
505
+ records = metadata_segment.get_metadata(
506
+ where=where,
507
+ where_document=where_document,
508
+ ids=ids,
509
+ limit=limit,
510
+ offset=offset,
511
+ )
512
+
513
+ if len(records) == 0:
514
+ # Nothing to return if there are no records
515
+ return GetResult(
516
+ ids=[],
517
+ embeddings=[] if "embeddings" in include else None,
518
+ metadatas=[] if "metadatas" in include else None,
519
+ documents=[] if "documents" in include else None,
520
+ uris=[] if "uris" in include else None,
521
+ data=[] if "data" in include else None,
522
+ )
523
+
524
+ vectors: Sequence[t.VectorEmbeddingRecord] = []
525
+ if "embeddings" in include:
526
+ vector_ids = [r["id"] for r in records]
527
+ vector_segment = self._manager.get_segment(collection_id, VectorReader)
528
+ vectors = vector_segment.get_vectors(ids=vector_ids)
529
+
530
+ # TODO: Fix type so we don't need to ignore
531
+ # It is possible to have a set of records, some with metadata and some without
532
+ # Same with documents
533
+
534
+ metadatas = [r["metadata"] for r in records]
535
+
536
+ if "documents" in include:
537
+ documents = [_doc(m) for m in metadatas]
538
+
539
+ if "uris" in include:
540
+ uris = [_uri(m) for m in metadatas]
541
+
542
+ ids_amount = len(ids) if ids else 0
543
+ self._product_telemetry_client.capture(
544
+ CollectionGetEvent(
545
+ collection_uuid=str(collection_id),
546
+ ids_count=ids_amount,
547
+ limit=limit if limit else 0,
548
+ include_metadata=ids_amount if "metadatas" in include else 0,
549
+ include_documents=ids_amount if "documents" in include else 0,
550
+ include_uris=ids_amount if "uris" in include else 0,
551
+ )
552
+ )
553
+
554
+ return GetResult(
555
+ ids=[r["id"] for r in records],
556
+ embeddings=[r["embedding"] for r in vectors]
557
+ if "embeddings" in include
558
+ else None,
559
+ metadatas=_clean_metadatas(metadatas)
560
+ if "metadatas" in include
561
+ else None, # type: ignore
562
+ documents=documents if "documents" in include else None, # type: ignore
563
+ uris=uris if "uris" in include else None, # type: ignore
564
+ data=None,
565
+ )
566
+
567
+ @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
568
+ @override
569
+ def _delete(
570
+ self,
571
+ collection_id: UUID,
572
+ ids: Optional[IDs] = None,
573
+ where: Optional[Where] = None,
574
+ where_document: Optional[WhereDocument] = None,
575
+ ) -> IDs:
576
+ add_attributes_to_current_span(
577
+ {
578
+ "collection_id": str(collection_id),
579
+ "ids_count": len(ids) if ids else 0,
580
+ }
581
+ )
582
+
583
+ where = validate_where(where) if where is not None and len(where) > 0 else None
584
+ where_document = (
585
+ validate_where_document(where_document)
586
+ if where_document is not None and len(where_document) > 0
587
+ else None
588
+ )
589
+
590
+ # You must have at least one of non-empty ids, where, or where_document.
591
+ if (
592
+ (ids is None or (ids is not None and len(ids) == 0))
593
+ and (where is None or (where is not None and len(where) == 0))
594
+ and (
595
+ where_document is None
596
+ or (where_document is not None and len(where_document) == 0)
597
+ )
598
+ ):
599
+ raise ValueError(
600
+ """
601
+ You must provide either ids, where, or where_document to delete. If
602
+ you want to delete all data in a collection you can delete the
603
+ collection itself using the delete_collection method. Or alternatively,
604
+ you can get() all the relevant ids and then delete them.
605
+ """
606
+ )
607
+
608
+ coll = self._get_collection(collection_id)
609
+ self._manager.hint_use_collection(collection_id, t.Operation.DELETE)
610
+
611
+ if (where or where_document) or not ids:
612
+ metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
613
+ records = metadata_segment.get_metadata(
614
+ where=where, where_document=where_document, ids=ids
615
+ )
616
+ ids_to_delete = [r["id"] for r in records]
617
+ else:
618
+ ids_to_delete = ids
619
+
620
+ if len(ids_to_delete) == 0:
621
+ return []
622
+
623
+ records_to_submit = []
624
+ for r in _records(
625
+ operation=t.Operation.DELETE, ids=ids_to_delete, collection_id=collection_id
626
+ ):
627
+ self._validate_embedding_record(coll, r)
628
+ records_to_submit.append(r)
629
+ self._producer.submit_embeddings(coll["topic"], records_to_submit)
630
+
631
+ self._product_telemetry_client.capture(
632
+ CollectionDeleteEvent(
633
+ collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
634
+ )
635
+ )
636
+ return ids_to_delete
637
+
638
+ @trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION)
639
+ @override
640
+ def _count(self, collection_id: UUID) -> int:
641
+ add_attributes_to_current_span({"collection_id": str(collection_id)})
642
+ metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
643
+ return metadata_segment.count()
644
+
645
+ @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
646
+ @override
647
+ def _query(
648
+ self,
649
+ collection_id: UUID,
650
+ query_embeddings: Embeddings,
651
+ n_results: int = 10,
652
+ where: Where = {},
653
+ where_document: WhereDocument = {},
654
+ include: Include = ["documents", "metadatas", "distances"],
655
+ ) -> QueryResult:
656
+ add_attributes_to_current_span(
657
+ {
658
+ "collection_id": str(collection_id),
659
+ "n_results": n_results,
660
+ "where": str(where),
661
+ }
662
+ )
663
+ where = validate_where(where) if where is not None and len(where) > 0 else where
664
+ where_document = (
665
+ validate_where_document(where_document)
666
+ if where_document is not None and len(where_document) > 0
667
+ else where_document
668
+ )
669
+
670
+ allowed_ids = None
671
+
672
+ coll = self._get_collection(collection_id)
673
+ for embedding in query_embeddings:
674
+ self._validate_dimension(coll, len(embedding), update=False)
675
+
676
+ metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
677
+
678
+ if where or where_document:
679
+ records = metadata_reader.get_metadata(
680
+ where=where, where_document=where_document
681
+ )
682
+ allowed_ids = [r["id"] for r in records]
683
+
684
+ query = t.VectorQuery(
685
+ vectors=query_embeddings,
686
+ k=n_results,
687
+ allowed_ids=allowed_ids,
688
+ include_embeddings="embeddings" in include,
689
+ options=None,
690
+ )
691
+
692
+ vector_reader = self._manager.get_segment(collection_id, VectorReader)
693
+ results = vector_reader.query_vectors(query)
694
+
695
+ ids: List[List[str]] = []
696
+ distances: List[List[float]] = []
697
+ embeddings: List[List[Embedding]] = []
698
+ documents: List[List[Document]] = []
699
+ uris: List[List[URI]] = []
700
+ metadatas: List[List[t.Metadata]] = []
701
+
702
+ for result in results:
703
+ ids.append([r["id"] for r in result])
704
+ if "distances" in include:
705
+ distances.append([r["distance"] for r in result])
706
+ if "embeddings" in include:
707
+ embeddings.append([cast(Embedding, r["embedding"]) for r in result])
708
+
709
+ if "documents" in include or "metadatas" in include or "uris" in include:
710
+ all_ids: Set[str] = set()
711
+ for id_list in ids:
712
+ all_ids.update(id_list)
713
+ records = metadata_reader.get_metadata(ids=list(all_ids))
714
+ metadata_by_id = {r["id"]: r["metadata"] for r in records}
715
+ for id_list in ids:
716
+ # In the segment based architecture, it is possible for one segment
717
+ # to have a record that another segment does not have. This results in
718
+ # data inconsistency. For the case of the local segments and the
719
+ # local segment manager, there is a case where a thread writes
720
+ # a record to the vector segment but not the metadata segment.
721
+ # Then a query'ing thread reads from the vector segment and
722
+ # queries the metadata segment. The metadata segment does not have
723
+ # the record. In this case we choose to return potentially
724
+ # incorrect data in the form of None.
725
+ metadata_list = [metadata_by_id.get(id, None) for id in id_list]
726
+ if "metadatas" in include:
727
+ metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
728
+ if "documents" in include:
729
+ doc_list = [_doc(m) for m in metadata_list]
730
+ documents.append(doc_list) # type: ignore
731
+ if "uris" in include:
732
+ uri_list = [_uri(m) for m in metadata_list]
733
+ uris.append(uri_list) # type: ignore
734
+
735
+ query_amount = len(query_embeddings)
736
+ self._product_telemetry_client.capture(
737
+ CollectionQueryEvent(
738
+ collection_uuid=str(collection_id),
739
+ query_amount=query_amount,
740
+ n_results=n_results,
741
+ with_metadata_filter=query_amount if where is not None else 0,
742
+ with_document_filter=query_amount if where_document is not None else 0,
743
+ include_metadatas=query_amount if "metadatas" in include else 0,
744
+ include_documents=query_amount if "documents" in include else 0,
745
+ include_uris=query_amount if "uris" in include else 0,
746
+ include_distances=query_amount if "distances" in include else 0,
747
+ )
748
+ )
749
+
750
+ return QueryResult(
751
+ ids=ids,
752
+ distances=distances if distances else None,
753
+ metadatas=metadatas if metadatas else None,
754
+ embeddings=embeddings if embeddings else None,
755
+ documents=documents if documents else None,
756
+ uris=uris if uris else None,
757
+ data=None,
758
+ )
759
+
760
+ @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
761
+ @override
762
+ def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
763
+ add_attributes_to_current_span({"collection_id": str(collection_id)})
764
+ return self._get(collection_id, limit=n) # type: ignore
765
+
766
+ @override
767
+ def get_version(self) -> str:
768
+ return __version__
769
+
770
+ @override
771
+ def reset_state(self) -> None:
772
+ self._collection_cache = {}
773
+
774
+ @override
775
+ def reset(self) -> bool:
776
+ self._system.reset_state()
777
+ return True
778
+
779
+ @override
780
+ def get_settings(self) -> Settings:
781
+ return self._settings
782
+
783
+ @property
784
+ @override
785
+ def max_batch_size(self) -> int:
786
+ return self._producer.max_batch_size
787
+
788
+ # TODO: This could potentially cause race conditions in a distributed version of the
789
+ # system, since the cache is only local.
790
+ # TODO: promote collection -> topic to a base class method so that it can be
791
+ # used for channel assignment in the distributed version of the system.
792
+ @trace_method("SegmentAPI._validate_embedding_record", OpenTelemetryGranularity.ALL)
793
+ def _validate_embedding_record(
794
+ self, collection: t.Collection, record: t.SubmitEmbeddingRecord
795
+ ) -> None:
796
+ """Validate the dimension of an embedding record before submitting it to the system."""
797
+ add_attributes_to_current_span({"collection_id": str(collection["id"])})
798
+ if record["embedding"]:
799
+ self._validate_dimension(collection, len(record["embedding"]), update=True)
800
+
801
+ @trace_method("SegmentAPI._validate_dimension", OpenTelemetryGranularity.ALL)
802
+ def _validate_dimension(
803
+ self, collection: t.Collection, dim: int, update: bool
804
+ ) -> None:
805
+ """Validate that a collection supports records of the given dimension. If update
806
+ is true, update the collection if the collection doesn't already have a
807
+ dimension."""
808
+ if collection["dimension"] is None:
809
+ if update:
810
+ id = collection["id"]
811
+ self._sysdb.update_collection(id=id, dimension=dim)
812
+ self._collection_cache[id]["dimension"] = dim
813
+ elif collection["dimension"] != dim:
814
+ raise InvalidDimensionException(
815
+ f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}"
816
+ )
817
+ else:
818
+ return # all is well
819
+
820
+ @trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL)
821
+ def _get_collection(self, collection_id: UUID) -> t.Collection:
822
+ """Read-through cache for collection data"""
823
+ if collection_id not in self._collection_cache:
824
+ collections = self._sysdb.get_collections(id=collection_id)
825
+ if not collections:
826
+ raise InvalidCollectionException(
827
+ f"Collection {collection_id} does not exist."
828
+ )
829
+ self._collection_cache[collection_id] = collections[0]
830
+ return self._collection_cache[collection_id]
831
+
832
+
833
+ def _records(
834
+ operation: t.Operation,
835
+ ids: IDs,
836
+ collection_id: UUID,
837
+ embeddings: Optional[Embeddings] = None,
838
+ metadatas: Optional[Metadatas] = None,
839
+ documents: Optional[Documents] = None,
840
+ uris: Optional[URIs] = None,
841
+ ) -> Generator[t.SubmitEmbeddingRecord, None, None]:
842
+ """Convert parallel lists of embeddings, metadatas and documents to a sequence of
843
+ SubmitEmbeddingRecords"""
844
+
845
+ # Presumes that callers were invoked via Collection model, which means
846
+ # that we know that the embeddings, metadatas and documents have already been
847
+ # normalized and are guaranteed to be consistently named lists.
848
+
849
+ for i, id in enumerate(ids):
850
+ metadata = None
851
+ if metadatas:
852
+ metadata = metadatas[i]
853
+
854
+ if documents:
855
+ document = documents[i]
856
+ if metadata:
857
+ metadata = {**metadata, "chroma:document": document}
858
+ else:
859
+ metadata = {"chroma:document": document}
860
+
861
+ if uris:
862
+ uri = uris[i]
863
+ if metadata:
864
+ metadata = {**metadata, "chroma:uri": uri}
865
+ else:
866
+ metadata = {"chroma:uri": uri}
867
+
868
+ record = t.SubmitEmbeddingRecord(
869
+ id=id,
870
+ embedding=embeddings[i] if embeddings else None,
871
+ encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
872
+ metadata=metadata,
873
+ operation=operation,
874
+ collection_id=collection_id,
875
+ )
876
+ yield record
877
+
878
+
879
+ def _doc(metadata: Optional[t.Metadata]) -> Optional[str]:
880
+ """Retrieve the document (if any) from a Metadata map"""
881
+
882
+ if metadata and "chroma:document" in metadata:
883
+ return str(metadata["chroma:document"])
884
+ return None
885
+
886
+
887
+ def _uri(metadata: Optional[t.Metadata]) -> Optional[str]:
888
+ """Retrieve the uri (if any) from a Metadata map"""
889
+
890
+ if metadata and "chroma:uri" in metadata:
891
+ return str(metadata["chroma:uri"])
892
+ return None
893
+
894
+
895
+ def _clean_metadatas(
896
+ metadata: List[Optional[t.Metadata]],
897
+ ) -> List[Optional[t.Metadata]]:
898
+ """Remove any chroma-specific metadata keys that the client shouldn't see from a
899
+ list of metadata maps."""
900
+ return [_clean_metadata(m) for m in metadata]
901
+
902
+
903
+ def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]:
904
+ """Remove any chroma-specific metadata keys that the client shouldn't see from a
905
+ metadata map."""
906
+ if not metadata:
907
+ return None
908
+ result = {}
909
+ for k, v in metadata.items():
910
+ if not k.startswith("chroma:"):
911
+ result[k] = v
912
+ if len(result) == 0:
913
+ return None
914
+ return result
chromadb/api/types.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
2
+ from numpy.typing import NDArray
3
+ import numpy as np
4
+ from typing_extensions import Literal, TypedDict, Protocol
5
+ import chromadb.errors as errors
6
+ from chromadb.types import (
7
+ Metadata,
8
+ UpdateMetadata,
9
+ Vector,
10
+ LiteralValue,
11
+ LogicalOperator,
12
+ WhereOperator,
13
+ OperatorExpression,
14
+ Where,
15
+ WhereDocumentOperator,
16
+ WhereDocument,
17
+ )
18
+ from inspect import signature
19
+ from tenacity import retry
20
+
21
+ # Re-export types from chromadb.types
22
+ __all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
23
+
24
+ T = TypeVar("T")
25
+ OneOrMany = Union[T, List[T]]
26
+
27
+ # URIs
28
+ URI = str
29
+ URIs = List[URI]
30
+
31
+
32
+ def maybe_cast_one_to_many_uri(target: OneOrMany[URI]) -> URIs:
33
+ if isinstance(target, str):
34
+ # One URI
35
+ return cast(URIs, [target])
36
+ # Already a sequence
37
+ return cast(URIs, target)
38
+
39
+
40
+ # IDs
41
+ ID = str
42
+ IDs = List[ID]
43
+
44
+
45
+ def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
46
+ if isinstance(target, str):
47
+ # One ID
48
+ return cast(IDs, [target])
49
+ # Already a sequence
50
+ return cast(IDs, target)
51
+
52
+
53
+ # Embeddings
54
+ Embedding = Vector
55
+ Embeddings = List[Embedding]
56
+
57
+
58
+ def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings:
59
+ if isinstance(target, List):
60
+ # One Embedding
61
+ if isinstance(target[0], (int, float)):
62
+ return cast(Embeddings, [target])
63
+ # Already a sequence
64
+ return cast(Embeddings, target)
65
+
66
+
67
+ # Metadatas
68
+ Metadatas = List[Metadata]
69
+
70
+
71
+ def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas:
72
+ # One Metadata dict
73
+ if isinstance(target, dict):
74
+ return cast(Metadatas, [target])
75
+ # Already a sequence
76
+ return cast(Metadatas, target)
77
+
78
+
79
+ CollectionMetadata = Dict[str, Any]
80
+ UpdateCollectionMetadata = UpdateMetadata
81
+
82
+ # Documents
83
+ Document = str
84
+ Documents = List[Document]
85
+
86
+
87
+ def is_document(target: Any) -> bool:
88
+ if not isinstance(target, str):
89
+ return False
90
+ return True
91
+
92
+
93
+ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:
94
+ # One Document
95
+ if is_document(target):
96
+ return cast(Documents, [target])
97
+ # Already a sequence
98
+ return cast(Documents, target)
99
+
100
+
101
+ # Images
102
+ ImageDType = Union[np.uint, np.int_, np.float_]
103
+ Image = NDArray[ImageDType]
104
+ Images = List[Image]
105
+
106
+
107
+ def is_image(target: Any) -> bool:
108
+ if not isinstance(target, np.ndarray):
109
+ return False
110
+ if len(target.shape) < 2:
111
+ return False
112
+ return True
113
+
114
+
115
+ def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:
116
+ if is_image(target):
117
+ return cast(Images, [target])
118
+ # Already a sequence
119
+ return cast(Images, target)
120
+
121
+
122
+ Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID)
123
+
124
+ # This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]]
125
+ # However, this provokes an incompatibility with the Overrides library and Python 3.7
126
+ Include = List[
127
+ Union[
128
+ Literal["documents"],
129
+ Literal["embeddings"],
130
+ Literal["metadatas"],
131
+ Literal["distances"],
132
+ Literal["uris"],
133
+ Literal["data"],
134
+ ]
135
+ ]
136
+
137
+ # Re-export types from chromadb.types
138
+ LiteralValue = LiteralValue
139
+ LogicalOperator = LogicalOperator
140
+ WhereOperator = WhereOperator
141
+ OperatorExpression = OperatorExpression
142
+ Where = Where
143
+ WhereDocumentOperator = WhereDocumentOperator
144
+
145
+ Embeddable = Union[Documents, Images]
146
+ D = TypeVar("D", bound=Embeddable, contravariant=True)
147
+
148
+
149
+ Loadable = List[Optional[Image]]
150
+ L = TypeVar("L", covariant=True, bound=Loadable)
151
+
152
+
153
+ class GetResult(TypedDict):
154
+ ids: List[ID]
155
+ embeddings: Optional[List[Embedding]]
156
+ documents: Optional[List[Document]]
157
+ uris: Optional[URIs]
158
+ data: Optional[Loadable]
159
+ metadatas: Optional[List[Metadata]]
160
+
161
+
162
+ class QueryResult(TypedDict):
163
+ ids: List[IDs]
164
+ embeddings: Optional[List[List[Embedding]]]
165
+ documents: Optional[List[List[Document]]]
166
+ uris: Optional[List[List[URI]]]
167
+ data: Optional[List[Loadable]]
168
+ metadatas: Optional[List[List[Metadata]]]
169
+ distances: Optional[List[List[float]]]
170
+
171
+
172
+ class IndexMetadata(TypedDict):
173
+ dimensionality: int
174
+ # The current number of elements in the index (total = additions - deletes)
175
+ curr_elements: int
176
+ # The auto-incrementing ID of the last inserted element, never decreases so
177
+ # can be used as a count of total historical size. Should increase by 1 every add.
178
+ # Assume cannot overflow
179
+ total_elements_added: int
180
+ time_created: float
181
+
182
+
183
+ class EmbeddingFunction(Protocol[D]):
184
+ def __call__(self, input: D) -> Embeddings:
185
+ ...
186
+
187
+ def __init_subclass__(cls) -> None:
188
+ super().__init_subclass__()
189
+ # Raise an exception if __call__ is not defined since it is expected to be defined
190
+ call = getattr(cls, "__call__")
191
+
192
+ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
193
+ result = call(self, input)
194
+ return validate_embeddings(maybe_cast_one_to_many_embedding(result))
195
+
196
+ setattr(cls, "__call__", __call__)
197
+
198
+ def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
199
+ return retry(**retry_kwargs)(self.__call__)(input)
200
+
201
+
202
+ def validate_embedding_function(
203
+ embedding_function: EmbeddingFunction[Embeddable],
204
+ ) -> None:
205
+ function_signature = signature(
206
+ embedding_function.__class__.__call__
207
+ ).parameters.keys()
208
+ protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys()
209
+
210
+ if not function_signature == protocol_signature:
211
+ raise ValueError(
212
+ f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n"
213
+ "Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n"
214
+ "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
215
+ )
216
+
217
+
218
+ class DataLoader(Protocol[L]):
219
+ def __call__(self, uris: URIs) -> L:
220
+ ...
221
+
222
+
223
+ def validate_ids(ids: IDs) -> IDs:
224
+ """Validates ids to ensure it is a list of strings"""
225
+ if not isinstance(ids, list):
226
+ raise ValueError(f"Expected IDs to be a list, got {ids}")
227
+ if len(ids) == 0:
228
+ raise ValueError(f"Expected IDs to be a non-empty list, got {ids}")
229
+ seen = set()
230
+ dups = set()
231
+ for id_ in ids:
232
+ if not isinstance(id_, str):
233
+ raise ValueError(f"Expected ID to be a str, got {id_}")
234
+ if id_ in seen:
235
+ dups.add(id_)
236
+ else:
237
+ seen.add(id_)
238
+ if dups:
239
+ n_dups = len(dups)
240
+ if n_dups < 10:
241
+ example_string = ", ".join(dups)
242
+ message = (
243
+ f"Expected IDs to be unique, found duplicates of: {example_string}"
244
+ )
245
+ else:
246
+ examples = []
247
+ for idx, dup in enumerate(dups):
248
+ examples.append(dup)
249
+ if idx == 10:
250
+ break
251
+ example_string = (
252
+ f"{', '.join(examples[:5])}, ..., {', '.join(examples[-5:])}"
253
+ )
254
+ message = f"Expected IDs to be unique, found {n_dups} duplicated IDs: {example_string}"
255
+ raise errors.DuplicateIDError(message)
256
+ return ids
257
+
258
+
259
+ def validate_metadata(metadata: Metadata) -> Metadata:
260
+ """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
261
+ if not isinstance(metadata, dict) and metadata is not None:
262
+ raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
263
+ if metadata is None:
264
+ return metadata
265
+ if len(metadata) == 0:
266
+ raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
267
+ for key, value in metadata.items():
268
+ if not isinstance(key, str):
269
+ raise TypeError(
270
+ f"Expected metadata key to be a str, got {key} which is a {type(key)}"
271
+ )
272
+ # isinstance(True, int) evaluates to True, so we need to check for bools separately
273
+ if not isinstance(value, bool) and not isinstance(value, (str, int, float)):
274
+ raise ValueError(
275
+ f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value)}"
276
+ )
277
+ return metadata
278
+
279
+
280
+ def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
281
+ """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
282
+ if not isinstance(metadata, dict) and metadata is not None:
283
+ raise ValueError(f"Expected metadata to be a dict or None, got {metadata}")
284
+ if metadata is None:
285
+ return metadata
286
+ if len(metadata) == 0:
287
+ raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
288
+ for key, value in metadata.items():
289
+ if not isinstance(key, str):
290
+ raise ValueError(f"Expected metadata key to be a str, got {key}")
291
+ # isinstance(True, int) evaluates to True, so we need to check for bools separately
292
+ if not isinstance(value, bool) and not isinstance(
293
+ value, (str, int, float, type(None))
294
+ ):
295
+ raise ValueError(
296
+ f"Expected metadata value to be a str, int, or float, got {value}"
297
+ )
298
+ return metadata
299
+
300
+
301
+ def validate_metadatas(metadatas: Metadatas) -> Metadatas:
302
+ """Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools"""
303
+ if not isinstance(metadatas, list):
304
+ raise ValueError(f"Expected metadatas to be a list, got {metadatas}")
305
+ for metadata in metadatas:
306
+ validate_metadata(metadata)
307
+ return metadatas
308
+
309
+
310
+ def validate_where(where: Where) -> Where:
311
+ """
312
+ Validates where to ensure it is a dictionary of strings to strings, ints, floats or operator expressions,
313
+ or in the case of $and and $or, a list of where expressions
314
+ """
315
+ if not isinstance(where, dict):
316
+ raise ValueError(f"Expected where to be a dict, got {where}")
317
+ if len(where) != 1:
318
+ raise ValueError(f"Expected where to have exactly one operator, got {where}")
319
+ for key, value in where.items():
320
+ if not isinstance(key, str):
321
+ raise ValueError(f"Expected where key to be a str, got {key}")
322
+ if (
323
+ key != "$and"
324
+ and key != "$or"
325
+ and key != "$in"
326
+ and key != "$nin"
327
+ and not isinstance(value, (str, int, float, dict))
328
+ ):
329
+ raise ValueError(
330
+ f"Expected where value to be a str, int, float, or operator expression, got {value}"
331
+ )
332
+ if key == "$and" or key == "$or":
333
+ if not isinstance(value, list):
334
+ raise ValueError(
335
+ f"Expected where value for $and or $or to be a list of where expressions, got {value}"
336
+ )
337
+ if len(value) <= 1:
338
+ raise ValueError(
339
+ f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}"
340
+ )
341
+ for where_expression in value:
342
+ validate_where(where_expression)
343
+ # Value is a operator expression
344
+ if isinstance(value, dict):
345
+ # Ensure there is only one operator
346
+ if len(value) != 1:
347
+ raise ValueError(
348
+ f"Expected operator expression to have exactly one operator, got {value}"
349
+ )
350
+
351
+ for operator, operand in value.items():
352
+ # Only numbers can be compared with gt, gte, lt, lte
353
+ if operator in ["$gt", "$gte", "$lt", "$lte"]:
354
+ if not isinstance(operand, (int, float)):
355
+ raise ValueError(
356
+ f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
357
+ )
358
+ if operator in ["$in", "$nin"]:
359
+ if not isinstance(operand, list):
360
+ raise ValueError(
361
+ f"Expected operand value to be an list for operator {operator}, got {operand}"
362
+ )
363
+ if operator not in [
364
+ "$gt",
365
+ "$gte",
366
+ "$lt",
367
+ "$lte",
368
+ "$ne",
369
+ "$eq",
370
+ "$in",
371
+ "$nin",
372
+ ]:
373
+ raise ValueError(
374
+ f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
375
+ f"got {operator}"
376
+ )
377
+
378
+ if not isinstance(operand, (str, int, float, list)):
379
+ raise ValueError(
380
+ f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
381
+ )
382
+ if isinstance(operand, list) and (
383
+ len(operand) == 0
384
+ or not all(isinstance(x, type(operand[0])) for x in operand)
385
+ ):
386
+ raise ValueError(
387
+ f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
388
+ f"got {operand}"
389
+ )
390
+ return where
391
+
392
+
393
+ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
394
+ """
395
+ Validates where_document to ensure it is a dictionary of WhereDocumentOperator to strings, or in the case of $and and $or,
396
+ a list of where_document expressions
397
+ """
398
+ if not isinstance(where_document, dict):
399
+ raise ValueError(
400
+ f"Expected where document to be a dictionary, got {where_document}"
401
+ )
402
+ if len(where_document) != 1:
403
+ raise ValueError(
404
+ f"Expected where document to have exactly one operator, got {where_document}"
405
+ )
406
+ for operator, operand in where_document.items():
407
+ if operator not in ["$contains", "$not_contains", "$and", "$or"]:
408
+ raise ValueError(
409
+ f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
410
+ )
411
+ if operator == "$and" or operator == "$or":
412
+ if not isinstance(operand, list):
413
+ raise ValueError(
414
+ f"Expected document value for $and or $or to be a list of where document expressions, got {operand}"
415
+ )
416
+ if len(operand) <= 1:
417
+ raise ValueError(
418
+ f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}"
419
+ )
420
+ for where_document_expression in operand:
421
+ validate_where_document(where_document_expression)
422
+ # Value is a $contains operator
423
+ elif not isinstance(operand, str):
424
+ raise ValueError(
425
+ f"Expected where document operand value for operator $contains to be a str, got {operand}"
426
+ )
427
+ elif len(operand) == 0:
428
+ raise ValueError(
429
+ "Expected where document operand value for operator $contains to be a non-empty str"
430
+ )
431
+ return where_document
432
+
433
+
434
+ def validate_include(include: Include, allow_distances: bool) -> Include:
435
+ """Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
436
+ to control if distances is allowed"""
437
+
438
+ if not isinstance(include, list):
439
+ raise ValueError(f"Expected include to be a list, got {include}")
440
+ for item in include:
441
+ if not isinstance(item, str):
442
+ raise ValueError(f"Expected include item to be a str, got {item}")
443
+ allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
444
+ if allow_distances:
445
+ allowed_values.append("distances")
446
+ if item not in allowed_values:
447
+ raise ValueError(
448
+ f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
449
+ )
450
+ return include
451
+
452
+
453
+ def validate_n_results(n_results: int) -> int:
454
+ """Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
455
+ # Check Number of requested results
456
+ if not isinstance(n_results, int):
457
+ raise ValueError(
458
+ f"Expected requested number of results to be a int, got {n_results}"
459
+ )
460
+ if n_results <= 0:
461
+ raise TypeError(
462
+ f"Number of requested results {n_results}, cannot be negative, or zero."
463
+ )
464
+ return n_results
465
+
466
+
467
+ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
468
+ """Validates embeddings to ensure it is a list of list of ints, or floats"""
469
+ if not isinstance(embeddings, list):
470
+ raise ValueError(f"Expected embeddings to be a list, got {embeddings}")
471
+ if len(embeddings) == 0:
472
+ raise ValueError(
473
+ f"Expected embeddings to be a list with at least one item, got {embeddings}"
474
+ )
475
+ if not all([isinstance(e, list) for e in embeddings]):
476
+ raise ValueError(
477
+ f"Expected each embedding in the embeddings to be a list, got {embeddings}"
478
+ )
479
+ for i,embedding in enumerate(embeddings):
480
+ if len(embedding) == 0:
481
+ raise ValueError(
482
+ f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"
483
+ )
484
+ if not all(
485
+ [
486
+ isinstance(value, (int, float)) and not isinstance(value, bool)
487
+ for value in embedding
488
+ ]
489
+ ):
490
+ raise ValueError(
491
+ f"Expected each value in the embedding to be a int or float, got {embeddings}"
492
+ )
493
+ return embeddings
494
+
495
+
496
+ def validate_batch(
497
+ batch: Tuple[
498
+ IDs,
499
+ Optional[Embeddings],
500
+ Optional[Metadatas],
501
+ Optional[Documents],
502
+ Optional[URIs],
503
+ ],
504
+ limits: Dict[str, Any],
505
+ ) -> None:
506
+ if len(batch[0]) > limits["max_batch_size"]:
507
+ raise ValueError(
508
+ f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
509
+ )