tuandunghcmut commited on
Commit
fdde15c
·
verified ·
1 Parent(s): db08794

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. clone-IDEA-Research/Grounded-SAM-2/.clang-format +85 -0
  2. clone-IDEA-Research/Grounded-SAM-2/.gitignore +147 -0
  3. clone-IDEA-Research/Grounded-SAM-2/.watchmanconfig +1 -0
  4. clone-IDEA-Research/Grounded-SAM-2/CODE_OF_CONDUCT.md +80 -0
  5. clone-IDEA-Research/Grounded-SAM-2/CONTRIBUTING.md +31 -0
  6. clone-IDEA-Research/Grounded-SAM-2/Dockerfile +37 -0
  7. clone-IDEA-Research/Grounded-SAM-2/INSTALL.md +189 -0
  8. clone-IDEA-Research/Grounded-SAM-2/LICENSE +201 -0
  9. clone-IDEA-Research/Grounded-SAM-2/LICENSE_cctorch +29 -0
  10. clone-IDEA-Research/Grounded-SAM-2/LICENSE_groundingdino +201 -0
  11. clone-IDEA-Research/Grounded-SAM-2/LICENSE_sam2 +201 -0
  12. clone-IDEA-Research/Grounded-SAM-2/MANIFEST.in +7 -0
  13. clone-IDEA-Research/Grounded-SAM-2/Makefile +37 -0
  14. clone-IDEA-Research/Grounded-SAM-2/README.md +484 -0
  15. clone-IDEA-Research/Grounded-SAM-2/SAM2_README.md +140 -0
  16. clone-IDEA-Research/Grounded-SAM-2/backend.Dockerfile +64 -0
  17. clone-IDEA-Research/Grounded-SAM-2/docker-compose.yaml +42 -0
  18. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_dinox_demo.py +245 -0
  19. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_autolabel_pipeline.py +198 -0
  20. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_image_demo.py +657 -0
  21. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_gd1.5_demo.py +249 -0
  22. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_hf_model_demo.py +187 -0
  23. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_local_demo.py +160 -0
  24. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo.py +198 -0
  25. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_dinox.py +237 -0
  26. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py +214 -0
  27. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py +220 -0
  28. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py +239 -0
  29. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py +203 -0
  30. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py +224 -0
  31. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_plus.py +247 -0
  32. clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_gd1.5.py +221 -0
  33. clone-IDEA-Research/Grounded-SAM-2/pyproject.toml +6 -0
  34. clone-IDEA-Research/Grounded-SAM-2/sam2/__init__.py +11 -0
  35. clone-IDEA-Research/Grounded-SAM-2/sam2/automatic_mask_generator.py +454 -0
  36. clone-IDEA-Research/Grounded-SAM-2/sam2/build_sam.py +167 -0
  37. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_b+.yaml +113 -0
  38. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_l.yaml +117 -0
  39. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_s.yaml +116 -0
  40. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_t.yaml +118 -0
  41. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_image_predictor.py +466 -0
  42. clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_video_predictor.py +1172 -0
  43. clone-IDEA-Research/Grounded-SAM-2/setup.py +174 -0
  44. clone-IDEA-Research/Grounded-Segment-Anything/.gitignore +135 -0
  45. clone-IDEA-Research/Grounded-Segment-Anything/.gitmodules +7 -0
  46. clone-IDEA-Research/Grounded-Segment-Anything/CITATION.cff +8 -0
  47. clone-IDEA-Research/Grounded-Segment-Anything/Dockerfile +30 -0
  48. clone-IDEA-Research/Grounded-Segment-Anything/LICENSE +201 -0
  49. clone-IDEA-Research/Grounded-Segment-Anything/Makefile +43 -0
  50. clone-IDEA-Research/Grounded-Segment-Anything/README.md +808 -0
clone-IDEA-Research/Grounded-SAM-2/.clang-format ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AccessModifierOffset: -1
2
+ AlignAfterOpenBracket: AlwaysBreak
3
+ AlignConsecutiveAssignments: false
4
+ AlignConsecutiveDeclarations: false
5
+ AlignEscapedNewlinesLeft: true
6
+ AlignOperands: false
7
+ AlignTrailingComments: false
8
+ AllowAllParametersOfDeclarationOnNextLine: false
9
+ AllowShortBlocksOnASingleLine: false
10
+ AllowShortCaseLabelsOnASingleLine: false
11
+ AllowShortFunctionsOnASingleLine: Empty
12
+ AllowShortIfStatementsOnASingleLine: false
13
+ AllowShortLoopsOnASingleLine: false
14
+ AlwaysBreakAfterReturnType: None
15
+ AlwaysBreakBeforeMultilineStrings: true
16
+ AlwaysBreakTemplateDeclarations: true
17
+ BinPackArguments: false
18
+ BinPackParameters: false
19
+ BraceWrapping:
20
+ AfterClass: false
21
+ AfterControlStatement: false
22
+ AfterEnum: false
23
+ AfterFunction: false
24
+ AfterNamespace: false
25
+ AfterObjCDeclaration: false
26
+ AfterStruct: false
27
+ AfterUnion: false
28
+ BeforeCatch: false
29
+ BeforeElse: false
30
+ IndentBraces: false
31
+ BreakBeforeBinaryOperators: None
32
+ BreakBeforeBraces: Attach
33
+ BreakBeforeTernaryOperators: true
34
+ BreakConstructorInitializersBeforeComma: false
35
+ BreakAfterJavaFieldAnnotations: false
36
+ BreakStringLiterals: false
37
+ ColumnLimit: 80
38
+ CommentPragmas: '^ IWYU pragma:'
39
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
40
+ ConstructorInitializerIndentWidth: 4
41
+ ContinuationIndentWidth: 4
42
+ Cpp11BracedListStyle: true
43
+ DerivePointerAlignment: false
44
+ DisableFormat: false
45
+ ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
46
+ IncludeCategories:
47
+ - Regex: '^<.*\.h(pp)?>'
48
+ Priority: 1
49
+ - Regex: '^<.*'
50
+ Priority: 2
51
+ - Regex: '.*'
52
+ Priority: 3
53
+ IndentCaseLabels: true
54
+ IndentWidth: 2
55
+ IndentWrappedFunctionNames: false
56
+ KeepEmptyLinesAtTheStartOfBlocks: false
57
+ MacroBlockBegin: ''
58
+ MacroBlockEnd: ''
59
+ MaxEmptyLinesToKeep: 1
60
+ NamespaceIndentation: None
61
+ ObjCBlockIndentWidth: 2
62
+ ObjCSpaceAfterProperty: false
63
+ ObjCSpaceBeforeProtocolList: false
64
+ PenaltyBreakBeforeFirstCallParameter: 1
65
+ PenaltyBreakComment: 300
66
+ PenaltyBreakFirstLessLess: 120
67
+ PenaltyBreakString: 1000
68
+ PenaltyExcessCharacter: 1000000
69
+ PenaltyReturnTypeOnItsOwnLine: 200
70
+ PointerAlignment: Left
71
+ ReflowComments: true
72
+ SortIncludes: true
73
+ SpaceAfterCStyleCast: false
74
+ SpaceBeforeAssignmentOperators: true
75
+ SpaceBeforeParens: ControlStatements
76
+ SpaceInEmptyParentheses: false
77
+ SpacesBeforeTrailingComments: 1
78
+ SpacesInAngles: false
79
+ SpacesInContainerLiterals: true
80
+ SpacesInCStyleCastParentheses: false
81
+ SpacesInParentheses: false
82
+ SpacesInSquareBrackets: false
83
+ Standard: Cpp11
84
+ TabWidth: 8
85
+ UseTab: Never
clone-IDEA-Research/Grounded-SAM-2/.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 2
2
+ .vscode/
3
+ .DS_Store
4
+ __pycache__/
5
+ *-checkpoint.ipynb
6
+ .venv
7
+ *.egg*
8
+ build/*
9
+ _C.*
10
+ outputs/*
11
+ checkpoints/*.pt
12
+ *test*
13
+ # Byte-compiled / optimized / DLL files
14
+ __pycache__/
15
+ *.py[cod]
16
+ *$py.class
17
+
18
+ # C extensions
19
+ *.so
20
+
21
+ # Distribution / packaging
22
+ .Python
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ pip-wheel-metadata/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .nox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ *.py,cover
63
+ .hypothesis/
64
+ .pytest_cache/
65
+
66
+ # Translations
67
+ *.mo
68
+ *.pot
69
+
70
+ # Django stuff:
71
+ *.log
72
+ local_settings.py
73
+ db.sqlite3
74
+ db.sqlite3-journal
75
+
76
+ # Flask stuff:
77
+ instance/
78
+ .webassets-cache
79
+
80
+ # Scrapy stuff:
81
+ .scrapy
82
+
83
+ # Sphinx documentation
84
+ docs/_build/
85
+
86
+ # PyBuilder
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107
+ __pypackages__/
108
+
109
+ # Celery stuff
110
+ celerybeat-schedule
111
+ celerybeat.pid
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+
143
+ # checkpoint
144
+ *.pth
145
+ outputs/
146
+
147
+ .idea/
clone-IDEA-Research/Grounded-SAM-2/.watchmanconfig ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
clone-IDEA-Research/Grounded-SAM-2/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
clone-IDEA-Research/Grounded-SAM-2/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to segment-anything
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to segment-anything, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
clone-IDEA-Research/Grounded-SAM-2/Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
2
+
3
+ # Arguments to build Docker Image using CUDA
4
+ ARG USE_CUDA=0
5
+ ARG TORCH_ARCH="7.0;7.5;8.0;8.6"
6
+
7
+ ENV AM_I_DOCKER=True
8
+ ENV BUILD_WITH_CUDA="${USE_CUDA}"
9
+ ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}"
10
+ ENV CUDA_HOME=/usr/local/cuda-12.1/
11
+ # Ensure CUDA is correctly set up
12
+ ENV PATH=/usr/local/cuda-12.1/bin:${PATH}
13
+ ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
14
+
15
+ # Install required packages and specific gcc/g++
16
+ RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
17
+ libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 -y \
18
+ && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
19
+
20
+ ENV CC=gcc-10
21
+ ENV CXX=g++-10
22
+
23
+ RUN mkdir -p /home/appuser/Grounded-SAM-2
24
+ COPY . /home/appuser/Grounded-SAM-2/
25
+
26
+ WORKDIR /home/appuser/Grounded-SAM-2
27
+
28
+
29
+ # Install essential Python packages
30
+ RUN python -m pip install --upgrade pip setuptools wheel numpy \
31
+ opencv-python transformers supervision pycocotools addict yapf timm
32
+
33
+ # Install segment_anything package in editable mode
34
+ RUN python -m pip install -e .
35
+
36
+ # Install grounding dino
37
+ RUN python -m pip install --no-build-isolation -e grounding_dino
clone-IDEA-Research/Grounded-SAM-2/INSTALL.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ ### Requirements
4
+
5
+ - Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
6
+ * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
7
+ - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
8
+ - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
9
+
10
+ Then, install SAM 2 from the root of this repository via
11
+ ```bash
12
+ pip install -e ".[notebooks]"
13
+ ```
14
+
15
+ Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
16
+ ```bash
17
+ # skip the SAM 2 CUDA extension
18
+ SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]"
19
+ ```
20
+ This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
21
+
22
+ ### Building the SAM 2 CUDA extension
23
+
24
+ By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
25
+
26
+ If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.
27
+
28
+ If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
29
+ ```bash
30
+ pip uninstall -y SAM-2 && \
31
+ rm -f ./sam2/*.so && \
32
+ SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]"
33
+ ```
34
+
35
+ Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
36
+
37
+ Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
38
+
39
+ ### Common Installation Issues
40
+
41
+ Click each issue for its solutions:
42
+
43
+ <details>
44
+ <summary>
45
+ I got `ImportError: cannot import name '_C' from 'sam2'`
46
+ </summary>
47
+ <br/>
48
+
49
+ This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
50
+
51
+ In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77.
52
+ </details>
53
+
54
+ <details>
55
+ <summary>
56
+ I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'`
57
+ </summary>
58
+ <br/>
59
+
60
+ This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
61
+ ```bash
62
+ export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo
63
+ export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
64
+ ```
65
+ to manually add `sam2_configs` into your Python's `sys.path`.
66
+
67
+ </details>
68
+
69
+ <details>
70
+ <summary>
71
+ I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints
72
+ </summary>
73
+ <br/>
74
+
75
+ This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps:
76
+
77
+ 1. pull the latest code from the `main` branch of this repo
78
+ 2. run `pip uninstall -y SAM-2` to uninstall any previous installations
79
+ 3. then install the latest repo again using `pip install -e ".[notebooks]"`
80
+
81
+ In case the steps above still don't resolve the error, please try running in your Python environment the following
82
+ ```python
83
+ from sam2.modeling import sam2_base
84
+
85
+ print(sam2_base.__file__)
86
+ ```
87
+ and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation.
88
+
89
+ </details>
90
+
91
+ <details>
92
+ <summary>
93
+ My installation failed with `CUDA_HOME environment variable is not set`
94
+ </summary>
95
+ <br/>
96
+
97
+ This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
98
+ ```
99
+ export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
100
+ ```
101
+ and rerun the installation.
102
+
103
+ Also, you should make sure
104
+ ```
105
+ python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
106
+ ```
107
+ print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
108
+
109
+ If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command:
110
+ ```
111
+ pip install --no-build-isolation -e .
112
+ ```
113
+
114
+ </details>
115
+
116
+ <details>
117
+ <summary>
118
+ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
119
+ </summary>
120
+ <br/>
121
+
122
+ This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
123
+
124
+ In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
125
+
126
+ We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
127
+ </details>
128
+
129
+ <details>
130
+ <summary>
131
+ I got `CUDA error: no kernel image is available for execution on the device`
132
+ </summary>
133
+ <br/>
134
+
135
+ A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).
136
+
137
+ You can try pulling the latest code from the SAM 2 repo and running the following
138
+ ```
139
+ export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
140
+ ```
141
+ to manually specify the CUDA capability in the compilation target that matches your GPU.
142
+ </details>
143
+
144
+ <details>
145
+ <summary>
146
+ I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
147
+ </summary>
148
+ <br/>
149
+
150
+ This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
151
+ ```python
152
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
153
+ ```
154
+ in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
155
+ ```python
156
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
157
+ ```
158
+ to relax the attention kernel setting and use other kernels than Flash Attention.
159
+ </details>
160
+
161
+ <details>
162
+ <summary>
163
+ I got `Error compiling objects for extension`
164
+ </summary>
165
+ <br/>
166
+
167
+ You may see error log of:
168
+ > unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
169
+
170
+ This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).<br>
171
+ You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py). <br>
172
+ After adding the argument, `get_extension()` will look like this:
173
+ ```python
174
+ def get_extensions():
175
+ srcs = ["sam2/csrc/connected_components.cu"]
176
+ compile_args = {
177
+ "cxx": [],
178
+ "nvcc": [
179
+ "-DCUDA_HAS_FP16=1",
180
+ "-D__CUDA_NO_HALF_OPERATORS__",
181
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
182
+ "-D__CUDA_NO_HALF2_OPERATORS__",
183
+ "-allow-unsupported-compiler" # Add this argument
184
+ ],
185
+ }
186
+ ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
187
+ return ext_modules
188
+ ```
189
+ </details>
clone-IDEA-Research/Grounded-SAM-2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 - present, IDEA Research.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
clone-IDEA-Research/Grounded-SAM-2/LICENSE_cctorch ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
clone-IDEA-Research/Grounded-SAM-2/LICENSE_groundingdino ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 - present, IDEA Research.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
clone-IDEA-Research/Grounded-SAM-2/LICENSE_sam2 ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
clone-IDEA-Research/Grounded-SAM-2/MANIFEST.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ recursive-include sam2 *.yaml #include all config files
clone-IDEA-Research/Grounded-SAM-2/Makefile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Get version of CUDA and enable it for compilation if CUDA > 11.0
2
+ # This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
3
+ # and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
4
+ # when running in Docker
5
+ # Check if nvcc is installed
6
+ NVCC := $(shell which nvcc)
7
+ ifeq ($(NVCC),)
8
+ # NVCC not found
9
+ USE_CUDA := 0
10
+ NVCC_VERSION := "not installed"
11
+ else
12
+ NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
13
+ USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
14
+ endif
15
+
16
+ # Add the list of supported ARCHs
17
+ ifeq ($(USE_CUDA), 1)
18
+ TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX"
19
+ BUILD_MESSAGE := "I will try to build the image with CUDA support"
20
+ else
21
+ TORCH_CUDA_ARCH_LIST :=
22
+ BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
23
+ endif
24
+
25
+
26
+ build-image:
27
+ @echo $(BUILD_MESSAGE)
28
+ docker build --build-arg USE_CUDA=$(USE_CUDA) \
29
+ --build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
30
+ -t grounded_sam2:1.0 .
31
+ run:
32
+ docker run --gpus all -it --rm --net=host --privileged \
33
+ -v /tmp/.X11-unix:/tmp/.X11-unix \
34
+ -v "${PWD}":/home/appuser/Grounded-SAM-2 \
35
+ -e DISPLAY=$DISPLAY \
36
+ --name=gsa \
37
+ --ipc=host -it grounded_sam2:1.0
clone-IDEA-Research/Grounded-SAM-2/README.md ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Grounded SAM 2: Ground and Track Anything in Videos
2
+
3
+ **[IDEA-Research](https://github.com/idea-research)**
4
+
5
+ [Tianhe Ren](https://rentainhe.github.io/), [Shuo Shen](https://github.com/ShuoShenDe)
6
+
7
+ [[`SAM 2 Paper`](https://arxiv.org/abs/2408.00714)] [[`Grounding DINO Paper`](https://arxiv.org/abs/2303.05499)] [[`Grounding DINO 1.5 Paper`](https://arxiv.org/abs/2405.10300)] [[`DINO-X Paper`](https://arxiv.org/abs/2411.14347)] [[`BibTeX`](#citation)]
8
+
9
+ [![Video Name](./assets/grounded_sam_2_intro.jpg)](https://github.com/user-attachments/assets/f0fb0022-779a-49fb-8f46-3a18a8b4e893)
10
+
11
+ ## Highlights
12
+
13
+ Grounded SAM 2 is a foundation model pipeline towards grounding and track anything in Videos with [Grounding DINO](https://arxiv.org/abs/2303.05499), [Grounding DINO 1.5](https://arxiv.org/abs/2405.10300), [Florence-2](https://arxiv.org/abs/2311.06242), [DINO-X](https://arxiv.org/abs/2411.14347) and [SAM 2](https://arxiv.org/abs/2408.00714).
14
+
15
+ In this repo, we've supported the following demo with **simple implementations**:
16
+ - **Ground and Segment Anything** with Grounding DINO, Grounding DINO 1.5 & 1.6, DINO-X and SAM 2
17
+ - **Ground and Track Anything** with Grounding DINO, Grounding DINO 1.5 & 1.6, DINO-X and SAM 2
18
+ - **Detect, Segment and Track Visualization** based on the powerful [supervision](https://github.com/roboflow/supervision) library.
19
+
20
+ Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
21
+
22
+ ## Latest updates
23
+
24
+ - **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
25
+
26
+ - **2024.10.24**: Support [SAHI (Slicing Aided Hyper Inference)](https://docs.ultralytics.com/guides/sahi-tiled-inference/) on Grounded SAM 2 (with Grounding DINO 1.5) which may be helpful for inferencing high resolution image with dense small objects (e.g. **4K** images).
27
+
28
+ - **2024.10.10**: Support `SAM-2.1` models, if you want to use `SAM 2.1` model, you need to update to the latest code and reinstall SAM 2 follow [SAM 2.1 Installation](https://github.com/facebookresearch/sam2?tab=readme-ov-file#latest-updates).
29
+
30
+ - **2024.08.31**: Support `dump json results` in Grounded SAM 2 Image Demos (with Grounding DINO).
31
+
32
+ - **2024.08.20**: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`.
33
+
34
+ - **2024.08.09**: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
35
+
36
+ - **2024.08.07**: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos.
37
+
38
+ ## Contents
39
+ - [Installation](#installation)
40
+ - [Grounded SAM 2 Demos](#grounded-sam-2-demos)
41
+ - [Grounded SAM 2 Image Demo](#grounded-sam-2-image-demo-with-grounding-dino)
42
+ - [Grounded SAM 2 Image Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-image-demo-with-grounding-dino-15--16)
43
+ - [Grounded SAM 2 Image Demo (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x)
44
+ - [Grounded SAM 2 with SAHI for High Resolution Image Inference](#sahi-slicing-aided-hyper-inference-with-grounding-dino-15-and-sam-2)
45
+ - [Automatically Saving Grounding and Segmentation Results](#automatically-saving-grounding-results-image-demo)
46
+ - [Grounded SAM 2 Video Object Tracking Demo](#grounded-sam-2-video-object-tracking-demo)
47
+ - [Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-grounding-dino-15--16)
48
+ - [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino)
49
+ - [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
50
+ - [Grounded SAM 2 Video Object Tracking Demo (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x)
51
+ - [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
52
+ - [Grounded SAM 2 Florence-2 Demos](#grounded-sam-2-florence-2-demos)
53
+ - [Grounded SAM 2 Florence-2 Image Demo](#grounded-sam-2-florence-2-image-demo)
54
+ - [Grounded SAM 2 Florence-2 Image Auto-Labeling Demo](#grounded-sam-2-florence-2-image-auto-labeling-demo)
55
+ - [Citation](#citation)
56
+
57
+
58
+
59
+ ## Installation
60
+
61
+ Download the pretrained `SAM 2` checkpoints:
62
+
63
+ ```bash
64
+ cd checkpoints
65
+ bash download_ckpts.sh
66
+ ```
67
+
68
+ Download the pretrained `Grounding DINO` checkpoints:
69
+
70
+ ```bash
71
+ cd gdino_checkpoints
72
+ bash download_ckpts.sh
73
+ ```
74
+
75
+ ### Installation without docker
76
+
77
+ Install PyTorch environment first. We use `python=3.10`, as well as `torch >= 2.3.1`, `torchvision>=0.18.1` and `cuda-12.1` in our environment to run this demo. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. You can easily install the latest version of PyTorch as follows:
78
+
79
+ ```bash
80
+ pip3 install torch torchvision torchaudio
81
+ ```
82
+
83
+ Since we need the CUDA compilation environment to compile the `Deformable Attention` operator used in Grounding DINO, we need to check whether the CUDA environment variables have been set correctly (which you can refer to [Grounding DINO Installation](https://github.com/IDEA-Research/GroundingDINO?tab=readme-ov-file#hammer_and_wrench-install) for more details). You can set the environment variable manually as follows if you want to build a local GPU environment for Grounding DINO to run Grounded SAM 2:
84
+
85
+ ```bash
86
+ export CUDA_HOME=/path/to/cuda-12.1/
87
+ ```
88
+
89
+ Install `Segment Anything 2`:
90
+
91
+ ```bash
92
+ pip install -e .
93
+ ```
94
+
95
+ Install `Grounding DINO`:
96
+
97
+ ```bash
98
+ pip install --no-build-isolation -e grounding_dino
99
+ ```
100
+
101
+ ### Installation with docker
102
+ Build the Docker image and Run the Docker container:
103
+
104
+ ```
105
+ cd Grounded-SAM-2
106
+ make build-image
107
+ make run
108
+ ```
109
+ After executing these commands, you will be inside the Docker environment. The working directory within the container is set to: `/home/appuser/Grounded-SAM-2`
110
+
111
+ Once inside the Docker environment, you can start the demo by running:
112
+ ```
113
+ python grounded_sam2_tracking_demo.py
114
+ ```
115
+
116
+ ## Grounded SAM 2 Demos
117
+ ### Grounded SAM 2 Image Demo (with Grounding DINO)
118
+ Note that `Grounding DINO` has already been supported in [Huggingface](https://huggingface.co/IDEA-Research/grounding-dino-tiny), so we provide two choices for running `Grounded SAM 2` model:
119
+ - Use huggingface API to inference Grounding DINO (which is simple and clear)
120
+
121
+ ```bash
122
+ python grounded_sam2_hf_model_demo.py
123
+ ```
124
+
125
+ > [!NOTE]
126
+ > 🚨 If you encounter network issues while using the `HuggingFace` model, you can resolve them by setting the appropriate mirror source as `export HF_ENDPOINT=https://hf-mirror.com`
127
+
128
+ - Load local pretrained Grounding DINO checkpoint and inference with Grounding DINO original API (make sure you've already downloaded the pretrained checkpoint)
129
+
130
+ ```bash
131
+ python grounded_sam2_local_demo.py
132
+ ```
133
+
134
+
135
+ ### Grounded SAM 2 Image Demo (with Grounding DINO 1.5 & 1.6)
136
+
137
+ We've already released our most capable open-set detection model [Grounding DINO 1.5 & 1.6](https://github.com/IDEA-Research/Grounding-DINO-1.5-API), which can be combined with SAM 2 for stronger open-set detection and segmentation capability. You can apply the API token first and run Grounded SAM 2 with Grounding DINO 1.5 as follows:
138
+
139
+ Install the latest DDS cloudapi:
140
+
141
+ ```bash
142
+ pip install dds-cloudapi-sdk --upgrade
143
+ ```
144
+
145
+ Apply your API token from our official website here: [request API token](https://deepdataspace.com/request_api).
146
+
147
+ ```bash
148
+ python grounded_sam2_gd1.5_demo.py
149
+ ```
150
+
151
+ ### SAHI (Slicing Aided Hyper Inference) with Grounding DINO 1.5 and SAM 2
152
+
153
+ If your images are high resolution with dense objects, directly using Grounding DINO 1.5 for inference on the original image may not be the best choice. We support [SAHI (Slicing Aided Hyper Inference)](https://docs.ultralytics.com/guides/sahi-tiled-inference/), which works by first dividing the original image into smaller overlapping patches. Inference is then performed separately on each patch, and the final detection results are merged. This method is highly effective and accuracy for dense and small objects detection in high resolution images.
154
+
155
+ You can run SAHI inference by setting the following param in [grounded_sam2_gd1.5_demo.py](./grounded_sam2_gd1.5_demo.py):
156
+
157
+ ```python
158
+ WITH_SLICE_INFERENCE = True
159
+ ```
160
+
161
+ The visualization is shown as follows:
162
+
163
+ | Text Prompt | Input Image | Grounded SAM 2 | Grounded SAM 2 with SAHI |
164
+ |:----:|:----:|:----:|:----:|
165
+ | `Person` | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam_2/demo_images/dense%20people.png?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam_2/grounding_dino_1.5_slice_inference/grounded_sam2_annotated_image_with_mask.jpg?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam_2/grounding_dino_1.5_slice_inference/grounded_sam2_annotated_image_with_mask_with_slice_inference.jpg?raw=true) |
166
+
167
+ - **Notes:** We only support SAHI on Grounding DINO 1.5 because it works better with stronger grounding model which may produce less hallucination results.
168
+
169
+ ### Grounded SAM 2 Image Demo (with DINO-X)
170
+
171
+ We've implemented Grounded SAM 2 with the strongest open-world perception model [DINO-X](https://github.com/IDEA-Research/DINO-X-API) for better open-set detection and segmentation performance. You can apply the API token first and run Grounded SAM 2 with DINO-X as follows:
172
+
173
+ Install the latest DDS cloudapi:
174
+
175
+ ```bash
176
+ pip install dds-cloudapi-sdk --upgrade
177
+ ```
178
+
179
+ Apply your API token from our official website here: [request API token](https://deepdataspace.com/request_api).
180
+
181
+ ```bash
182
+ python grounded_sam2_dinox_demo.py
183
+ ```
184
+
185
+ ### Automatically Saving Grounding Results (Image Demo)
186
+
187
+ After setting `DUMP_JSON_RESULTS=True` in the following Grounded SAM 2 Image Demos:
188
+ - [grounded_sam2_local_demo.py](./grounded_sam2_local_demo.py)
189
+ - [grounded_sam2_hf_model_demo.py](./grounded_sam2_hf_model_demo.py)
190
+ - [grounded_sam2_gd1.5_demo.py](./grounded_sam2_gd1.5_demo.py)
191
+ - [grounded_sam2_dinox_demo.py](./grounded_sam2_dinox_demo.py)
192
+
193
+ The `grounding` and `segmentation` results will be automatically saved in the `outputs` dir with the following format:
194
+
195
+ ```python
196
+ {
197
+ "image_path": "path/to/image.jpg",
198
+ "annotations": [
199
+ {
200
+ "class_name": "class_name",
201
+ "bbox": [x1, y1, x2, y2],
202
+ "segmentation": {
203
+ "size": [h, w],
204
+ "counts": "rle_encoded_mask"
205
+ },
206
+ "score": confidence score
207
+ }
208
+ ],
209
+ "box_format": "xyxy",
210
+ "img_width": w,
211
+ "img_height": h
212
+ }
213
+ ```
214
+
215
+
216
+
217
+ ### Grounded SAM 2 Video Object Tracking Demo
218
+
219
+ Based on the strong tracking capability of SAM 2, we can combined it with Grounding DINO for open-set object segmentation and tracking. You can run the following scripts to get the tracking results with Grounded SAM 2:
220
+
221
+ ```bash
222
+ python grounded_sam2_tracking_demo.py
223
+ ```
224
+
225
+ - The tracking results of each frame will be saved in `./tracking_results`
226
+ - The video will be save as `children_tracking_demo_video.mp4`
227
+ - You can refine this file with different text prompt and video clips yourself to get more tracking results.
228
+ - We only prompt the first video frame with Grounding DINO here for simple usage.
229
+
230
+ #### Support Various Prompt Type for Tracking
231
+
232
+ We've supported different types of prompt for Grounded SAM 2 tracking demo:
233
+
234
+ - **Point Prompt**: In order to **get a stable segmentation results**, we re-use the SAM 2 image predictor to get the prediction mask from each object based on Grounding DINO box outputs, then we **uniformly sample points from the prediction mask** as point prompts for SAM 2 video predictor
235
+ - **Box Prompt**: We directly use the box outputs from Grounding DINO as box prompts for SAM 2 video predictor
236
+ - **Mask Prompt**: We use the SAM 2 mask prediction results based on Grounding DINO box outputs as mask prompt for SAM 2 video predictor.
237
+
238
+ ![Grounded SAM 2 Tracking Pipeline](./assets/g_sam2_tracking_pipeline_vis_new.png)
239
+
240
+
241
+ ### Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)
242
+
243
+ We've also support video object tracking demo based on our stronger `Grounding DINO 1.5` model and `SAM 2`, you can try the following demo after applying the API keys for running `Grounding DINO 1.5`:
244
+
245
+ ```bash
246
+ python grounded_sam2_tracking_demo_with_gd1.5.py
247
+ ```
248
+
249
+ ### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO)
250
+
251
+ Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO and SAM 2 by using the following scripts:
252
+
253
+ ```bash
254
+ python grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py
255
+ ```
256
+
257
+ If you are not convenient to use huggingface demo, you can also run tracking demo with local grounding dino model with the following scripts:
258
+
259
+ ```bash
260
+ python grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
261
+ ```
262
+
263
+ ### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO 1.5 & 1.6)
264
+
265
+ Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO 1.5 and SAM 2 by using the following scripts:
266
+
267
+ ```bash
268
+ python grounded_sam2_tracking_demo_custom_video_input_gd1.5.py
269
+ ```
270
+
271
+ You can specify the params in this file:
272
+
273
+ ```python
274
+ VIDEO_PATH = "./assets/hippopotamus.mp4"
275
+ TEXT_PROMPT = "hippopotamus."
276
+ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
277
+ API_TOKEN_FOR_GD1_5 = "Your API token" # api token for G-DINO 1.5
278
+ PROMPT_TYPE_FOR_VIDEO = "mask" # using SAM 2 mask prediction as prompt for video predictor
279
+ ```
280
+
281
+ After running our demo code, you can get the tracking results as follows:
282
+
283
+ [![Video Name](./assets/hippopotamus_seg.jpg)](https://github.com/user-attachments/assets/1fbdc6f4-3e50-4221-9600-98c397beecdf)
284
+
285
+ And we will automatically save the tracking visualization results in `OUTPUT_VIDEO_PATH`.
286
+
287
+ > [!WARNING]
288
+ > We initialize the box prompts on the first frame of the input video. If you want to start from different frame, you can refine `ann_frame_idx` by yourself in our code.
289
+
290
+ ### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with DINO-X)
291
+
292
+ Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with DINO-X and SAM 2 by using the following scripts:
293
+
294
+ ```bash
295
+ python grounded_sam2_tracking_demo_custom_video_input_dinox.py
296
+ ```
297
+
298
+ ### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
299
+
300
+ In above demos, we only prompt Grounded SAM 2 in specific frame, which may not be friendly to find new object during the whole video. In this demo, we try to **find new objects** and assign them with new ID across the whole video, this function is **still under develop**. it's not that stable now.
301
+
302
+ Users can upload their own video files and specify custom text prompts for grounding and tracking using the Grounding DINO and SAM 2 frameworks. To do this, execute the script:
303
+
304
+
305
+ ```bash
306
+ python grounded_sam2_tracking_demo_with_continuous_id.py
307
+ ```
308
+
309
+ You can customize various parameters including:
310
+
311
+ - `text`: The grounding text prompt.
312
+ - `video_dir`: Directory containing the video files.
313
+ - `output_dir`: Directory to save the processed output.
314
+ - `output_video_path`: Path for the output video.
315
+ - `step`: Frame stepping for processing.
316
+ - `box_threshold`: box threshold for groundingdino model
317
+ - `text_threshold`: text threshold for groundingdino model
318
+ Note: This method supports only the mask type of text prompt.
319
+
320
+ After running our demo code, you can get the tracking results as follows:
321
+
322
+ [![Video Name](./assets/tracking_car_mask_1.jpg)](https://github.com/user-attachments/assets/d3f91ad0-3d32-43c4-a0dc-0bed661415f4)
323
+
324
+ If you want to try `Grounding DINO 1.5` model, you can run the following scripts after setting your API token:
325
+
326
+ ```bash
327
+ python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
328
+ ```
329
+
330
+ ### Grounded-SAM-2 Video Object Tracking with Continuous ID plus Reverse Tracking(with Grounding DINO)
331
+ This method could simply cover the whole lifetime of the object
332
+ ```bash
333
+ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
334
+
335
+ ```
336
+
337
+ ## Grounded SAM 2 Florence-2 Demos
338
+ ### Grounded SAM 2 Florence-2 Image Demo
339
+
340
+ In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
341
+
342
+ [Florence-2](https://arxiv.org/abs/2311.06242) is a powerful vision foundation model by Microsoft which supports a series of vision tasks by prompting with special `task_prompt` includes but not limited to:
343
+
344
+ | Task | Task Prompt | Text Input | Task Introduction |
345
+ |:---:|:---:|:---:|:---:|
346
+ | Object Detection | `<OD>` | &#10008; | Detect main objects with single category name |
347
+ | Dense Region Caption | `<DENSE_REGION_CAPTION>` | &#10008; | Detect main objects with short description |
348
+ | Region Proposal | `<REGION_PROPOSAL>` | &#10008; | Generate proposals without category name |
349
+ | Phrase Grounding | `<CAPTION_TO_PHRASE_GROUNDING>` | &#10004; | Ground main objects in image mentioned in caption |
350
+ | Referring Expression Segmentation | `<REFERRING_EXPRESSION_SEGMENTATION>` | &#10004; | Ground the object which is most related to the text input |
351
+ | Open Vocabulary Detection and Segmentation | `<OPEN_VOCABULARY_DETECTION>` | &#10004; | Ground any object with text input |
352
+
353
+
354
+ Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo:
355
+
356
+ > [!NOTE]
357
+ > 🚨 If you encounter network issues while using the `HuggingFace` model, you can resolve them by setting the appropriate mirror source as `export HF_ENDPOINT=https://hf-mirror.com`
358
+
359
+ **Object Detection and Segmentation**
360
+ ```bash
361
+ python grounded_sam2_florence2_image_demo.py \
362
+ --pipeline object_detection_segmentation \
363
+ --image_path ./notebooks/images/cars.jpg
364
+ ```
365
+
366
+ **Dense Region Caption and Segmentation**
367
+ ```bash
368
+ python grounded_sam2_florence2_image_demo.py \
369
+ --pipeline dense_region_caption_segmentation \
370
+ --image_path ./notebooks/images/cars.jpg
371
+ ```
372
+
373
+ **Region Proposal and Segmentation**
374
+ ```bash
375
+ python grounded_sam2_florence2_image_demo.py \
376
+ --pipeline region_proposal_segmentation \
377
+ --image_path ./notebooks/images/cars.jpg
378
+ ```
379
+
380
+ **Phrase Grounding and Segmentation**
381
+ ```bash
382
+ python grounded_sam2_florence2_image_demo.py \
383
+ --pipeline phrase_grounding_segmentation \
384
+ --image_path ./notebooks/images/cars.jpg \
385
+ --text_input "The image shows two vintage Chevrolet cars parked side by side, with one being a red convertible and the other a pink sedan, \
386
+ set against the backdrop of an urban area with a multi-story building and trees. \
387
+ The cars have Cuban license plates, indicating a location likely in Cuba."
388
+ ```
389
+
390
+ **Referring Expression Segmentation**
391
+ ```bash
392
+ python grounded_sam2_florence2_image_demo.py \
393
+ --pipeline referring_expression_segmentation \
394
+ --image_path ./notebooks/images/cars.jpg \
395
+ --text_input "The left red car."
396
+ ```
397
+
398
+ **Open-Vocabulary Detection and Segmentation**
399
+ ```bash
400
+ python grounded_sam2_florence2_image_demo.py \
401
+ --pipeline open_vocabulary_detection_segmentation \
402
+ --image_path ./notebooks/images/cars.jpg \
403
+ --text_input "car <and> building"
404
+ ```
405
+ - Note that if you want to **detect multiple classes** you should split them with `<and>` in your input text.
406
+
407
+
408
+ ### Grounded SAM 2 Florence-2 Image Auto-Labeling Demo
409
+ `Florence-2` can be used as a auto image annotator by cascading its caption capability with its grounding capability.
410
+
411
+ | Task | Task Prompt | Text Input |
412
+ |:---:|:---:|:---:|
413
+ | Caption + Phrase Grounding | `<CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | &#10008; |
414
+ | Detailed Caption + Phrase Grounding | `<DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | &#10008; |
415
+ | More Detailed Caption + Phrase Grounding | `<MORE_DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | &#10008; |
416
+
417
+ You can try the following scripts to run these demo:
418
+
419
+ **Caption to Phrase Grounding**
420
+ ```bash
421
+ python grounded_sam2_florence2_autolabel_pipeline.py \
422
+ --image_path ./notebooks/images/groceries.jpg \
423
+ --pipeline caption_to_phrase_grounding \
424
+ --caption_type caption
425
+ ```
426
+
427
+ - You can specify `caption_type` to control the granularity of the caption, if you want a more detailed caption, you can try `--caption_type detailed_caption` or `--caption_type more_detailed_caption`.
428
+
429
+ ### Citation
430
+
431
+ If you find this project helpful for your research, please consider citing the following BibTeX entry.
432
+
433
+ ```BibTex
434
+ @misc{ravi2024sam2segmentimages,
435
+ title={SAM 2: Segment Anything in Images and Videos},
436
+ author={Nikhila Ravi and Valentin Gabeur and Yuan-Ting Hu and Ronghang Hu and Chaitanya Ryali and Tengyu Ma and Haitham Khedr and Roman Rädle and Chloe Rolland and Laura Gustafson and Eric Mintun and Junting Pan and Kalyan Vasudev Alwala and Nicolas Carion and Chao-Yuan Wu and Ross Girshick and Piotr Dollár and Christoph Feichtenhofer},
437
+ year={2024},
438
+ eprint={2408.00714},
439
+ archivePrefix={arXiv},
440
+ primaryClass={cs.CV},
441
+ url={https://arxiv.org/abs/2408.00714},
442
+ }
443
+
444
+ @article{liu2023grounding,
445
+ title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
446
+ author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
447
+ journal={arXiv preprint arXiv:2303.05499},
448
+ year={2023}
449
+ }
450
+
451
+ @misc{ren2024grounding,
452
+ title={Grounding DINO 1.5: Advance the "Edge" of Open-Set Object Detection},
453
+ author={Tianhe Ren and Qing Jiang and Shilong Liu and Zhaoyang Zeng and Wenlong Liu and Han Gao and Hongjie Huang and Zhengyu Ma and Xiaoke Jiang and Yihao Chen and Yuda Xiong and Hao Zhang and Feng Li and Peijun Tang and Kent Yu and Lei Zhang},
454
+ year={2024},
455
+ eprint={2405.10300},
456
+ archivePrefix={arXiv},
457
+ primaryClass={cs.CV}
458
+ }
459
+
460
+ @misc{ren2024grounded,
461
+ title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
462
+ author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},
463
+ year={2024},
464
+ eprint={2401.14159},
465
+ archivePrefix={arXiv},
466
+ primaryClass={cs.CV}
467
+ }
468
+
469
+ @article{kirillov2023segany,
470
+ title={Segment Anything},
471
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
472
+ journal={arXiv:2304.02643},
473
+ year={2023}
474
+ }
475
+
476
+ @misc{jiang2024trex2,
477
+ title={T-Rex2: Towards Generic Object Detection via Text-Visual Prompt Synergy},
478
+ author={Qing Jiang and Feng Li and Zhaoyang Zeng and Tianhe Ren and Shilong Liu and Lei Zhang},
479
+ year={2024},
480
+ eprint={2403.14610},
481
+ archivePrefix={arXiv},
482
+ primaryClass={cs.CV}
483
+ }
484
+ ```
clone-IDEA-Research/Grounded-SAM-2/SAM2_README.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 2: Segment Anything in Images and Videos
2
+
3
+ **[AI at Meta, FAIR](https://ai.meta.com/research/)**
4
+
5
+ [Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
6
+
7
+ [[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
8
+
9
+ ![SAM 2 architecture](assets/model_diagram.png?raw=true)
10
+
11
+ **Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
12
+
13
+ ![SA-V dataset](assets/sa_v_dataset.jpg?raw=true)
14
+
15
+ ## Installation
16
+
17
+ Please install SAM 2 on a GPU machine using:
18
+
19
+ ```bash
20
+ git clone https://github.com/facebookresearch/segment-anything-2.git
21
+
22
+ cd segment-anything-2; pip install -e .
23
+ ```
24
+
25
+ To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
26
+
27
+ ```bash
28
+ pip install -e ".[demo]"
29
+ ```
30
+
31
+ ## Getting Started
32
+
33
+ ### Download Checkpoints
34
+
35
+ First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
36
+
37
+ ```bash
38
+ cd checkpoints
39
+ ./download_ckpts.sh
40
+ ```
41
+
42
+ or individually from:
43
+
44
+ - [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
45
+ - [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
46
+ - [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)
47
+ - [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)
48
+
49
+ Then SAM 2 can be used in a few lines as follows for image and video prediction.
50
+
51
+ ### Image prediction
52
+
53
+ SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
54
+
55
+ ```python
56
+ import torch
57
+ from sam2.build_sam import build_sam2
58
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
59
+
60
+ checkpoint = "./checkpoints/sam2_hiera_large.pt"
61
+ model_cfg = "sam2_hiera_l.yaml"
62
+ predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
63
+
64
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
65
+ predictor.set_image(<your_image>)
66
+ masks, _, _ = predictor.predict(<input_prompts>)
67
+ ```
68
+
69
+ Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases.
70
+
71
+ SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images.
72
+
73
+ ### Video prediction
74
+
75
+ For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
76
+
77
+ ```python
78
+ import torch
79
+ from sam2.build_sam import build_sam2_video_predictor
80
+
81
+ checkpoint = "./checkpoints/sam2_hiera_large.pt"
82
+ model_cfg = "sam2_hiera_l.yaml"
83
+ predictor = build_sam2_video_predictor(model_cfg, checkpoint)
84
+
85
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
86
+ state = predictor.init_state(<your_video>)
87
+
88
+ # add new prompts and instantly get the output on the same frame
89
+ frame_idx, object_ids, masks = predictor.add_new_points(state, <your prompts>):
90
+
91
+ # propagate the prompts to get masklets throughout the video
92
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
93
+ ...
94
+ ```
95
+
96
+ Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
97
+
98
+ ## Model Description
99
+
100
+ | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
101
+ | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
102
+ | sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
103
+ | sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
104
+ | sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
105
+ | sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
106
+
107
+ \* Compile the model by setting `compile_image_encoder: True` in the config.
108
+
109
+ ## Segment Anything Video Dataset
110
+
111
+ See [sav_dataset/README.md](sav_dataset/README.md) for details.
112
+
113
+ ## License
114
+
115
+ The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models.
116
+
117
+ ## Contributing
118
+
119
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
120
+
121
+ ## Contributors
122
+
123
+ The SAM 2 project was made possible with the help of many contributors (alphabetical):
124
+
125
+ Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
126
+
127
+ Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
128
+
129
+ ## Citing SAM 2
130
+
131
+ If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
132
+
133
+ ```bibtex
134
+ @article{ravi2024sam2,
135
+ title={SAM 2: Segment Anything in Images and Videos},
136
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
137
+ journal={arXiv preprint},
138
+ year={2024}
139
+ }
140
+ ```
clone-IDEA-Research/Grounded-SAM-2/backend.Dockerfile ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
2
+ ARG MODEL_SIZE=base_plus
3
+
4
+ FROM ${BASE_IMAGE}
5
+
6
+ # Gunicorn environment variables
7
+ ENV GUNICORN_WORKERS=1
8
+ ENV GUNICORN_THREADS=2
9
+ ENV GUNICORN_PORT=5000
10
+
11
+ # SAM 2 environment variables
12
+ ENV APP_ROOT=/opt/sam2
13
+ ENV PYTHONUNBUFFERED=1
14
+ ENV SAM2_BUILD_CUDA=0
15
+ ENV MODEL_SIZE=${MODEL_SIZE}
16
+
17
+ # Install system requirements
18
+ RUN apt-get update && apt-get install -y --no-install-recommends \
19
+ ffmpeg \
20
+ libavutil-dev \
21
+ libavcodec-dev \
22
+ libavformat-dev \
23
+ libswscale-dev \
24
+ pkg-config \
25
+ build-essential \
26
+ libffi-dev
27
+
28
+ COPY setup.py .
29
+ COPY README.md .
30
+
31
+ RUN pip install --upgrade pip setuptools
32
+ RUN pip install -e ".[interactive-demo]"
33
+
34
+ # https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707
35
+ RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg
36
+
37
+ # Make app directory. This directory will host all files required for the
38
+ # backend and SAM 2 inference files.
39
+ RUN mkdir ${APP_ROOT}
40
+
41
+ # Copy backend server files
42
+ COPY demo/backend/server ${APP_ROOT}/server
43
+
44
+ # Copy SAM 2 inference files
45
+ COPY sam2 ${APP_ROOT}/server/sam2
46
+
47
+ # Download SAM 2.1 checkpoints
48
+ ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt
49
+ ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt
50
+ ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt
51
+ ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt
52
+
53
+ WORKDIR ${APP_ROOT}/server
54
+
55
+ # https://pythonspeed.com/articles/gunicorn-in-docker/
56
+ CMD gunicorn --worker-tmp-dir /dev/shm \
57
+ --worker-class gthread app:app \
58
+ --log-level info \
59
+ --access-logfile /dev/stdout \
60
+ --log-file /dev/stderr \
61
+ --workers ${GUNICORN_WORKERS} \
62
+ --threads ${GUNICORN_THREADS} \
63
+ --bind 0.0.0.0:${GUNICORN_PORT} \
64
+ --timeout 60
clone-IDEA-Research/Grounded-SAM-2/docker-compose.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ frontend:
3
+ image: sam2/frontend
4
+ build:
5
+ context: ./demo/frontend
6
+ dockerfile: frontend.Dockerfile
7
+ ports:
8
+ - 7262:80
9
+
10
+ backend:
11
+ image: sam2/backend
12
+ build:
13
+ context: .
14
+ dockerfile: backend.Dockerfile
15
+ ports:
16
+ - 7263:5000
17
+ volumes:
18
+ - ./demo/data/:/data/:rw
19
+ environment:
20
+ - SERVER_ENVIRONMENT=DEV
21
+ - GUNICORN_WORKERS=1
22
+ # Inference API needs to have at least 2 threads to handle an incoming
23
+ # parallel cancel propagation request
24
+ - GUNICORN_THREADS=2
25
+ - GUNICORN_PORT=5000
26
+ - API_URL=http://localhost:7263
27
+ - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4
28
+ # # ffmpeg/video encode settings
29
+ - FFMPEG_NUM_THREADS=1
30
+ - VIDEO_ENCODE_CODEC=libx264
31
+ - VIDEO_ENCODE_CRF=23
32
+ - VIDEO_ENCODE_FPS=24
33
+ - VIDEO_ENCODE_MAX_WIDTH=1280
34
+ - VIDEO_ENCODE_MAX_HEIGHT=720
35
+ - VIDEO_ENCODE_VERBOSE=False
36
+ deploy:
37
+ resources:
38
+ reservations:
39
+ devices:
40
+ - driver: nvidia
41
+ count: 1
42
+ capabilities: [gpu]
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_dinox_demo.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
5
+ from dds_cloudapi_sdk.tasks.types import DetectionTarget
6
+ from dds_cloudapi_sdk import TextPrompt
7
+
8
+ import os
9
+ import cv2
10
+ import json
11
+ import torch
12
+ import tempfile
13
+ import numpy as np
14
+ import supervision as sv
15
+ import pycocotools.mask as mask_util
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ from sam2.build_sam import build_sam2
19
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
20
+
21
+ """
22
+ Hyper parameters
23
+ """
24
+ API_TOKEN = "Your API token"
25
+ TEXT_PROMPT = "car . building ."
26
+ IMG_PATH = "notebooks/images/cars.jpg"
27
+ SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
28
+ SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
29
+ BOX_THRESHOLD = 0.2
30
+ WITH_SLICE_INFERENCE = False
31
+ SLICE_WH = (480, 480)
32
+ OVERLAP_RATIO = (0.2, 0.2)
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
35
+ DUMP_JSON_RESULTS = True
36
+
37
+ # create output directory
38
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
+
40
+ """
41
+ Prompt DINO-X with Text for Box Prompt Generation with Cloud API
42
+ """
43
+ # Step 1: initialize the config
44
+ token = API_TOKEN
45
+ config = Config(token)
46
+
47
+ # Step 2: initialize the client
48
+ client = Client(config)
49
+
50
+ # Step 3: run the task by DetectionTask class
51
+ # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
52
+ # if you are processing local image file, upload them to DDS server to get the image url
53
+
54
+ classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
55
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
56
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
57
+
58
+ if WITH_SLICE_INFERENCE:
59
+ def callback(image_slice: np.ndarray) -> sv.Detections:
60
+ print("Inference on image slice")
61
+ # save the img as temp img file for GD-1.5 API usage
62
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
63
+ temp_filename = tmpfile.name
64
+ cv2.imwrite(temp_filename, image_slice)
65
+ image_url = client.upload_file(temp_filename)
66
+ task = DinoxTask(
67
+ image_url=image_url,
68
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
69
+ bbox_threshold=0.25,
70
+ targets=[DetectionTarget.BBox],
71
+ )
72
+ client.run_task(task)
73
+ result = task.result
74
+ # detele the tempfile
75
+ os.remove(temp_filename)
76
+
77
+ input_boxes = []
78
+ confidences = []
79
+ class_ids = []
80
+ objects = result.objects
81
+ for idx, obj in enumerate(objects):
82
+ input_boxes.append(obj.bbox)
83
+ confidences.append(obj.score)
84
+ cls_name = obj.category.lower().strip()
85
+ class_ids.append(class_name_to_id[cls_name])
86
+ # ensure input_boxes with shape (_, 4)
87
+ input_boxes = np.array(input_boxes).reshape(-1, 4)
88
+ class_ids = np.array(class_ids)
89
+ confidences = np.array(confidences)
90
+ return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
91
+
92
+ slicer = sv.InferenceSlicer(
93
+ callback=callback,
94
+ slice_wh=SLICE_WH,
95
+ overlap_ratio_wh=OVERLAP_RATIO,
96
+ iou_threshold=0.5,
97
+ overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
98
+ )
99
+ detections = slicer(cv2.imread(IMG_PATH))
100
+ class_names = [class_id_to_name[id] for id in detections.class_id]
101
+ confidences = detections.confidence
102
+ class_ids = detections.class_id
103
+ input_boxes = detections.xyxy
104
+ else:
105
+ image_url = client.upload_file(IMG_PATH)
106
+
107
+ task = DinoxTask(
108
+ image_url=image_url,
109
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
110
+ bbox_threshold=0.25,
111
+ targets=[DetectionTarget.BBox],
112
+ )
113
+
114
+ client.run_task(task)
115
+ result = task.result
116
+
117
+ objects = result.objects # the list of detected objects
118
+
119
+
120
+ input_boxes = []
121
+ confidences = []
122
+ class_names = []
123
+ class_ids = []
124
+
125
+ for idx, obj in enumerate(objects):
126
+ input_boxes.append(obj.bbox)
127
+ confidences.append(obj.score)
128
+ cls_name = obj.category.lower().strip()
129
+ class_names.append(cls_name)
130
+ class_ids.append(class_name_to_id[cls_name])
131
+
132
+ input_boxes = np.array(input_boxes)
133
+ class_ids = np.array(class_ids)
134
+
135
+ """
136
+ Init SAM 2 Model and Predict Mask with Box Prompt
137
+ """
138
+
139
+ # environment settings
140
+ # use bfloat16
141
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
142
+
143
+ if torch.cuda.get_device_properties(0).major >= 8:
144
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
145
+ torch.backends.cuda.matmul.allow_tf32 = True
146
+ torch.backends.cudnn.allow_tf32 = True
147
+
148
+ # build SAM2 image predictor
149
+ sam2_checkpoint = SAM2_CHECKPOINT
150
+ model_cfg = SAM2_MODEL_CONFIG
151
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
152
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
153
+
154
+ image = Image.open(IMG_PATH)
155
+
156
+ sam2_predictor.set_image(np.array(image.convert("RGB")))
157
+
158
+ masks, scores, logits = sam2_predictor.predict(
159
+ point_coords=None,
160
+ point_labels=None,
161
+ box=input_boxes,
162
+ multimask_output=False,
163
+ )
164
+
165
+
166
+ """
167
+ Post-process the output of the model to get the masks, scores, and logits for visualization
168
+ """
169
+ # convert the shape to (n, H, W)
170
+ if masks.ndim == 4:
171
+ masks = masks.squeeze(1)
172
+
173
+
174
+ """
175
+ Visualization the Predict Results
176
+ """
177
+
178
+ labels = [
179
+ f"{class_name} {confidence:.2f}"
180
+ for class_name, confidence
181
+ in zip(class_names, confidences)
182
+ ]
183
+
184
+ """
185
+ Visualize image with supervision useful API
186
+ """
187
+ img = cv2.imread(IMG_PATH)
188
+ detections = sv.Detections(
189
+ xyxy=input_boxes, # (n, 4)
190
+ mask=masks.astype(bool), # (n, h, w)
191
+ class_id=class_ids
192
+ )
193
+
194
+ box_annotator = sv.BoxAnnotator()
195
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
196
+
197
+ label_annotator = sv.LabelAnnotator()
198
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
199
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "dinox_annotated_image.jpg"), annotated_frame)
200
+
201
+ mask_annotator = sv.MaskAnnotator()
202
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
203
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "dinox_sam2_annotated_image_with_mask.jpg"), annotated_frame)
204
+
205
+ print(f'Annotated image has already been saved as to "{OUTPUT_DIR}"')
206
+
207
+ """
208
+ Dump the results in standard format and save as json files
209
+ """
210
+
211
+ def single_mask_to_rle(mask):
212
+ rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
213
+ rle["counts"] = rle["counts"].decode("utf-8")
214
+ return rle
215
+
216
+ if DUMP_JSON_RESULTS:
217
+ print("Start dumping the annotation...")
218
+ # convert mask into rle format
219
+ mask_rles = [single_mask_to_rle(mask) for mask in masks]
220
+
221
+ input_boxes = input_boxes.tolist()
222
+ scores = scores.tolist()
223
+ # FIXME: class_names should be a list of strings without spaces
224
+ class_names = [class_name.strip() for class_name in class_names]
225
+ # save the results in standard format
226
+ results = {
227
+ "image_path": IMG_PATH,
228
+ "annotations" : [
229
+ {
230
+ "class_name": class_name,
231
+ "bbox": box,
232
+ "segmentation": mask_rle,
233
+ "score": score,
234
+ }
235
+ for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
236
+ ],
237
+ "box_format": "xyxy",
238
+ "img_width": image.width,
239
+ "img_height": image.height,
240
+ }
241
+
242
+ with open(os.path.join(OUTPUT_DIR, "grounded_sam2_dinox_image_demo_results.json"), "w") as f:
243
+ json.dump(results, f, indent=4)
244
+
245
+ print(f'Annotation has already been saved to "{OUTPUT_DIR}"')
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_autolabel_pipeline.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import supervision as sv
7
+ from PIL import Image
8
+ from sam2.build_sam import build_sam2
9
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
+ from utils.supervision_utils import CUSTOM_COLOR_MAP
12
+
13
+ """
14
+ Define Some Hyperparam
15
+ """
16
+
17
+ TASK_PROMPT = {
18
+ "caption": "<CAPTION>",
19
+ "detailed_caption": "<DETAILED_CAPTION>",
20
+ "more_detailed_caption": "<MORE_DETAILED_CAPTION>",
21
+ "object_detection": "<OD>",
22
+ "dense_region_caption": "<DENSE_REGION_CAPTION>",
23
+ "region_proposal": "<REGION_PROPOSAL>",
24
+ "phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
25
+ "referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
26
+ "region_to_segmentation": "<REGION_TO_SEGMENTATION>",
27
+ "open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
28
+ "region_to_category": "<REGION_TO_CATEGORY>",
29
+ "region_to_description": "<REGION_TO_DESCRIPTION>",
30
+ "ocr": "<OCR>",
31
+ "ocr_with_region": "<OCR_WITH_REGION>",
32
+ }
33
+
34
+ OUTPUT_DIR = "./outputs"
35
+
36
+ if not os.path.exists(OUTPUT_DIR):
37
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
38
+
39
+ """
40
+ Init Florence-2 and SAM 2 Model
41
+ """
42
+
43
+ FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
44
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
45
+ SAM2_CONFIG = "sam2_hiera_l.yaml"
46
+
47
+ # environment settings
48
+ # use bfloat16
49
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
50
+
51
+ if torch.cuda.get_device_properties(0).major >= 8:
52
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
53
+ torch.backends.cuda.matmul.allow_tf32 = True
54
+ torch.backends.cudnn.allow_tf32 = True
55
+
56
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
57
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
58
+
59
+ # build florence-2
60
+ florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
61
+ florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
62
+
63
+ # build sam 2
64
+ sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
65
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
66
+
67
+ def run_florence2(task_prompt, text_input, model, processor, image):
68
+ assert model is not None, "You should pass the init florence-2 model here"
69
+ assert processor is not None, "You should set florence-2 processor here"
70
+
71
+ device = model.device
72
+
73
+ if text_input is None:
74
+ prompt = task_prompt
75
+ else:
76
+ prompt = task_prompt + text_input
77
+
78
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
79
+ generated_ids = model.generate(
80
+ input_ids=inputs["input_ids"].to(device),
81
+ pixel_values=inputs["pixel_values"].to(device),
82
+ max_new_tokens=1024,
83
+ early_stopping=False,
84
+ do_sample=False,
85
+ num_beams=3,
86
+ )
87
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
88
+ parsed_answer = processor.post_process_generation(
89
+ generated_text,
90
+ task=task_prompt,
91
+ image_size=(image.width, image.height)
92
+ )
93
+ return parsed_answer
94
+
95
+
96
+ """
97
+ We try to support a series of cascaded auto-labelling pipelines with Florence-2 and SAM 2
98
+ """
99
+
100
+ """
101
+ Auto-Labelling Pipeline 1: Caption/Detailed Caption/More Detailed Caption + Phrase Grounding + Segmentation
102
+ """
103
+ def caption_phrase_grounding_and_segmentation(
104
+ florence2_model,
105
+ florence2_processor,
106
+ sam2_predictor,
107
+ image_path,
108
+ caption_task_prompt='<CAPTION>',
109
+ output_dir=OUTPUT_DIR
110
+ ):
111
+ assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"]
112
+ image = Image.open(image_path).convert("RGB")
113
+
114
+ # image caption
115
+ caption_results = run_florence2(caption_task_prompt, None, florence2_model, florence2_processor, image)
116
+ text_input = caption_results[caption_task_prompt]
117
+ print(f'Image caption for "{image_path}": ', text_input)
118
+
119
+ # phrase grounding
120
+ grounding_results = run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', text_input, florence2_model, florence2_processor, image)
121
+ grounding_results = grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']
122
+
123
+ # parse florence-2 detection results
124
+ input_boxes = np.array(grounding_results["bboxes"])
125
+ class_names = grounding_results["labels"]
126
+ class_ids = np.array(list(range(len(class_names))))
127
+
128
+ # predict mask with SAM 2
129
+ sam2_predictor.set_image(np.array(image))
130
+ masks, scores, logits = sam2_predictor.predict(
131
+ point_coords=None,
132
+ point_labels=None,
133
+ box=input_boxes,
134
+ multimask_output=False,
135
+ )
136
+
137
+ if masks.ndim == 4:
138
+ masks = masks.squeeze(1)
139
+
140
+ # specify labels
141
+ labels = [
142
+ f"{class_name}" for class_name in class_names
143
+ ]
144
+
145
+ # visualization results
146
+ img = cv2.imread(image_path)
147
+ detections = sv.Detections(
148
+ xyxy=input_boxes,
149
+ mask=masks.astype(bool),
150
+ class_id=class_ids
151
+ )
152
+
153
+ box_annotator = sv.BoxAnnotator()
154
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
155
+
156
+ label_annotator = sv.LabelAnnotator()
157
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
158
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling.jpg"), annotated_frame)
159
+
160
+ mask_annotator = sv.MaskAnnotator()
161
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
162
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling_with_mask.jpg"), annotated_frame)
163
+
164
+ print(f'Successfully save annotated image to "{output_dir}"')
165
+
166
+
167
+ if __name__ == "__main__":
168
+
169
+ parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
170
+ parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
171
+ parser.add_argument("--pipeline", type=str, default="caption_to_phrase_grounding", required=True, help="pipeline to use")
172
+ parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption")
173
+ args = parser.parse_args()
174
+
175
+ CAPTION_TO_TASK_PROMPT = {
176
+ "caption": "<CAPTION>",
177
+ "detailed_caption": "<DETAILED_CAPTION>",
178
+ "more_detailed_caption": "<MORE_DETAILED_CAPTION>"
179
+ }
180
+
181
+ IMAGE_PATH = args.image_path
182
+ PIPELINE = args.pipeline
183
+ CAPTION_TYPE = args.caption_type
184
+ assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption"]
185
+
186
+ print(f"Running pipeline: {PIPELINE} now.")
187
+
188
+ if PIPELINE == "caption_to_phrase_grounding":
189
+ # pipeline-1: caption + phrase grounding + segmentation
190
+ caption_phrase_grounding_and_segmentation(
191
+ florence2_model=florence2_model,
192
+ florence2_processor=florence2_processor,
193
+ sam2_predictor=sam2_predictor,
194
+ caption_task_prompt=CAPTION_TO_TASK_PROMPT[CAPTION_TYPE],
195
+ image_path=IMAGE_PATH
196
+ )
197
+ else:
198
+ raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_image_demo.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import supervision as sv
7
+ from PIL import Image
8
+ from sam2.build_sam import build_sam2
9
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
+ from utils.supervision_utils import CUSTOM_COLOR_MAP
12
+
13
+ """
14
+ Define Some Hyperparam
15
+ """
16
+
17
+ TASK_PROMPT = {
18
+ "caption": "<CAPTION>",
19
+ "detailed_caption": "<DETAILED_CAPTION>",
20
+ "more_detailed_caption": "<MORE_DETAILED_CAPTION",
21
+ "object_detection": "<OD>",
22
+ "dense_region_caption": "<DENSE_REGION_CAPTION>",
23
+ "region_proposal": "<REGION_PROPOSAL>",
24
+ "phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
25
+ "referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
26
+ "region_to_segmentation": "<REGION_TO_SEGMENTATION>",
27
+ "open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
28
+ "region_to_category": "<REGION_TO_CATEGORY>",
29
+ "region_to_description": "<REGION_TO_DESCRIPTION>",
30
+ "ocr": "<OCR>",
31
+ "ocr_with_region": "<OCR_WITH_REGION>",
32
+ }
33
+
34
+ OUTPUT_DIR = "./outputs"
35
+
36
+ if not os.path.exists(OUTPUT_DIR):
37
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
38
+
39
+ """
40
+ Init Florence-2 and SAM 2 Model
41
+ """
42
+
43
+ FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
44
+ SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
45
+ SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
46
+
47
+ # environment settings
48
+ # use bfloat16
49
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
50
+
51
+ if torch.cuda.get_device_properties(0).major >= 8:
52
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
53
+ torch.backends.cuda.matmul.allow_tf32 = True
54
+ torch.backends.cudnn.allow_tf32 = True
55
+
56
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
57
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
58
+
59
+ # build florence-2
60
+ florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
61
+ florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
62
+
63
+ # build sam 2
64
+ sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
65
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
66
+
67
+ def run_florence2(task_prompt, text_input, model, processor, image):
68
+ assert model is not None, "You should pass the init florence-2 model here"
69
+ assert processor is not None, "You should set florence-2 processor here"
70
+
71
+ device = model.device
72
+
73
+ if text_input is None:
74
+ prompt = task_prompt
75
+ else:
76
+ prompt = task_prompt + text_input
77
+
78
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
79
+ generated_ids = model.generate(
80
+ input_ids=inputs["input_ids"].to(device),
81
+ pixel_values=inputs["pixel_values"].to(device),
82
+ max_new_tokens=1024,
83
+ early_stopping=False,
84
+ do_sample=False,
85
+ num_beams=3,
86
+ )
87
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
88
+ parsed_answer = processor.post_process_generation(
89
+ generated_text,
90
+ task=task_prompt,
91
+ image_size=(image.width, image.height)
92
+ )
93
+ return parsed_answer
94
+
95
+
96
+ """
97
+ We support a set of pipelines built by Florence-2 + SAM 2
98
+ """
99
+
100
+ """
101
+ Pipeline-1: Object Detection + Segmentation
102
+ """
103
+ def object_detection_and_segmentation(
104
+ florence2_model,
105
+ florence2_processor,
106
+ sam2_predictor,
107
+ image_path,
108
+ task_prompt="<OD>",
109
+ text_input=None,
110
+ output_dir=OUTPUT_DIR
111
+ ):
112
+ assert text_input is None, "Text input should be None when calling object detection pipeline."
113
+ # run florence-2 object detection in demo
114
+ image = Image.open(image_path).convert("RGB")
115
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
116
+
117
+ """ Florence-2 Object Detection Output Format
118
+ {'<OD>':
119
+ {
120
+ 'bboxes':
121
+ [
122
+ [33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
123
+ [454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
124
+ [224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
125
+ [449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
126
+ [91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
127
+ ],
128
+ 'labels': ['car', 'door', 'door', 'wheel', 'wheel']
129
+ }
130
+ }
131
+ """
132
+ results = results[task_prompt]
133
+ # parse florence-2 detection results
134
+ input_boxes = np.array(results["bboxes"])
135
+ class_names = results["labels"]
136
+ class_ids = np.array(list(range(len(class_names))))
137
+
138
+ # predict mask with SAM 2
139
+ sam2_predictor.set_image(np.array(image))
140
+ masks, scores, logits = sam2_predictor.predict(
141
+ point_coords=None,
142
+ point_labels=None,
143
+ box=input_boxes,
144
+ multimask_output=False,
145
+ )
146
+
147
+ if masks.ndim == 4:
148
+ masks = masks.squeeze(1)
149
+
150
+ # specify labels
151
+ labels = [
152
+ f"{class_name}" for class_name in class_names
153
+ ]
154
+
155
+ # visualization results
156
+ img = cv2.imread(image_path)
157
+ detections = sv.Detections(
158
+ xyxy=input_boxes,
159
+ mask=masks.astype(bool),
160
+ class_id=class_ids
161
+ )
162
+
163
+ box_annotator = sv.BoxAnnotator()
164
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
165
+
166
+ label_annotator = sv.LabelAnnotator()
167
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
168
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
169
+
170
+ mask_annotator = sv.MaskAnnotator()
171
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
172
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
173
+
174
+ print(f'Successfully save annotated image to "{output_dir}"')
175
+
176
+ """
177
+ Pipeline 2: Dense Region Caption + Segmentation
178
+ """
179
+ def dense_region_caption_and_segmentation(
180
+ florence2_model,
181
+ florence2_processor,
182
+ sam2_predictor,
183
+ image_path,
184
+ task_prompt="<DENSE_REGION_CAPTION>",
185
+ text_input=None,
186
+ output_dir=OUTPUT_DIR
187
+ ):
188
+ assert text_input is None, "Text input should be None when calling dense region caption pipeline."
189
+ # run florence-2 object detection in demo
190
+ image = Image.open(image_path).convert("RGB")
191
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
192
+
193
+ """ Florence-2 Object Detection Output Format
194
+ {'<DENSE_REGION_CAPTION>':
195
+ {
196
+ 'bboxes':
197
+ [
198
+ [33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
199
+ [454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
200
+ [224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
201
+ [449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
202
+ [91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
203
+ ],
204
+ 'labels': ['turquoise Volkswagen Beetle', 'wooden double doors with metal handles', 'wheel', 'wheel', 'door']
205
+ }
206
+ }
207
+ """
208
+ results = results[task_prompt]
209
+ # parse florence-2 detection results
210
+ input_boxes = np.array(results["bboxes"])
211
+ class_names = results["labels"]
212
+ class_ids = np.array(list(range(len(class_names))))
213
+
214
+ # predict mask with SAM 2
215
+ sam2_predictor.set_image(np.array(image))
216
+ masks, scores, logits = sam2_predictor.predict(
217
+ point_coords=None,
218
+ point_labels=None,
219
+ box=input_boxes,
220
+ multimask_output=False,
221
+ )
222
+
223
+ if masks.ndim == 4:
224
+ masks = masks.squeeze(1)
225
+
226
+ # specify labels
227
+ labels = [
228
+ f"{class_name}" for class_name in class_names
229
+ ]
230
+
231
+ # visualization results
232
+ img = cv2.imread(image_path)
233
+ detections = sv.Detections(
234
+ xyxy=input_boxes,
235
+ mask=masks.astype(bool),
236
+ class_id=class_ids
237
+ )
238
+
239
+ box_annotator = sv.BoxAnnotator()
240
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
241
+
242
+ label_annotator = sv.LabelAnnotator()
243
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
244
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_annotated_image.jpg"), annotated_frame)
245
+
246
+ mask_annotator = sv.MaskAnnotator()
247
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
248
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_image_with_mask.jpg"), annotated_frame)
249
+
250
+ print(f'Successfully save annotated image to "{output_dir}"')
251
+
252
+
253
+ """
254
+ Pipeline 3: Region Proposal + Segmentation
255
+ """
256
+ def region_proposal_and_segmentation(
257
+ florence2_model,
258
+ florence2_processor,
259
+ sam2_predictor,
260
+ image_path,
261
+ task_prompt="<REGION_PROPOSAL>",
262
+ text_input=None,
263
+ output_dir=OUTPUT_DIR
264
+ ):
265
+ assert text_input is None, "Text input should be None when calling region proposal pipeline."
266
+ # run florence-2 object detection in demo
267
+ image = Image.open(image_path).convert("RGB")
268
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
269
+
270
+ """ Florence-2 Object Detection Output Format
271
+ {'<REGION_PROPOSAL>':
272
+ {
273
+ 'bboxes':
274
+ [
275
+ [33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
276
+ [454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
277
+ [224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
278
+ [449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
279
+ [91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
280
+ ],
281
+ 'labels': ['', '', '', '', '', '', '']
282
+ }
283
+ }
284
+ """
285
+ results = results[task_prompt]
286
+ # parse florence-2 detection results
287
+ input_boxes = np.array(results["bboxes"])
288
+ class_names = results["labels"]
289
+ class_ids = np.array(list(range(len(class_names))))
290
+
291
+ # predict mask with SAM 2
292
+ sam2_predictor.set_image(np.array(image))
293
+ masks, scores, logits = sam2_predictor.predict(
294
+ point_coords=None,
295
+ point_labels=None,
296
+ box=input_boxes,
297
+ multimask_output=False,
298
+ )
299
+
300
+ if masks.ndim == 4:
301
+ masks = masks.squeeze(1)
302
+
303
+ # specify labels
304
+ labels = [
305
+ f"region_{idx}" for idx, class_name in enumerate(class_names)
306
+ ]
307
+
308
+ # visualization results
309
+ img = cv2.imread(image_path)
310
+ detections = sv.Detections(
311
+ xyxy=input_boxes,
312
+ mask=masks.astype(bool),
313
+ class_id=class_ids
314
+ )
315
+
316
+ box_annotator = sv.BoxAnnotator()
317
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
318
+
319
+ label_annotator = sv.LabelAnnotator()
320
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
321
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal.jpg"), annotated_frame)
322
+
323
+ mask_annotator = sv.MaskAnnotator()
324
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
325
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal_with_mask.jpg"), annotated_frame)
326
+
327
+ print(f'Successfully save annotated image to "{output_dir}"')
328
+
329
+
330
+ """
331
+ Pipeline 4: Phrase Grounding + Segmentation
332
+ """
333
+ def phrase_grounding_and_segmentation(
334
+ florence2_model,
335
+ florence2_processor,
336
+ sam2_predictor,
337
+ image_path,
338
+ task_prompt="<CAPTION_TO_PHRASE_GROUNDING>",
339
+ text_input=None,
340
+ output_dir=OUTPUT_DIR
341
+ ):
342
+ # run florence-2 object detection in demo
343
+ image = Image.open(image_path).convert("RGB")
344
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
345
+
346
+ """ Florence-2 Object Detection Output Format
347
+ {'<CAPTION_TO_PHRASE_GROUNDING>':
348
+ {
349
+ 'bboxes':
350
+ [
351
+ [34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594],
352
+ [1.5999999046325684, 4.079999923706055, 639.0399780273438, 305.03997802734375]
353
+ ],
354
+ 'labels': ['A green car', 'a yellow building']
355
+ }
356
+ }
357
+ """
358
+ assert text_input is not None, "Text input should not be None when calling phrase grounding pipeline."
359
+ results = results[task_prompt]
360
+ # parse florence-2 detection results
361
+ input_boxes = np.array(results["bboxes"])
362
+ class_names = results["labels"]
363
+ class_ids = np.array(list(range(len(class_names))))
364
+
365
+ # predict mask with SAM 2
366
+ sam2_predictor.set_image(np.array(image))
367
+ masks, scores, logits = sam2_predictor.predict(
368
+ point_coords=None,
369
+ point_labels=None,
370
+ box=input_boxes,
371
+ multimask_output=False,
372
+ )
373
+
374
+ if masks.ndim == 4:
375
+ masks = masks.squeeze(1)
376
+
377
+ # specify labels
378
+ labels = [
379
+ f"{class_name}" for class_name in class_names
380
+ ]
381
+
382
+ # visualization results
383
+ img = cv2.imread(image_path)
384
+ detections = sv.Detections(
385
+ xyxy=input_boxes,
386
+ mask=masks.astype(bool),
387
+ class_id=class_ids
388
+ )
389
+
390
+ box_annotator = sv.BoxAnnotator()
391
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
392
+
393
+ label_annotator = sv.LabelAnnotator()
394
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
395
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding.jpg"), annotated_frame)
396
+
397
+ mask_annotator = sv.MaskAnnotator()
398
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
399
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding_with_mask.jpg"), annotated_frame)
400
+
401
+ print(f'Successfully save annotated image to "{output_dir}"')
402
+
403
+
404
+ """
405
+ Pipeline 5: Referring Expression Segmentation
406
+
407
+ Note that Florence-2 directly support referring segmentation with polygon output format, which may be not that accurate,
408
+ therefore we try to decode box from polygon and use SAM 2 for mask prediction
409
+ """
410
+ def referring_expression_segmentation(
411
+ florence2_model,
412
+ florence2_processor,
413
+ sam2_predictor,
414
+ image_path,
415
+ task_prompt="<REFERRING_EXPRESSION_SEGMENTATION>",
416
+ text_input=None,
417
+ output_dir=OUTPUT_DIR
418
+ ):
419
+ # run florence-2 object detection in demo
420
+ image = Image.open(image_path).convert("RGB")
421
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
422
+
423
+ """ Florence-2 Object Detection Output Format
424
+ {'<REFERRING_EXPRESSION_SEGMENTATION>':
425
+ {
426
+ 'polygons': [[[...]]]
427
+ 'labels': ['']
428
+ }
429
+ }
430
+ """
431
+ assert text_input is not None, "Text input should not be None when calling referring segmentation pipeline."
432
+ results = results[task_prompt]
433
+ # parse florence-2 detection results
434
+ polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
435
+ class_names = [text_input]
436
+ class_ids = np.array(list(range(len(class_names))))
437
+
438
+ # parse polygon format to mask
439
+ img_width, img_height = image.size[0], image.size[1]
440
+ florence2_mask = np.zeros((img_height, img_width), dtype=np.uint8)
441
+ if len(polygon_points) < 3:
442
+ print("Invalid polygon:", polygon_points)
443
+ exit()
444
+ cv2.fillPoly(florence2_mask, [polygon_points], 1)
445
+ if florence2_mask.ndim == 2:
446
+ florence2_mask = florence2_mask[None]
447
+
448
+ # compute bounding box based on polygon points
449
+ x_min = np.min(polygon_points[:, 0])
450
+ y_min = np.min(polygon_points[:, 1])
451
+ x_max = np.max(polygon_points[:, 0])
452
+ y_max = np.max(polygon_points[:, 1])
453
+
454
+ input_boxes = np.array([[x_min, y_min, x_max, y_max]])
455
+
456
+ # predict mask with SAM 2
457
+ sam2_predictor.set_image(np.array(image))
458
+ sam2_masks, scores, logits = sam2_predictor.predict(
459
+ point_coords=None,
460
+ point_labels=None,
461
+ box=input_boxes,
462
+ multimask_output=False,
463
+ )
464
+
465
+ if sam2_masks.ndim == 4:
466
+ sam2_masks = sam2_masks.squeeze(1)
467
+
468
+ # specify labels
469
+ labels = [
470
+ f"{class_name}" for class_name in class_names
471
+ ]
472
+
473
+ # visualization florence2 mask
474
+ img = cv2.imread(image_path)
475
+ detections = sv.Detections(
476
+ xyxy=input_boxes,
477
+ mask=florence2_mask.astype(bool),
478
+ class_id=class_ids
479
+ )
480
+
481
+ box_annotator = sv.BoxAnnotator()
482
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
483
+
484
+ label_annotator = sv.LabelAnnotator()
485
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
486
+ cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box.jpg"), annotated_frame)
487
+
488
+ mask_annotator = sv.MaskAnnotator()
489
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
490
+ cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box_with_mask.jpg"), annotated_frame)
491
+
492
+ print(f'Successfully save florence-2 annotated image to "{output_dir}"')
493
+
494
+ # visualize sam2 mask
495
+ img = cv2.imread(image_path)
496
+ detections = sv.Detections(
497
+ xyxy=input_boxes,
498
+ mask=sam2_masks.astype(bool),
499
+ class_id=class_ids
500
+ )
501
+
502
+ box_annotator = sv.BoxAnnotator()
503
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
504
+
505
+ label_annotator = sv.LabelAnnotator()
506
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
507
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box.jpg"), annotated_frame)
508
+
509
+ mask_annotator = sv.MaskAnnotator()
510
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
511
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box_with_sam2_mask.jpg"), annotated_frame)
512
+
513
+ print(f'Successfully save sam2 annotated image to "{output_dir}"')
514
+
515
+
516
+ """
517
+ Pipeline 6: Open-Vocabulary Detection + Segmentation
518
+ """
519
+ def open_vocabulary_detection_and_segmentation(
520
+ florence2_model,
521
+ florence2_processor,
522
+ sam2_predictor,
523
+ image_path,
524
+ task_prompt="<OPEN_VOCABULARY_DETECTION>",
525
+ text_input=None,
526
+ output_dir=OUTPUT_DIR
527
+ ):
528
+ # run florence-2 object detection in demo
529
+ image = Image.open(image_path).convert("RGB")
530
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
531
+
532
+ """ Florence-2 Open-Vocabulary Detection Output Format
533
+ {'<OPEN_VOCABULARY_DETECTION>':
534
+ {
535
+ 'bboxes':
536
+ [
537
+ [34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594]
538
+ ],
539
+ 'bboxes_labels': ['A green car'],
540
+ 'polygons': [],
541
+ 'polygons_labels': []
542
+ }
543
+ }
544
+ """
545
+ assert text_input is not None, "Text input should not be None when calling open-vocabulary detection pipeline."
546
+ results = results[task_prompt]
547
+ # parse florence-2 detection results
548
+ input_boxes = np.array(results["bboxes"])
549
+ print(results)
550
+ class_names = results["bboxes_labels"]
551
+ class_ids = np.array(list(range(len(class_names))))
552
+
553
+ # predict mask with SAM 2
554
+ sam2_predictor.set_image(np.array(image))
555
+ masks, scores, logits = sam2_predictor.predict(
556
+ point_coords=None,
557
+ point_labels=None,
558
+ box=input_boxes,
559
+ multimask_output=False,
560
+ )
561
+
562
+ if masks.ndim == 4:
563
+ masks = masks.squeeze(1)
564
+
565
+ # specify labels
566
+ labels = [
567
+ f"{class_name}" for class_name in class_names
568
+ ]
569
+
570
+ # visualization results
571
+ img = cv2.imread(image_path)
572
+ detections = sv.Detections(
573
+ xyxy=input_boxes,
574
+ mask=masks.astype(bool),
575
+ class_id=class_ids
576
+ )
577
+
578
+ box_annotator = sv.BoxAnnotator()
579
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
580
+
581
+ label_annotator = sv.LabelAnnotator()
582
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
583
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame)
584
+
585
+ mask_annotator = sv.MaskAnnotator()
586
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
587
+ cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame)
588
+
589
+ print(f'Successfully save annotated image to "{output_dir}"')
590
+
591
+ if __name__ == "__main__":
592
+
593
+ parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
594
+ parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
595
+ parser.add_argument("--pipeline", type=str, default="object_detection_segmentation", required=True, help="path to image file")
596
+ parser.add_argument("--text_input", type=str, default=None, required=False, help="path to image file")
597
+ args = parser.parse_args()
598
+
599
+ IMAGE_PATH = args.image_path
600
+ PIPELINE = args.pipeline
601
+ INPUT_TEXT = args.text_input
602
+
603
+ print(f"Running pipeline: {PIPELINE} now.")
604
+
605
+ if PIPELINE == "object_detection_segmentation":
606
+ # pipeline-1: detection + segmentation
607
+ object_detection_and_segmentation(
608
+ florence2_model=florence2_model,
609
+ florence2_processor=florence2_processor,
610
+ sam2_predictor=sam2_predictor,
611
+ image_path=IMAGE_PATH
612
+ )
613
+ elif PIPELINE == "dense_region_caption_segmentation":
614
+ # pipeline-2: dense region caption + segmentation
615
+ dense_region_caption_and_segmentation(
616
+ florence2_model=florence2_model,
617
+ florence2_processor=florence2_processor,
618
+ sam2_predictor=sam2_predictor,
619
+ image_path=IMAGE_PATH
620
+ )
621
+ elif PIPELINE == "region_proposal_segmentation":
622
+ # pipeline-3: dense region caption + segmentation
623
+ region_proposal_and_segmentation(
624
+ florence2_model=florence2_model,
625
+ florence2_processor=florence2_processor,
626
+ sam2_predictor=sam2_predictor,
627
+ image_path=IMAGE_PATH
628
+ )
629
+ elif PIPELINE == "phrase_grounding_segmentation":
630
+ # pipeline-4: phrase grounding + segmentation
631
+ phrase_grounding_and_segmentation(
632
+ florence2_model=florence2_model,
633
+ florence2_processor=florence2_processor,
634
+ sam2_predictor=sam2_predictor,
635
+ image_path=IMAGE_PATH,
636
+ text_input=INPUT_TEXT
637
+ )
638
+ elif PIPELINE == "referring_expression_segmentation":
639
+ # pipeline-5: referring segmentation + segmentation
640
+ referring_expression_segmentation(
641
+ florence2_model=florence2_model,
642
+ florence2_processor=florence2_processor,
643
+ sam2_predictor=sam2_predictor,
644
+ image_path=IMAGE_PATH,
645
+ text_input=INPUT_TEXT
646
+ )
647
+ elif PIPELINE == "open_vocabulary_detection_segmentation":
648
+ # pipeline-6: open-vocabulary detection + segmentation
649
+ open_vocabulary_detection_and_segmentation(
650
+ florence2_model=florence2_model,
651
+ florence2_processor=florence2_processor,
652
+ sam2_predictor=sam2_predictor,
653
+ image_path=IMAGE_PATH,
654
+ text_input=INPUT_TEXT
655
+ )
656
+ else:
657
+ raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_gd1.5_demo.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk import DetectionTask
5
+ from dds_cloudapi_sdk import TextPrompt
6
+ from dds_cloudapi_sdk import DetectionModel
7
+ from dds_cloudapi_sdk import DetectionTarget
8
+
9
+ import os
10
+ import cv2
11
+ import json
12
+ import torch
13
+ import tempfile
14
+ import numpy as np
15
+ import supervision as sv
16
+ import pycocotools.mask as mask_util
17
+ from pathlib import Path
18
+ from PIL import Image
19
+ from sam2.build_sam import build_sam2
20
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
21
+
22
+ """
23
+ Hyper parameters
24
+ """
25
+ API_TOKEN = "Your API token"
26
+ TEXT_PROMPT = "car . building ."
27
+ IMG_PATH = "notebooks/images/cars.jpg"
28
+ SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
29
+ SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
30
+ GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
31
+ BOX_THRESHOLD = 0.2
32
+ WITH_SLICE_INFERENCE = False
33
+ SLICE_WH = (480, 480)
34
+ OVERLAP_RATIO = (0.2, 0.2)
35
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
+ OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
37
+ DUMP_JSON_RESULTS = True
38
+
39
+ # create output directory
40
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
41
+
42
+ """
43
+ Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
44
+ """
45
+ # Step 1: initialize the config
46
+ token = API_TOKEN
47
+ config = Config(token)
48
+
49
+ # Step 2: initialize the client
50
+ client = Client(config)
51
+
52
+ # Step 3: run the task by DetectionTask class
53
+ # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
54
+ # if you are processing local image file, upload them to DDS server to get the image url
55
+
56
+ classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
57
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
58
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
59
+
60
+ if WITH_SLICE_INFERENCE:
61
+ def callback(image_slice: np.ndarray) -> sv.Detections:
62
+ print("Inference on image slice")
63
+ # save the img as temp img file for GD-1.5 API usage
64
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
65
+ temp_filename = tmpfile.name
66
+ cv2.imwrite(temp_filename, image_slice)
67
+ image_url = client.upload_file(temp_filename)
68
+ task = DetectionTask(
69
+ image_url=image_url,
70
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
71
+ targets=[DetectionTarget.BBox], # detect bbox
72
+ model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
73
+ bbox_threshold=BOX_THRESHOLD, # box confidence threshold
74
+ )
75
+ client.run_task(task)
76
+ result = task.result
77
+ # detele the tempfile
78
+ os.remove(temp_filename)
79
+
80
+ input_boxes = []
81
+ confidences = []
82
+ class_ids = []
83
+ objects = result.objects
84
+ for idx, obj in enumerate(objects):
85
+ input_boxes.append(obj.bbox)
86
+ confidences.append(obj.score)
87
+ cls_name = obj.category.lower().strip()
88
+ class_ids.append(class_name_to_id[cls_name])
89
+ # ensure input_boxes with shape (_, 4)
90
+ input_boxes = np.array(input_boxes).reshape(-1, 4)
91
+ class_ids = np.array(class_ids)
92
+ confidences = np.array(confidences)
93
+ return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
94
+
95
+ slicer = sv.InferenceSlicer(
96
+ callback=callback,
97
+ slice_wh=SLICE_WH,
98
+ overlap_ratio_wh=OVERLAP_RATIO,
99
+ iou_threshold=0.5,
100
+ overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
101
+ )
102
+ detections = slicer(cv2.imread(IMG_PATH))
103
+ class_names = [class_id_to_name[id] for id in detections.class_id]
104
+ confidences = detections.confidence
105
+ class_ids = detections.class_id
106
+ input_boxes = detections.xyxy
107
+ else:
108
+ image_url = client.upload_file(IMG_PATH)
109
+
110
+ task = DetectionTask(
111
+ image_url=image_url,
112
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
113
+ targets=[DetectionTarget.BBox], # detect bbox
114
+ model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
115
+ bbox_threshold=BOX_THRESHOLD, # box confidence threshold
116
+ )
117
+
118
+ client.run_task(task)
119
+ result = task.result
120
+
121
+ objects = result.objects # the list of detected objects
122
+
123
+
124
+ input_boxes = []
125
+ confidences = []
126
+ class_names = []
127
+ class_ids = []
128
+
129
+ for idx, obj in enumerate(objects):
130
+ input_boxes.append(obj.bbox)
131
+ confidences.append(obj.score)
132
+ cls_name = obj.category.lower().strip()
133
+ class_names.append(cls_name)
134
+ class_ids.append(class_name_to_id[cls_name])
135
+
136
+ input_boxes = np.array(input_boxes)
137
+ class_ids = np.array(class_ids)
138
+
139
+ """
140
+ Init SAM 2 Model and Predict Mask with Box Prompt
141
+ """
142
+
143
+ # environment settings
144
+ # use bfloat16
145
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
146
+
147
+ if torch.cuda.get_device_properties(0).major >= 8:
148
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
149
+ torch.backends.cuda.matmul.allow_tf32 = True
150
+ torch.backends.cudnn.allow_tf32 = True
151
+
152
+ # build SAM2 image predictor
153
+ sam2_checkpoint = SAM2_CHECKPOINT
154
+ model_cfg = SAM2_MODEL_CONFIG
155
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
156
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
157
+
158
+ image = Image.open(IMG_PATH)
159
+
160
+ sam2_predictor.set_image(np.array(image.convert("RGB")))
161
+
162
+ masks, scores, logits = sam2_predictor.predict(
163
+ point_coords=None,
164
+ point_labels=None,
165
+ box=input_boxes,
166
+ multimask_output=False,
167
+ )
168
+
169
+
170
+ """
171
+ Post-process the output of the model to get the masks, scores, and logits for visualization
172
+ """
173
+ # convert the shape to (n, H, W)
174
+ if masks.ndim == 4:
175
+ masks = masks.squeeze(1)
176
+
177
+
178
+ """
179
+ Visualization the Predict Results
180
+ """
181
+
182
+ labels = [
183
+ f"{class_name} {confidence:.2f}"
184
+ for class_name, confidence
185
+ in zip(class_names, confidences)
186
+ ]
187
+
188
+ """
189
+ Visualize image with supervision useful API
190
+ """
191
+ img = cv2.imread(IMG_PATH)
192
+ detections = sv.Detections(
193
+ xyxy=input_boxes, # (n, 4)
194
+ mask=masks.astype(bool), # (n, h, w)
195
+ class_id=class_ids
196
+ )
197
+
198
+ box_annotator = sv.BoxAnnotator()
199
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
200
+
201
+ label_annotator = sv.LabelAnnotator()
202
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
203
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
204
+
205
+ mask_annotator = sv.MaskAnnotator()
206
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
207
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
208
+
209
+ print(f'Annotated image has already been saved as to "{OUTPUT_DIR}"')
210
+
211
+ """
212
+ Dump the results in standard format and save as json files
213
+ """
214
+
215
+ def single_mask_to_rle(mask):
216
+ rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
217
+ rle["counts"] = rle["counts"].decode("utf-8")
218
+ return rle
219
+
220
+ if DUMP_JSON_RESULTS:
221
+ print("Start dumping the annotation...")
222
+ # convert mask into rle format
223
+ mask_rles = [single_mask_to_rle(mask) for mask in masks]
224
+
225
+ input_boxes = input_boxes.tolist()
226
+ scores = scores.tolist()
227
+ # FIXME: class_names should be a list of strings without spaces
228
+ class_names = [class_name.strip() for class_name in class_names]
229
+ # save the results in standard format
230
+ results = {
231
+ "image_path": IMG_PATH,
232
+ "annotations" : [
233
+ {
234
+ "class_name": class_name,
235
+ "bbox": box,
236
+ "segmentation": mask_rle,
237
+ "score": score,
238
+ }
239
+ for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
240
+ ],
241
+ "box_format": "xyxy",
242
+ "img_width": image.width,
243
+ "img_height": image.height,
244
+ }
245
+
246
+ with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
247
+ json.dump(results, f, indent=4)
248
+
249
+ print(f'Annotation has already been saved to "{OUTPUT_DIR}"')
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_hf_model_demo.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import cv2
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import supervision as sv
8
+ import pycocotools.mask as mask_util
9
+ from pathlib import Path
10
+ from supervision.draw.color import ColorPalette
11
+ from utils.supervision_utils import CUSTOM_COLOR_MAP
12
+ from PIL import Image
13
+ from sam2.build_sam import build_sam2
14
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
15
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
16
+
17
+ """
18
+ Hyper parameters
19
+ """
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--grounding-model', default="IDEA-Research/grounding-dino-tiny")
22
+ parser.add_argument("--text-prompt", default="car. tire.")
23
+ parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
24
+ parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
25
+ parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
26
+ parser.add_argument("--output-dir", default="outputs/test_sam2.1")
27
+ parser.add_argument("--no-dump-json", action="store_true")
28
+ parser.add_argument("--force-cpu", action="store_true")
29
+ args = parser.parse_args()
30
+
31
+ GROUNDING_MODEL = args.grounding_model
32
+ TEXT_PROMPT = args.text_prompt
33
+ IMG_PATH = args.img_path
34
+ SAM2_CHECKPOINT = args.sam2_checkpoint
35
+ SAM2_MODEL_CONFIG = args.sam2_model_config
36
+ DEVICE = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
37
+ OUTPUT_DIR = Path(args.output_dir)
38
+ DUMP_JSON_RESULTS = not args.no_dump_json
39
+
40
+ # create output directory
41
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
42
+
43
+ # environment settings
44
+ # use bfloat16
45
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
46
+
47
+ if torch.cuda.get_device_properties(0).major >= 8:
48
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ torch.backends.cudnn.allow_tf32 = True
51
+
52
+ # build SAM2 image predictor
53
+ sam2_checkpoint = SAM2_CHECKPOINT
54
+ model_cfg = SAM2_MODEL_CONFIG
55
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
56
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
57
+
58
+ # build grounding dino from huggingface
59
+ model_id = GROUNDING_MODEL
60
+ processor = AutoProcessor.from_pretrained(model_id)
61
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
62
+
63
+
64
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
65
+ # VERY important: text queries need to be lowercased + end with a dot
66
+ text = TEXT_PROMPT
67
+ img_path = IMG_PATH
68
+
69
+ image = Image.open(img_path)
70
+
71
+ sam2_predictor.set_image(np.array(image.convert("RGB")))
72
+
73
+ inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
74
+ with torch.no_grad():
75
+ outputs = grounding_model(**inputs)
76
+
77
+ results = processor.post_process_grounded_object_detection(
78
+ outputs,
79
+ inputs.input_ids,
80
+ box_threshold=0.4,
81
+ text_threshold=0.3,
82
+ target_sizes=[image.size[::-1]]
83
+ )
84
+
85
+ """
86
+ Results is a list of dict with the following structure:
87
+ [
88
+ {
89
+ 'scores': tensor([0.7969, 0.6469, 0.6002, 0.4220], device='cuda:0'),
90
+ 'labels': ['car', 'tire', 'tire', 'tire'],
91
+ 'boxes': tensor([[ 89.3244, 278.6940, 1710.3505, 851.5143],
92
+ [1392.4701, 554.4064, 1628.6133, 777.5872],
93
+ [ 436.1182, 621.8940, 676.5255, 851.6897],
94
+ [1236.0990, 688.3547, 1400.2427, 753.1256]], device='cuda:0')
95
+ }
96
+ ]
97
+ """
98
+
99
+ # get the box prompt for SAM 2
100
+ input_boxes = results[0]["boxes"].cpu().numpy()
101
+
102
+ masks, scores, logits = sam2_predictor.predict(
103
+ point_coords=None,
104
+ point_labels=None,
105
+ box=input_boxes,
106
+ multimask_output=False,
107
+ )
108
+
109
+
110
+ """
111
+ Post-process the output of the model to get the masks, scores, and logits for visualization
112
+ """
113
+ # convert the shape to (n, H, W)
114
+ if masks.ndim == 4:
115
+ masks = masks.squeeze(1)
116
+
117
+
118
+ confidences = results[0]["scores"].cpu().numpy().tolist()
119
+ class_names = results[0]["labels"]
120
+ class_ids = np.array(list(range(len(class_names))))
121
+
122
+ labels = [
123
+ f"{class_name} {confidence:.2f}"
124
+ for class_name, confidence
125
+ in zip(class_names, confidences)
126
+ ]
127
+
128
+ """
129
+ Visualize image with supervision useful API
130
+ """
131
+ img = cv2.imread(img_path)
132
+ detections = sv.Detections(
133
+ xyxy=input_boxes, # (n, 4)
134
+ mask=masks.astype(bool), # (n, h, w)
135
+ class_id=class_ids
136
+ )
137
+
138
+ """
139
+ Note that if you want to use default color map,
140
+ you can set color=ColorPalette.DEFAULT
141
+ """
142
+ box_annotator = sv.BoxAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
143
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
144
+
145
+ label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
146
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
147
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
148
+
149
+ mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
150
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
151
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
152
+
153
+
154
+ """
155
+ Dump the results in standard format and save as json files
156
+ """
157
+
158
+ def single_mask_to_rle(mask):
159
+ rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
160
+ rle["counts"] = rle["counts"].decode("utf-8")
161
+ return rle
162
+
163
+ if DUMP_JSON_RESULTS:
164
+ # convert mask into rle format
165
+ mask_rles = [single_mask_to_rle(mask) for mask in masks]
166
+
167
+ input_boxes = input_boxes.tolist()
168
+ scores = scores.tolist()
169
+ # save the results in standard format
170
+ results = {
171
+ "image_path": img_path,
172
+ "annotations" : [
173
+ {
174
+ "class_name": class_name,
175
+ "bbox": box,
176
+ "segmentation": mask_rle,
177
+ "score": score,
178
+ }
179
+ for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
180
+ ],
181
+ "box_format": "xyxy",
182
+ "img_width": image.width,
183
+ "img_height": image.height,
184
+ }
185
+
186
+ with open(os.path.join(OUTPUT_DIR, "grounded_sam2_hf_model_demo_results.json"), "w") as f:
187
+ json.dump(results, f, indent=4)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_local_demo.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ import supervision as sv
7
+ import pycocotools.mask as mask_util
8
+ from pathlib import Path
9
+ from torchvision.ops import box_convert
10
+ from sam2.build_sam import build_sam2
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
13
+
14
+ """
15
+ Hyper parameters
16
+ """
17
+ TEXT_PROMPT = "car. tire."
18
+ IMG_PATH = "notebooks/images/truck.jpg"
19
+ SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
20
+ SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
21
+ GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
22
+ GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
23
+ BOX_THRESHOLD = 0.35
24
+ TEXT_THRESHOLD = 0.25
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
27
+ DUMP_JSON_RESULTS = True
28
+
29
+ # create output directory
30
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
31
+
32
+ # environment settings
33
+ # use bfloat16
34
+
35
+ # build SAM2 image predictor
36
+ sam2_checkpoint = SAM2_CHECKPOINT
37
+ model_cfg = SAM2_MODEL_CONFIG
38
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
39
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
40
+
41
+ # build grounding dino model
42
+ grounding_model = load_model(
43
+ model_config_path=GROUNDING_DINO_CONFIG,
44
+ model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
45
+ device=DEVICE
46
+ )
47
+
48
+
49
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
50
+ # VERY important: text queries need to be lowercased + end with a dot
51
+ text = TEXT_PROMPT
52
+ img_path = IMG_PATH
53
+
54
+ image_source, image = load_image(img_path)
55
+
56
+ sam2_predictor.set_image(image_source)
57
+
58
+ boxes, confidences, labels = predict(
59
+ model=grounding_model,
60
+ image=image,
61
+ caption=text,
62
+ box_threshold=BOX_THRESHOLD,
63
+ text_threshold=TEXT_THRESHOLD,
64
+ )
65
+
66
+ # process the box prompt for SAM 2
67
+ h, w, _ = image_source.shape
68
+ boxes = boxes * torch.Tensor([w, h, w, h])
69
+ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
70
+
71
+
72
+ # FIXME: figure how does this influence the G-DINO model
73
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
74
+
75
+ if torch.cuda.get_device_properties(0).major >= 8:
76
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
77
+ torch.backends.cuda.matmul.allow_tf32 = True
78
+ torch.backends.cudnn.allow_tf32 = True
79
+
80
+ masks, scores, logits = sam2_predictor.predict(
81
+ point_coords=None,
82
+ point_labels=None,
83
+ box=input_boxes,
84
+ multimask_output=False,
85
+ )
86
+
87
+ """
88
+ Post-process the output of the model to get the masks, scores, and logits for visualization
89
+ """
90
+ # convert the shape to (n, H, W)
91
+ if masks.ndim == 4:
92
+ masks = masks.squeeze(1)
93
+
94
+
95
+ confidences = confidences.numpy().tolist()
96
+ class_names = labels
97
+
98
+ class_ids = np.array(list(range(len(class_names))))
99
+
100
+ labels = [
101
+ f"{class_name} {confidence:.2f}"
102
+ for class_name, confidence
103
+ in zip(class_names, confidences)
104
+ ]
105
+
106
+ """
107
+ Visualize image with supervision useful API
108
+ """
109
+ img = cv2.imread(img_path)
110
+ detections = sv.Detections(
111
+ xyxy=input_boxes, # (n, 4)
112
+ mask=masks.astype(bool), # (n, h, w)
113
+ class_id=class_ids
114
+ )
115
+
116
+ box_annotator = sv.BoxAnnotator()
117
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
118
+
119
+ label_annotator = sv.LabelAnnotator()
120
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
121
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
122
+
123
+ mask_annotator = sv.MaskAnnotator()
124
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
125
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
126
+
127
+ """
128
+ Dump the results in standard format and save as json files
129
+ """
130
+
131
+ def single_mask_to_rle(mask):
132
+ rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
133
+ rle["counts"] = rle["counts"].decode("utf-8")
134
+ return rle
135
+
136
+ if DUMP_JSON_RESULTS:
137
+ # convert mask into rle format
138
+ mask_rles = [single_mask_to_rle(mask) for mask in masks]
139
+
140
+ input_boxes = input_boxes.tolist()
141
+ scores = scores.tolist()
142
+ # save the results in standard format
143
+ results = {
144
+ "image_path": img_path,
145
+ "annotations" : [
146
+ {
147
+ "class_name": class_name,
148
+ "bbox": box,
149
+ "segmentation": mask_rle,
150
+ "score": score,
151
+ }
152
+ for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
153
+ ],
154
+ "box_format": "xyxy",
155
+ "img_width": w,
156
+ "img_height": h,
157
+ }
158
+
159
+ with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
160
+ json.dump(results, f, indent=4)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
10
+ from utils.track_utils import sample_points_from_masks
11
+ from utils.video_utils import create_video_from_images
12
+
13
+
14
+ """
15
+ Step 1: Environment settings and model initialization
16
+ """
17
+ # use bfloat16 for the entire notebook
18
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
19
+
20
+ if torch.cuda.get_device_properties(0).major >= 8:
21
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.allow_tf32 = True
24
+
25
+ # init sam image predictor and video predictor model
26
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
27
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
28
+
29
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
30
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
31
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
32
+
33
+
34
+ # init grounding dino model from huggingface
35
+ model_id = "IDEA-Research/grounding-dino-tiny"
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ processor = AutoProcessor.from_pretrained(model_id)
38
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
39
+
40
+
41
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
42
+ # VERY important: text queries need to be lowercased + end with a dot
43
+ text = "car."
44
+
45
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
46
+
47
+ video_dir = "notebooks/videos/car"
48
+
49
+ # scan all the JPEG frame names in this directory
50
+ frame_names = [
51
+ p for p in os.listdir(video_dir)
52
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
53
+ ]
54
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
55
+
56
+ # init video predictor state
57
+ inference_state = video_predictor.init_state(video_path=video_dir)
58
+
59
+ ann_frame_idx = 0 # the frame index we interact with
60
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
61
+
62
+
63
+ """
64
+ Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
65
+ """
66
+
67
+ # prompt grounding dino to get the box coordinates on specific frame
68
+ img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
69
+ image = Image.open(img_path)
70
+
71
+ # run Grounding DINO on the image
72
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device)
73
+ with torch.no_grad():
74
+ outputs = grounding_model(**inputs)
75
+
76
+ results = processor.post_process_grounded_object_detection(
77
+ outputs,
78
+ inputs.input_ids,
79
+ box_threshold=0.25,
80
+ text_threshold=0.3,
81
+ target_sizes=[image.size[::-1]]
82
+ )
83
+
84
+ # prompt SAM image predictor to get the mask for the object
85
+ image_predictor.set_image(np.array(image.convert("RGB")))
86
+
87
+ # process the detection results
88
+ input_boxes = results[0]["boxes"].cpu().numpy()
89
+ OBJECTS = results[0]["labels"]
90
+
91
+ # prompt SAM 2 image predictor to get the mask for the object
92
+ masks, scores, logits = image_predictor.predict(
93
+ point_coords=None,
94
+ point_labels=None,
95
+ box=input_boxes,
96
+ multimask_output=False,
97
+ )
98
+
99
+ # convert the mask shape to (n, H, W)
100
+ if masks.ndim == 3:
101
+ masks = masks[None]
102
+ scores = scores[None]
103
+ logits = logits[None]
104
+ elif masks.ndim == 4:
105
+ masks = masks.squeeze(1)
106
+
107
+ """
108
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
109
+ """
110
+
111
+ PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
112
+
113
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
114
+
115
+ # If you are using point prompts, we uniformly sample positive points based on the mask
116
+ if PROMPT_TYPE_FOR_VIDEO == "point":
117
+ # sample the positive points from mask for each objects
118
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
119
+
120
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
121
+ labels = np.ones((points.shape[0]), dtype=np.int32)
122
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
123
+ inference_state=inference_state,
124
+ frame_idx=ann_frame_idx,
125
+ obj_id=object_id,
126
+ points=points,
127
+ labels=labels,
128
+ )
129
+ # Using box prompt
130
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
131
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
132
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
133
+ inference_state=inference_state,
134
+ frame_idx=ann_frame_idx,
135
+ obj_id=object_id,
136
+ box=box,
137
+ )
138
+ # Using mask prompt is a more straightforward way
139
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
140
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
141
+ labels = np.ones((1), dtype=np.int32)
142
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
143
+ inference_state=inference_state,
144
+ frame_idx=ann_frame_idx,
145
+ obj_id=object_id,
146
+ mask=mask
147
+ )
148
+ else:
149
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
150
+
151
+
152
+ """
153
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
154
+ """
155
+ video_segments = {} # video_segments contains the per-frame segmentation results
156
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
157
+ video_segments[out_frame_idx] = {
158
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
159
+ for i, out_obj_id in enumerate(out_obj_ids)
160
+ }
161
+
162
+ """
163
+ Step 5: Visualize the segment results across the video and save them
164
+ """
165
+
166
+ save_dir = "./tracking_results"
167
+
168
+ if not os.path.exists(save_dir):
169
+ os.makedirs(save_dir)
170
+
171
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
172
+ for frame_idx, segments in video_segments.items():
173
+ img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
174
+
175
+ object_ids = list(segments.keys())
176
+ masks = list(segments.values())
177
+ masks = np.concatenate(masks, axis=0)
178
+
179
+ detections = sv.Detections(
180
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
181
+ mask=masks, # (n, h, w)
182
+ class_id=np.array(object_ids, dtype=np.int32),
183
+ )
184
+ box_annotator = sv.BoxAnnotator()
185
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
186
+ label_annotator = sv.LabelAnnotator()
187
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
188
+ mask_annotator = sv.MaskAnnotator()
189
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
190
+ cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
191
+
192
+
193
+ """
194
+ Step 6: Convert the annotated frames to video
195
+ """
196
+
197
+ output_video_path = "./children_tracking_demo_video.mp4"
198
+ create_video_from_images(save_dir, output_video_path)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_dinox.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
5
+ from dds_cloudapi_sdk.tasks.types import DetectionTarget
6
+ from dds_cloudapi_sdk import TextPrompt
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+ import supervision as sv
13
+
14
+ from pathlib import Path
15
+ from tqdm import tqdm
16
+ from PIL import Image
17
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
18
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
19
+ from utils.track_utils import sample_points_from_masks
20
+ from utils.video_utils import create_video_from_images
21
+
22
+ """
23
+ Hyperparam for Ground and Tracking
24
+ """
25
+ VIDEO_PATH = "./assets/hippopotamus.mp4"
26
+ TEXT_PROMPT = "hippopotamus."
27
+ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
28
+ SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
29
+ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
30
+ API_TOKEN_FOR_DINOX = "Your API token"
31
+ PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
32
+ BOX_THRESHOLD = 0.2
33
+
34
+ """
35
+ Step 1: Environment settings and model initialization for SAM 2
36
+ """
37
+ # use bfloat16 for the entire notebook
38
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
39
+
40
+ if torch.cuda.get_device_properties(0).major >= 8:
41
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+ torch.backends.cudnn.allow_tf32 = True
44
+
45
+ # init sam image predictor and video predictor model
46
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
47
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
48
+
49
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
50
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
51
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
52
+
53
+
54
+ # # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
55
+ # video_dir = "notebooks/videos/bedroom"
56
+
57
+ """
58
+ Custom video input directly using video files
59
+ """
60
+ video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
61
+ print(video_info)
62
+ frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
63
+
64
+ # saving video to frames
65
+ source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
66
+ source_frames.mkdir(parents=True, exist_ok=True)
67
+
68
+ with sv.ImageSink(
69
+ target_dir_path=source_frames,
70
+ overwrite=True,
71
+ image_name_pattern="{:05d}.jpg"
72
+ ) as sink:
73
+ for frame in tqdm(frame_generator, desc="Saving Video Frames"):
74
+ sink.save_image(frame)
75
+
76
+ # scan all the JPEG frame names in this directory
77
+ frame_names = [
78
+ p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
79
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
80
+ ]
81
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
82
+
83
+ # init video predictor state
84
+ inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
85
+
86
+ ann_frame_idx = 0 # the frame index we interact with
87
+ """
88
+ Step 2: Prompt DINO-X with Cloud API for box coordinates
89
+ """
90
+
91
+ # prompt grounding dino to get the box coordinates on specific frame
92
+ img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
93
+ image = Image.open(img_path)
94
+
95
+ # Step 1: initialize the config
96
+ config = Config(API_TOKEN_FOR_DINOX)
97
+
98
+ # Step 2: initialize the client
99
+ client = Client(config)
100
+
101
+ # Step 3: run the task by DetectionTask class
102
+ # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
103
+ # if you are processing local image file, upload them to DDS server to get the image url
104
+ image_url = client.upload_file(img_path)
105
+
106
+ task = DinoxTask(
107
+ image_url=image_url,
108
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
109
+ bbox_threshold=0.25,
110
+ targets=[DetectionTarget.BBox],
111
+ )
112
+
113
+ client.run_task(task)
114
+ result = task.result
115
+
116
+ objects = result.objects # the list of detected objects
117
+
118
+
119
+ input_boxes = []
120
+ confidences = []
121
+ class_names = []
122
+
123
+ for idx, obj in enumerate(objects):
124
+ input_boxes.append(obj.bbox)
125
+ confidences.append(obj.score)
126
+ class_names.append(obj.category)
127
+
128
+ input_boxes = np.array(input_boxes)
129
+
130
+ print(input_boxes)
131
+
132
+ # prompt SAM image predictor to get the mask for the object
133
+ image_predictor.set_image(np.array(image.convert("RGB")))
134
+
135
+ # process the detection results
136
+ OBJECTS = class_names
137
+
138
+ print(OBJECTS)
139
+
140
+ # prompt SAM 2 image predictor to get the mask for the object
141
+ masks, scores, logits = image_predictor.predict(
142
+ point_coords=None,
143
+ point_labels=None,
144
+ box=input_boxes,
145
+ multimask_output=False,
146
+ )
147
+ # convert the mask shape to (n, H, W)
148
+ if masks.ndim == 4:
149
+ masks = masks.squeeze(1)
150
+
151
+ """
152
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
153
+ """
154
+
155
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
156
+
157
+ # If you are using point prompts, we uniformly sample positive points based on the mask
158
+ if PROMPT_TYPE_FOR_VIDEO == "point":
159
+ # sample the positive points from mask for each objects
160
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
161
+
162
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
163
+ labels = np.ones((points.shape[0]), dtype=np.int32)
164
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
165
+ inference_state=inference_state,
166
+ frame_idx=ann_frame_idx,
167
+ obj_id=object_id,
168
+ points=points,
169
+ labels=labels,
170
+ )
171
+ # Using box prompt
172
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
173
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
174
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
175
+ inference_state=inference_state,
176
+ frame_idx=ann_frame_idx,
177
+ obj_id=object_id,
178
+ box=box,
179
+ )
180
+ # Using mask prompt is a more straightforward way
181
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
182
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
183
+ labels = np.ones((1), dtype=np.int32)
184
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
185
+ inference_state=inference_state,
186
+ frame_idx=ann_frame_idx,
187
+ obj_id=object_id,
188
+ mask=mask
189
+ )
190
+ else:
191
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
192
+
193
+ """
194
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
195
+ """
196
+ video_segments = {} # video_segments contains the per-frame segmentation results
197
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
198
+ video_segments[out_frame_idx] = {
199
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
200
+ for i, out_obj_id in enumerate(out_obj_ids)
201
+ }
202
+
203
+ """
204
+ Step 5: Visualize the segment results across the video and save them
205
+ """
206
+
207
+ if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
208
+ os.makedirs(SAVE_TRACKING_RESULTS_DIR)
209
+
210
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
211
+
212
+ for frame_idx, segments in video_segments.items():
213
+ img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
214
+
215
+ object_ids = list(segments.keys())
216
+ masks = list(segments.values())
217
+ masks = np.concatenate(masks, axis=0)
218
+
219
+ detections = sv.Detections(
220
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
221
+ mask=masks, # (n, h, w)
222
+ class_id=np.array(object_ids, dtype=np.int32),
223
+ )
224
+ box_annotator = sv.BoxAnnotator()
225
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
226
+ label_annotator = sv.LabelAnnotator()
227
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
228
+ mask_annotator = sv.MaskAnnotator()
229
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
230
+ cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
231
+
232
+
233
+ """
234
+ Step 6: Convert the annotated frames to video
235
+ """
236
+
237
+ create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
+
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
13
+ from utils.track_utils import sample_points_from_masks
14
+ from utils.video_utils import create_video_from_images
15
+
16
+ """
17
+ Hyperparam for Ground and Tracking
18
+ """
19
+ MODEL_ID = "IDEA-Research/grounding-dino-tiny"
20
+ VIDEO_PATH = "./assets/hippopotamus.mp4"
21
+ TEXT_PROMPT = "hippopotamus."
22
+ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
23
+ SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
24
+ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
25
+ PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
26
+
27
+ """
28
+ Step 1: Environment settings and model initialization for SAM 2
29
+ """
30
+ # use bfloat16 for the entire notebook
31
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
32
+
33
+ if torch.cuda.get_device_properties(0).major >= 8:
34
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
35
+ torch.backends.cuda.matmul.allow_tf32 = True
36
+ torch.backends.cudnn.allow_tf32 = True
37
+
38
+ # init sam image predictor and video predictor model
39
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
40
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
41
+
42
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
43
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
44
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
45
+
46
+ # build grounding dino from huggingface
47
+ model_id = MODEL_ID
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ processor = AutoProcessor.from_pretrained(model_id)
50
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
51
+
52
+
53
+ """
54
+ Custom video input directly using video files
55
+ """
56
+ video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
57
+ print(video_info)
58
+ frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
59
+
60
+ # saving video to frames
61
+ source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
62
+ source_frames.mkdir(parents=True, exist_ok=True)
63
+
64
+ with sv.ImageSink(
65
+ target_dir_path=source_frames,
66
+ overwrite=True,
67
+ image_name_pattern="{:05d}.jpg"
68
+ ) as sink:
69
+ for frame in tqdm(frame_generator, desc="Saving Video Frames"):
70
+ sink.save_image(frame)
71
+
72
+ # scan all the JPEG frame names in this directory
73
+ frame_names = [
74
+ p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
75
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
76
+ ]
77
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
78
+
79
+ # init video predictor state
80
+ inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
81
+
82
+ ann_frame_idx = 0 # the frame index we interact with
83
+ """
84
+ Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
85
+ """
86
+
87
+ # prompt grounding dino to get the box coordinates on specific frame
88
+ img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
89
+ image = Image.open(img_path)
90
+ inputs = processor(images=image, text=TEXT_PROMPT, return_tensors="pt").to(device)
91
+ with torch.no_grad():
92
+ outputs = grounding_model(**inputs)
93
+
94
+ results = processor.post_process_grounded_object_detection(
95
+ outputs,
96
+ inputs.input_ids,
97
+ box_threshold=0.4,
98
+ text_threshold=0.3,
99
+ target_sizes=[image.size[::-1]]
100
+ )
101
+
102
+ input_boxes = results[0]["boxes"].cpu().numpy()
103
+ confidences = results[0]["scores"].cpu().numpy().tolist()
104
+ class_names = results[0]["labels"]
105
+
106
+ print(input_boxes)
107
+
108
+ # prompt SAM image predictor to get the mask for the object
109
+ image_predictor.set_image(np.array(image.convert("RGB")))
110
+
111
+ # process the detection results
112
+ OBJECTS = class_names
113
+
114
+ print(OBJECTS)
115
+
116
+ # prompt SAM 2 image predictor to get the mask for the object
117
+ masks, scores, logits = image_predictor.predict(
118
+ point_coords=None,
119
+ point_labels=None,
120
+ box=input_boxes,
121
+ multimask_output=False,
122
+ )
123
+ # convert the mask shape to (n, H, W)
124
+ if masks.ndim == 4:
125
+ masks = masks.squeeze(1)
126
+
127
+ """
128
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
129
+ """
130
+
131
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
132
+
133
+ # If you are using point prompts, we uniformly sample positive points based on the mask
134
+ if PROMPT_TYPE_FOR_VIDEO == "point":
135
+ # sample the positive points from mask for each objects
136
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
137
+
138
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
139
+ labels = np.ones((points.shape[0]), dtype=np.int32)
140
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
141
+ inference_state=inference_state,
142
+ frame_idx=ann_frame_idx,
143
+ obj_id=object_id,
144
+ points=points,
145
+ labels=labels,
146
+ )
147
+ # Using box prompt
148
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
149
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
150
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
151
+ inference_state=inference_state,
152
+ frame_idx=ann_frame_idx,
153
+ obj_id=object_id,
154
+ box=box,
155
+ )
156
+ # Using mask prompt is a more straightforward way
157
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
158
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
159
+ labels = np.ones((1), dtype=np.int32)
160
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
161
+ inference_state=inference_state,
162
+ frame_idx=ann_frame_idx,
163
+ obj_id=object_id,
164
+ mask=mask
165
+ )
166
+ else:
167
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
168
+
169
+
170
+ """
171
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
172
+ """
173
+ video_segments = {} # video_segments contains the per-frame segmentation results
174
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
175
+ video_segments[out_frame_idx] = {
176
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
177
+ for i, out_obj_id in enumerate(out_obj_ids)
178
+ }
179
+
180
+ """
181
+ Step 5: Visualize the segment results across the video and save them
182
+ """
183
+
184
+ if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
185
+ os.makedirs(SAVE_TRACKING_RESULTS_DIR)
186
+
187
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
188
+
189
+ for frame_idx, segments in video_segments.items():
190
+ img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
191
+
192
+ object_ids = list(segments.keys())
193
+ masks = list(segments.values())
194
+ masks = np.concatenate(masks, axis=0)
195
+
196
+ detections = sv.Detections(
197
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
198
+ mask=masks, # (n, h, w)
199
+ class_id=np.array(object_ids, dtype=np.int32),
200
+ )
201
+ box_annotator = sv.BoxAnnotator()
202
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
203
+ label_annotator = sv.LabelAnnotator()
204
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
205
+ mask_annotator = sv.MaskAnnotator()
206
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
207
+ cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
208
+
209
+
210
+ """
211
+ Step 6: Convert the annotated frames to video
212
+ """
213
+
214
+ create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
+ from torchvision.ops import box_convert
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
13
+ from utils.track_utils import sample_points_from_masks
14
+ from utils.video_utils import create_video_from_images
15
+
16
+ """
17
+ Hyperparam for Ground and Tracking
18
+ """
19
+ GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
20
+ GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
21
+ BOX_THRESHOLD = 0.35
22
+ TEXT_THRESHOLD = 0.25
23
+ VIDEO_PATH = "./assets/hippopotamus.mp4"
24
+ TEXT_PROMPT = "hippopotamus."
25
+ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
26
+ SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
27
+ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
28
+ PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ """
32
+ Step 1: Environment settings and model initialization for Grounding DINO and SAM 2
33
+ """
34
+ # build grounding dino model from local path
35
+ grounding_model = load_model(
36
+ model_config_path=GROUNDING_DINO_CONFIG,
37
+ model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
38
+ device=DEVICE
39
+ )
40
+
41
+
42
+ # init sam image predictor and video predictor model
43
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
44
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
45
+
46
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
47
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
48
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
49
+
50
+
51
+ """
52
+ Custom video input directly using video files
53
+ """
54
+ video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
55
+ print(video_info)
56
+ frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
57
+
58
+ # saving video to frames
59
+ source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
60
+ source_frames.mkdir(parents=True, exist_ok=True)
61
+
62
+ with sv.ImageSink(
63
+ target_dir_path=source_frames,
64
+ overwrite=True,
65
+ image_name_pattern="{:05d}.jpg"
66
+ ) as sink:
67
+ for frame in tqdm(frame_generator, desc="Saving Video Frames"):
68
+ sink.save_image(frame)
69
+
70
+ # scan all the JPEG frame names in this directory
71
+ frame_names = [
72
+ p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
73
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
74
+ ]
75
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
76
+
77
+ # init video predictor state
78
+ inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
79
+
80
+ ann_frame_idx = 0 # the frame index we interact with
81
+ """
82
+ Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
83
+ """
84
+
85
+ # prompt grounding dino to get the box coordinates on specific frame
86
+ img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
87
+ image_source, image = load_image(img_path)
88
+
89
+ boxes, confidences, labels = predict(
90
+ model=grounding_model,
91
+ image=image,
92
+ caption=TEXT_PROMPT,
93
+ box_threshold=BOX_THRESHOLD,
94
+ text_threshold=TEXT_THRESHOLD,
95
+ )
96
+
97
+ # process the box prompt for SAM 2
98
+ h, w, _ = image_source.shape
99
+ boxes = boxes * torch.Tensor([w, h, w, h])
100
+ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
101
+ confidences = confidences.numpy().tolist()
102
+ class_names = labels
103
+
104
+ print(input_boxes)
105
+
106
+ # prompt SAM image predictor to get the mask for the object
107
+ image_predictor.set_image(image_source)
108
+
109
+ # process the detection results
110
+ OBJECTS = class_names
111
+
112
+ print(OBJECTS)
113
+
114
+ # FIXME: figure how does this influence the G-DINO model
115
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
116
+
117
+ if torch.cuda.get_device_properties(0).major >= 8:
118
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
119
+ torch.backends.cuda.matmul.allow_tf32 = True
120
+ torch.backends.cudnn.allow_tf32 = True
121
+
122
+ # prompt SAM 2 image predictor to get the mask for the object
123
+ masks, scores, logits = image_predictor.predict(
124
+ point_coords=None,
125
+ point_labels=None,
126
+ box=input_boxes,
127
+ multimask_output=False,
128
+ )
129
+ # convert the mask shape to (n, H, W)
130
+ if masks.ndim == 4:
131
+ masks = masks.squeeze(1)
132
+
133
+ """
134
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
135
+ """
136
+
137
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
138
+
139
+ # If you are using point prompts, we uniformly sample positive points based on the mask
140
+ if PROMPT_TYPE_FOR_VIDEO == "point":
141
+ # sample the positive points from mask for each objects
142
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
143
+
144
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
145
+ labels = np.ones((points.shape[0]), dtype=np.int32)
146
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
147
+ inference_state=inference_state,
148
+ frame_idx=ann_frame_idx,
149
+ obj_id=object_id,
150
+ points=points,
151
+ labels=labels,
152
+ )
153
+ # Using box prompt
154
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
155
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
156
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
157
+ inference_state=inference_state,
158
+ frame_idx=ann_frame_idx,
159
+ obj_id=object_id,
160
+ box=box,
161
+ )
162
+ # Using mask prompt is a more straightforward way
163
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
164
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
165
+ labels = np.ones((1), dtype=np.int32)
166
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
167
+ inference_state=inference_state,
168
+ frame_idx=ann_frame_idx,
169
+ obj_id=object_id,
170
+ mask=mask
171
+ )
172
+ else:
173
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
174
+
175
+
176
+ """
177
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
178
+ """
179
+ video_segments = {} # video_segments contains the per-frame segmentation results
180
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
181
+ video_segments[out_frame_idx] = {
182
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
183
+ for i, out_obj_id in enumerate(out_obj_ids)
184
+ }
185
+
186
+ """
187
+ Step 5: Visualize the segment results across the video and save them
188
+ """
189
+
190
+ if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
191
+ os.makedirs(SAVE_TRACKING_RESULTS_DIR)
192
+
193
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
194
+
195
+ for frame_idx, segments in video_segments.items():
196
+ img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
197
+
198
+ object_ids = list(segments.keys())
199
+ masks = list(segments.values())
200
+ masks = np.concatenate(masks, axis=0)
201
+
202
+ detections = sv.Detections(
203
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
204
+ mask=masks, # (n, h, w)
205
+ class_id=np.array(object_ids, dtype=np.int32),
206
+ )
207
+ box_annotator = sv.BoxAnnotator()
208
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
209
+ label_annotator = sv.LabelAnnotator()
210
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
211
+ mask_annotator = sv.MaskAnnotator()
212
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
213
+ cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
214
+
215
+
216
+ """
217
+ Step 6: Convert the annotated frames to video
218
+ """
219
+
220
+ create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk import DetectionTask
5
+ from dds_cloudapi_sdk import TextPrompt
6
+ from dds_cloudapi_sdk import DetectionModel
7
+ from dds_cloudapi_sdk import DetectionTarget
8
+
9
+ import os
10
+ import cv2
11
+ import torch
12
+ import numpy as np
13
+ import supervision as sv
14
+
15
+ from pathlib import Path
16
+ from tqdm import tqdm
17
+ from PIL import Image
18
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
19
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
20
+ from utils.track_utils import sample_points_from_masks
21
+ from utils.video_utils import create_video_from_images
22
+
23
+ """
24
+ Hyperparam for Ground and Tracking
25
+ """
26
+ VIDEO_PATH = "./assets/hippopotamus.mp4"
27
+ TEXT_PROMPT = "hippopotamus."
28
+ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
29
+ SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
30
+ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
31
+ API_TOKEN_FOR_GD1_5 = "Your API token"
32
+ PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
33
+ BOX_THRESHOLD = 0.2
34
+
35
+ """
36
+ Step 1: Environment settings and model initialization for SAM 2
37
+ """
38
+ # use bfloat16 for the entire notebook
39
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
40
+
41
+ if torch.cuda.get_device_properties(0).major >= 8:
42
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
43
+ torch.backends.cuda.matmul.allow_tf32 = True
44
+ torch.backends.cudnn.allow_tf32 = True
45
+
46
+ # init sam image predictor and video predictor model
47
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
48
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
49
+
50
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
51
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
52
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
53
+
54
+
55
+ # # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
56
+ # video_dir = "notebooks/videos/bedroom"
57
+
58
+ """
59
+ Custom video input directly using video files
60
+ """
61
+ video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
62
+ print(video_info)
63
+ frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
64
+
65
+ # saving video to frames
66
+ source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
67
+ source_frames.mkdir(parents=True, exist_ok=True)
68
+
69
+ with sv.ImageSink(
70
+ target_dir_path=source_frames,
71
+ overwrite=True,
72
+ image_name_pattern="{:05d}.jpg"
73
+ ) as sink:
74
+ for frame in tqdm(frame_generator, desc="Saving Video Frames"):
75
+ sink.save_image(frame)
76
+
77
+ # scan all the JPEG frame names in this directory
78
+ frame_names = [
79
+ p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
80
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
81
+ ]
82
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
83
+
84
+ # init video predictor state
85
+ inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
86
+
87
+ ann_frame_idx = 0 # the frame index we interact with
88
+ """
89
+ Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
90
+ """
91
+
92
+ # prompt grounding dino to get the box coordinates on specific frame
93
+ img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
94
+ image = Image.open(img_path)
95
+
96
+ # Step 1: initialize the config
97
+ config = Config(API_TOKEN_FOR_GD1_5)
98
+
99
+ # Step 2: initialize the client
100
+ client = Client(config)
101
+
102
+ # Step 3: run the task by DetectionTask class
103
+ # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
104
+ # if you are processing local image file, upload them to DDS server to get the image url
105
+ image_url = client.upload_file(img_path)
106
+
107
+ task = DetectionTask(
108
+ image_url=image_url,
109
+ prompts=[TextPrompt(text=TEXT_PROMPT)],
110
+ targets=[DetectionTarget.BBox], # detect bbox
111
+ model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
112
+ bbox_threshold=BOX_THRESHOLD,
113
+ )
114
+
115
+ client.run_task(task)
116
+ result = task.result
117
+
118
+ objects = result.objects # the list of detected objects
119
+
120
+
121
+ input_boxes = []
122
+ confidences = []
123
+ class_names = []
124
+
125
+ for idx, obj in enumerate(objects):
126
+ input_boxes.append(obj.bbox)
127
+ confidences.append(obj.score)
128
+ class_names.append(obj.category)
129
+
130
+ input_boxes = np.array(input_boxes)
131
+
132
+ print(input_boxes)
133
+
134
+ # prompt SAM image predictor to get the mask for the object
135
+ image_predictor.set_image(np.array(image.convert("RGB")))
136
+
137
+ # process the detection results
138
+ OBJECTS = class_names
139
+
140
+ print(OBJECTS)
141
+
142
+ # prompt SAM 2 image predictor to get the mask for the object
143
+ masks, scores, logits = image_predictor.predict(
144
+ point_coords=None,
145
+ point_labels=None,
146
+ box=input_boxes,
147
+ multimask_output=False,
148
+ )
149
+ # convert the mask shape to (n, H, W)
150
+ if masks.ndim == 4:
151
+ masks = masks.squeeze(1)
152
+
153
+ """
154
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
155
+ """
156
+
157
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
158
+
159
+ # If you are using point prompts, we uniformly sample positive points based on the mask
160
+ if PROMPT_TYPE_FOR_VIDEO == "point":
161
+ # sample the positive points from mask for each objects
162
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
163
+
164
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
165
+ labels = np.ones((points.shape[0]), dtype=np.int32)
166
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
167
+ inference_state=inference_state,
168
+ frame_idx=ann_frame_idx,
169
+ obj_id=object_id,
170
+ points=points,
171
+ labels=labels,
172
+ )
173
+ # Using box prompt
174
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
175
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
176
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
177
+ inference_state=inference_state,
178
+ frame_idx=ann_frame_idx,
179
+ obj_id=object_id,
180
+ box=box,
181
+ )
182
+ # Using mask prompt is a more straightforward way
183
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
184
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
185
+ labels = np.ones((1), dtype=np.int32)
186
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
187
+ inference_state=inference_state,
188
+ frame_idx=ann_frame_idx,
189
+ obj_id=object_id,
190
+ mask=mask
191
+ )
192
+ else:
193
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
194
+
195
+ """
196
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
197
+ """
198
+ video_segments = {} # video_segments contains the per-frame segmentation results
199
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
200
+ video_segments[out_frame_idx] = {
201
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
202
+ for i, out_obj_id in enumerate(out_obj_ids)
203
+ }
204
+
205
+ """
206
+ Step 5: Visualize the segment results across the video and save them
207
+ """
208
+
209
+ if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
210
+ os.makedirs(SAVE_TRACKING_RESULTS_DIR)
211
+
212
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
213
+
214
+ for frame_idx, segments in video_segments.items():
215
+ img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
216
+
217
+ object_ids = list(segments.keys())
218
+ masks = list(segments.values())
219
+ masks = np.concatenate(masks, axis=0)
220
+
221
+ detections = sv.Detections(
222
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
223
+ mask=masks, # (n, h, w)
224
+ class_id=np.array(object_ids, dtype=np.int32),
225
+ )
226
+ box_annotator = sv.BoxAnnotator()
227
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
228
+ label_annotator = sv.LabelAnnotator()
229
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
230
+ mask_annotator = sv.MaskAnnotator()
231
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
232
+ cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
233
+
234
+
235
+ """
236
+ Step 6: Convert the annotated frames to video
237
+ """
238
+
239
+ create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
10
+ from utils.track_utils import sample_points_from_masks
11
+ from utils.video_utils import create_video_from_images
12
+ from utils.common_utils import CommonUtils
13
+ from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
14
+ import json
15
+ import copy
16
+
17
+ """
18
+ Step 1: Environment settings and model initialization
19
+ """
20
+ # use bfloat16 for the entire notebook
21
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
22
+
23
+ if torch.cuda.get_device_properties(0).major >= 8:
24
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ # init sam image predictor and video predictor model
29
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
30
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ print("device", device)
33
+
34
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
35
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
36
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
37
+
38
+
39
+ # init grounding dino model from huggingface
40
+ model_id = "IDEA-Research/grounding-dino-tiny"
41
+ processor = AutoProcessor.from_pretrained(model_id)
42
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
43
+
44
+
45
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
46
+ # VERY important: text queries need to be lowercased + end with a dot
47
+ text = "car."
48
+
49
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
50
+ video_dir = "notebooks/videos/car"
51
+ # 'output_dir' is the directory to save the annotated frames
52
+ output_dir = "./outputs"
53
+ # 'output_video_path' is the path to save the final video
54
+ output_video_path = "./outputs/output.mp4"
55
+ # create the output directory
56
+ CommonUtils.creat_dirs(output_dir)
57
+ mask_data_dir = os.path.join(output_dir, "mask_data")
58
+ json_data_dir = os.path.join(output_dir, "json_data")
59
+ result_dir = os.path.join(output_dir, "result")
60
+ CommonUtils.creat_dirs(mask_data_dir)
61
+ CommonUtils.creat_dirs(json_data_dir)
62
+ # scan all the JPEG frame names in this directory
63
+ frame_names = [
64
+ p for p in os.listdir(video_dir)
65
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
66
+ ]
67
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
68
+
69
+ # init video predictor state
70
+ inference_state = video_predictor.init_state(video_path=video_dir, offload_video_to_cpu=True, async_loading_frames=True)
71
+ step = 20 # the step to sample frames for Grounding DINO predictor
72
+
73
+ sam2_masks = MaskDictionaryModel()
74
+ PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
75
+ objects_count = 0
76
+
77
+ """
78
+ Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
79
+ """
80
+ print("Total frames:", len(frame_names))
81
+ for start_frame_idx in range(0, len(frame_names), step):
82
+ # prompt grounding dino to get the box coordinates on specific frame
83
+ print("start_frame_idx", start_frame_idx)
84
+ # continue
85
+ img_path = os.path.join(video_dir, frame_names[start_frame_idx])
86
+ image = Image.open(img_path)
87
+ image_base_name = frame_names[start_frame_idx].split(".")[0]
88
+ mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
89
+
90
+ # run Grounding DINO on the image
91
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device)
92
+ with torch.no_grad():
93
+ outputs = grounding_model(**inputs)
94
+
95
+ results = processor.post_process_grounded_object_detection(
96
+ outputs,
97
+ inputs.input_ids,
98
+ box_threshold=0.25,
99
+ text_threshold=0.25,
100
+ target_sizes=[image.size[::-1]]
101
+ )
102
+
103
+ # prompt SAM image predictor to get the mask for the object
104
+ image_predictor.set_image(np.array(image.convert("RGB")))
105
+
106
+ # process the detection results
107
+ input_boxes = results[0]["boxes"] # .cpu().numpy()
108
+ # print("results[0]",results[0])
109
+ OBJECTS = results[0]["labels"]
110
+ if input_boxes.shape[0] != 0:
111
+ # prompt SAM 2 image predictor to get the mask for the object
112
+ masks, scores, logits = image_predictor.predict(
113
+ point_coords=None,
114
+ point_labels=None,
115
+ box=input_boxes,
116
+ multimask_output=False,
117
+ )
118
+ # convert the mask shape to (n, H, W)
119
+ if masks.ndim == 2:
120
+ masks = masks[None]
121
+ scores = scores[None]
122
+ logits = logits[None]
123
+ elif masks.ndim == 4:
124
+ masks = masks.squeeze(1)
125
+
126
+ """
127
+ Step 3: Register each object's positive points to video predictor
128
+ """
129
+
130
+ # If you are using point prompts, we uniformly sample positive points based on the mask
131
+ if mask_dict.promote_type == "mask":
132
+ mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
133
+ else:
134
+ raise NotImplementedError("SAM 2 video predictor only support mask prompts")
135
+
136
+
137
+ """
138
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
139
+ """
140
+ objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
141
+ print("objects_count", objects_count)
142
+ else:
143
+ print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
144
+ mask_dict = sam2_masks
145
+
146
+
147
+ if len(mask_dict.labels) == 0:
148
+ mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
149
+ print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
150
+ continue
151
+ else:
152
+ video_predictor.reset_state(inference_state)
153
+
154
+ for object_id, object_info in mask_dict.labels.items():
155
+ frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
156
+ inference_state,
157
+ start_frame_idx,
158
+ object_id,
159
+ object_info.mask,
160
+ )
161
+
162
+ video_segments = {} # output the following {step} frames tracking masks
163
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
164
+ frame_masks = MaskDictionaryModel()
165
+
166
+ for i, out_obj_id in enumerate(out_obj_ids):
167
+ out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
168
+ object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
169
+ object_info.update_box()
170
+ frame_masks.labels[out_obj_id] = object_info
171
+ image_base_name = frame_names[out_frame_idx].split(".")[0]
172
+ frame_masks.mask_name = f"mask_{image_base_name}.npy"
173
+ frame_masks.mask_height = out_mask.shape[-2]
174
+ frame_masks.mask_width = out_mask.shape[-1]
175
+
176
+ video_segments[out_frame_idx] = frame_masks
177
+ sam2_masks = copy.deepcopy(frame_masks)
178
+
179
+ print("video_segments:", len(video_segments))
180
+ """
181
+ Step 5: save the tracking masks and json files
182
+ """
183
+ for frame_idx, frame_masks_info in video_segments.items():
184
+ mask = frame_masks_info.labels
185
+ mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
186
+ for obj_id, obj_info in mask.items():
187
+ mask_img[obj_info.mask == True] = obj_id
188
+
189
+ mask_img = mask_img.numpy().astype(np.uint16)
190
+ np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
191
+
192
+ json_data = frame_masks_info.to_dict()
193
+ json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
194
+ with open(json_data_path, "w") as f:
195
+ json.dump(json_data, f)
196
+
197
+
198
+ """
199
+ Step 6: Draw the results and save the video
200
+ """
201
+ CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
202
+
203
+ create_video_from_images(result_dir, output_video_path, frame_rate=15)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk import DetectionTask
5
+ from dds_cloudapi_sdk import TextPrompt
6
+ from dds_cloudapi_sdk import DetectionModel
7
+ from dds_cloudapi_sdk import DetectionTarget
8
+
9
+
10
+ import os
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
17
+ from utils.video_utils import create_video_from_images
18
+ from utils.common_utils import CommonUtils
19
+ from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
20
+ import json
21
+ import copy
22
+
23
+ """
24
+ Step 1: Environment settings and model initialization
25
+ """
26
+ # use bfloat16 for the entire notebook
27
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
28
+
29
+ if torch.cuda.get_device_properties(0).major >= 8:
30
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ torch.backends.cudnn.allow_tf32 = True
33
+
34
+ # init sam image predictor and video predictor model
35
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
36
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print("device", device)
39
+
40
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
41
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
42
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
43
+
44
+
45
+ # init grounding dino model from huggingface
46
+ model_id = "IDEA-Research/grounding-dino-tiny"
47
+ processor = AutoProcessor.from_pretrained(model_id)
48
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
49
+
50
+
51
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
52
+ # VERY important: text queries need to be lowercased + end with a dot
53
+ text = "car."
54
+
55
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
56
+ video_dir = "notebooks/videos/car"
57
+ # 'output_dir' is the directory to save the annotated frames
58
+ output_dir = "./outputs"
59
+ # 'output_video_path' is the path to save the final video
60
+ output_video_path = "./outputs/output.mp4"
61
+ # create the output directory
62
+ CommonUtils.creat_dirs(output_dir)
63
+ mask_data_dir = os.path.join(output_dir, "mask_data")
64
+ json_data_dir = os.path.join(output_dir, "json_data")
65
+ result_dir = os.path.join(output_dir, "result")
66
+ CommonUtils.creat_dirs(mask_data_dir)
67
+ CommonUtils.creat_dirs(json_data_dir)
68
+ # scan all the JPEG frame names in this directory
69
+ frame_names = [
70
+ p for p in os.listdir(video_dir)
71
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
72
+ ]
73
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
74
+
75
+ # init video predictor state
76
+ inference_state = video_predictor.init_state(video_path=video_dir)
77
+ step = 10 # the step to sample frames for Grounding DINO predictor
78
+
79
+ sam2_masks = MaskDictionaryModel()
80
+ PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
81
+ objects_count = 0
82
+
83
+ """
84
+ Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
85
+ """
86
+ print("Total frames:", len(frame_names))
87
+ for start_frame_idx in range(0, len(frame_names), step):
88
+ # prompt grounding dino to get the box coordinates on specific frame
89
+ print("start_frame_idx", start_frame_idx)
90
+ # continue
91
+ img_path = os.path.join(video_dir, frame_names[start_frame_idx])
92
+ image = Image.open(img_path)
93
+ image_base_name = frame_names[start_frame_idx].split(".")[0]
94
+ mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
95
+
96
+ # run Grounding DINO 1.5 on the image
97
+
98
+ API_TOKEN_FOR_GD1_5 = "Your API token"
99
+
100
+ config = Config(API_TOKEN_FOR_GD1_5)
101
+ # Step 2: initialize the client
102
+ client = Client(config)
103
+
104
+ image_url = client.upload_file(img_path)
105
+ task = DetectionTask(
106
+ image_url=image_url,
107
+ prompts=[TextPrompt(text=text)],
108
+ targets=[DetectionTarget.BBox], # detect bbox
109
+ model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
110
+ )
111
+ client.run_task(task)
112
+ result = task.result
113
+
114
+ objects = result.objects # the list of detected objects
115
+ input_boxes = []
116
+ confidences = []
117
+ class_names = []
118
+
119
+ for idx, obj in enumerate(objects):
120
+ input_boxes.append(obj.bbox)
121
+ confidences.append(obj.score)
122
+ class_names.append(obj.category)
123
+
124
+ input_boxes = np.array(input_boxes)
125
+ OBJECTS = class_names
126
+ if input_boxes.shape[0] != 0:
127
+ # prompt SAM image predictor to get the mask for the object
128
+ image_predictor.set_image(np.array(image.convert("RGB")))
129
+
130
+ # prompt SAM 2 image predictor to get the mask for the object
131
+ masks, scores, logits = image_predictor.predict(
132
+ point_coords=None,
133
+ point_labels=None,
134
+ box=input_boxes,
135
+ multimask_output=False,
136
+ )
137
+ # convert the mask shape to (n, H, W)
138
+ if masks.ndim == 2:
139
+ masks = masks[None]
140
+ scores = scores[None]
141
+ logits = logits[None]
142
+ elif masks.ndim == 4:
143
+ masks = masks.squeeze(1)
144
+
145
+ """
146
+ Step 3: Register each object's positive points to video predictor
147
+ """
148
+
149
+ # If you are using point prompts, we uniformly sample positive points based on the mask
150
+ if mask_dict.promote_type == "mask":
151
+ mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
152
+ else:
153
+ raise NotImplementedError("SAM 2 video predictor only support mask prompts")
154
+
155
+
156
+
157
+ objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
158
+ print("objects_count", objects_count)
159
+
160
+ else:
161
+ print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
162
+ mask_dict = sam2_masks
163
+
164
+
165
+ """
166
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
167
+ """
168
+ if len(mask_dict.labels) == 0:
169
+ mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
170
+ print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
171
+ continue
172
+ else:
173
+ video_predictor.reset_state(inference_state)
174
+
175
+ for object_id, object_info in mask_dict.labels.items():
176
+ frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
177
+ inference_state,
178
+ start_frame_idx,
179
+ object_id,
180
+ object_info.mask,
181
+ )
182
+
183
+ video_segments = {} # output the following {step} frames tracking masks
184
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
185
+ frame_masks = MaskDictionaryModel()
186
+
187
+ for i, out_obj_id in enumerate(out_obj_ids):
188
+ out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
189
+ object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
190
+ object_info.update_box()
191
+ frame_masks.labels[out_obj_id] = object_info
192
+ image_base_name = frame_names[out_frame_idx].split(".")[0]
193
+ frame_masks.mask_name = f"mask_{image_base_name}.npy"
194
+ frame_masks.mask_height = out_mask.shape[-2]
195
+ frame_masks.mask_width = out_mask.shape[-1]
196
+
197
+ video_segments[out_frame_idx] = frame_masks
198
+ sam2_masks = copy.deepcopy(frame_masks)
199
+
200
+ print("video_segments:", len(video_segments))
201
+ """
202
+ Step 5: save the tracking masks and json files
203
+ """
204
+ for frame_idx, frame_masks_info in video_segments.items():
205
+ mask = frame_masks_info.labels
206
+ mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
207
+ for obj_id, obj_info in mask.items():
208
+ mask_img[obj_info.mask == True] = obj_id
209
+
210
+ mask_img = mask_img.numpy().astype(np.uint16)
211
+ np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
212
+
213
+ json_data = frame_masks_info.to_dict()
214
+ json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
215
+ with open(json_data_path, "w") as f:
216
+ json.dump(json_data, f)
217
+
218
+
219
+ """
220
+ Step 6: Draw the results and save the video
221
+ """
222
+ CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
223
+
224
+ create_video_from_images(result_dir, output_video_path, frame_rate=30)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_plus.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import supervision as sv
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
10
+ from utils.track_utils import sample_points_from_masks
11
+ from utils.video_utils import create_video_from_images
12
+ from utils.common_utils import CommonUtils
13
+ from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
14
+ import json
15
+ import copy
16
+
17
+ # This demo shows the continuous object tracking plus reverse tracking with Grounding DINO and SAM 2
18
+ """
19
+ Step 1: Environment settings and model initialization
20
+ """
21
+ # use bfloat16 for the entire notebook
22
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
23
+
24
+ if torch.cuda.get_device_properties(0).major >= 8:
25
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+
29
+ # init sam image predictor and video predictor model
30
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
31
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print("device", device)
34
+
35
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
36
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
37
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
38
+
39
+
40
+ # init grounding dino model from huggingface
41
+ model_id = "IDEA-Research/grounding-dino-tiny"
42
+ processor = AutoProcessor.from_pretrained(model_id)
43
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
44
+
45
+
46
+ # setup the input image and text prompt for SAM 2 and Grounding DINO
47
+ # VERY important: text queries need to be lowercased + end with a dot
48
+ text = "car."
49
+
50
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
51
+ video_dir = "notebooks/videos/car"
52
+ # 'output_dir' is the directory to save the annotated frames
53
+ output_dir = "outputs"
54
+ # 'output_video_path' is the path to save the final video
55
+ output_video_path = "./outputs/output.mp4"
56
+ # create the output directory
57
+ mask_data_dir = os.path.join(output_dir, "mask_data")
58
+ json_data_dir = os.path.join(output_dir, "json_data")
59
+ result_dir = os.path.join(output_dir, "result")
60
+ CommonUtils.creat_dirs(mask_data_dir)
61
+ CommonUtils.creat_dirs(json_data_dir)
62
+ # scan all the JPEG frame names in this directory
63
+ frame_names = [
64
+ p for p in os.listdir(video_dir)
65
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
66
+ ]
67
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
68
+
69
+ # init video predictor state
70
+ inference_state = video_predictor.init_state(video_path=video_dir)
71
+ step = 20 # the step to sample frames for Grounding DINO predictor
72
+
73
+ sam2_masks = MaskDictionaryModel()
74
+ PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
75
+ objects_count = 0
76
+ frame_object_count = {}
77
+ """
78
+ Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
79
+ """
80
+ print("Total frames:", len(frame_names))
81
+ for start_frame_idx in range(0, len(frame_names), step):
82
+ # prompt grounding dino to get the box coordinates on specific frame
83
+ print("start_frame_idx", start_frame_idx)
84
+ # continue
85
+ img_path = os.path.join(video_dir, frame_names[start_frame_idx])
86
+ image = Image.open(img_path).convert("RGB")
87
+ image_base_name = frame_names[start_frame_idx].split(".")[0]
88
+ mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
89
+
90
+ # run Grounding DINO on the image
91
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device)
92
+ with torch.no_grad():
93
+ outputs = grounding_model(**inputs)
94
+
95
+ results = processor.post_process_grounded_object_detection(
96
+ outputs,
97
+ inputs.input_ids,
98
+ box_threshold=0.25,
99
+ text_threshold=0.25,
100
+ target_sizes=[image.size[::-1]]
101
+ )
102
+
103
+ # prompt SAM image predictor to get the mask for the object
104
+ image_predictor.set_image(np.array(image.convert("RGB")))
105
+
106
+ # process the detection results
107
+ input_boxes = results[0]["boxes"] # .cpu().numpy()
108
+ # print("results[0]",results[0])
109
+ OBJECTS = results[0]["labels"]
110
+ if input_boxes.shape[0] != 0:
111
+
112
+ # prompt SAM 2 image predictor to get the mask for the object
113
+ masks, scores, logits = image_predictor.predict(
114
+ point_coords=None,
115
+ point_labels=None,
116
+ box=input_boxes,
117
+ multimask_output=False,
118
+ )
119
+ # convert the mask shape to (n, H, W)
120
+ if masks.ndim == 2:
121
+ masks = masks[None]
122
+ scores = scores[None]
123
+ logits = logits[None]
124
+ elif masks.ndim == 4:
125
+ masks = masks.squeeze(1)
126
+ """
127
+ Step 3: Register each object's positive points to video predictor
128
+ """
129
+
130
+ # If you are using point prompts, we uniformly sample positive points based on the mask
131
+ if mask_dict.promote_type == "mask":
132
+ mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
133
+ else:
134
+ raise NotImplementedError("SAM 2 video predictor only support mask prompts")
135
+ else:
136
+ print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
137
+ mask_dict = sam2_masks
138
+
139
+ """
140
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
141
+ """
142
+ objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
143
+ frame_object_count[start_frame_idx] = objects_count
144
+ print("objects_count", objects_count)
145
+
146
+ if len(mask_dict.labels) == 0:
147
+ mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
148
+ print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
149
+ continue
150
+ else:
151
+ video_predictor.reset_state(inference_state)
152
+
153
+ for object_id, object_info in mask_dict.labels.items():
154
+ frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
155
+ inference_state,
156
+ start_frame_idx,
157
+ object_id,
158
+ object_info.mask,
159
+ )
160
+
161
+ video_segments = {} # output the following {step} frames tracking masks
162
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
163
+ frame_masks = MaskDictionaryModel()
164
+
165
+ for i, out_obj_id in enumerate(out_obj_ids):
166
+ out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
167
+ object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))
168
+ object_info.update_box()
169
+ frame_masks.labels[out_obj_id] = object_info
170
+ image_base_name = frame_names[out_frame_idx].split(".")[0]
171
+ frame_masks.mask_name = f"mask_{image_base_name}.npy"
172
+ frame_masks.mask_height = out_mask.shape[-2]
173
+ frame_masks.mask_width = out_mask.shape[-1]
174
+
175
+ video_segments[out_frame_idx] = frame_masks
176
+ sam2_masks = copy.deepcopy(frame_masks)
177
+
178
+ print("video_segments:", len(video_segments))
179
+ """
180
+ Step 5: save the tracking masks and json files
181
+ """
182
+ for frame_idx, frame_masks_info in video_segments.items():
183
+ mask = frame_masks_info.labels
184
+ mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
185
+ for obj_id, obj_info in mask.items():
186
+ mask_img[obj_info.mask == True] = obj_id
187
+
188
+ mask_img = mask_img.numpy().astype(np.uint16)
189
+ np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
190
+
191
+ json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
192
+ frame_masks_info.to_json(json_data_path)
193
+
194
+
195
+ CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
196
+
197
+ print("try reverse tracking")
198
+ start_object_id = 0
199
+ object_info_dict = {}
200
+ for frame_idx, current_object_count in frame_object_count.items():
201
+ print("reverse tracking frame", frame_idx, frame_names[frame_idx])
202
+ if frame_idx != 0:
203
+ video_predictor.reset_state(inference_state)
204
+ image_base_name = frame_names[frame_idx].split(".")[0]
205
+ json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
206
+ json_data = MaskDictionaryModel().from_json(json_data_path)
207
+ mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
208
+ mask_array = np.load(mask_data_path)
209
+ for object_id in range(start_object_id+1, current_object_count+1):
210
+ print("reverse tracking object", object_id)
211
+ object_info_dict[object_id] = json_data.labels[object_id]
212
+ video_predictor.add_new_mask(inference_state, frame_idx, object_id, mask_array == object_id)
213
+ start_object_id = current_object_count
214
+
215
+
216
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step*2, start_frame_idx=frame_idx, reverse=True):
217
+ image_base_name = frame_names[out_frame_idx].split(".")[0]
218
+ json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
219
+ json_data = MaskDictionaryModel().from_json(json_data_path)
220
+ mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
221
+ mask_array = np.load(mask_data_path)
222
+ # merge the reverse tracking masks with the original masks
223
+ for i, out_obj_id in enumerate(out_obj_ids):
224
+ out_mask = (out_mask_logits[i] > 0.0).cpu()
225
+ if out_mask.sum() == 0:
226
+ print("no mask for object", out_obj_id, "at frame", out_frame_idx)
227
+ continue
228
+ object_info = object_info_dict[out_obj_id]
229
+ object_info.mask = out_mask[0]
230
+ object_info.update_box()
231
+ json_data.labels[out_obj_id] = object_info
232
+ mask_array = np.where(mask_array != out_obj_id, mask_array, 0)
233
+ mask_array[object_info.mask] = out_obj_id
234
+
235
+ np.save(mask_data_path, mask_array)
236
+ json_data.to_json(json_data_path)
237
+
238
+
239
+
240
+
241
+
242
+ """
243
+ Step 6: Draw the results and save the video
244
+ """
245
+ CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir+"_reverse")
246
+
247
+ create_video_from_images(result_dir, output_video_path, frame_rate=15)
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_gd1.5.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dds cloudapi for Grounding DINO 1.5
2
+ from dds_cloudapi_sdk import Config
3
+ from dds_cloudapi_sdk import Client
4
+ from dds_cloudapi_sdk import DetectionTask
5
+ from dds_cloudapi_sdk import TextPrompt
6
+ from dds_cloudapi_sdk import DetectionModel
7
+ from dds_cloudapi_sdk import DetectionTarget
8
+
9
+ import os
10
+ import cv2
11
+ import torch
12
+ import numpy as np
13
+ import supervision as sv
14
+ from PIL import Image
15
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
16
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
17
+ from utils.track_utils import sample_points_from_masks
18
+ from utils.video_utils import create_video_from_images
19
+
20
+
21
+ """
22
+ Step 1: Environment settings and model initialization for SAM 2
23
+ """
24
+ # use bfloat16 for the entire notebook
25
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
26
+
27
+ if torch.cuda.get_device_properties(0).major >= 8:
28
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
29
+ torch.backends.cuda.matmul.allow_tf32 = True
30
+ torch.backends.cudnn.allow_tf32 = True
31
+
32
+ # init sam image predictor and video predictor model
33
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
34
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
35
+
36
+ video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
37
+ sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
38
+ image_predictor = SAM2ImagePredictor(sam2_image_model)
39
+
40
+
41
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
42
+ video_dir = "notebooks/videos/bedroom"
43
+
44
+ # scan all the JPEG frame names in this directory
45
+ frame_names = [
46
+ p for p in os.listdir(video_dir)
47
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
48
+ ]
49
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
50
+
51
+ # init video predictor state
52
+ inference_state = video_predictor.init_state(video_path=video_dir)
53
+
54
+ ann_frame_idx = 0 # the frame index we interact with
55
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
56
+
57
+
58
+ """
59
+ Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
60
+ """
61
+
62
+ # prompt grounding dino to get the box coordinates on specific frame
63
+ img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
64
+ image = Image.open(img_path)
65
+
66
+ # Step 1: initialize the config
67
+ token = "Your API token"
68
+ config = Config(token)
69
+
70
+ # Step 2: initialize the client
71
+ client = Client(config)
72
+
73
+ # Step 3: run the task by DetectionTask class
74
+ # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
75
+ # if you are processing local image file, upload them to DDS server to get the image url
76
+ image_url = client.upload_file(img_path)
77
+
78
+ task = DetectionTask(
79
+ image_url=image_url,
80
+ prompts=[TextPrompt(text="children. pillow")],
81
+ targets=[DetectionTarget.BBox], # detect bbox
82
+ model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
83
+ bbox_threshold=0.2,
84
+ )
85
+
86
+ client.run_task(task)
87
+ result = task.result
88
+
89
+ objects = result.objects # the list of detected objects
90
+
91
+
92
+ input_boxes = []
93
+ confidences = []
94
+ class_names = []
95
+
96
+ for idx, obj in enumerate(objects):
97
+ input_boxes.append(obj.bbox)
98
+ confidences.append(obj.score)
99
+ class_names.append(obj.category)
100
+
101
+ input_boxes = np.array(input_boxes)
102
+
103
+ print(input_boxes)
104
+
105
+ # prompt SAM image predictor to get the mask for the object
106
+ image_predictor.set_image(np.array(image.convert("RGB")))
107
+
108
+ # process the detection results
109
+ OBJECTS = class_names
110
+
111
+ print(OBJECTS)
112
+
113
+ # prompt SAM 2 image predictor to get the mask for the object
114
+ masks, scores, logits = image_predictor.predict(
115
+ point_coords=None,
116
+ point_labels=None,
117
+ box=input_boxes,
118
+ multimask_output=False,
119
+ )
120
+
121
+ # convert the mask shape to (n, H, W)
122
+ if masks.ndim == 3:
123
+ masks = masks[None]
124
+ scores = scores[None]
125
+ logits = logits[None]
126
+ elif masks.ndim == 4:
127
+ masks = masks.squeeze(1)
128
+
129
+ """
130
+ Step 3: Register each object's positive points to video predictor with seperate add_new_points call
131
+ """
132
+
133
+ PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
134
+
135
+ assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
136
+
137
+ # If you are using point prompts, we uniformly sample positive points based on the mask
138
+ if PROMPT_TYPE_FOR_VIDEO == "point":
139
+ # sample the positive points from mask for each objects
140
+ all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
141
+
142
+ for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
143
+ labels = np.ones((points.shape[0]), dtype=np.int32)
144
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
145
+ inference_state=inference_state,
146
+ frame_idx=ann_frame_idx,
147
+ obj_id=object_id,
148
+ points=points,
149
+ labels=labels,
150
+ )
151
+ # Using box prompt
152
+ elif PROMPT_TYPE_FOR_VIDEO == "box":
153
+ for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
154
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
155
+ inference_state=inference_state,
156
+ frame_idx=ann_frame_idx,
157
+ obj_id=object_id,
158
+ box=box,
159
+ )
160
+ # Using mask prompt is a more straightforward way
161
+ elif PROMPT_TYPE_FOR_VIDEO == "mask":
162
+ for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
163
+ labels = np.ones((1), dtype=np.int32)
164
+ _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
165
+ inference_state=inference_state,
166
+ frame_idx=ann_frame_idx,
167
+ obj_id=object_id,
168
+ mask=mask
169
+ )
170
+ else:
171
+ raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
172
+
173
+
174
+
175
+ """
176
+ Step 4: Propagate the video predictor to get the segmentation results for each frame
177
+ """
178
+ video_segments = {} # video_segments contains the per-frame segmentation results
179
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
180
+ video_segments[out_frame_idx] = {
181
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
182
+ for i, out_obj_id in enumerate(out_obj_ids)
183
+ }
184
+
185
+ """
186
+ Step 5: Visualize the segment results across the video and save them
187
+ """
188
+
189
+ save_dir = "./tracking_results"
190
+
191
+ if not os.path.exists(save_dir):
192
+ os.makedirs(save_dir)
193
+
194
+ ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
195
+ for frame_idx, segments in video_segments.items():
196
+ img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
197
+
198
+ object_ids = list(segments.keys())
199
+ masks = list(segments.values())
200
+ masks = np.concatenate(masks, axis=0)
201
+
202
+ detections = sv.Detections(
203
+ xyxy=sv.mask_to_xyxy(masks), # (n, 4)
204
+ mask=masks, # (n, h, w)
205
+ class_id=np.array(object_ids, dtype=np.int32),
206
+ )
207
+ box_annotator = sv.BoxAnnotator()
208
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
209
+ label_annotator = sv.LabelAnnotator()
210
+ annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
211
+ mask_annotator = sv.MaskAnnotator()
212
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
213
+ cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
214
+
215
+
216
+ """
217
+ Step 6: Convert the annotated frames to video
218
+ """
219
+
220
+ output_video_path = "./children_tracking_demo_video.mp4"
221
+ create_video_from_images(save_dir, output_video_path)
clone-IDEA-Research/Grounded-SAM-2/pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "torch>=2.3.1",
5
+ ]
6
+ build-backend = "setuptools.build_meta"
clone-IDEA-Research/Grounded-SAM-2/sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("sam2", version_base="1.2")
clone-IDEA-Research/Grounded-SAM-2/sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ **kwargs,
57
+ ) -> None:
58
+ """
59
+ Using a SAM 2 model, generates masks for the entire image.
60
+ Generates a grid of point prompts over the image, then filters
61
+ low quality and duplicate masks. The default settings are chosen
62
+ for SAM 2 with a HieraL backbone.
63
+
64
+ Arguments:
65
+ model (Sam): The SAM 2 model to use for mask prediction.
66
+ points_per_side (int or None): The number of points to be sampled
67
+ along one side of the image. The total number of points is
68
+ points_per_side**2. If None, 'point_grids' must provide explicit
69
+ point sampling.
70
+ points_per_batch (int): Sets the number of points run simultaneously
71
+ by the model. Higher numbers may be faster but use more GPU memory.
72
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
73
+ model's predicted mask quality.
74
+ stability_score_thresh (float): A filtering threshold in [0,1], using
75
+ the stability of the mask under changes to the cutoff used to binarize
76
+ the model's mask predictions.
77
+ stability_score_offset (float): The amount to shift the cutoff when
78
+ calculated the stability score.
79
+ mask_threshold (float): Threshold for binarizing the mask logits
80
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
81
+ suppression to filter duplicate masks.
82
+ crop_n_layers (int): If >0, mask prediction will be run again on
83
+ crops of the image. Sets the number of layers to run, where each
84
+ layer has 2**i_layer number of image crops.
85
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
86
+ suppression to filter duplicate masks between different crops.
87
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
88
+ In the first crop layer, crops will overlap by this fraction of
89
+ the image length. Later layers with more crops scale down this overlap.
90
+ crop_n_points_downscale_factor (int): The number of points-per-side
91
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
92
+ point_grids (list(np.ndarray) or None): A list over explicit grids
93
+ of points used for sampling, normalized to [0,1]. The nth grid in the
94
+ list is used in the nth crop layer. Exclusive with points_per_side.
95
+ min_mask_region_area (int): If >0, postprocessing will be applied
96
+ to remove disconnected regions and holes in masks with area smaller
97
+ than min_mask_region_area. Requires opencv.
98
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
99
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
100
+ For large resolutions, 'binary_mask' may consume large amounts of
101
+ memory.
102
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
103
+ multimask_output (bool): Whether to output multimask at each point of the grid.
104
+ """
105
+
106
+ assert (points_per_side is None) != (
107
+ point_grids is None
108
+ ), "Exactly one of points_per_side or point_grid must be provided."
109
+ if points_per_side is not None:
110
+ self.point_grids = build_all_layer_point_grids(
111
+ points_per_side,
112
+ crop_n_layers,
113
+ crop_n_points_downscale_factor,
114
+ )
115
+ elif point_grids is not None:
116
+ self.point_grids = point_grids
117
+ else:
118
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
119
+
120
+ assert output_mode in [
121
+ "binary_mask",
122
+ "uncompressed_rle",
123
+ "coco_rle",
124
+ ], f"Unknown output_mode {output_mode}."
125
+ if output_mode == "coco_rle":
126
+ try:
127
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
128
+ except ImportError as e:
129
+ print("Please install pycocotools")
130
+ raise e
131
+
132
+ self.predictor = SAM2ImagePredictor(
133
+ model,
134
+ max_hole_area=min_mask_region_area,
135
+ max_sprinkle_area=min_mask_region_area,
136
+ )
137
+ self.points_per_batch = points_per_batch
138
+ self.pred_iou_thresh = pred_iou_thresh
139
+ self.stability_score_thresh = stability_score_thresh
140
+ self.stability_score_offset = stability_score_offset
141
+ self.mask_threshold = mask_threshold
142
+ self.box_nms_thresh = box_nms_thresh
143
+ self.crop_n_layers = crop_n_layers
144
+ self.crop_nms_thresh = crop_nms_thresh
145
+ self.crop_overlap_ratio = crop_overlap_ratio
146
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
147
+ self.min_mask_region_area = min_mask_region_area
148
+ self.output_mode = output_mode
149
+ self.use_m2m = use_m2m
150
+ self.multimask_output = multimask_output
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
154
+ """
155
+ Load a pretrained model from the Hugging Face hub.
156
+
157
+ Arguments:
158
+ model_id (str): The Hugging Face repository ID.
159
+ **kwargs: Additional arguments to pass to the model constructor.
160
+
161
+ Returns:
162
+ (SAM2AutomaticMaskGenerator): The loaded model.
163
+ """
164
+ from sam2.build_sam import build_sam2_hf
165
+
166
+ sam_model = build_sam2_hf(model_id, **kwargs)
167
+ return cls(sam_model, **kwargs)
168
+
169
+ @torch.no_grad()
170
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
171
+ """
172
+ Generates masks for the given image.
173
+
174
+ Arguments:
175
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
176
+
177
+ Returns:
178
+ list(dict(str, any)): A list over records for masks. Each record is
179
+ a dict containing the following keys:
180
+ segmentation (dict(str, any) or np.ndarray): The mask. If
181
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
182
+ is a dictionary containing the RLE.
183
+ bbox (list(float)): The box around the mask, in XYWH format.
184
+ area (int): The area in pixels of the mask.
185
+ predicted_iou (float): The model's own prediction of the mask's
186
+ quality. This is filtered by the pred_iou_thresh parameter.
187
+ point_coords (list(list(float))): The point coordinates input
188
+ to the model to generate this mask.
189
+ stability_score (float): A measure of the mask's quality. This
190
+ is filtered on using the stability_score_thresh parameter.
191
+ crop_box (list(float)): The crop of the image used to generate
192
+ the mask, given in XYWH format.
193
+ """
194
+
195
+ # Generate masks
196
+ mask_data = self._generate_masks(image)
197
+
198
+ # Encode masks
199
+ if self.output_mode == "coco_rle":
200
+ mask_data["segmentations"] = [
201
+ coco_encode_rle(rle) for rle in mask_data["rles"]
202
+ ]
203
+ elif self.output_mode == "binary_mask":
204
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
205
+ else:
206
+ mask_data["segmentations"] = mask_data["rles"]
207
+
208
+ # Write mask records
209
+ curr_anns = []
210
+ for idx in range(len(mask_data["segmentations"])):
211
+ ann = {
212
+ "segmentation": mask_data["segmentations"][idx],
213
+ "area": area_from_rle(mask_data["rles"][idx]),
214
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
215
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
216
+ "point_coords": [mask_data["points"][idx].tolist()],
217
+ "stability_score": mask_data["stability_score"][idx].item(),
218
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
219
+ }
220
+ curr_anns.append(ann)
221
+
222
+ return curr_anns
223
+
224
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
225
+ orig_size = image.shape[:2]
226
+ crop_boxes, layer_idxs = generate_crop_boxes(
227
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
228
+ )
229
+
230
+ # Iterate over image crops
231
+ data = MaskData()
232
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
233
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
234
+ data.cat(crop_data)
235
+
236
+ # Remove duplicate masks between crops
237
+ if len(crop_boxes) > 1:
238
+ # Prefer masks from smaller crops
239
+ scores = 1 / box_area(data["crop_boxes"])
240
+ scores = scores.to(data["boxes"].device)
241
+ keep_by_nms = batched_nms(
242
+ data["boxes"].float(),
243
+ scores,
244
+ torch.zeros_like(data["boxes"][:, 0]), # categories
245
+ iou_threshold=self.crop_nms_thresh,
246
+ )
247
+ data.filter(keep_by_nms)
248
+ data.to_numpy()
249
+ return data
250
+
251
+ def _process_crop(
252
+ self,
253
+ image: np.ndarray,
254
+ crop_box: List[int],
255
+ crop_layer_idx: int,
256
+ orig_size: Tuple[int, ...],
257
+ ) -> MaskData:
258
+ # Crop the image and calculate embeddings
259
+ x0, y0, x1, y1 = crop_box
260
+ cropped_im = image[y0:y1, x0:x1, :]
261
+ cropped_im_size = cropped_im.shape[:2]
262
+ self.predictor.set_image(cropped_im)
263
+
264
+ # Get points for this crop
265
+ points_scale = np.array(cropped_im_size)[None, ::-1]
266
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
267
+
268
+ # Generate masks for this crop in batches
269
+ data = MaskData()
270
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
271
+ batch_data = self._process_batch(
272
+ points, cropped_im_size, crop_box, orig_size, normalize=True
273
+ )
274
+ data.cat(batch_data)
275
+ del batch_data
276
+ self.predictor.reset_predictor()
277
+
278
+ # Remove duplicates within this crop.
279
+ keep_by_nms = batched_nms(
280
+ data["boxes"].float(),
281
+ data["iou_preds"],
282
+ torch.zeros_like(data["boxes"][:, 0]), # categories
283
+ iou_threshold=self.box_nms_thresh,
284
+ )
285
+ data.filter(keep_by_nms)
286
+
287
+ # Return to the original image frame
288
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
289
+ data["points"] = uncrop_points(data["points"], crop_box)
290
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
291
+
292
+ return data
293
+
294
+ def _process_batch(
295
+ self,
296
+ points: np.ndarray,
297
+ im_size: Tuple[int, ...],
298
+ crop_box: List[int],
299
+ orig_size: Tuple[int, ...],
300
+ normalize=False,
301
+ ) -> MaskData:
302
+ orig_h, orig_w = orig_size
303
+
304
+ # Run model on this batch
305
+ points = torch.as_tensor(
306
+ points, dtype=torch.float32, device=self.predictor.device
307
+ )
308
+ in_points = self.predictor._transforms.transform_coords(
309
+ points, normalize=normalize, orig_hw=im_size
310
+ )
311
+ in_labels = torch.ones(
312
+ in_points.shape[0], dtype=torch.int, device=in_points.device
313
+ )
314
+ masks, iou_preds, low_res_masks = self.predictor._predict(
315
+ in_points[:, None, :],
316
+ in_labels[:, None],
317
+ multimask_output=self.multimask_output,
318
+ return_logits=True,
319
+ )
320
+
321
+ # Serialize predictions and store in MaskData
322
+ data = MaskData(
323
+ masks=masks.flatten(0, 1),
324
+ iou_preds=iou_preds.flatten(0, 1),
325
+ points=points.repeat_interleave(masks.shape[1], dim=0),
326
+ low_res_masks=low_res_masks.flatten(0, 1),
327
+ )
328
+ del masks
329
+
330
+ if not self.use_m2m:
331
+ # Filter by predicted IoU
332
+ if self.pred_iou_thresh > 0.0:
333
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
334
+ data.filter(keep_mask)
335
+
336
+ # Calculate and filter by stability score
337
+ data["stability_score"] = calculate_stability_score(
338
+ data["masks"], self.mask_threshold, self.stability_score_offset
339
+ )
340
+ if self.stability_score_thresh > 0.0:
341
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
342
+ data.filter(keep_mask)
343
+ else:
344
+ # One step refinement using previous mask predictions
345
+ in_points = self.predictor._transforms.transform_coords(
346
+ data["points"], normalize=normalize, orig_hw=im_size
347
+ )
348
+ labels = torch.ones(
349
+ in_points.shape[0], dtype=torch.int, device=in_points.device
350
+ )
351
+ masks, ious = self.refine_with_m2m(
352
+ in_points, labels, data["low_res_masks"], self.points_per_batch
353
+ )
354
+ data["masks"] = masks.squeeze(1)
355
+ data["iou_preds"] = ious.squeeze(1)
356
+
357
+ if self.pred_iou_thresh > 0.0:
358
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
359
+ data.filter(keep_mask)
360
+
361
+ data["stability_score"] = calculate_stability_score(
362
+ data["masks"], self.mask_threshold, self.stability_score_offset
363
+ )
364
+ if self.stability_score_thresh > 0.0:
365
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
366
+ data.filter(keep_mask)
367
+
368
+ # Threshold masks and calculate boxes
369
+ data["masks"] = data["masks"] > self.mask_threshold
370
+ data["boxes"] = batched_mask_to_box(data["masks"])
371
+
372
+ # Filter boxes that touch crop boundaries
373
+ keep_mask = ~is_box_near_crop_edge(
374
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
375
+ )
376
+ if not torch.all(keep_mask):
377
+ data.filter(keep_mask)
378
+
379
+ # Compress to RLE
380
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
381
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
382
+ del data["masks"]
383
+
384
+ return data
385
+
386
+ @staticmethod
387
+ def postprocess_small_regions(
388
+ mask_data: MaskData, min_area: int, nms_thresh: float
389
+ ) -> MaskData:
390
+ """
391
+ Removes small disconnected regions and holes in masks, then reruns
392
+ box NMS to remove any new duplicates.
393
+
394
+ Edits mask_data in place.
395
+
396
+ Requires open-cv as a dependency.
397
+ """
398
+ if len(mask_data["rles"]) == 0:
399
+ return mask_data
400
+
401
+ # Filter small disconnected regions and holes
402
+ new_masks = []
403
+ scores = []
404
+ for rle in mask_data["rles"]:
405
+ mask = rle_to_mask(rle)
406
+
407
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
408
+ unchanged = not changed
409
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
410
+ unchanged = unchanged and not changed
411
+
412
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
413
+ # Give score=0 to changed masks and score=1 to unchanged masks
414
+ # so NMS will prefer ones that didn't need postprocessing
415
+ scores.append(float(unchanged))
416
+
417
+ # Recalculate boxes and remove any new duplicates
418
+ masks = torch.cat(new_masks, dim=0)
419
+ boxes = batched_mask_to_box(masks)
420
+ keep_by_nms = batched_nms(
421
+ boxes.float(),
422
+ torch.as_tensor(scores),
423
+ torch.zeros_like(boxes[:, 0]), # categories
424
+ iou_threshold=nms_thresh,
425
+ )
426
+
427
+ # Only recalculate RLEs for masks that have changed
428
+ for i_mask in keep_by_nms:
429
+ if scores[i_mask] == 0.0:
430
+ mask_torch = masks[i_mask].unsqueeze(0)
431
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
432
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
433
+ mask_data.filter(keep_by_nms)
434
+
435
+ return mask_data
436
+
437
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
438
+ new_masks = []
439
+ new_iou_preds = []
440
+
441
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
442
+ points_per_batch, points, point_labels, low_res_masks
443
+ ):
444
+ best_masks, best_iou_preds, _ = self.predictor._predict(
445
+ cur_points[:, None, :],
446
+ cur_point_labels[:, None],
447
+ mask_input=low_res_mask[:, None, :],
448
+ multimask_output=False,
449
+ return_logits=True,
450
+ )
451
+ new_masks.append(best_masks)
452
+ new_iou_preds.append(best_iou_preds)
453
+ masks = torch.cat(new_masks, dim=0)
454
+ return masks, torch.cat(new_iou_preds, dim=0)
clone-IDEA-Research/Grounded-SAM-2/sam2/build_sam.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from hydra import compose
12
+ from hydra.utils import instantiate
13
+ from omegaconf import OmegaConf
14
+
15
+ import sam2
16
+
17
+ # Check if the user is running Python from the parent directory of the sam2 repo
18
+ # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
+ # it could shadow the sam2 package and cause issues.
20
+ if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
+ # This typically happens because the user is running Python from the parent directory
24
+ # that contains the sam2 repo they cloned.
25
+ raise RuntimeError(
26
+ "You're likely running Python from the parent directory of the sam2 repository "
27
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
+ "This is not supported since the `sam2` Python package could be shadowed by the "
29
+ "repository name (the repository is also named `sam2` and contains the Python package "
30
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
32
+ )
33
+
34
+
35
+ HF_MODEL_ID_TO_FILENAMES = {
36
+ "facebook/sam2-hiera-tiny": (
37
+ "configs/sam2/sam2_hiera_t.yaml",
38
+ "sam2_hiera_tiny.pt",
39
+ ),
40
+ "facebook/sam2-hiera-small": (
41
+ "configs/sam2/sam2_hiera_s.yaml",
42
+ "sam2_hiera_small.pt",
43
+ ),
44
+ "facebook/sam2-hiera-base-plus": (
45
+ "configs/sam2/sam2_hiera_b+.yaml",
46
+ "sam2_hiera_base_plus.pt",
47
+ ),
48
+ "facebook/sam2-hiera-large": (
49
+ "configs/sam2/sam2_hiera_l.yaml",
50
+ "sam2_hiera_large.pt",
51
+ ),
52
+ "facebook/sam2.1-hiera-tiny": (
53
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
54
+ "sam2.1_hiera_tiny.pt",
55
+ ),
56
+ "facebook/sam2.1-hiera-small": (
57
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
58
+ "sam2.1_hiera_small.pt",
59
+ ),
60
+ "facebook/sam2.1-hiera-base-plus": (
61
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
62
+ "sam2.1_hiera_base_plus.pt",
63
+ ),
64
+ "facebook/sam2.1-hiera-large": (
65
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
66
+ "sam2.1_hiera_large.pt",
67
+ ),
68
+ }
69
+
70
+
71
+ def build_sam2(
72
+ config_file,
73
+ ckpt_path=None,
74
+ device="cuda",
75
+ mode="eval",
76
+ hydra_overrides_extra=[],
77
+ apply_postprocessing=True,
78
+ **kwargs,
79
+ ):
80
+
81
+ if apply_postprocessing:
82
+ hydra_overrides_extra = hydra_overrides_extra.copy()
83
+ hydra_overrides_extra += [
84
+ # dynamically fall back to multi-mask if the single mask is not stable
85
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
+ ]
89
+ # Read config and init model
90
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
+ OmegaConf.resolve(cfg)
92
+ model = instantiate(cfg.model, _recursive_=True)
93
+ _load_checkpoint(model, ckpt_path)
94
+ model = model.to(device)
95
+ if mode == "eval":
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def build_sam2_video_predictor(
101
+ config_file,
102
+ ckpt_path=None,
103
+ device="cuda",
104
+ mode="eval",
105
+ hydra_overrides_extra=[],
106
+ apply_postprocessing=True,
107
+ **kwargs,
108
+ ):
109
+ hydra_overrides = [
110
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111
+ ]
112
+ if apply_postprocessing:
113
+ hydra_overrides_extra = hydra_overrides_extra.copy()
114
+ hydra_overrides_extra += [
115
+ # dynamically fall back to multi-mask if the single mask is not stable
116
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
117
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
118
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
119
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
120
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
121
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
122
+ "++model.fill_hole_area=8",
123
+ ]
124
+ hydra_overrides.extend(hydra_overrides_extra)
125
+
126
+ # Read config and init model
127
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
128
+ OmegaConf.resolve(cfg)
129
+ model = instantiate(cfg.model, _recursive_=True)
130
+ _load_checkpoint(model, ckpt_path)
131
+ model = model.to(device)
132
+ if mode == "eval":
133
+ model.eval()
134
+ return model
135
+
136
+
137
+ def _hf_download(model_id):
138
+ from huggingface_hub import hf_hub_download
139
+
140
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
141
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
142
+ return config_name, ckpt_path
143
+
144
+
145
+ def build_sam2_hf(model_id, **kwargs):
146
+ config_name, ckpt_path = _hf_download(model_id)
147
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
148
+
149
+
150
+ def build_sam2_video_predictor_hf(model_id, **kwargs):
151
+ config_name, ckpt_path = _hf_download(model_id)
152
+ return build_sam2_video_predictor(
153
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
154
+ )
155
+
156
+
157
+ def _load_checkpoint(model, ckpt_path):
158
+ if ckpt_path is not None:
159
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
160
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
161
+ if missing_keys:
162
+ logging.error(missing_keys)
163
+ raise RuntimeError()
164
+ if unexpected_keys:
165
+ logging.error(unexpected_keys)
166
+ raise RuntimeError()
167
+ logging.info("Loaded checkpoint sucessfully")
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_image_predictor.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL.Image import Image
14
+
15
+ from sam2.modeling.sam2_base import SAM2Base
16
+
17
+ from sam2.utils.transforms import SAM2Transforms
18
+
19
+
20
+ class SAM2ImagePredictor:
21
+ def __init__(
22
+ self,
23
+ sam_model: SAM2Base,
24
+ mask_threshold=0.0,
25
+ max_hole_area=0.0,
26
+ max_sprinkle_area=0.0,
27
+ **kwargs,
28
+ ) -> None:
29
+ """
30
+ Uses SAM-2 to calculate the image embedding for an image, and then
31
+ allow repeated, efficient mask prediction given prompts.
32
+
33
+ Arguments:
34
+ sam_model (Sam-2): The model to use for mask prediction.
35
+ mask_threshold (float): The threshold to use when converting mask logits
36
+ to binary masks. Masks are thresholded at 0 by default.
37
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
38
+ the maximum area of max_hole_area in low_res_masks.
39
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
40
+ the maximum area of max_sprinkle_area in low_res_masks.
41
+ """
42
+ super().__init__()
43
+ self.model = sam_model
44
+ self._transforms = SAM2Transforms(
45
+ resolution=self.model.image_size,
46
+ mask_threshold=mask_threshold,
47
+ max_hole_area=max_hole_area,
48
+ max_sprinkle_area=max_sprinkle_area,
49
+ )
50
+
51
+ # Predictor state
52
+ self._is_image_set = False
53
+ self._features = None
54
+ self._orig_hw = None
55
+ # Whether the predictor is set for single image or a batch of images
56
+ self._is_batch = False
57
+
58
+ # Predictor config
59
+ self.mask_threshold = mask_threshold
60
+
61
+ # Spatial dim for backbone feature maps
62
+ self._bb_feat_sizes = [
63
+ (256, 256),
64
+ (128, 128),
65
+ (64, 64),
66
+ ]
67
+
68
+ @classmethod
69
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
70
+ """
71
+ Load a pretrained model from the Hugging Face hub.
72
+
73
+ Arguments:
74
+ model_id (str): The Hugging Face repository ID.
75
+ **kwargs: Additional arguments to pass to the model constructor.
76
+
77
+ Returns:
78
+ (SAM2ImagePredictor): The loaded model.
79
+ """
80
+ from sam2.build_sam import build_sam2_hf
81
+
82
+ sam_model = build_sam2_hf(model_id, **kwargs)
83
+ return cls(sam_model, **kwargs)
84
+
85
+ @torch.no_grad()
86
+ def set_image(
87
+ self,
88
+ image: Union[np.ndarray, Image],
89
+ ) -> None:
90
+ """
91
+ Calculates the image embeddings for the provided image, allowing
92
+ masks to be predicted with the 'predict' method.
93
+
94
+ Arguments:
95
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
96
+ with pixel values in [0, 255].
97
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
98
+ """
99
+ self.reset_predictor()
100
+ # Transform the image to the form expected by the model
101
+ if isinstance(image, np.ndarray):
102
+ logging.info("For numpy array image, we assume (HxWxC) format")
103
+ self._orig_hw = [image.shape[:2]]
104
+ elif isinstance(image, Image):
105
+ w, h = image.size
106
+ self._orig_hw = [(h, w)]
107
+ else:
108
+ raise NotImplementedError("Image format not supported")
109
+
110
+ input_image = self._transforms(image)
111
+ input_image = input_image[None, ...].to(self.device)
112
+
113
+ assert (
114
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
115
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
116
+ logging.info("Computing image embeddings for the provided image...")
117
+ backbone_out = self.model.forward_image(input_image)
118
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
119
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
120
+ if self.model.directly_add_no_mem_embed:
121
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
122
+
123
+ feats = [
124
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
125
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
126
+ ][::-1]
127
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
128
+ self._is_image_set = True
129
+ logging.info("Image embeddings computed.")
130
+
131
+ @torch.no_grad()
132
+ def set_image_batch(
133
+ self,
134
+ image_list: List[Union[np.ndarray]],
135
+ ) -> None:
136
+ """
137
+ Calculates the image embeddings for the provided image batch, allowing
138
+ masks to be predicted with the 'predict_batch' method.
139
+
140
+ Arguments:
141
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
142
+ with pixel values in [0, 255].
143
+ """
144
+ self.reset_predictor()
145
+ assert isinstance(image_list, list)
146
+ self._orig_hw = []
147
+ for image in image_list:
148
+ assert isinstance(
149
+ image, np.ndarray
150
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
151
+ self._orig_hw.append(image.shape[:2])
152
+ # Transform the image to the form expected by the model
153
+ img_batch = self._transforms.forward_batch(image_list)
154
+ img_batch = img_batch.to(self.device)
155
+ batch_size = img_batch.shape[0]
156
+ assert (
157
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
158
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
159
+ logging.info("Computing image embeddings for the provided images...")
160
+ backbone_out = self.model.forward_image(img_batch)
161
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
162
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
163
+ if self.model.directly_add_no_mem_embed:
164
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
165
+
166
+ feats = [
167
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
168
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
169
+ ][::-1]
170
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
171
+ self._is_image_set = True
172
+ self._is_batch = True
173
+ logging.info("Image embeddings computed.")
174
+
175
+ def predict_batch(
176
+ self,
177
+ point_coords_batch: List[np.ndarray] = None,
178
+ point_labels_batch: List[np.ndarray] = None,
179
+ box_batch: List[np.ndarray] = None,
180
+ mask_input_batch: List[np.ndarray] = None,
181
+ multimask_output: bool = True,
182
+ return_logits: bool = False,
183
+ normalize_coords=True,
184
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
185
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
186
+ It returns a tuple of lists of masks, ious, and low_res_masks_logits.
187
+ """
188
+ assert self._is_batch, "This function should only be used when in batched mode"
189
+ if not self._is_image_set:
190
+ raise RuntimeError(
191
+ "An image must be set with .set_image_batch(...) before mask prediction."
192
+ )
193
+ num_images = len(self._features["image_embed"])
194
+ all_masks = []
195
+ all_ious = []
196
+ all_low_res_masks = []
197
+ for img_idx in range(num_images):
198
+ # Transform input prompts
199
+ point_coords = (
200
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
201
+ )
202
+ point_labels = (
203
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
204
+ )
205
+ box = box_batch[img_idx] if box_batch is not None else None
206
+ mask_input = (
207
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
208
+ )
209
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
210
+ point_coords,
211
+ point_labels,
212
+ box,
213
+ mask_input,
214
+ normalize_coords,
215
+ img_idx=img_idx,
216
+ )
217
+ masks, iou_predictions, low_res_masks = self._predict(
218
+ unnorm_coords,
219
+ labels,
220
+ unnorm_box,
221
+ mask_input,
222
+ multimask_output,
223
+ return_logits=return_logits,
224
+ img_idx=img_idx,
225
+ )
226
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
227
+ iou_predictions_np = (
228
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
229
+ )
230
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
231
+ all_masks.append(masks_np)
232
+ all_ious.append(iou_predictions_np)
233
+ all_low_res_masks.append(low_res_masks_np)
234
+
235
+ return all_masks, all_ious, all_low_res_masks
236
+
237
+ def predict(
238
+ self,
239
+ point_coords: Optional[np.ndarray] = None,
240
+ point_labels: Optional[np.ndarray] = None,
241
+ box: Optional[np.ndarray] = None,
242
+ mask_input: Optional[np.ndarray] = None,
243
+ multimask_output: bool = True,
244
+ return_logits: bool = False,
245
+ normalize_coords=True,
246
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
247
+ """
248
+ Predict masks for the given input prompts, using the currently set image.
249
+
250
+ Arguments:
251
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
252
+ model. Each point is in (X,Y) in pixels.
253
+ point_labels (np.ndarray or None): A length N array of labels for the
254
+ point prompts. 1 indicates a foreground point and 0 indicates a
255
+ background point.
256
+ box (np.ndarray or None): A length 4 array given a box prompt to the
257
+ model, in XYXY format.
258
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
259
+ coming from a previous prediction iteration. Has form 1xHxW, where
260
+ for SAM, H=W=256.
261
+ multimask_output (bool): If true, the model will return three masks.
262
+ For ambiguous input prompts (such as a single click), this will often
263
+ produce better masks than a single prediction. If only a single
264
+ mask is needed, the model's predicted quality score can be used
265
+ to select the best mask. For non-ambiguous prompts, such as multiple
266
+ input prompts, multimask_output=False can give better results.
267
+ return_logits (bool): If true, returns un-thresholded masks logits
268
+ instead of a binary mask.
269
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
270
+
271
+ Returns:
272
+ (np.ndarray): The output masks in CxHxW format, where C is the
273
+ number of masks, and (H, W) is the original image size.
274
+ (np.ndarray): An array of length C containing the model's
275
+ predictions for the quality of each mask.
276
+ (np.ndarray): An array of shape CxHxW, where C is the number
277
+ of masks and H=W=256. These low resolution logits can be passed to
278
+ a subsequent iteration as mask input.
279
+ """
280
+ if not self._is_image_set:
281
+ raise RuntimeError(
282
+ "An image must be set with .set_image(...) before mask prediction."
283
+ )
284
+
285
+ # Transform input prompts
286
+
287
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
288
+ point_coords, point_labels, box, mask_input, normalize_coords
289
+ )
290
+
291
+ masks, iou_predictions, low_res_masks = self._predict(
292
+ unnorm_coords,
293
+ labels,
294
+ unnorm_box,
295
+ mask_input,
296
+ multimask_output,
297
+ return_logits=return_logits,
298
+ )
299
+
300
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
301
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
302
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
303
+ return masks_np, iou_predictions_np, low_res_masks_np
304
+
305
+ def _prep_prompts(
306
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
307
+ ):
308
+
309
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
310
+ if point_coords is not None:
311
+ assert (
312
+ point_labels is not None
313
+ ), "point_labels must be supplied if point_coords is supplied."
314
+ point_coords = torch.as_tensor(
315
+ point_coords, dtype=torch.float, device=self.device
316
+ )
317
+ unnorm_coords = self._transforms.transform_coords(
318
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
319
+ )
320
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
321
+ if len(unnorm_coords.shape) == 2:
322
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
323
+ if box is not None:
324
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
325
+ unnorm_box = self._transforms.transform_boxes(
326
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
327
+ ) # Bx2x2
328
+ if mask_logits is not None:
329
+ mask_input = torch.as_tensor(
330
+ mask_logits, dtype=torch.float, device=self.device
331
+ )
332
+ if len(mask_input.shape) == 3:
333
+ mask_input = mask_input[None, :, :, :]
334
+ return mask_input, unnorm_coords, labels, unnorm_box
335
+
336
+ @torch.no_grad()
337
+ def _predict(
338
+ self,
339
+ point_coords: Optional[torch.Tensor],
340
+ point_labels: Optional[torch.Tensor],
341
+ boxes: Optional[torch.Tensor] = None,
342
+ mask_input: Optional[torch.Tensor] = None,
343
+ multimask_output: bool = True,
344
+ return_logits: bool = False,
345
+ img_idx: int = -1,
346
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
347
+ """
348
+ Predict masks for the given input prompts, using the currently set image.
349
+ Input prompts are batched torch tensors and are expected to already be
350
+ transformed to the input frame using SAM2Transforms.
351
+
352
+ Arguments:
353
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
354
+ model. Each point is in (X,Y) in pixels.
355
+ point_labels (torch.Tensor or None): A BxN array of labels for the
356
+ point prompts. 1 indicates a foreground point and 0 indicates a
357
+ background point.
358
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
359
+ model, in XYXY format.
360
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
361
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
362
+ for SAM, H=W=256. Masks returned by a previous iteration of the
363
+ predict method do not need further transformation.
364
+ multimask_output (bool): If true, the model will return three masks.
365
+ For ambiguous input prompts (such as a single click), this will often
366
+ produce better masks than a single prediction. If only a single
367
+ mask is needed, the model's predicted quality score can be used
368
+ to select the best mask. For non-ambiguous prompts, such as multiple
369
+ input prompts, multimask_output=False can give better results.
370
+ return_logits (bool): If true, returns un-thresholded masks logits
371
+ instead of a binary mask.
372
+
373
+ Returns:
374
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
375
+ number of masks, and (H, W) is the original image size.
376
+ (torch.Tensor): An array of shape BxC containing the model's
377
+ predictions for the quality of each mask.
378
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
379
+ of masks and H=W=256. These low res logits can be passed to
380
+ a subsequent iteration as mask input.
381
+ """
382
+ if not self._is_image_set:
383
+ raise RuntimeError(
384
+ "An image must be set with .set_image(...) before mask prediction."
385
+ )
386
+
387
+ if point_coords is not None:
388
+ concat_points = (point_coords, point_labels)
389
+ else:
390
+ concat_points = None
391
+
392
+ # Embed prompts
393
+ if boxes is not None:
394
+ box_coords = boxes.reshape(-1, 2, 2)
395
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
396
+ box_labels = box_labels.repeat(boxes.size(0), 1)
397
+ # we merge "boxes" and "points" into a single "concat_points" input (where
398
+ # boxes are added at the beginning) to sam_prompt_encoder
399
+ if concat_points is not None:
400
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
401
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
402
+ concat_points = (concat_coords, concat_labels)
403
+ else:
404
+ concat_points = (box_coords, box_labels)
405
+
406
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
407
+ points=concat_points,
408
+ boxes=None,
409
+ masks=mask_input,
410
+ )
411
+
412
+ # Predict masks
413
+ batched_mode = (
414
+ concat_points is not None and concat_points[0].shape[0] > 1
415
+ ) # multi object prediction
416
+ high_res_features = [
417
+ feat_level[img_idx].unsqueeze(0)
418
+ for feat_level in self._features["high_res_feats"]
419
+ ]
420
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
421
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
422
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
423
+ sparse_prompt_embeddings=sparse_embeddings,
424
+ dense_prompt_embeddings=dense_embeddings,
425
+ multimask_output=multimask_output,
426
+ repeat_image=batched_mode,
427
+ high_res_features=high_res_features,
428
+ )
429
+
430
+ # Upscale the masks to the original image resolution
431
+ masks = self._transforms.postprocess_masks(
432
+ low_res_masks, self._orig_hw[img_idx]
433
+ )
434
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
435
+ if not return_logits:
436
+ masks = masks > self.mask_threshold
437
+
438
+ return masks, iou_predictions, low_res_masks
439
+
440
+ def get_image_embedding(self) -> torch.Tensor:
441
+ """
442
+ Returns the image embeddings for the currently set image, with
443
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
444
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
445
+ """
446
+ if not self._is_image_set:
447
+ raise RuntimeError(
448
+ "An image must be set with .set_image(...) to generate an embedding."
449
+ )
450
+ assert (
451
+ self._features is not None
452
+ ), "Features must exist if an image has been set."
453
+ return self._features["image_embed"]
454
+
455
+ @property
456
+ def device(self) -> torch.device:
457
+ return self.model.device
458
+
459
+ def reset_predictor(self) -> None:
460
+ """
461
+ Resets the image embeddings and other state variables.
462
+ """
463
+ self._is_image_set = False
464
+ self._features = None
465
+ self._orig_hw = None
466
+ self._is_batch = False
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_video_predictor.py ADDED
@@ -0,0 +1,1172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+
12
+ from tqdm import tqdm
13
+
14
+ from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
15
+ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
16
+
17
+
18
+ class SAM2VideoPredictor(SAM2Base):
19
+ """The predictor class to handle user interactions and manage inference states."""
20
+
21
+ def __init__(
22
+ self,
23
+ fill_hole_area=0,
24
+ # whether to apply non-overlapping constraints on the output object masks
25
+ non_overlap_masks=False,
26
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
27
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
28
+ clear_non_cond_mem_around_input=False,
29
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
30
+ clear_non_cond_mem_for_multi_obj=False,
31
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
32
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
33
+ add_all_frames_to_correct_as_cond=False,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.fill_hole_area = fill_hole_area
38
+ self.non_overlap_masks = non_overlap_masks
39
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
40
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
41
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
42
+
43
+ @torch.inference_mode()
44
+ def init_state(
45
+ self,
46
+ video_path,
47
+ offload_video_to_cpu=False,
48
+ offload_state_to_cpu=False,
49
+ async_loading_frames=False,
50
+ ):
51
+ """Initialize an inference state."""
52
+ compute_device = self.device # device of the model
53
+ images, video_height, video_width = load_video_frames(
54
+ video_path=video_path,
55
+ image_size=self.image_size,
56
+ offload_video_to_cpu=offload_video_to_cpu,
57
+ async_loading_frames=async_loading_frames,
58
+ compute_device=compute_device,
59
+ )
60
+ inference_state = {}
61
+ inference_state["images"] = images
62
+ inference_state["num_frames"] = len(images)
63
+ # whether to offload the video frames to CPU memory
64
+ # turning on this option saves the GPU memory with only a very small overhead
65
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
66
+ # whether to offload the inference state to CPU memory
67
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
68
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
69
+ # and from 24 to 21 when tracking two objects)
70
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
71
+ # the original video height and width, used for resizing final output scores
72
+ inference_state["video_height"] = video_height
73
+ inference_state["video_width"] = video_width
74
+ inference_state["device"] = compute_device
75
+ if offload_state_to_cpu:
76
+ inference_state["storage_device"] = torch.device("cpu")
77
+ else:
78
+ inference_state["storage_device"] = compute_device
79
+ # inputs on each frame
80
+ inference_state["point_inputs_per_obj"] = {}
81
+ inference_state["mask_inputs_per_obj"] = {}
82
+ # visual features on a small number of recently visited frames for quick interactions
83
+ inference_state["cached_features"] = {}
84
+ # values that don't change across frames (so we only need to hold one copy of them)
85
+ inference_state["constants"] = {}
86
+ # mapping between client-side object id and model-side object index
87
+ inference_state["obj_id_to_idx"] = OrderedDict()
88
+ inference_state["obj_idx_to_id"] = OrderedDict()
89
+ inference_state["obj_ids"] = []
90
+ # A storage to hold the model's tracking results and states on each frame
91
+ inference_state["output_dict"] = {
92
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
93
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
94
+ }
95
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
96
+ inference_state["output_dict_per_obj"] = {}
97
+ # A temporary storage to hold new outputs when user interact with a frame
98
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
99
+ inference_state["temp_output_dict_per_obj"] = {}
100
+ # Frames that already holds consolidated outputs from click or mask inputs
101
+ # (we directly use their consolidated outputs during tracking)
102
+ inference_state["consolidated_frame_inds"] = {
103
+ "cond_frame_outputs": set(), # set containing frame indices
104
+ "non_cond_frame_outputs": set(), # set containing frame indices
105
+ }
106
+ # metadata for each tracking frame (e.g. which direction it's tracked)
107
+ inference_state["tracking_has_started"] = False
108
+ inference_state["frames_already_tracked"] = {}
109
+ # Warm up the visual backbone and cache the image feature on frame 0
110
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
+ return inference_state
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
115
+ """
116
+ Load a pretrained model from the Hugging Face hub.
117
+
118
+ Arguments:
119
+ model_id (str): The Hugging Face repository ID.
120
+ **kwargs: Additional arguments to pass to the model constructor.
121
+
122
+ Returns:
123
+ (SAM2VideoPredictor): The loaded model.
124
+ """
125
+ from sam2.build_sam import build_sam2_video_predictor_hf
126
+
127
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
128
+ return sam_model
129
+
130
+ def _obj_id_to_idx(self, inference_state, obj_id):
131
+ """Map client-side object id to model-side object index."""
132
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
133
+ if obj_idx is not None:
134
+ return obj_idx
135
+
136
+ # This is a new object id not sent to the server before. We only allow adding
137
+ # new objects *before* the tracking starts.
138
+ allow_new_object = not inference_state["tracking_has_started"]
139
+ if allow_new_object:
140
+ # get the next object slot
141
+ obj_idx = len(inference_state["obj_id_to_idx"])
142
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
143
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
144
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
145
+ # set up input and output structures for this object
146
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
147
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
148
+ inference_state["output_dict_per_obj"][obj_idx] = {
149
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
150
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
151
+ }
152
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
153
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
154
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
155
+ }
156
+ return obj_idx
157
+ else:
158
+ raise RuntimeError(
159
+ f"Cannot add new object id {obj_id} after tracking starts. "
160
+ f"All existing object ids: {inference_state['obj_ids']}. "
161
+ f"Please call 'reset_state' to restart from scratch."
162
+ )
163
+
164
+ def _obj_idx_to_id(self, inference_state, obj_idx):
165
+ """Map model-side object index to client-side object id."""
166
+ return inference_state["obj_idx_to_id"][obj_idx]
167
+
168
+ def _get_obj_num(self, inference_state):
169
+ """Get the total number of unique object ids received so far in this session."""
170
+ return len(inference_state["obj_idx_to_id"])
171
+
172
+ @torch.inference_mode()
173
+ def add_new_points_or_box(
174
+ self,
175
+ inference_state,
176
+ frame_idx,
177
+ obj_id,
178
+ points=None,
179
+ labels=None,
180
+ clear_old_points=True,
181
+ normalize_coords=True,
182
+ box=None,
183
+ ):
184
+ """Add new points to a frame."""
185
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
186
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
187
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
188
+
189
+ if (points is not None) != (labels is not None):
190
+ raise ValueError("points and labels must be provided together")
191
+ if points is None and box is None:
192
+ raise ValueError("at least one of points or box must be provided as input")
193
+
194
+ if points is None:
195
+ points = torch.zeros(0, 2, dtype=torch.float32)
196
+ elif not isinstance(points, torch.Tensor):
197
+ points = torch.tensor(points, dtype=torch.float32)
198
+ if labels is None:
199
+ labels = torch.zeros(0, dtype=torch.int32)
200
+ elif not isinstance(labels, torch.Tensor):
201
+ labels = torch.tensor(labels, dtype=torch.int32)
202
+ if points.dim() == 2:
203
+ points = points.unsqueeze(0) # add batch dimension
204
+ if labels.dim() == 1:
205
+ labels = labels.unsqueeze(0) # add batch dimension
206
+
207
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
208
+ # along with the user-provided points (consistent with how SAM 2 is trained).
209
+ if box is not None:
210
+ if not clear_old_points:
211
+ raise ValueError(
212
+ "cannot add box without clearing old points, since "
213
+ "box prompt must be provided before any point prompt "
214
+ "(please use clear_old_points=True instead)"
215
+ )
216
+ if inference_state["tracking_has_started"]:
217
+ warnings.warn(
218
+ "You are adding a box after tracking starts. SAM 2 may not always be "
219
+ "able to incorporate a box prompt for *refinement*. If you intend to "
220
+ "use box prompt as an *initial* input before tracking, please call "
221
+ "'reset_state' on the inference state to restart from scratch.",
222
+ category=UserWarning,
223
+ stacklevel=2,
224
+ )
225
+ if not isinstance(box, torch.Tensor):
226
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
227
+ box_coords = box.reshape(1, 2, 2)
228
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
229
+ box_labels = box_labels.reshape(1, 2)
230
+ points = torch.cat([box_coords, points], dim=1)
231
+ labels = torch.cat([box_labels, labels], dim=1)
232
+
233
+ if normalize_coords:
234
+ video_H = inference_state["video_height"]
235
+ video_W = inference_state["video_width"]
236
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
237
+ # scale the (normalized) coordinates by the model's internal image size
238
+ points = points * self.image_size
239
+ points = points.to(inference_state["device"])
240
+ labels = labels.to(inference_state["device"])
241
+
242
+ if not clear_old_points:
243
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
244
+ else:
245
+ point_inputs = None
246
+ point_inputs = concat_points(point_inputs, points, labels)
247
+
248
+ point_inputs_per_frame[frame_idx] = point_inputs
249
+ mask_inputs_per_frame.pop(frame_idx, None)
250
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
251
+ # frame, meaning that the inputs points are to generate segments on this frame without
252
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
253
+ # the input points will be used to correct the already tracked masks.
254
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
255
+ # whether to track in reverse time order
256
+ if is_init_cond_frame:
257
+ reverse = False
258
+ else:
259
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
260
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
261
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
262
+ # Add a frame to conditioning output if it's an initial conditioning frame or
263
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
264
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
265
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
266
+
267
+ # Get any previously predicted mask logits on this object and feed it along with
268
+ # the new clicks into the SAM mask decoder.
269
+ prev_sam_mask_logits = None
270
+ # lookup temporary output dict first, which contains the most recent output
271
+ # (if not found, then lookup conditioning and non-conditioning frame output)
272
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
273
+ if prev_out is None:
274
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
275
+ if prev_out is None:
276
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
277
+
278
+ if prev_out is not None and prev_out["pred_masks"] is not None:
279
+ device = inference_state["device"]
280
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
281
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
282
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
283
+ current_out, _ = self._run_single_frame_inference(
284
+ inference_state=inference_state,
285
+ output_dict=obj_output_dict, # run on the slice of a single object
286
+ frame_idx=frame_idx,
287
+ batch_size=1, # run on the slice of a single object
288
+ is_init_cond_frame=is_init_cond_frame,
289
+ point_inputs=point_inputs,
290
+ mask_inputs=None,
291
+ reverse=reverse,
292
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
293
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
294
+ # allows us to enforce non-overlapping constraints on all objects before encoding
295
+ # them into memory.
296
+ run_mem_encoder=False,
297
+ prev_sam_mask_logits=prev_sam_mask_logits,
298
+ )
299
+ # Add the output to the output dict (to be used as future memory)
300
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
301
+
302
+ # Resize the output mask to the original video resolution
303
+ obj_ids = inference_state["obj_ids"]
304
+ consolidated_out = self._consolidate_temp_output_across_obj(
305
+ inference_state,
306
+ frame_idx,
307
+ is_cond=is_cond,
308
+ run_mem_encoder=False,
309
+ consolidate_at_video_res=True,
310
+ )
311
+ _, video_res_masks = self._get_orig_video_res_output(
312
+ inference_state, consolidated_out["pred_masks_video_res"]
313
+ )
314
+ return frame_idx, obj_ids, video_res_masks
315
+
316
+ def add_new_points(self, *args, **kwargs):
317
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
318
+ return self.add_new_points_or_box(*args, **kwargs)
319
+
320
+ @torch.inference_mode()
321
+ def add_new_mask(
322
+ self,
323
+ inference_state,
324
+ frame_idx,
325
+ obj_id,
326
+ mask,
327
+ ):
328
+ """Add new mask to a frame."""
329
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
330
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
331
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
332
+
333
+ if not isinstance(mask, torch.Tensor):
334
+ mask = torch.tensor(mask, dtype=torch.bool)
335
+ assert mask.dim() == 2
336
+ mask_H, mask_W = mask.shape
337
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
338
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
339
+
340
+ # resize the mask if it doesn't match the model's image size
341
+ if mask_H != self.image_size or mask_W != self.image_size:
342
+ mask_inputs = torch.nn.functional.interpolate(
343
+ mask_inputs_orig,
344
+ size=(self.image_size, self.image_size),
345
+ align_corners=False,
346
+ mode="bilinear",
347
+ antialias=True, # use antialias for downsampling
348
+ )
349
+ mask_inputs = (mask_inputs >= 0.5).float()
350
+ else:
351
+ mask_inputs = mask_inputs_orig
352
+
353
+ mask_inputs_per_frame[frame_idx] = mask_inputs
354
+ point_inputs_per_frame.pop(frame_idx, None)
355
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
356
+ # frame, meaning that the inputs points are to generate segments on this frame without
357
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
358
+ # the input points will be used to correct the already tracked masks.
359
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
360
+ # whether to track in reverse time order
361
+ if is_init_cond_frame:
362
+ reverse = False
363
+ else:
364
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
365
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
366
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
367
+ # Add a frame to conditioning output if it's an initial conditioning frame or
368
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
369
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
370
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
371
+
372
+ current_out, _ = self._run_single_frame_inference(
373
+ inference_state=inference_state,
374
+ output_dict=obj_output_dict, # run on the slice of a single object
375
+ frame_idx=frame_idx,
376
+ batch_size=1, # run on the slice of a single object
377
+ is_init_cond_frame=is_init_cond_frame,
378
+ point_inputs=None,
379
+ mask_inputs=mask_inputs,
380
+ reverse=reverse,
381
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
382
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
383
+ # allows us to enforce non-overlapping constraints on all objects before encoding
384
+ # them into memory.
385
+ run_mem_encoder=False,
386
+ )
387
+ # Add the output to the output dict (to be used as future memory)
388
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
389
+
390
+ # Resize the output mask to the original video resolution
391
+ obj_ids = inference_state["obj_ids"]
392
+ consolidated_out = self._consolidate_temp_output_across_obj(
393
+ inference_state,
394
+ frame_idx,
395
+ is_cond=is_cond,
396
+ run_mem_encoder=False,
397
+ consolidate_at_video_res=True,
398
+ )
399
+ _, video_res_masks = self._get_orig_video_res_output(
400
+ inference_state, consolidated_out["pred_masks_video_res"]
401
+ )
402
+ return frame_idx, obj_ids, video_res_masks
403
+
404
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
405
+ """
406
+ Resize the object scores to the original video resolution (video_res_masks)
407
+ and apply non-overlapping constraints for final output.
408
+ """
409
+ device = inference_state["device"]
410
+ video_H = inference_state["video_height"]
411
+ video_W = inference_state["video_width"]
412
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
413
+ if any_res_masks.shape[-2:] == (video_H, video_W):
414
+ video_res_masks = any_res_masks
415
+ else:
416
+ video_res_masks = torch.nn.functional.interpolate(
417
+ any_res_masks,
418
+ size=(video_H, video_W),
419
+ mode="bilinear",
420
+ align_corners=False,
421
+ )
422
+ if self.non_overlap_masks:
423
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
424
+ return any_res_masks, video_res_masks
425
+
426
+ def _consolidate_temp_output_across_obj(
427
+ self,
428
+ inference_state,
429
+ frame_idx,
430
+ is_cond,
431
+ run_mem_encoder,
432
+ consolidate_at_video_res=False,
433
+ ):
434
+ """
435
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
436
+ a frame into a single output for all objects, including
437
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
438
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
439
+ (if they don't exist in `output_dict_per_obj` for this frame);
440
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
441
+ on the object scores.
442
+ """
443
+ batch_size = self._get_obj_num(inference_state)
444
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
445
+ # Optionally, we allow consolidating the temporary outputs at the original
446
+ # video resolution (to provide a better editing experience for mask prompts).
447
+ if consolidate_at_video_res:
448
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
449
+ consolidated_H = inference_state["video_height"]
450
+ consolidated_W = inference_state["video_width"]
451
+ consolidated_mask_key = "pred_masks_video_res"
452
+ else:
453
+ consolidated_H = consolidated_W = self.image_size // 4
454
+ consolidated_mask_key = "pred_masks"
455
+
456
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
457
+ # will be added when rerunning the memory encoder after applying non-overlapping
458
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
459
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
460
+ consolidated_out = {
461
+ "maskmem_features": None,
462
+ "maskmem_pos_enc": None,
463
+ consolidated_mask_key: torch.full(
464
+ size=(batch_size, 1, consolidated_H, consolidated_W),
465
+ fill_value=NO_OBJ_SCORE,
466
+ dtype=torch.float32,
467
+ device=inference_state["storage_device"],
468
+ ),
469
+ "obj_ptr": torch.full(
470
+ size=(batch_size, self.hidden_dim),
471
+ fill_value=NO_OBJ_SCORE,
472
+ dtype=torch.float32,
473
+ device=inference_state["device"],
474
+ ),
475
+ "object_score_logits": torch.full(
476
+ size=(batch_size, 1),
477
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
478
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
+ fill_value=10.0,
480
+ dtype=torch.float32,
481
+ device=inference_state["device"],
482
+ ),
483
+ }
484
+ empty_mask_ptr = None
485
+ for obj_idx in range(batch_size):
486
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
487
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
488
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
489
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
490
+ # we fall back and look up its previous output in "output_dict_per_obj".
491
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
492
+ # "output_dict_per_obj" to find a previous output for this object.
493
+ if out is None:
494
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
495
+ if out is None:
496
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
497
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
498
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
499
+ # placeholder above) and set its object pointer to be a dummy pointer.
500
+ if out is None:
501
+ # Fill in dummy object pointers for those objects without any inputs or
502
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
503
+ # i.e. when we need to build the memory for tracking).
504
+ if run_mem_encoder:
505
+ if empty_mask_ptr is None:
506
+ empty_mask_ptr = self._get_empty_mask_ptr(
507
+ inference_state, frame_idx
508
+ )
509
+ # fill object pointer with a dummy pointer (based on an empty mask)
510
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
511
+ continue
512
+ # Add the temporary object output mask to consolidated output mask
513
+ obj_mask = out["pred_masks"]
514
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
515
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
516
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
517
+ else:
518
+ # Resize first if temporary object mask has a different resolution
519
+ resized_obj_mask = torch.nn.functional.interpolate(
520
+ obj_mask,
521
+ size=consolidated_pred_masks.shape[-2:],
522
+ mode="bilinear",
523
+ align_corners=False,
524
+ )
525
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
526
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
527
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
528
+ "object_score_logits"
529
+ ]
530
+
531
+ # Optionally, apply non-overlapping constraints on the consolidated scores
532
+ # and rerun the memory encoder
533
+ if run_mem_encoder:
534
+ device = inference_state["device"]
535
+ high_res_masks = torch.nn.functional.interpolate(
536
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
537
+ size=(self.image_size, self.image_size),
538
+ mode="bilinear",
539
+ align_corners=False,
540
+ )
541
+ if self.non_overlap_masks_for_mem_enc:
542
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
543
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
544
+ inference_state=inference_state,
545
+ frame_idx=frame_idx,
546
+ batch_size=batch_size,
547
+ high_res_masks=high_res_masks,
548
+ object_score_logits=consolidated_out["object_score_logits"],
549
+ is_mask_from_pts=True, # these frames are what the user interacted with
550
+ )
551
+ consolidated_out["maskmem_features"] = maskmem_features
552
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
553
+
554
+ return consolidated_out
555
+
556
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
557
+ """Get a dummy object pointer based on an empty mask on the current frame."""
558
+ # A dummy (empty) mask with a single object
559
+ batch_size = 1
560
+ mask_inputs = torch.zeros(
561
+ (batch_size, 1, self.image_size, self.image_size),
562
+ dtype=torch.float32,
563
+ device=inference_state["device"],
564
+ )
565
+
566
+ # Retrieve correct image features
567
+ (
568
+ _,
569
+ _,
570
+ current_vision_feats,
571
+ current_vision_pos_embeds,
572
+ feat_sizes,
573
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
574
+
575
+ # Feed the empty mask and image feature above to get a dummy object pointer
576
+ current_out = self.track_step(
577
+ frame_idx=frame_idx,
578
+ is_init_cond_frame=True,
579
+ current_vision_feats=current_vision_feats,
580
+ current_vision_pos_embeds=current_vision_pos_embeds,
581
+ feat_sizes=feat_sizes,
582
+ point_inputs=None,
583
+ mask_inputs=mask_inputs,
584
+ output_dict={},
585
+ num_frames=inference_state["num_frames"],
586
+ track_in_reverse=False,
587
+ run_mem_encoder=False,
588
+ prev_sam_mask_logits=None,
589
+ )
590
+ return current_out["obj_ptr"]
591
+
592
+ @torch.inference_mode()
593
+ def propagate_in_video_preflight(self, inference_state):
594
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
595
+ # Tracking has started and we don't allow adding new objects until session is reset.
596
+ inference_state["tracking_has_started"] = True
597
+ batch_size = self._get_obj_num(inference_state)
598
+
599
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
600
+ # add them into "output_dict".
601
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
602
+ output_dict = inference_state["output_dict"]
603
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
604
+ # temporary outputs have been added (either in this call or any previous calls
605
+ # to `propagate_in_video_preflight`).
606
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
607
+ for is_cond in [False, True]:
608
+ # Separately consolidate conditioning and non-conditioning temp outputs
609
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
610
+ # Find all the frames that contain temporary outputs for any objects
611
+ # (these should be the frames that have just received clicks for mask inputs
612
+ # via `add_new_points_or_box` or `add_new_mask`)
613
+ temp_frame_inds = set()
614
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
615
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
616
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
617
+ # consolidate the temporary output across all objects on this frame
618
+ for frame_idx in temp_frame_inds:
619
+ consolidated_out = self._consolidate_temp_output_across_obj(
620
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
621
+ )
622
+ # merge them into "output_dict" and also create per-object slices
623
+ output_dict[storage_key][frame_idx] = consolidated_out
624
+ self._add_output_per_object(
625
+ inference_state, frame_idx, consolidated_out, storage_key
626
+ )
627
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
628
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
629
+ )
630
+ if clear_non_cond_mem:
631
+ # clear non-conditioning memory of the surrounding frames
632
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
633
+
634
+ # clear temporary outputs in `temp_output_dict_per_obj`
635
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
636
+ obj_temp_output_dict[storage_key].clear()
637
+
638
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
639
+ # output on the same frame in "non_cond_frame_outputs"
640
+ for frame_idx in output_dict["cond_frame_outputs"]:
641
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
642
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
643
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
644
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
645
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
646
+ assert frame_idx in output_dict["cond_frame_outputs"]
647
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
648
+
649
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
650
+ # with either points or mask inputs (which should be true under a correct workflow).
651
+ all_consolidated_frame_inds = (
652
+ consolidated_frame_inds["cond_frame_outputs"]
653
+ | consolidated_frame_inds["non_cond_frame_outputs"]
654
+ )
655
+ input_frames_inds = set()
656
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
657
+ input_frames_inds.update(point_inputs_per_frame.keys())
658
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
659
+ input_frames_inds.update(mask_inputs_per_frame.keys())
660
+ assert all_consolidated_frame_inds == input_frames_inds
661
+
662
+ @torch.inference_mode()
663
+ def propagate_in_video(
664
+ self,
665
+ inference_state,
666
+ start_frame_idx=None,
667
+ max_frame_num_to_track=None,
668
+ reverse=False,
669
+ ):
670
+ """Propagate the input points across frames to track in the entire video."""
671
+ self.propagate_in_video_preflight(inference_state)
672
+
673
+ output_dict = inference_state["output_dict"]
674
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
675
+ obj_ids = inference_state["obj_ids"]
676
+ num_frames = inference_state["num_frames"]
677
+ batch_size = self._get_obj_num(inference_state)
678
+ if len(output_dict["cond_frame_outputs"]) == 0:
679
+ raise RuntimeError("No points are provided; please add points first")
680
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
681
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
682
+ )
683
+
684
+ # set start index, end index, and processing order
685
+ if start_frame_idx is None:
686
+ # default: start from the earliest frame with input points
687
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
688
+ if max_frame_num_to_track is None:
689
+ # default: track all the frames in the video
690
+ max_frame_num_to_track = num_frames
691
+ if reverse:
692
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
693
+ if start_frame_idx > 0:
694
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
695
+ else:
696
+ processing_order = [] # skip reverse tracking if starting from frame 0
697
+ else:
698
+ end_frame_idx = min(
699
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
700
+ )
701
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
702
+
703
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
704
+ # We skip those frames already in consolidated outputs (these are frames
705
+ # that received input clicks or mask). Note that we cannot directly run
706
+ # batched forward on them via `_run_single_frame_inference` because the
707
+ # number of clicks on each object might be different.
708
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
709
+ storage_key = "cond_frame_outputs"
710
+ current_out = output_dict[storage_key][frame_idx]
711
+ pred_masks = current_out["pred_masks"]
712
+ if clear_non_cond_mem:
713
+ # clear non-conditioning memory of the surrounding frames
714
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
715
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
716
+ storage_key = "non_cond_frame_outputs"
717
+ current_out = output_dict[storage_key][frame_idx]
718
+ pred_masks = current_out["pred_masks"]
719
+ else:
720
+ storage_key = "non_cond_frame_outputs"
721
+ current_out, pred_masks = self._run_single_frame_inference(
722
+ inference_state=inference_state,
723
+ output_dict=output_dict,
724
+ frame_idx=frame_idx,
725
+ batch_size=batch_size,
726
+ is_init_cond_frame=False,
727
+ point_inputs=None,
728
+ mask_inputs=None,
729
+ reverse=reverse,
730
+ run_mem_encoder=True,
731
+ )
732
+ output_dict[storage_key][frame_idx] = current_out
733
+ # Create slices of per-object outputs for subsequent interaction with each
734
+ # individual object after tracking.
735
+ self._add_output_per_object(
736
+ inference_state, frame_idx, current_out, storage_key
737
+ )
738
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
739
+
740
+ # Resize the output mask to the original video resolution (we directly use
741
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
742
+ _, video_res_masks = self._get_orig_video_res_output(
743
+ inference_state, pred_masks
744
+ )
745
+ yield frame_idx, obj_ids, video_res_masks
746
+
747
+ def _add_output_per_object(
748
+ self, inference_state, frame_idx, current_out, storage_key
749
+ ):
750
+ """
751
+ Split a multi-object output into per-object output slices and add them into
752
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
753
+ """
754
+ maskmem_features = current_out["maskmem_features"]
755
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
756
+
757
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
758
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
759
+
760
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
761
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
762
+ obj_slice = slice(obj_idx, obj_idx + 1)
763
+ obj_out = {
764
+ "maskmem_features": None,
765
+ "maskmem_pos_enc": None,
766
+ "pred_masks": current_out["pred_masks"][obj_slice],
767
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
768
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
769
+ }
770
+ if maskmem_features is not None:
771
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
772
+ if maskmem_pos_enc is not None:
773
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
774
+ obj_output_dict[storage_key][frame_idx] = obj_out
775
+
776
+ @torch.inference_mode()
777
+ def clear_all_prompts_in_frame(
778
+ self, inference_state, frame_idx, obj_id, need_output=True
779
+ ):
780
+ """Remove all input points or mask in a specific frame for a given object."""
781
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
782
+
783
+ # Clear the conditioning information on the given frame
784
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
785
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
786
+
787
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
788
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
789
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
790
+
791
+ # Check and see if there are still any inputs left on this frame
792
+ batch_size = self._get_obj_num(inference_state)
793
+ frame_has_input = False
794
+ for obj_idx2 in range(batch_size):
795
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
796
+ frame_has_input = True
797
+ break
798
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
799
+ frame_has_input = True
800
+ break
801
+
802
+ # If this frame has no remaining inputs for any objects, we further clear its
803
+ # conditioning frame status
804
+ if not frame_has_input:
805
+ output_dict = inference_state["output_dict"]
806
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
807
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
808
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
809
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
810
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
811
+ if out is not None:
812
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
813
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
814
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
815
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
816
+ # Similarly, do it for the sliced output on each object.
817
+ for obj_idx2 in range(batch_size):
818
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
819
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
820
+ if obj_out is not None:
821
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
822
+
823
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
824
+ if len(output_dict["cond_frame_outputs"]) == 0:
825
+ self._reset_tracking_results(inference_state)
826
+
827
+ if not need_output:
828
+ return
829
+ # Finally, output updated masks per object (after removing the inputs above)
830
+ obj_ids = inference_state["obj_ids"]
831
+ is_cond = any(
832
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
833
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
834
+ )
835
+ consolidated_out = self._consolidate_temp_output_across_obj(
836
+ inference_state,
837
+ frame_idx,
838
+ is_cond=is_cond,
839
+ run_mem_encoder=False,
840
+ consolidate_at_video_res=True,
841
+ )
842
+ _, video_res_masks = self._get_orig_video_res_output(
843
+ inference_state, consolidated_out["pred_masks_video_res"]
844
+ )
845
+ return frame_idx, obj_ids, video_res_masks
846
+
847
+ @torch.inference_mode()
848
+ def reset_state(self, inference_state):
849
+ """Remove all input points or mask in all frames throughout the video."""
850
+ self._reset_tracking_results(inference_state)
851
+ # Remove all object ids
852
+ inference_state["obj_id_to_idx"].clear()
853
+ inference_state["obj_idx_to_id"].clear()
854
+ inference_state["obj_ids"].clear()
855
+ inference_state["point_inputs_per_obj"].clear()
856
+ inference_state["mask_inputs_per_obj"].clear()
857
+ inference_state["output_dict_per_obj"].clear()
858
+ inference_state["temp_output_dict_per_obj"].clear()
859
+
860
+ def _reset_tracking_results(self, inference_state):
861
+ """Reset all tracking inputs and results across the videos."""
862
+ for v in inference_state["point_inputs_per_obj"].values():
863
+ v.clear()
864
+ for v in inference_state["mask_inputs_per_obj"].values():
865
+ v.clear()
866
+ for v in inference_state["output_dict_per_obj"].values():
867
+ v["cond_frame_outputs"].clear()
868
+ v["non_cond_frame_outputs"].clear()
869
+ for v in inference_state["temp_output_dict_per_obj"].values():
870
+ v["cond_frame_outputs"].clear()
871
+ v["non_cond_frame_outputs"].clear()
872
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
873
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
874
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
875
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
876
+ inference_state["tracking_has_started"] = False
877
+ inference_state["frames_already_tracked"].clear()
878
+
879
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
880
+ """Compute the image features on a given frame."""
881
+ # Look up in the cache first
882
+ image, backbone_out = inference_state["cached_features"].get(
883
+ frame_idx, (None, None)
884
+ )
885
+ if backbone_out is None:
886
+ # Cache miss -- we will run inference on a single image
887
+ device = inference_state["device"]
888
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
889
+ backbone_out = self.forward_image(image)
890
+ # Cache the most recent frame's feature (for repeated interactions with
891
+ # a frame; we can use an LRU cache for more frames in the future).
892
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
893
+
894
+ # expand the features to have the same dimension as the number of objects
895
+ expanded_image = image.expand(batch_size, -1, -1, -1)
896
+ expanded_backbone_out = {
897
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
898
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
899
+ }
900
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
901
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
902
+ batch_size, -1, -1, -1
903
+ )
904
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
905
+ pos = pos.expand(batch_size, -1, -1, -1)
906
+ expanded_backbone_out["vision_pos_enc"][i] = pos
907
+
908
+ features = self._prepare_backbone_features(expanded_backbone_out)
909
+ features = (expanded_image,) + features
910
+ return features
911
+
912
+ def _run_single_frame_inference(
913
+ self,
914
+ inference_state,
915
+ output_dict,
916
+ frame_idx,
917
+ batch_size,
918
+ is_init_cond_frame,
919
+ point_inputs,
920
+ mask_inputs,
921
+ reverse,
922
+ run_mem_encoder,
923
+ prev_sam_mask_logits=None,
924
+ ):
925
+ """Run tracking on a single frame based on current inputs and previous memory."""
926
+ # Retrieve correct image features
927
+ (
928
+ _,
929
+ _,
930
+ current_vision_feats,
931
+ current_vision_pos_embeds,
932
+ feat_sizes,
933
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
934
+
935
+ # point and mask should not appear as input simultaneously on the same frame
936
+ assert point_inputs is None or mask_inputs is None
937
+ current_out = self.track_step(
938
+ frame_idx=frame_idx,
939
+ is_init_cond_frame=is_init_cond_frame,
940
+ current_vision_feats=current_vision_feats,
941
+ current_vision_pos_embeds=current_vision_pos_embeds,
942
+ feat_sizes=feat_sizes,
943
+ point_inputs=point_inputs,
944
+ mask_inputs=mask_inputs,
945
+ output_dict=output_dict,
946
+ num_frames=inference_state["num_frames"],
947
+ track_in_reverse=reverse,
948
+ run_mem_encoder=run_mem_encoder,
949
+ prev_sam_mask_logits=prev_sam_mask_logits,
950
+ )
951
+
952
+ # optionally offload the output to CPU memory to save GPU space
953
+ storage_device = inference_state["storage_device"]
954
+ maskmem_features = current_out["maskmem_features"]
955
+ if maskmem_features is not None:
956
+ maskmem_features = maskmem_features.to(torch.bfloat16)
957
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
958
+ pred_masks_gpu = current_out["pred_masks"]
959
+ # potentially fill holes in the predicted masks
960
+ if self.fill_hole_area > 0:
961
+ pred_masks_gpu = fill_holes_in_mask_scores(
962
+ pred_masks_gpu, self.fill_hole_area
963
+ )
964
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
965
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
966
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
967
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
968
+ obj_ptr = current_out["obj_ptr"]
969
+ object_score_logits = current_out["object_score_logits"]
970
+ # make a compact version of this frame's output to reduce the state size
971
+ compact_current_out = {
972
+ "maskmem_features": maskmem_features,
973
+ "maskmem_pos_enc": maskmem_pos_enc,
974
+ "pred_masks": pred_masks,
975
+ "obj_ptr": obj_ptr,
976
+ "object_score_logits": object_score_logits,
977
+ }
978
+ return compact_current_out, pred_masks_gpu
979
+
980
+ def _run_memory_encoder(
981
+ self,
982
+ inference_state,
983
+ frame_idx,
984
+ batch_size,
985
+ high_res_masks,
986
+ object_score_logits,
987
+ is_mask_from_pts,
988
+ ):
989
+ """
990
+ Run the memory encoder on `high_res_masks`. This is usually after applying
991
+ non-overlapping constraints to object scores. Since their scores changed, their
992
+ memory also need to be computed again with the memory encoder.
993
+ """
994
+ # Retrieve correct image features
995
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
996
+ inference_state, frame_idx, batch_size
997
+ )
998
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
999
+ current_vision_feats=current_vision_feats,
1000
+ feat_sizes=feat_sizes,
1001
+ pred_masks_high_res=high_res_masks,
1002
+ object_score_logits=object_score_logits,
1003
+ is_mask_from_pts=is_mask_from_pts,
1004
+ )
1005
+
1006
+ # optionally offload the output to CPU memory to save GPU space
1007
+ storage_device = inference_state["storage_device"]
1008
+ maskmem_features = maskmem_features.to(torch.bfloat16)
1009
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
1010
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1011
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
1012
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
1013
+ )
1014
+ return maskmem_features, maskmem_pos_enc
1015
+
1016
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
1017
+ """
1018
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1019
+ a constant in the inference session to reduce session storage size.
1020
+ """
1021
+ model_constants = inference_state["constants"]
1022
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
1023
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
1024
+ if out_maskmem_pos_enc is not None:
1025
+ if "maskmem_pos_enc" not in model_constants:
1026
+ assert isinstance(out_maskmem_pos_enc, list)
1027
+ # only take the slice for one object, since it's same across objects
1028
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1029
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
1030
+ else:
1031
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1032
+ # expand the cached maskmem_pos_enc to the actual batch size
1033
+ batch_size = out_maskmem_pos_enc[0].size(0)
1034
+ expanded_maskmem_pos_enc = [
1035
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
1036
+ ]
1037
+ else:
1038
+ expanded_maskmem_pos_enc = None
1039
+ return expanded_maskmem_pos_enc
1040
+
1041
+ @torch.inference_mode()
1042
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
1043
+ """
1044
+ Remove an object id from the tracking state. If strict is True, we check whether
1045
+ the object id actually exists and raise an error if it doesn't exist.
1046
+ """
1047
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
1048
+ updated_frames = []
1049
+ # Check whether this object_id to remove actually exists and possibly raise an error.
1050
+ if old_obj_idx_to_rm is None:
1051
+ if not strict:
1052
+ return inference_state["obj_ids"], updated_frames
1053
+ raise RuntimeError(
1054
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
1055
+ f"All existing object ids: {inference_state['obj_ids']}."
1056
+ )
1057
+
1058
+ # If this is the only remaining object id, we simply reset the state.
1059
+ if len(inference_state["obj_id_to_idx"]) == 1:
1060
+ self.reset_state(inference_state)
1061
+ return inference_state["obj_ids"], updated_frames
1062
+
1063
+ # There are still remaining objects after removing this object id. In this case,
1064
+ # we need to delete the object storage from inference state tensors.
1065
+ # Step 0: clear the input on those frames where this object id has point or mask input
1066
+ # (note that this step is required as it might downgrade conditioning frames to
1067
+ # non-conditioning ones)
1068
+ obj_input_frames_inds = set()
1069
+ obj_input_frames_inds.update(
1070
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
1071
+ )
1072
+ obj_input_frames_inds.update(
1073
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
1074
+ )
1075
+ for frame_idx in obj_input_frames_inds:
1076
+ self.clear_all_prompts_in_frame(
1077
+ inference_state, frame_idx, obj_id, need_output=False
1078
+ )
1079
+
1080
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
1081
+ # since Step 0 still requires the old object id mappings in inference_state)
1082
+ old_obj_ids = inference_state["obj_ids"]
1083
+ old_obj_inds = list(range(len(old_obj_ids)))
1084
+ remain_old_obj_inds = old_obj_inds.copy()
1085
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
1086
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
1087
+ new_obj_inds = list(range(len(new_obj_ids)))
1088
+ # build new mappings
1089
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
1090
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
1091
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
1092
+ inference_state["obj_ids"] = new_obj_ids
1093
+
1094
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
1095
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
1096
+ # it's already handled in Step 0)
1097
+ def _map_keys(container):
1098
+ new_kvs = []
1099
+ for k in old_obj_inds:
1100
+ v = container.pop(k)
1101
+ if k in old_idx_to_new_idx:
1102
+ new_kvs.append((old_idx_to_new_idx[k], v))
1103
+ container.update(new_kvs)
1104
+
1105
+ _map_keys(inference_state["point_inputs_per_obj"])
1106
+ _map_keys(inference_state["mask_inputs_per_obj"])
1107
+ _map_keys(inference_state["output_dict_per_obj"])
1108
+ _map_keys(inference_state["temp_output_dict_per_obj"])
1109
+
1110
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
1111
+ def _slice_state(output_dict, storage_key):
1112
+ for frame_idx, out in output_dict[storage_key].items():
1113
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
1114
+ out["maskmem_pos_enc"] = [
1115
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
1116
+ ]
1117
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1118
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
1119
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
1120
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
1121
+ out["object_score_logits"] = out["object_score_logits"][
1122
+ remain_old_obj_inds
1123
+ ]
1124
+ # also update the per-object slices
1125
+ self._add_output_per_object(
1126
+ inference_state, frame_idx, out, storage_key
1127
+ )
1128
+
1129
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
1130
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
1131
+
1132
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1133
+ # could show an updated mask for objects previously occluded by the object being removed
1134
+ if need_output:
1135
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
1136
+ for frame_idx in obj_input_frames_inds:
1137
+ is_cond = any(
1138
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
1139
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
1140
+ )
1141
+ consolidated_out = self._consolidate_temp_output_across_obj(
1142
+ inference_state,
1143
+ frame_idx,
1144
+ is_cond=is_cond,
1145
+ run_mem_encoder=False,
1146
+ consolidate_at_video_res=True,
1147
+ )
1148
+ _, video_res_masks = self._get_orig_video_res_output(
1149
+ inference_state, consolidated_out["pred_masks_video_res"]
1150
+ )
1151
+ updated_frames.append((frame_idx, video_res_masks))
1152
+
1153
+ return inference_state["obj_ids"], updated_frames
1154
+
1155
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
1156
+ """
1157
+ Remove the non-conditioning memory around the input frame. When users provide
1158
+ correction clicks, the surrounding frames' non-conditioning memories can still
1159
+ contain outdated object appearance information and could confuse the model.
1160
+
1161
+ This method clears those non-conditioning memories surrounding the interacted
1162
+ frame to avoid giving the model both old and new information about the object.
1163
+ """
1164
+ r = self.memory_temporal_stride_for_eval
1165
+ frame_idx_begin = frame_idx - r * self.num_maskmem
1166
+ frame_idx_end = frame_idx + r * self.num_maskmem
1167
+ output_dict = inference_state["output_dict"]
1168
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
1169
+ for t in range(frame_idx_begin, frame_idx_end + 1):
1170
+ non_cond_frame_outputs.pop(t, None)
1171
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
1172
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
clone-IDEA-Research/Grounded-SAM-2/setup.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ # Package metadata
11
+ NAME = "SAM-2"
12
+ VERSION = "1.0"
13
+ DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
14
+ URL = "https://github.com/facebookresearch/sam2"
15
+ AUTHOR = "Meta AI"
16
+ AUTHOR_EMAIL = "[email protected]"
17
+ LICENSE = "Apache 2.0"
18
+
19
+ # Read the contents of README file
20
+ with open("README.md", "r", encoding="utf-8") as f:
21
+ LONG_DESCRIPTION = f.read()
22
+
23
+ # Required dependencies
24
+ REQUIRED_PACKAGES = [
25
+ "torch>=2.3.1",
26
+ "torchvision>=0.18.1",
27
+ "numpy>=1.24.4",
28
+ "tqdm>=4.66.1",
29
+ "hydra-core>=1.3.2",
30
+ "iopath>=0.1.10",
31
+ "pillow>=9.4.0",
32
+ ]
33
+
34
+ EXTRA_PACKAGES = {
35
+ "notebooks": [
36
+ "matplotlib>=3.9.1",
37
+ "jupyter>=1.0.0",
38
+ "opencv-python>=4.7.0",
39
+ "eva-decord>=0.6.1",
40
+ ],
41
+ "interactive-demo": [
42
+ "Flask>=3.0.3",
43
+ "Flask-Cors>=5.0.0",
44
+ "av>=13.0.0",
45
+ "dataclasses-json>=0.6.7",
46
+ "eva-decord>=0.6.1",
47
+ "gunicorn>=23.0.0",
48
+ "imagesize>=1.4.1",
49
+ "pycocotools>=2.0.8",
50
+ "strawberry-graphql>=0.243.0",
51
+ ],
52
+ "dev": [
53
+ "black==24.2.0",
54
+ "usort==1.0.2",
55
+ "ufmt==2.0.0b2",
56
+ "fvcore>=0.1.5.post20221221",
57
+ "pandas>=2.2.2",
58
+ "scikit-image>=0.24.0",
59
+ "tensorboard>=2.17.0",
60
+ "pycocotools>=2.0.8",
61
+ "tensordict>=0.5.0",
62
+ "opencv-python>=4.7.0",
63
+ "submitit>=1.5.1",
64
+ ],
65
+ }
66
+
67
+ # By default, we also build the SAM 2 CUDA extension.
68
+ # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
69
+ BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
70
+ # By default, we allow SAM 2 installation to proceed even with build errors.
71
+ # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
72
+ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
73
+
74
+ # Catch and skip errors during extension building and print a warning message
75
+ # (note that this message only shows up under verbose build mode
76
+ # "pip install -v -e ." or "python setup.py build_ext -v")
77
+ CUDA_ERROR_MSG = (
78
+ "{}\n\n"
79
+ "Failed to build the SAM 2 CUDA extension due to the error above. "
80
+ "You can still use SAM 2 and it's OK to ignore the error above, although some "
81
+ "post-processing functionality may be limited (which doesn't affect the results in most cases; "
82
+ "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
83
+ )
84
+
85
+
86
+ def get_extensions():
87
+ if not BUILD_CUDA:
88
+ return []
89
+
90
+ try:
91
+ from torch.utils.cpp_extension import CUDAExtension
92
+
93
+ srcs = ["sam2/csrc/connected_components.cu"]
94
+ compile_args = {
95
+ "cxx": [],
96
+ "nvcc": [
97
+ "-DCUDA_HAS_FP16=1",
98
+ "-D__CUDA_NO_HALF_OPERATORS__",
99
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
100
+ "-D__CUDA_NO_HALF2_OPERATORS__",
101
+ ],
102
+ }
103
+ ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
104
+ except Exception as e:
105
+ if BUILD_ALLOW_ERRORS:
106
+ print(CUDA_ERROR_MSG.format(e))
107
+ ext_modules = []
108
+ else:
109
+ raise e
110
+
111
+ return ext_modules
112
+
113
+
114
+ try:
115
+ from torch.utils.cpp_extension import BuildExtension
116
+
117
+ class BuildExtensionIgnoreErrors(BuildExtension):
118
+
119
+ def finalize_options(self):
120
+ try:
121
+ super().finalize_options()
122
+ except Exception as e:
123
+ print(CUDA_ERROR_MSG.format(e))
124
+ self.extensions = []
125
+
126
+ def build_extensions(self):
127
+ try:
128
+ super().build_extensions()
129
+ except Exception as e:
130
+ print(CUDA_ERROR_MSG.format(e))
131
+ self.extensions = []
132
+
133
+ def get_ext_filename(self, ext_name):
134
+ try:
135
+ return super().get_ext_filename(ext_name)
136
+ except Exception as e:
137
+ print(CUDA_ERROR_MSG.format(e))
138
+ self.extensions = []
139
+ return "_C.so"
140
+
141
+ cmdclass = {
142
+ "build_ext": (
143
+ BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
144
+ if BUILD_ALLOW_ERRORS
145
+ else BuildExtension.with_options(no_python_abi_suffix=True)
146
+ )
147
+ }
148
+ except Exception as e:
149
+ cmdclass = {}
150
+ if BUILD_ALLOW_ERRORS:
151
+ print(CUDA_ERROR_MSG.format(e))
152
+ else:
153
+ raise e
154
+
155
+
156
+ # Setup configuration
157
+ setup(
158
+ name=NAME,
159
+ version=VERSION,
160
+ description=DESCRIPTION,
161
+ long_description=LONG_DESCRIPTION,
162
+ long_description_content_type="text/markdown",
163
+ url=URL,
164
+ author=AUTHOR,
165
+ author_email=AUTHOR_EMAIL,
166
+ license=LICENSE,
167
+ packages=find_packages(exclude="notebooks"),
168
+ include_package_data=True,
169
+ install_requires=REQUIRED_PACKAGES,
170
+ extras_require=EXTRA_PACKAGES,
171
+ python_requires=">=3.10.0",
172
+ ext_modules=get_extensions(),
173
+ cmdclass=cmdclass,
174
+ )
clone-IDEA-Research/Grounded-Segment-Anything/.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # checkpoint
132
+ *.pth
133
+ outputs/
134
+
135
+ .idea/
clone-IDEA-Research/Grounded-Segment-Anything/.gitmodules ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ [submodule "grounded-sam-osx"]
3
+ path = grounded-sam-osx
4
+ url = https://github.com/linjing7/grounded-sam-osx.git
5
+ [submodule "VISAM"]
6
+ path = VISAM
7
+ url = https://github.com/BingfengYan/VISAM
clone-IDEA-Research/Grounded-Segment-Anything/CITATION.cff ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ authors:
4
+ - name: "Grounded-SAM Contributors"
5
+ title: "Grounded-Segment-Anything"
6
+ date-released: 2023-04-06
7
+ url: "https://github.com/IDEA-Research/Grounded-Segment-Anything"
8
+ license: Apache-2.0
clone-IDEA-Research/Grounded-Segment-Anything/Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
2
+
3
+ # Arguments to build Docker Image using CUDA
4
+ ARG USE_CUDA=0
5
+ ARG TORCH_ARCH=
6
+
7
+ ENV AM_I_DOCKER True
8
+ ENV BUILD_WITH_CUDA "${USE_CUDA}"
9
+ ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}"
10
+ ENV CUDA_HOME /usr/local/cuda-11.6/
11
+
12
+ RUN mkdir -p /home/appuser/Grounded-Segment-Anything
13
+ COPY . /home/appuser/Grounded-Segment-Anything/
14
+
15
+ RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
16
+ libsm6=2:* libxext6=2:* git=1:* nano=2.* \
17
+ vim=2:* -y \
18
+ && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
19
+
20
+ WORKDIR /home/appuser/Grounded-Segment-Anything
21
+ RUN python -m pip install --no-cache-dir -e segment_anything
22
+
23
+ # When using build isolation, PyTorch with newer CUDA is installed and can't compile GroundingDINO
24
+ RUN python -m pip install --no-cache-dir wheel
25
+ RUN python -m pip install --no-cache-dir --no-build-isolation -e GroundingDINO
26
+
27
+ WORKDIR /home/appuser
28
+ RUN pip install --no-cache-dir diffusers[torch]==0.15.1 opencv-python==4.7.0.72 \
29
+ pycocotools==2.0.6 matplotlib==3.5.3 \
30
+ onnxruntime==1.14.1 onnx==1.13.1 ipykernel==6.16.2 scipy gradio openai
clone-IDEA-Research/Grounded-Segment-Anything/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, IDEA, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
clone-IDEA-Research/Grounded-Segment-Anything/Makefile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Get version of CUDA and enable it for compilation if CUDA > 11.0
2
+ # This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
3
+ # and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
4
+ # when running in Docker
5
+ # Check if nvcc is installed
6
+ NVCC := $(shell which nvcc)
7
+ ifeq ($(NVCC),)
8
+ # NVCC not found
9
+ USE_CUDA := 0
10
+ NVCC_VERSION := "not installed"
11
+ else
12
+ NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
13
+ USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
14
+ endif
15
+
16
+ # Add the list of supported ARCHs
17
+ ifeq ($(USE_CUDA), 1)
18
+ TORCH_CUDA_ARCH_LIST := "3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
19
+ BUILD_MESSAGE := "I will try to build the image with CUDA support"
20
+ else
21
+ TORCH_CUDA_ARCH_LIST :=
22
+ BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
23
+ endif
24
+
25
+
26
+ build-image:
27
+ @echo $(BUILD_MESSAGE)
28
+ docker build --build-arg USE_CUDA=$(USE_CUDA) \
29
+ --build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
30
+ -t gsa:v0 .
31
+ run:
32
+ ifeq (,$(wildcard ./sam_vit_h_4b8939.pth))
33
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
34
+ endif
35
+ ifeq (,$(wildcard ./groundingdino_swint_ogc.pth))
36
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
37
+ endif
38
+ docker run --gpus all -it --rm --net=host --privileged \
39
+ -v /tmp/.X11-unix:/tmp/.X11-unix \
40
+ -v "${PWD}":/home/appuser/Grounded-Segment-Anything \
41
+ -e DISPLAY=$DISPLAY \
42
+ --name=gsa \
43
+ --ipc=host -it gsa:v0
clone-IDEA-Research/Grounded-Segment-Anything/README.md ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![](./assets/Grounded-SAM_logo.png)
2
+
3
+ # Grounded-Segment-Anything
4
+ [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/oEQYStnF2l8) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/automated-dataset-annotation-and-evaluation-with-grounding-dino-and-sam.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/camenduru/grounded-segment-anything-colab) [![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/IDEA-Research/Grounded-SAM) [![Replicate](https://replicate.com/cjwbw/grounded-recognize-anything/badge)](https://replicate.com/cjwbw/grounded-recognize-anything) [![ModelScope Official Demo](https://img.shields.io/badge/ModelScope-Official%20Demo-important)](https://modelscope.cn/studios/tuofeilunhifi/Grounded-Segment-Anything/summary) [![Huggingface Demo by Community](https://img.shields.io/badge/Huggingface-Demo%20by%20Community-red)](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) [![Stable-Diffusion WebUI](https://img.shields.io/badge/Stable--Diffusion-WebUI%20by%20Community-critical)](https://github.com/continue-revolution/sd-webui-segment-anything) [![Jupyter Notebook Demo](https://img.shields.io/badge/Demo-Jupyter%20Notebook-informational)](./grounded_sam.ipynb) [![Static Badge](https://img.shields.io/badge/GroundingDINO-arXiv-blue)](https://arxiv.org/abs/2303.05499) [![Static Badge](https://img.shields.io/badge/Segment_Anything-arXiv-blue)](https://arxiv.org/abs/2304.02643) [![Static Badge](https://img.shields.io/badge/Grounded_SAM-arXiv-blue)](https://arxiv.org/abs/2401.14159)
5
+
6
+
7
+ We plan to create a very interesting demo by combining [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment Anything](https://github.com/facebookresearch/segment-anything) which aims to detect and segment anything with text inputs! And we will continue to improve it and create more interesting demos based on this foundation. And we have already released an overall technical report about our project on arXiv, please check [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159) for more details.
8
+
9
+ - 🔥 **[Grounded SAM 2](https://github.com/IDEA-Research/Grounded-SAM-2)** is released now, which combines Grounding DINO with [SAM 2](https://github.com/facebookresearch/segment-anything-2) for any object tracking in open-world scenarios.
10
+ - 🔥 **[Grounding DINO 1.5](https://github.com/IDEA-Research/Grounding-DINO-1.5-API)** is released now, which is IDEA Research's **Most Capable** Open-World Object Detection Model!
11
+ - 🔥 **[Grounding DINO](https://arxiv.org/abs/2303.05499)** and **[Grounded SAM](https://arxiv.org/abs/2401.14159)** are now supported in Huggingface. For more convenient use, you can refer to [this documentation](https://huggingface.co/docs/transformers/model_doc/grounding-dino)
12
+
13
+ We are very willing to **help everyone share and promote new projects** based on Segment-Anything, Please check out here for more amazing demos and works in the community: [Highlight Extension Projects](#highlighted-projects). You can submit a new issue (with `project` tag) or a new pull request to add new project's links.
14
+
15
+ ![](./assets/grounded_sam_new_demo_image.png)
16
+
17
+ ![](./assets/ram_grounded_sam_new.png)
18
+
19
+ **🍄 Why Building this Project?**
20
+
21
+ The **core idea** behind this project is to **combine the strengths of different models in order to build a very powerful pipeline for solving complex problems**. And it's worth mentioning that this is a workflow for combining strong expert models, where **all parts can be used separately or in combination, and can be replaced with any similar but different models (like replacing Grounding DINO with GLIP or other detectors / replacing Stable-Diffusion with ControlNet or GLIGEN/ Combining with ChatGPT)**.
22
+
23
+ **🍇 Updates**
24
+ - **`2024/01/26`** We have released a comprehensive technical report about our project on arXiv, please check [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159) for more details. And we are profoundly grateful for the contributions of all the contributors in this project.
25
+ - **`2023/12/17`** Support [Grounded-RepViT-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-repvit-sam-demo) demo, thanks a lot for their great work!
26
+ - **`2023/12/16`** Support [Grounded-Edge-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-edge-sam-demo) demo, thanks a lot for their great work!
27
+ - **`2023/12/10`** Support [Grounded-Efficient-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-efficient-sam-demo) demo, thanks a lot for their great work!
28
+ - **`2023/11/24`** Release [RAM++](https://arxiv.org/abs/2310.15200), which is the next generation of RAM. RAM++ can recognize any category with high accuracy, including both predefined common categories and diverse open-set categories.
29
+ - **`2023/11/23`** Release our newly proposed visual prompt counting model [T-Rex](https://github.com/IDEA-Research/T-Rex). The introduction [Video](https://www.youtube.com/watch?v=engIEhZogAQ) and [Demo](https://deepdataspace.com/playground/ivp) is available in [DDS](https://github.com/IDEA-Research/deepdataspace) now.
30
+ - **`2023/07/25`** Support [Light-HQ-SAM](https://github.com/SysCV/sam-hq) in [EfficientSAM](./EfficientSAM/), credits to [Mingqiao Ye](https://github.com/ymq2017) and [Lei Ke](https://github.com/lkeab), thanks a lot for their great work!
31
+ - **`2023/07/14`** Combining **Grounding-DINO-B** with [SAM-HQ](https://github.com/SysCV/sam-hq) achieves **49.6 mean AP** in [Segmentation in the Wild](https://eval.ai/web/challenges/challenge-page/1931/overview) competition zero-shot track, surpassing Grounded-SAM by **3.6 mean AP**, thanks for their great work!
32
+ - **`2023/06/28`** Combining Grounding-DINO with Efficient SAM variants including [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) and [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) in [EfficientSAM](./EfficientSAM/) for faster annotating, thanks a lot for their great work!
33
+ - **`2023/06/20`** By combining **Grounding-DINO-L** with **SAM-ViT-H**, Grounded-SAM achieves 46.0 mean AP in [Segmentation in the Wild](https://eval.ai/web/challenges/challenge-page/1931/overview) competition zero-shot track on [CVPR 2023 workshop](https://computer-vision-in-the-wild.github.io/cvpr-2023/), surpassing [UNINEXT (CVPR 2023)](https://github.com/MasterBin-IIAU/UNINEXT) by about **4 mean AP**.
34
+ - **`2023/06/16`** Release [RAM-Grounded-SAM Replicate Online Demo](https://replicate.com/cjwbw/ram-grounded-sam). Thanks a lot to [Chenxi](https://chenxwh.github.io/) for providing this nice demo 🌹.
35
+ - **`2023/06/14`** Support [RAM-Grounded-SAM & SAM-HQ](./automatic_label_ram_demo.py) and update [Simple Automatic Label Demo](./automatic_label_ram_demo.py) to support [RAM](https://github.com/OPPOMKLab/recognize-anything), setting up a strong automatic annotation pipeline.
36
+ - **`2023/06/13`** Checkout the [Autodistill: Train YOLOv8 with ZERO Annotations](https://youtu.be/gKTYMfwPo4M) tutorial to learn how to use Grounded-SAM + [Autodistill](https://github.com/autodistill/autodistill) for automated data labeling and real-time model training.
37
+ - **`2023/06/13`** Support [SAM-HQ](https://github.com/SysCV/sam-hq) in [Grounded-SAM Demo](#running_man-grounded-sam-detect-and-segment-everything-with-text-prompt) for higher quality prediction.
38
+ - **`2023/06/12`** Support [RAM-Grounded-SAM](#label-grounded-sam-with-ram-or-tag2text-for-automatic-labeling) for strong automatic labeling pipeline! Thanks for [Recognize-Anything](https://github.com/OPPOMKLab/recognize-anything).
39
+ - **`2023/06/01`** Our Grounded-SAM has been accepted to present a **demo** at [ICCV 2023](https://iccv2023.thecvf.com/)! See you in Paris!
40
+ - **`2023/05/23`**: Support `Image-Referring-Segment`, `Audio-Referring-Segment` and `Text-Referring-Segment` in [ImageBind-SAM](./playground/ImageBind_SAM/).
41
+ - **`2023/05/03`**: Checkout the [Automated Dataset Annotation and Evaluation with GroundingDINO and SAM](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/automated-dataset-annotation-and-evaluation-with-grounding-dino-and-sam.ipynb) which is an amazing tutorial on automatic labeling! Thanks a lot for [Piotr Skalski](https://github.com/SkalskiP) and [Roboflow](https://github.com/roboflow/notebooks)!
42
+
43
+
44
+ ## Table of Contents
45
+ - [Grounded-Segment-Anything](#grounded-segment-anything)
46
+ - [Preliminary Works](#preliminary-works)
47
+ - [Highlighted Projects](#highlighted-projects)
48
+ - [Installation](#installation)
49
+ - [Install with Docker](#install-with-docker)
50
+ - [Install locally](#install-without-docker)
51
+ - [Grounded-SAM Playground](#grounded-sam-playground)
52
+ - [Step-by-Step Notebook Demo](#open_book-step-by-step-notebook-demo)
53
+ - [GroundingDINO: Detect Everything with Text Prompt](#running_man-groundingdino-detect-everything-with-text-prompt)
54
+ - [Grounded-SAM: Detect and Segment Everything with Text Prompt](#running_man-grounded-sam-detect-and-segment-everything-with-text-prompt)
55
+ - [Grounded-SAM with Inpainting: Detect, Segment and Generate Everything with Text Prompt](#skier-grounded-sam-with-inpainting-detect-segment-and-generate-everything-with-text-prompt)
56
+ - [Grounded-SAM and Inpaint Gradio APP](#golfing-grounded-sam-and-inpaint-gradio-app)
57
+ - [Grounded-SAM with RAM or Tag2Text for Automatic Labeling](#label-grounded-sam-with-ram-or-tag2text-for-automatic-labeling)
58
+ - [Grounded-SAM with BLIP & ChatGPT for Automatic Labeling](#robot-grounded-sam-with-blip-for-automatic-labeling)
59
+ - [Grounded-SAM with Whisper: Detect and Segment Anything with Audio](#open_mouth-grounded-sam-with-whisper-detect-and-segment-anything-with-audio)
60
+ - [Grounded-SAM ChatBot with Visual ChatGPT](#speech_balloon-grounded-sam-chatbot-demo)
61
+ - [Grounded-SAM with OSX for 3D Whole-Body Mesh Recovery](#man_dancing-run-grounded-segment-anything--osx-demo)
62
+ - [Grounded-SAM with VISAM for Tracking and Segment Anything](#man_dancing-run-grounded-segment-anything--visam-demo)
63
+ - [Interactive Fashion-Edit Playground: Click for Segmentation And Editing](#dancers-interactive-editing)
64
+ - [Interactive Human-face Editing Playground: Click And Editing Human Face](#dancers-interactive-editing)
65
+ - [3D Box Via Segment Anything](#camera-3d-box-via-segment-anything)
66
+ - [Playground: More Interesting and Imaginative Demos with Grounded-SAM](./playground/)
67
+ - [DeepFloyd: Image Generation with Text Prompt](./playground/DeepFloyd/)
68
+ - [PaintByExample: Exemplar-based Image Editing with Diffusion Models](./playground/PaintByExample/)
69
+ - [LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions](./playground/LaMa/)
70
+ - [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](./playground/RePaint/)
71
+ - [ImageBind with SAM: Segment with Different Modalities](./playground/ImageBind_SAM/)
72
+ - [Efficient SAM Series for Faster Annotation](./EfficientSAM/)
73
+ - [Grounded-FastSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-fastsam-demo)
74
+ - [Grounded-MobileSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-mobilesam-demo)
75
+ - [Grounded-Light-HQSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-light-hqsam-demo)
76
+ - [Grounded-Efficient-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-efficient-sam-demo)
77
+ - [Grounded-Edge-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-edge-sam-demo)
78
+ - [Grounded-RepViT-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-repvit-sam-demo)
79
+ - [Citation](#citation)
80
+
81
+ ## Preliminary Works
82
+
83
+ Here we provide some background knowledge that you may need to know before trying the demos.
84
+
85
+ <div align="center">
86
+
87
+ | Title | Intro | Description | Links |
88
+ |:----:|:----:|:----:|:----:|
89
+ | [Segment-Anything](https://arxiv.org/abs/2304.02643) | ![](https://github.com/facebookresearch/segment-anything/blob/main/assets/model_diagram.png?raw=true) | A strong foundation model aims to segment everything in an image, which needs prompts (as boxes/points/text) to generate masks | [[Github](https://github.com/facebookresearch/segment-anything)] <br> [[Page](https://segment-anything.com/)] <br> [[Demo](https://segment-anything.com/demo)] |
90
+ | [Grounding DINO](https://arxiv.org/abs/2303.05499) | ![](https://github.com/IDEA-Research/GroundingDINO/blob/main/.asset/hero_figure.png?raw=True) | A strong zero-shot detector which is capable of to generate high quality boxes and labels with free-form text. | [[Github](https://github.com/IDEA-Research/GroundingDINO)] <br> [[Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] |
91
+ | [OSX](http://arxiv.org/abs/2303.16160) | ![](https://github.com/IDEA-Research/OSX/blob/main/assets/demo_video.gif?raw=True) | A strong and efficient one-stage motion capture method to generate high quality 3D human mesh from monucular image. OSX also releases a large-scale upper-body dataset UBody for a more accurate reconstrution in the upper-body scene. | [[Github](https://github.com/IDEA-Research/OSX)] <br> [[Page](https://osx-ubody.github.io/)] <br> [[Video](https://osx-ubody.github.io/)] <br> [[Data](https://docs.google.com/forms/d/e/1FAIpQLSehgBP7wdn_XznGAM2AiJPiPLTqXXHw5uX9l7qeQ1Dh9HoO_A/viewform)] |
92
+ | [Stable-Diffusion](https://arxiv.org/abs/2112.10752) | ![](https://github.com/CompVis/stable-diffusion/blob/main/assets/stable-samples/txt2img/merged-0006.png?raw=True) | A super powerful open-source latent text-to-image diffusion model | [[Github](https://github.com/CompVis/stable-diffusion)] <br> [[Page](https://ommer-lab.com/research/latent-diffusion-models/)] |
93
+ | [RAM++](https://arxiv.org/abs/2310.15200) | ![](https://github.com/xinyu1205/recognize-anything/blob/main/images/ram_plus_compare.jpg) | RAM++ is the next generation of RAM, which can recognize any category with high accuracy. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] |
94
+ | [RAM](https://recognize-anything.github.io/) | ![](https://github.com/xinyu1205/Tag2Text/raw/main/images/localization_and_recognition.jpg) | RAM is an image tagging model, which can recognize any common category with high accuracy. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] <br> [[Demo](https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text)] |
95
+ | [BLIP](https://arxiv.org/abs/2201.12086) | ![](https://github.com/salesforce/LAVIS/raw/main/docs/_static/logo_final.png) | A wonderful language-vision model for image understanding. | [[GitHub](https://github.com/salesforce/LAVIS)] |
96
+ | [Visual ChatGPT](https://arxiv.org/abs/2303.04671) | ![](https://github.com/microsoft/TaskMatrix/raw/main/assets/figure.jpg) | A wonderful tool that connects ChatGPT and a series of Visual Foundation Models to enable sending and receiving images during chatting. | [[Github](https://github.com/microsoft/TaskMatrix)] <br> [[Demo](https://huggingface.co/spaces/microsoft/visual_chatgpt)] |
97
+ | [Tag2Text](https://tag2text.github.io/) | ![](https://github.com/xinyu1205/Tag2Text/raw/main/images/tag2text_framework.png) | An efficient and controllable vision-language model which can simultaneously output superior image captioning and image tagging. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] <br> [[Demo](https://huggingface.co/spaces/xinyu1205/Tag2Text)] |
98
+ | [VoxelNeXt](https://arxiv.org/abs/2303.11301) | ![](https://github.com/dvlab-research/VoxelNeXt/raw/master/docs/sequence-v2.gif) | A clean, simple, and fully-sparse 3D object detector, which predicts objects directly upon sparse voxel features. | [[Github](https://github.com/dvlab-research/VoxelNeXt)]
99
+
100
+ </div>
101
+
102
+ ## Highlighted Projects
103
+
104
+ Here we provide some impressive works you may find interesting:
105
+
106
+ <div align="center">
107
+
108
+ | Title | Description | Links |
109
+ |:---:|:---:|:---:|
110
+ | [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM) | A universal image segmentation model to enable segment and recognize anything at any desired granularity | [[Github](https://github.com/UX-Decoder/Semantic-SAM)] <br> [[Demo](https://github.com/UX-Decoder/Semantic-SAM)] |
111
+ | [SEEM: Segment Everything Everywhere All at Once](https://arxiv.org/pdf/2304.06718.pdf) | A powerful promptable segmentation model supports segmenting with various types of prompts (text, point, scribble, referring image, etc.) and any combination of prompts. | [[Github](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)] <br> [[Demo](https://huggingface.co/spaces/xdecoder/SEEM)] |
112
+ | [OpenSeeD](https://arxiv.org/pdf/2303.08131.pdf) | A simple framework for open-vocabulary segmentation and detection which supports interactive segmentation with box input to generate mask | [[Github](https://github.com/IDEA-Research/OpenSeeD)] |
113
+ | [LLaVA](https://arxiv.org/abs/2304.08485) | Visual instruction tuning with GPT-4 | [[Github](https://github.com/haotian-liu/LLaVA)] <br> [[Page](https://llava-vl.github.io/)] <br> [[Demo](https://llava.hliu.cc/)] <br> [[Data](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)] <br> [[Model](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)] |
114
+ | [GenSAM](https://arxiv.org/abs/2312.07374) | Relaxing the instance-specific manual prompt requirement in SAM through training-free test-time adaptation | [[Github](https://github.com/jyLin8100/GenSAM)] <br> [[Page](https://lwpyh.github.io/GenSAM/)] |
115
+
116
+ </div>
117
+
118
+ We also list some awesome segment-anything extension projects here you may find interesting:
119
+ - [Computer Vision in the Wild (CVinW) Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set tasks in computer vision.
120
+ - [Zero-Shot Anomaly Detection](https://github.com/caoyunkang/GroundedSAM-zero-shot-anomaly-detection) by Yunkang Cao
121
+ - [EditAnything: ControlNet + StableDiffusion based on the SAM segmentation mask](https://github.com/sail-sg/EditAnything) by Shanghua Gao and Pan Zhou
122
+ - [IEA: Image Editing Anything](https://github.com/feizc/IEA) by Zhengcong Fei
123
+ - [SAM-MMRorate: Combining Rotated Object Detector and SAM](https://github.com/Li-Qingyun/sam-mmrotate) by Qingyun Li and Xue Yang
124
+ - [Awesome-Anything](https://github.com/VainF/Awesome-Anything) by Gongfan Fang
125
+ - [Prompt-Segment-Anything](https://github.com/RockeyCoss/Prompt-Segment-Anything) by Rockey
126
+ - [WebUI for Segment-Anything and Grounded-SAM](https://github.com/continue-revolution/sd-webui-segment-anything) by Chengsong Zhang
127
+ - [Inpainting Anything: Inpaint Anything with SAM + Inpainting models](https://github.com/geekyutao/Inpaint-Anything) by Tao Yu
128
+ - [Grounded Segment Anything From Objects to Parts: Combining Segment-Anything with VLPart & GLIP & Visual ChatGPT](https://github.com/Cheems-Seminar/segment-anything-and-name-it) by Peize Sun and Shoufa Chen
129
+ - [Narapi-SAM: Integration of Segment Anything into Narapi (A nice viewer for SAM)](https://github.com/MIC-DKFZ/napari-sam) by MIC-DKFZ
130
+ - [Grounded Segment Anything Colab](https://github.com/camenduru/grounded-segment-anything-colab) by camenduru
131
+ - [Optical Character Recognition with Segment Anything](https://github.com/yeungchenwa/OCR-SAM) by Zhenhua Yang
132
+ - [Transform Image into Unique Paragraph with ChatGPT, BLIP2, OFA, GRIT, Segment Anything, ControlNet](https://github.com/showlab/Image2Paragraph) by showlab
133
+ - [Lang-Segment-Anything: Another awesome demo for combining GroundingDINO with Segment-Anything](https://github.com/luca-medeiros/lang-segment-anything) by Luca Medeiros
134
+ - [🥳 🚀 **Playground: Integrate SAM and OpenMMLab!**](https://github.com/open-mmlab/playground)
135
+ - [3D-object via Segment Anything](https://github.com/dvlab-research/3D-Box-Segment-Anything) by Yukang Chen
136
+ - [Image2Paragraph: Transform Image Into Unique Paragraph](https://github.com/showlab/Image2Paragraph) by Show Lab
137
+ - [Zero-shot Scene Graph Generate with Grounded-SAM](https://github.com/showlab/Image2Paragraph) by JackWhite-rwx
138
+ - [CLIP Surgery for Better Explainability with Enhancement in Open-Vocabulary Tasks](https://github.com/xmed-lab/CLIP_Surgery) by Eli-YiLi
139
+ - [Panoptic-Segment-Anything: Zero-shot panoptic segmentation using SAM](https://github.com/segments-ai/panoptic-segment-anything) by segments-ai
140
+ - [Caption-Anything: Generates Descriptive Captions for Any Object within an Image](https://github.com/ttengwang/Caption-Anything) by Teng Wang
141
+ - [Segment-Anything-3D: Transferring Segmentation Information of 2D Images to 3D Space](https://github.com/Pointcept/SegmentAnything3D) by Yunhan Yang
142
+ - [Expediting SAM without Fine-tuning](https://github.com/Expedit-LargeScale-Vision-Transformer/Expedit-SAM) by Weicong Liang and Yuhui Yuan
143
+ - [Semantic Segment Anything: Providing Rich Semantic Category Annotations for SAM](https://github.com/fudan-zvg/Semantic-Segment-Anything) by Jiaqi Chen and Zeyu Yang and Li Zhang
144
+ - [Enhance Everything: Combining SAM with Image Restoration and Enhancement Tasks](https://github.com/lixinustc/Enhance-Anything) by Xin Li
145
+ - [DragGAN](https://github.com/Zeqiang-Lai/DragGAN) by Shanghai AI Lab.
146
+ - [Tabletop HandyBot: Robotic arm assistant that performs tabletop tasks using Grounded-SAM](https://github.com/ycheng517/tabletop-handybot) by Yifei Cheng
147
+
148
+ ## Installation
149
+ The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
150
+
151
+ ### Install with Docker
152
+
153
+ Open one terminal:
154
+
155
+ ```
156
+ make build-image
157
+ ```
158
+
159
+ ```
160
+ make run
161
+ ```
162
+
163
+ That's it.
164
+
165
+ If you would like to allow visualization across docker container, open another terminal and type:
166
+
167
+ ```
168
+ xhost +
169
+ ```
170
+
171
+
172
+ ### Install without Docker
173
+ You should set the environment variable manually as follows if you want to build a local GPU environment for Grounded-SAM:
174
+ ```bash
175
+ export AM_I_DOCKER=False
176
+ export BUILD_WITH_CUDA=True
177
+ export CUDA_HOME=/path/to/cuda-11.3/
178
+ ```
179
+
180
+ Install Segment Anything:
181
+
182
+ ```bash
183
+ python -m pip install -e segment_anything
184
+ ```
185
+
186
+ Install Grounding DINO:
187
+
188
+ ```bash
189
+ pip install --no-build-isolation -e GroundingDINO
190
+ ```
191
+
192
+
193
+ Install diffusers:
194
+
195
+ ```bash
196
+ pip install --upgrade diffusers[torch]
197
+ ```
198
+
199
+ Install osx:
200
+
201
+ ```bash
202
+ git submodule update --init --recursive
203
+ cd grounded-sam-osx && bash install.sh
204
+ ```
205
+
206
+ Install RAM & Tag2Text:
207
+
208
+ ```bash
209
+ git clone https://github.com/xinyu1205/recognize-anything.git
210
+ pip install -r ./recognize-anything/requirements.txt
211
+ pip install -e ./recognize-anything/
212
+ ```
213
+
214
+ The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks.
215
+
216
+ ```
217
+ pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel
218
+ ```
219
+
220
+ More details can be found in [install segment anything](https://github.com/facebookresearch/segment-anything#installation) and [install GroundingDINO](https://github.com/IDEA-Research/GroundingDINO#install) and [install OSX](https://github.com/IDEA-Research/OSX)
221
+
222
+
223
+ ## Grounded-SAM Playground
224
+ Let's start exploring our Grounding-SAM Playground and we will release more interesting demos in the future, stay tuned!
225
+
226
+ ## :open_book: Step-by-Step Notebook Demo
227
+ Here we list some notebook demo provided in this project:
228
+ - [grounded_sam.ipynb](grounded_sam.ipynb)
229
+ - [grounded_sam_colab_demo.ipynb](grounded_sam_colab_demo.ipynb)
230
+ - [grounded_sam_3d_box.ipynb](grounded_sam_3d_box)
231
+
232
+
233
+ ### :running_man: GroundingDINO: Detect Everything with Text Prompt
234
+
235
+ :grapes: [[arXiv Paper](https://arxiv.org/abs/2303.05499)] &nbsp; :rose:[[Try the Colab Demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)] &nbsp; :sunflower: [[Try Huggingface Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] &nbsp; :mushroom: [[Automated Dataset Annotation and Evaluation](https://youtu.be/C4NqaRBz_Kw)]
236
+
237
+ Here's the step-by-step tutorial on running `GroundingDINO` demo:
238
+
239
+ **Step 1: Download the pretrained weights**
240
+
241
+ ```bash
242
+ cd Grounded-Segment-Anything
243
+
244
+ # download the pretrained groundingdino-swin-tiny model
245
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
246
+ ```
247
+
248
+ **Step 2: Running the demo**
249
+
250
+ ```bash
251
+ python grounding_dino_demo.py
252
+ ```
253
+
254
+ <details>
255
+ <summary> <b> Running with Python (same as demo but you can run it anywhere after installing GroundingDINO) </b> </summary>
256
+
257
+ ```python
258
+ from groundingdino.util.inference import load_model, load_image, predict, annotate
259
+ import cv2
260
+
261
+ model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./groundingdino_swint_ogc.pth")
262
+ IMAGE_PATH = "assets/demo1.jpg"
263
+ TEXT_PROMPT = "bear."
264
+ BOX_THRESHOLD = 0.35
265
+ TEXT_THRESHOLD = 0.25
266
+
267
+ image_source, image = load_image(IMAGE_PATH)
268
+
269
+ boxes, logits, phrases = predict(
270
+ model=model,
271
+ image=image,
272
+ caption=TEXT_PROMPT,
273
+ box_threshold=BOX_THRESHOLD,
274
+ text_threshold=TEXT_THRESHOLD
275
+ )
276
+
277
+ annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
278
+ cv2.imwrite("annotated_image.jpg", annotated_frame)
279
+ ```
280
+
281
+ </details>
282
+ <br>
283
+
284
+ **Tips**
285
+ - If you want to detect multiple objects in one sentence with [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO), we suggest separating each name with `.` . An example: `cat . dog . chair .`
286
+
287
+ **Step 3: Check the annotated image**
288
+
289
+ The annotated image will be saved as `./annotated_image.jpg`.
290
+
291
+ <div align="center">
292
+
293
+ | Text Prompt | Demo Image | Annotated Image |
294
+ |:----:|:----:|:----:|
295
+ | `Bear.` | ![](./assets/demo1.jpg) | ![](./assets/annotated_image.jpg) |
296
+ | `Horse. Clouds. Grasses. Sky. Hill` | ![](./assets/demo7.jpg) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounding_dino/groundingdino_demo7.jpg?raw=true)
297
+
298
+ </div>
299
+
300
+
301
+ ### :running_man: Grounded-SAM: Detect and Segment Everything with Text Prompt
302
+
303
+ Here's the step-by-step tutorial on running `Grounded-SAM` demo:
304
+
305
+ **Step 1: Download the pretrained weights**
306
+
307
+ ```bash
308
+ cd Grounded-Segment-Anything
309
+
310
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
311
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
312
+ ```
313
+
314
+ We provide two versions of Grounded-SAM demo here:
315
+ - [grounded_sam_demo.py](./grounded_sam_demo.py): our original implementation for Grounded-SAM.
316
+ - [grounded_sam_simple_demo.py](./grounded_sam_simple_demo.py) our updated more elegant version for Grounded-SAM.
317
+
318
+ **Step 2: Running original grounded-sam demo**
319
+ ```bash
320
+ # depends on your device
321
+ export CUDA_VISIBLE_DEVICES=0
322
+ ```
323
+
324
+ ```python
325
+
326
+ python grounded_sam_demo.py \
327
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
328
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
329
+ --sam_checkpoint sam_vit_h_4b8939.pth \
330
+ --input_image assets/demo1.jpg \
331
+ --output_dir "outputs" \
332
+ --box_threshold 0.3 \
333
+ --text_threshold 0.25 \
334
+ --text_prompt "bear" \
335
+ --device "cuda"
336
+ ```
337
+
338
+ The annotated results will be saved in `./outputs` as follows
339
+
340
+ <div align="center">
341
+
342
+ | Input Image | Annotated Image | Generated Mask |
343
+ |:----:|:----:|:----:|
344
+ | ![](./assets/demo1.jpg) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam/original_grounded_sam_demo1.jpg?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam/mask.jpg?raw=true) |
345
+
346
+ </div>
347
+
348
+ **Step 3: Running grounded-sam demo with sam-hq**
349
+ - Download the demo image
350
+ ```bash
351
+ wget https://github.com/IDEA-Research/detrex-storage/releases/download/grounded-sam-storage/sam_hq_demo_image.png
352
+ ```
353
+
354
+ - Download SAM-HQ checkpoint [here](https://github.com/SysCV/sam-hq#model-checkpoints)
355
+
356
+ - Running grounded-sam-hq demo as follows:
357
+ ```python
358
+ export CUDA_VISIBLE_DEVICES=0
359
+ python grounded_sam_demo.py \
360
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
361
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
362
+ --sam_hq_checkpoint ./sam_hq_vit_h.pth \ # path to sam-hq checkpoint
363
+ --use_sam_hq \ # set to use sam-hq model
364
+ --input_image sam_hq_demo_image.png \
365
+ --output_dir "outputs" \
366
+ --box_threshold 0.3 \
367
+ --text_threshold 0.25 \
368
+ --text_prompt "chair." \
369
+ --device "cuda"
370
+ ```
371
+
372
+ The annotated results will be saved in `./outputs` as follows
373
+
374
+ <div align="center">
375
+
376
+ | Input Image | SAM Output | SAM-HQ Output |
377
+ |:----:|:----:|:----:|
378
+ | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_hq_demo.png?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_output.jpg?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_hq_output.jpg?raw=true) |
379
+
380
+ </div>
381
+
382
+ **Step 4: Running the updated grounded-sam demo (optional)**
383
+
384
+ Note that this demo is almost same as the original demo, but **with more elegant code**.
385
+
386
+ ```python
387
+ python grounded_sam_simple_demo.py
388
+ ```
389
+
390
+ The annotated results will be saved as `./groundingdino_annotated_image.jpg` and `./grounded_sam_annotated_image.jpg`
391
+
392
+ <div align="center">
393
+
394
+ | Text Prompt | Input Image | GroundingDINO Annotated Image | Grounded-SAM Annotated Image |
395
+ |:----:|:----:|:----:|:----:|
396
+ | `The running dog` | ![](./assets/demo2.jpg) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam/groundingdino_annotated_image_demo2.jpg?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam/grounded_sam_annotated_image_demo2.jpg?raw=true) |
397
+ | `Horse. Clouds. Grasses. Sky. Hill` | ![](./assets/demo7.jpg) | ![](assets/groundingdino_annotated_image.jpg) | ![](assets/grounded_sam_annotated_image.jpg) |
398
+
399
+ </div>
400
+
401
+ **Step 5: Running the Sam model with multi-gpu**
402
+ ```bash
403
+ export CUDA_VISIBLE_DEVICES=0,1
404
+ ```
405
+ ```python
406
+
407
+ python grounded_sam_multi_gpu_demo.py \
408
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
409
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
410
+ --sam_checkpoint sam_vit_h_4b8939.pth \
411
+ --input_path assets/car \
412
+ --output_dir "outputs" \
413
+ --box_threshold 0.3 \
414
+ --text_threshold 0.25 \
415
+ --text_prompt "car" \
416
+ --device "cuda"
417
+ ```
418
+ You will see that the model is loaded once per GPU ![](assets/multi-gpu.png)
419
+
420
+ ### :skier: Grounded-SAM with Inpainting: Detect, Segment and Generate Everything with Text Prompt
421
+
422
+ **Step 1: Download the pretrained weights**
423
+
424
+ ```bash
425
+ cd Grounded-Segment-Anything
426
+
427
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
428
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
429
+ ```
430
+
431
+ **Step 2: Running grounded-sam inpainting demo**
432
+
433
+ ```bash
434
+ CUDA_VISIBLE_DEVICES=0
435
+ python grounded_sam_inpainting_demo.py \
436
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
437
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
438
+ --sam_checkpoint sam_vit_h_4b8939.pth \
439
+ --input_image assets/inpaint_demo.jpg \
440
+ --output_dir "outputs" \
441
+ --box_threshold 0.3 \
442
+ --text_threshold 0.25 \
443
+ --det_prompt "bench" \
444
+ --inpaint_prompt "A sofa, high quality, detailed" \
445
+ --device "cuda"
446
+ ```
447
+
448
+ The annotated and inpaint image will be saved in `./outputs`
449
+
450
+ **Step 3: Check the results**
451
+
452
+
453
+ <div align="center">
454
+
455
+ | Input Image | Det Prompt | Annotated Image | Inpaint Prompt | Inpaint Image |
456
+ |:---:|:---:|:---:|:---:|:---:|
457
+ |![](./assets/inpaint_demo.jpg) | `Bench` | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam_inpaint/grounded_sam_output.jpg?raw=true) | `A sofa, high quality, detailed` | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/grounded_sam_inpaint/grounded_sam_inpainting_output.jpg?raw=true) |
458
+
459
+ </div>
460
+
461
+ ### :golfing: Grounded-SAM and Inpaint Gradio APP
462
+
463
+ We support 6 tasks in the local Gradio APP:
464
+
465
+ 1. **scribble**: Segmentation is achieved through Segment Anything and mouse click interaction (you need to click on the object with the mouse, no need to specify the prompt).
466
+ 2. **automask**: Segment the entire image at once through Segment Anything (no need to specify a prompt).
467
+ 3. **det**: Realize detection through Grounding DINO and text interaction (text prompt needs to be specified).
468
+ 4. **seg**: Realize text interaction by combining Grounding DINO and Segment Anything to realize detection + segmentation (need to specify text prompt).
469
+ 5. **inpainting**: By combining Grounding DINO + Segment Anything + Stable Diffusion to achieve text exchange and replace the target object (need to specify text prompt and inpaint prompt) .
470
+ 6. **automatic**: By combining BLIP + Grounding DINO + Segment Anything to achieve non-interactive detection + segmentation (no need to specify prompt).
471
+
472
+ ```bash
473
+ python gradio_app.py
474
+ ```
475
+
476
+ - The gradio_app visualization as follows:
477
+
478
+ ![](./assets/gradio_demo.png)
479
+
480
+
481
+ ### :label: Grounded-SAM with RAM or Tag2Text for Automatic Labeling
482
+ [**The Recognize Anything Models**](https://github.com/OPPOMKLab/recognize-anything) are a series of open-source and strong fundamental image recognition models, including [RAM++](https://arxiv.org/abs/2310.15200), [RAM](https://arxiv.org/abs/2306.03514) and [Tag2text](https://arxiv.org/abs/2303.05657).
483
+
484
+
485
+ It is seamlessly linked to generate pseudo labels automatically as follows:
486
+ 1. Use RAM/Tag2Text to generate tags.
487
+ 2. Use Grounded-Segment-Anything to generate the boxes and masks.
488
+
489
+
490
+ **Step 1: Init submodule and download the pretrained checkpoint**
491
+
492
+ - Init submodule:
493
+
494
+ ```bash
495
+ cd Grounded-Segment-Anything
496
+ git submodule init
497
+ git submodule update
498
+ ```
499
+
500
+ - Download pretrained weights for `GroundingDINO`, `SAM` and `RAM/Tag2Text`:
501
+
502
+ ```bash
503
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
504
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
505
+
506
+
507
+ wget https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/ram_swin_large_14m.pth
508
+ wget https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/tag2text_swin_14m.pth
509
+ ```
510
+
511
+ **Step 2: Running the demo with RAM**
512
+ ```bash
513
+ export CUDA_VISIBLE_DEVICES=0
514
+ python automatic_label_ram_demo.py \
515
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
516
+ --ram_checkpoint ram_swin_large_14m.pth \
517
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
518
+ --sam_checkpoint sam_vit_h_4b8939.pth \
519
+ --input_image assets/demo9.jpg \
520
+ --output_dir "outputs" \
521
+ --box_threshold 0.25 \
522
+ --text_threshold 0.2 \
523
+ --iou_threshold 0.5 \
524
+ --device "cuda"
525
+ ```
526
+
527
+
528
+ **Step 2: Or Running the demo with Tag2Text**
529
+ ```bash
530
+ export CUDA_VISIBLE_DEVICES=0
531
+ python automatic_label_tag2text_demo.py \
532
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
533
+ --tag2text_checkpoint tag2text_swin_14m.pth \
534
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
535
+ --sam_checkpoint sam_vit_h_4b8939.pth \
536
+ --input_image assets/demo9.jpg \
537
+ --output_dir "outputs" \
538
+ --box_threshold 0.25 \
539
+ --text_threshold 0.2 \
540
+ --iou_threshold 0.5 \
541
+ --device "cuda"
542
+ ```
543
+
544
+ - RAM++ significantly improves the open-set capability of RAM, for [RAM++ inference on unseen categoreis](https://github.com/xinyu1205/recognize-anything#ram-inference-on-unseen-categories-open-set).
545
+ - Tag2Text also provides powerful captioning capabilities, and the process with captions can refer to [BLIP](#robot-run-grounded-segment-anything--blip-demo).
546
+ - The pseudo labels and model prediction visualization will be saved in `output_dir` as follows (right figure):
547
+
548
+ ![](./assets/automatic_label_output/demo9_tag2text_ram.jpg)
549
+
550
+
551
+ ### :robot: Grounded-SAM with BLIP for Automatic Labeling
552
+ It is easy to generate pseudo labels automatically as follows:
553
+ 1. Use BLIP (or other caption models) to generate a caption.
554
+ 2. Extract tags from the caption. We use ChatGPT to handle the potential complicated sentences.
555
+ 3. Use Grounded-Segment-Anything to generate the boxes and masks.
556
+
557
+ - Run Demo
558
+ ```bash
559
+ export OPENAI_API_KEY=your_openai_key
560
+ export OPENAI_API_BASE=https://closeai.deno.dev/v1
561
+ export CUDA_VISIBLE_DEVICES=0
562
+ python automatic_label_demo.py \
563
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
564
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
565
+ --sam_checkpoint sam_vit_h_4b8939.pth \
566
+ --input_image assets/demo3.jpg \
567
+ --output_dir "outputs" \
568
+ --openai_key $OPENAI_API_KEY \
569
+ --box_threshold 0.25 \
570
+ --text_threshold 0.2 \
571
+ --iou_threshold 0.5 \
572
+ --device "cuda"
573
+ ```
574
+
575
+ - When you don't have a paid Account for ChatGPT is also possible to use NLTK instead. Just don't include the ```openai_key``` Parameter when starting the Demo.
576
+ - The Script will automatically download the necessary NLTK Data.
577
+ - The pseudo labels and model prediction visualization will be saved in `output_dir` as follows:
578
+
579
+ ![](./assets/automatic_label_output_demo3.jpg)
580
+
581
+
582
+ ### :open_mouth: Grounded-SAM with Whisper: Detect and Segment Anything with Audio
583
+ Detect and segment anything with speech!
584
+
585
+ ![](assets/acoustics/gsam_whisper_inpainting_demo.png)
586
+
587
+ **Install Whisper**
588
+ ```bash
589
+ pip install -U openai-whisper
590
+ ```
591
+ See the [whisper official page](https://github.com/openai/whisper#setup) if you have other questions for the installation.
592
+
593
+ **Run Voice-to-Label Demo**
594
+
595
+ Optional: Download the demo audio file
596
+
597
+ ```bash
598
+ wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/demo_audio.mp3
599
+ ```
600
+
601
+
602
+ ```bash
603
+ export CUDA_VISIBLE_DEVICES=0
604
+ python grounded_sam_whisper_demo.py \
605
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
606
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
607
+ --sam_checkpoint sam_vit_h_4b8939.pth \
608
+ --input_image assets/demo4.jpg \
609
+ --output_dir "outputs" \
610
+ --box_threshold 0.3 \
611
+ --text_threshold 0.25 \
612
+ --speech_file "demo_audio.mp3" \
613
+ --device "cuda"
614
+ ```
615
+
616
+ ![](./assets/grounded_sam_whisper_output.jpg)
617
+
618
+ **Run Voice-to-inpaint Demo**
619
+
620
+ You can enable chatgpt to help you automatically detect the object and inpainting order with `--enable_chatgpt`.
621
+
622
+ Or you can specify the object you want to inpaint [stored in `args.det_speech_file`] and the text you want to inpaint with [stored in `args.inpaint_speech_file`].
623
+
624
+ ```bash
625
+ export OPENAI_API_KEY=your_openai_key
626
+ export OPENAI_API_BASE=https://closeai.deno.dev/v1
627
+ # Example: enable chatgpt
628
+ export CUDA_VISIBLE_DEVICES=0
629
+ python grounded_sam_whisper_inpainting_demo.py \
630
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
631
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
632
+ --sam_checkpoint sam_vit_h_4b8939.pth \
633
+ --input_image assets/inpaint_demo.jpg \
634
+ --output_dir "outputs" \
635
+ --box_threshold 0.3 \
636
+ --text_threshold 0.25 \
637
+ --prompt_speech_file assets/acoustics/prompt_speech_file.mp3 \
638
+ --enable_chatgpt \
639
+ --openai_key $OPENAI_API_KEY\
640
+ --device "cuda"
641
+ ```
642
+
643
+ ```bash
644
+ # Example: without chatgpt
645
+ export CUDA_VISIBLE_DEVICES=0
646
+ python grounded_sam_whisper_inpainting_demo.py \
647
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
648
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
649
+ --sam_checkpoint sam_vit_h_4b8939.pth \
650
+ --input_image assets/inpaint_demo.jpg \
651
+ --output_dir "outputs" \
652
+ --box_threshold 0.3 \
653
+ --text_threshold 0.25 \
654
+ --det_speech_file "assets/acoustics/det_voice.mp3" \
655
+ --inpaint_speech_file "assets/acoustics/inpaint_voice.mp3" \
656
+ --device "cuda"
657
+ ```
658
+
659
+ ![](./assets/acoustics/gsam_whisper_inpainting_pipeline.png)
660
+
661
+ ### :speech_balloon: Grounded-SAM ChatBot Demo
662
+
663
+ https://user-images.githubusercontent.com/24236723/231955561-2ae4ec1a-c75f-4cc5-9b7b-517aa1432123.mp4
664
+
665
+ Following [Visual ChatGPT](https://github.com/microsoft/visual-chatgpt), we add a ChatBot for our project. Currently, it supports:
666
+ 1. "Describe the image."
667
+ 2. "Detect the dog (and the cat) in the image."
668
+ 3. "Segment anything in the image."
669
+ 4. "Segment the dog (and the cat) in the image."
670
+ 5. "Help me label the image."
671
+ 6. "Replace the dog with a cat in the image."
672
+
673
+ To use the ChatBot:
674
+ - Install whisper if you want to use audio as input.
675
+ - Set the default model setting in the tool `Grounded_dino_sam_inpainting`.
676
+ - Run Demo
677
+ ```bash
678
+ export OPENAI_API_KEY=your_openai_key
679
+ export OPENAI_API_BASE=https://closeai.deno.dev/v1
680
+ export CUDA_VISIBLE_DEVICES=0
681
+ python chatbot.py
682
+ ```
683
+
684
+ ### :man_dancing: Run Grounded-Segment-Anything + OSX Demo
685
+
686
+ <p align="middle">
687
+ <img src="assets/osx/grouned_sam_osx_demo.gif">
688
+ <br>
689
+ </p>
690
+
691
+
692
+ - Download the checkpoint `osx_l_wo_decoder.pth.tar` from [here](https://drive.google.com/drive/folders/1x7MZbB6eAlrq5PKC9MaeIm4GqkBpokow?usp=share_link) for OSX:
693
+ - Download the human model files and place it into `grounded-sam-osx/utils/human_model_files` following the instruction of [OSX](https://github.com/IDEA-Research/OSX).
694
+
695
+ - Run Demo
696
+
697
+ ```shell
698
+ export CUDA_VISIBLE_DEVICES=0
699
+ python grounded_sam_osx_demo.py \
700
+ --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
701
+ --grounded_checkpoint groundingdino_swint_ogc.pth \
702
+ --sam_checkpoint sam_vit_h_4b8939.pth \
703
+ --osx_checkpoint osx_l_wo_decoder.pth.tar \
704
+ --input_image assets/osx/grounded_sam_osx_demo.png \
705
+ --output_dir "outputs" \
706
+ --box_threshold 0.3 \
707
+ --text_threshold 0.25 \
708
+ --text_prompt "humans, chairs" \
709
+ --device "cuda"
710
+ ```
711
+
712
+ - The model prediction visualization will be saved in `output_dir` as follows:
713
+
714
+ <img src="assets/osx/grounded_sam_osx_output.jpg" style="zoom: 49%;" />
715
+
716
+ - We also support promptable 3D whole-body mesh recovery. For example, you can track someone with a text prompt and estimate his 3D pose and shape :
717
+
718
+ | ![space-1.jpg](assets/osx/grounded_sam_osx_output1.jpg) |
719
+ | :---------------------------------------------------: |
720
+ | *A person with pink clothes* |
721
+
722
+ | ![space-1.jpg](assets/osx/grounded_sam_osx_output2.jpg) |
723
+ | :---------------------------------------------------: |
724
+ | *A man with a sunglasses* |
725
+
726
+
727
+ ## :man_dancing: Run Grounded-Segment-Anything + VISAM Demo
728
+
729
+ - Download the checkpoint `motrv2_dancetrack.pth` from [here](https://drive.google.com/file/d/1EA4lndu2yQcVgBKR09KfMe5efbf631Th/view?usp=share_link) for MOTRv2:
730
+ - See the more thing if you have other questions for the installation.
731
+
732
+ - Run Demo
733
+
734
+ ```shell
735
+ export CUDA_VISIBLE_DEVICES=0
736
+ python grounded_sam_visam.py \
737
+ --meta_arch motr \
738
+ --dataset_file e2e_dance \
739
+ --with_box_refine \
740
+ --query_interaction_layer QIMv2 \
741
+ --num_queries 10 \
742
+ --det_db det_db_motrv2.json \
743
+ --use_checkpoint \
744
+ --mot_path your_data_path \
745
+ --resume motrv2_dancetrack.pth \
746
+ --sam_checkpoint sam_vit_h_4b8939.pth \
747
+ --video_path DanceTrack/test/dancetrack0003
748
+ ```
749
+ |![](https://raw.githubusercontent.com/BingfengYan/MOTSAM/main/visam.gif)|
750
+
751
+
752
+ ### :dancers: Interactive Editing
753
+ - Release the interactive fashion-edit playground in [here](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/humanFace). Run in the notebook, just click for annotating points for further segmentation. Enjoy it!
754
+
755
+
756
+ - Release human-face-edit branch [here](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/humanFace). We'll keep updating this branch with more interesting features. Here are some examples:
757
+
758
+ ![](https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/humanFace/assets/231-hair-edit.png)
759
+
760
+ ## :camera: 3D-Box via Segment Anything
761
+ We extend the scope to 3D world by combining Segment Anything and [VoxelNeXt](https://github.com/dvlab-research/VoxelNeXt). When we provide a prompt (e.g., a point / box), the result is not only 2D segmentation mask, but also 3D boxes. Please check [voxelnext_3d_box](./voxelnext_3d_box/) for more details.
762
+ ![](https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/voxelnext_3d_box/images/sam-voxelnext.png)
763
+ ![](https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/voxelnext_3d_box/images/image_boxes2.png)
764
+
765
+
766
+
767
+
768
+ ## :cupid: Acknowledgements
769
+
770
+ - [Segment Anything](https://github.com/facebookresearch/segment-anything)
771
+ - [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)
772
+
773
+
774
+ ## Contributors
775
+
776
+ Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better.
777
+
778
+ <a href="https://github.com/IDEA-Research/Grounded-Segment-Anything/graphs/contributors">
779
+ <img src="https://contrib.rocks/image?repo=IDEA-Research/Grounded-Segment-Anything" />
780
+ </a>
781
+
782
+
783
+ ## Citation
784
+ If you find this project helpful for your research, please consider citing the following BibTeX entry.
785
+ ```BibTex
786
+ @article{kirillov2023segany,
787
+ title={Segment Anything},
788
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
789
+ journal={arXiv:2304.02643},
790
+ year={2023}
791
+ }
792
+
793
+ @article{liu2023grounding,
794
+ title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
795
+ author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
796
+ journal={arXiv preprint arXiv:2303.05499},
797
+ year={2023}
798
+ }
799
+
800
+ @misc{ren2024grounded,
801
+ title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
802
+ author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},
803
+ year={2024},
804
+ eprint={2401.14159},
805
+ archivePrefix={arXiv},
806
+ primaryClass={cs.CV}
807
+ }
808
+ ```