Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- clone-IDEA-Research/Grounded-SAM-2/.clang-format +85 -0
- clone-IDEA-Research/Grounded-SAM-2/.gitignore +147 -0
- clone-IDEA-Research/Grounded-SAM-2/.watchmanconfig +1 -0
- clone-IDEA-Research/Grounded-SAM-2/CODE_OF_CONDUCT.md +80 -0
- clone-IDEA-Research/Grounded-SAM-2/CONTRIBUTING.md +31 -0
- clone-IDEA-Research/Grounded-SAM-2/Dockerfile +37 -0
- clone-IDEA-Research/Grounded-SAM-2/INSTALL.md +189 -0
- clone-IDEA-Research/Grounded-SAM-2/LICENSE +201 -0
- clone-IDEA-Research/Grounded-SAM-2/LICENSE_cctorch +29 -0
- clone-IDEA-Research/Grounded-SAM-2/LICENSE_groundingdino +201 -0
- clone-IDEA-Research/Grounded-SAM-2/LICENSE_sam2 +201 -0
- clone-IDEA-Research/Grounded-SAM-2/MANIFEST.in +7 -0
- clone-IDEA-Research/Grounded-SAM-2/Makefile +37 -0
- clone-IDEA-Research/Grounded-SAM-2/README.md +484 -0
- clone-IDEA-Research/Grounded-SAM-2/SAM2_README.md +140 -0
- clone-IDEA-Research/Grounded-SAM-2/backend.Dockerfile +64 -0
- clone-IDEA-Research/Grounded-SAM-2/docker-compose.yaml +42 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_dinox_demo.py +245 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_autolabel_pipeline.py +198 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_image_demo.py +657 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_gd1.5_demo.py +249 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_hf_model_demo.py +187 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_local_demo.py +160 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo.py +198 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_dinox.py +237 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py +214 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py +220 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py +239 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py +203 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py +224 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_plus.py +247 -0
- clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_gd1.5.py +221 -0
- clone-IDEA-Research/Grounded-SAM-2/pyproject.toml +6 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/__init__.py +11 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/automatic_mask_generator.py +454 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/build_sam.py +167 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_b+.yaml +113 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_l.yaml +117 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_s.yaml +116 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_t.yaml +118 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_image_predictor.py +466 -0
- clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_video_predictor.py +1172 -0
- clone-IDEA-Research/Grounded-SAM-2/setup.py +174 -0
- clone-IDEA-Research/Grounded-Segment-Anything/.gitignore +135 -0
- clone-IDEA-Research/Grounded-Segment-Anything/.gitmodules +7 -0
- clone-IDEA-Research/Grounded-Segment-Anything/CITATION.cff +8 -0
- clone-IDEA-Research/Grounded-Segment-Anything/Dockerfile +30 -0
- clone-IDEA-Research/Grounded-Segment-Anything/LICENSE +201 -0
- clone-IDEA-Research/Grounded-Segment-Anything/Makefile +43 -0
- clone-IDEA-Research/Grounded-Segment-Anything/README.md +808 -0
clone-IDEA-Research/Grounded-SAM-2/.clang-format
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AccessModifierOffset: -1
|
2 |
+
AlignAfterOpenBracket: AlwaysBreak
|
3 |
+
AlignConsecutiveAssignments: false
|
4 |
+
AlignConsecutiveDeclarations: false
|
5 |
+
AlignEscapedNewlinesLeft: true
|
6 |
+
AlignOperands: false
|
7 |
+
AlignTrailingComments: false
|
8 |
+
AllowAllParametersOfDeclarationOnNextLine: false
|
9 |
+
AllowShortBlocksOnASingleLine: false
|
10 |
+
AllowShortCaseLabelsOnASingleLine: false
|
11 |
+
AllowShortFunctionsOnASingleLine: Empty
|
12 |
+
AllowShortIfStatementsOnASingleLine: false
|
13 |
+
AllowShortLoopsOnASingleLine: false
|
14 |
+
AlwaysBreakAfterReturnType: None
|
15 |
+
AlwaysBreakBeforeMultilineStrings: true
|
16 |
+
AlwaysBreakTemplateDeclarations: true
|
17 |
+
BinPackArguments: false
|
18 |
+
BinPackParameters: false
|
19 |
+
BraceWrapping:
|
20 |
+
AfterClass: false
|
21 |
+
AfterControlStatement: false
|
22 |
+
AfterEnum: false
|
23 |
+
AfterFunction: false
|
24 |
+
AfterNamespace: false
|
25 |
+
AfterObjCDeclaration: false
|
26 |
+
AfterStruct: false
|
27 |
+
AfterUnion: false
|
28 |
+
BeforeCatch: false
|
29 |
+
BeforeElse: false
|
30 |
+
IndentBraces: false
|
31 |
+
BreakBeforeBinaryOperators: None
|
32 |
+
BreakBeforeBraces: Attach
|
33 |
+
BreakBeforeTernaryOperators: true
|
34 |
+
BreakConstructorInitializersBeforeComma: false
|
35 |
+
BreakAfterJavaFieldAnnotations: false
|
36 |
+
BreakStringLiterals: false
|
37 |
+
ColumnLimit: 80
|
38 |
+
CommentPragmas: '^ IWYU pragma:'
|
39 |
+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
40 |
+
ConstructorInitializerIndentWidth: 4
|
41 |
+
ContinuationIndentWidth: 4
|
42 |
+
Cpp11BracedListStyle: true
|
43 |
+
DerivePointerAlignment: false
|
44 |
+
DisableFormat: false
|
45 |
+
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
|
46 |
+
IncludeCategories:
|
47 |
+
- Regex: '^<.*\.h(pp)?>'
|
48 |
+
Priority: 1
|
49 |
+
- Regex: '^<.*'
|
50 |
+
Priority: 2
|
51 |
+
- Regex: '.*'
|
52 |
+
Priority: 3
|
53 |
+
IndentCaseLabels: true
|
54 |
+
IndentWidth: 2
|
55 |
+
IndentWrappedFunctionNames: false
|
56 |
+
KeepEmptyLinesAtTheStartOfBlocks: false
|
57 |
+
MacroBlockBegin: ''
|
58 |
+
MacroBlockEnd: ''
|
59 |
+
MaxEmptyLinesToKeep: 1
|
60 |
+
NamespaceIndentation: None
|
61 |
+
ObjCBlockIndentWidth: 2
|
62 |
+
ObjCSpaceAfterProperty: false
|
63 |
+
ObjCSpaceBeforeProtocolList: false
|
64 |
+
PenaltyBreakBeforeFirstCallParameter: 1
|
65 |
+
PenaltyBreakComment: 300
|
66 |
+
PenaltyBreakFirstLessLess: 120
|
67 |
+
PenaltyBreakString: 1000
|
68 |
+
PenaltyExcessCharacter: 1000000
|
69 |
+
PenaltyReturnTypeOnItsOwnLine: 200
|
70 |
+
PointerAlignment: Left
|
71 |
+
ReflowComments: true
|
72 |
+
SortIncludes: true
|
73 |
+
SpaceAfterCStyleCast: false
|
74 |
+
SpaceBeforeAssignmentOperators: true
|
75 |
+
SpaceBeforeParens: ControlStatements
|
76 |
+
SpaceInEmptyParentheses: false
|
77 |
+
SpacesBeforeTrailingComments: 1
|
78 |
+
SpacesInAngles: false
|
79 |
+
SpacesInContainerLiterals: true
|
80 |
+
SpacesInCStyleCastParentheses: false
|
81 |
+
SpacesInParentheses: false
|
82 |
+
SpacesInSquareBrackets: false
|
83 |
+
Standard: Cpp11
|
84 |
+
TabWidth: 8
|
85 |
+
UseTab: Never
|
clone-IDEA-Research/Grounded-SAM-2/.gitignore
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SAM 2
|
2 |
+
.vscode/
|
3 |
+
.DS_Store
|
4 |
+
__pycache__/
|
5 |
+
*-checkpoint.ipynb
|
6 |
+
.venv
|
7 |
+
*.egg*
|
8 |
+
build/*
|
9 |
+
_C.*
|
10 |
+
outputs/*
|
11 |
+
checkpoints/*.pt
|
12 |
+
*test*
|
13 |
+
# Byte-compiled / optimized / DLL files
|
14 |
+
__pycache__/
|
15 |
+
*.py[cod]
|
16 |
+
*$py.class
|
17 |
+
|
18 |
+
# C extensions
|
19 |
+
*.so
|
20 |
+
|
21 |
+
# Distribution / packaging
|
22 |
+
.Python
|
23 |
+
build/
|
24 |
+
develop-eggs/
|
25 |
+
dist/
|
26 |
+
downloads/
|
27 |
+
eggs/
|
28 |
+
.eggs/
|
29 |
+
lib/
|
30 |
+
lib64/
|
31 |
+
parts/
|
32 |
+
sdist/
|
33 |
+
var/
|
34 |
+
wheels/
|
35 |
+
pip-wheel-metadata/
|
36 |
+
share/python-wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.nox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
.cache
|
59 |
+
nosetests.xml
|
60 |
+
coverage.xml
|
61 |
+
*.cover
|
62 |
+
*.py,cover
|
63 |
+
.hypothesis/
|
64 |
+
.pytest_cache/
|
65 |
+
|
66 |
+
# Translations
|
67 |
+
*.mo
|
68 |
+
*.pot
|
69 |
+
|
70 |
+
# Django stuff:
|
71 |
+
*.log
|
72 |
+
local_settings.py
|
73 |
+
db.sqlite3
|
74 |
+
db.sqlite3-journal
|
75 |
+
|
76 |
+
# Flask stuff:
|
77 |
+
instance/
|
78 |
+
.webassets-cache
|
79 |
+
|
80 |
+
# Scrapy stuff:
|
81 |
+
.scrapy
|
82 |
+
|
83 |
+
# Sphinx documentation
|
84 |
+
docs/_build/
|
85 |
+
|
86 |
+
# PyBuilder
|
87 |
+
target/
|
88 |
+
|
89 |
+
# Jupyter Notebook
|
90 |
+
.ipynb_checkpoints
|
91 |
+
|
92 |
+
# IPython
|
93 |
+
profile_default/
|
94 |
+
ipython_config.py
|
95 |
+
|
96 |
+
# pyenv
|
97 |
+
.python-version
|
98 |
+
|
99 |
+
# pipenv
|
100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
103 |
+
# install all needed dependencies.
|
104 |
+
#Pipfile.lock
|
105 |
+
|
106 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
107 |
+
__pypackages__/
|
108 |
+
|
109 |
+
# Celery stuff
|
110 |
+
celerybeat-schedule
|
111 |
+
celerybeat.pid
|
112 |
+
|
113 |
+
# SageMath parsed files
|
114 |
+
*.sage.py
|
115 |
+
|
116 |
+
# Environments
|
117 |
+
.env
|
118 |
+
.venv
|
119 |
+
env/
|
120 |
+
venv/
|
121 |
+
ENV/
|
122 |
+
env.bak/
|
123 |
+
venv.bak/
|
124 |
+
|
125 |
+
# Spyder project settings
|
126 |
+
.spyderproject
|
127 |
+
.spyproject
|
128 |
+
|
129 |
+
# Rope project settings
|
130 |
+
.ropeproject
|
131 |
+
|
132 |
+
# mkdocs documentation
|
133 |
+
/site
|
134 |
+
|
135 |
+
# mypy
|
136 |
+
.mypy_cache/
|
137 |
+
.dmypy.json
|
138 |
+
dmypy.json
|
139 |
+
|
140 |
+
# Pyre type checker
|
141 |
+
.pyre/
|
142 |
+
|
143 |
+
# checkpoint
|
144 |
+
*.pth
|
145 |
+
outputs/
|
146 |
+
|
147 |
+
.idea/
|
clone-IDEA-Research/Grounded-SAM-2/.watchmanconfig
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
clone-IDEA-Research/Grounded-SAM-2/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
clone-IDEA-Research/Grounded-SAM-2/CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to segment-anything
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to segment-anything, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
clone-IDEA-Research/Grounded-SAM-2/Dockerfile
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
|
2 |
+
|
3 |
+
# Arguments to build Docker Image using CUDA
|
4 |
+
ARG USE_CUDA=0
|
5 |
+
ARG TORCH_ARCH="7.0;7.5;8.0;8.6"
|
6 |
+
|
7 |
+
ENV AM_I_DOCKER=True
|
8 |
+
ENV BUILD_WITH_CUDA="${USE_CUDA}"
|
9 |
+
ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}"
|
10 |
+
ENV CUDA_HOME=/usr/local/cuda-12.1/
|
11 |
+
# Ensure CUDA is correctly set up
|
12 |
+
ENV PATH=/usr/local/cuda-12.1/bin:${PATH}
|
13 |
+
ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
|
14 |
+
|
15 |
+
# Install required packages and specific gcc/g++
|
16 |
+
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
|
17 |
+
libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 -y \
|
18 |
+
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
|
19 |
+
|
20 |
+
ENV CC=gcc-10
|
21 |
+
ENV CXX=g++-10
|
22 |
+
|
23 |
+
RUN mkdir -p /home/appuser/Grounded-SAM-2
|
24 |
+
COPY . /home/appuser/Grounded-SAM-2/
|
25 |
+
|
26 |
+
WORKDIR /home/appuser/Grounded-SAM-2
|
27 |
+
|
28 |
+
|
29 |
+
# Install essential Python packages
|
30 |
+
RUN python -m pip install --upgrade pip setuptools wheel numpy \
|
31 |
+
opencv-python transformers supervision pycocotools addict yapf timm
|
32 |
+
|
33 |
+
# Install segment_anything package in editable mode
|
34 |
+
RUN python -m pip install -e .
|
35 |
+
|
36 |
+
# Install grounding dino
|
37 |
+
RUN python -m pip install --no-build-isolation -e grounding_dino
|
clone-IDEA-Research/Grounded-SAM-2/INSTALL.md
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Installation
|
2 |
+
|
3 |
+
### Requirements
|
4 |
+
|
5 |
+
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
|
6 |
+
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
|
7 |
+
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
|
8 |
+
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
9 |
+
|
10 |
+
Then, install SAM 2 from the root of this repository via
|
11 |
+
```bash
|
12 |
+
pip install -e ".[notebooks]"
|
13 |
+
```
|
14 |
+
|
15 |
+
Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
|
16 |
+
```bash
|
17 |
+
# skip the SAM 2 CUDA extension
|
18 |
+
SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]"
|
19 |
+
```
|
20 |
+
This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
|
21 |
+
|
22 |
+
### Building the SAM 2 CUDA extension
|
23 |
+
|
24 |
+
By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
|
25 |
+
|
26 |
+
If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.
|
27 |
+
|
28 |
+
If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
|
29 |
+
```bash
|
30 |
+
pip uninstall -y SAM-2 && \
|
31 |
+
rm -f ./sam2/*.so && \
|
32 |
+
SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]"
|
33 |
+
```
|
34 |
+
|
35 |
+
Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
|
36 |
+
|
37 |
+
Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
|
38 |
+
|
39 |
+
### Common Installation Issues
|
40 |
+
|
41 |
+
Click each issue for its solutions:
|
42 |
+
|
43 |
+
<details>
|
44 |
+
<summary>
|
45 |
+
I got `ImportError: cannot import name '_C' from 'sam2'`
|
46 |
+
</summary>
|
47 |
+
<br/>
|
48 |
+
|
49 |
+
This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
|
50 |
+
|
51 |
+
In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77.
|
52 |
+
</details>
|
53 |
+
|
54 |
+
<details>
|
55 |
+
<summary>
|
56 |
+
I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'`
|
57 |
+
</summary>
|
58 |
+
<br/>
|
59 |
+
|
60 |
+
This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
|
61 |
+
```bash
|
62 |
+
export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo
|
63 |
+
export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
|
64 |
+
```
|
65 |
+
to manually add `sam2_configs` into your Python's `sys.path`.
|
66 |
+
|
67 |
+
</details>
|
68 |
+
|
69 |
+
<details>
|
70 |
+
<summary>
|
71 |
+
I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints
|
72 |
+
</summary>
|
73 |
+
<br/>
|
74 |
+
|
75 |
+
This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps:
|
76 |
+
|
77 |
+
1. pull the latest code from the `main` branch of this repo
|
78 |
+
2. run `pip uninstall -y SAM-2` to uninstall any previous installations
|
79 |
+
3. then install the latest repo again using `pip install -e ".[notebooks]"`
|
80 |
+
|
81 |
+
In case the steps above still don't resolve the error, please try running in your Python environment the following
|
82 |
+
```python
|
83 |
+
from sam2.modeling import sam2_base
|
84 |
+
|
85 |
+
print(sam2_base.__file__)
|
86 |
+
```
|
87 |
+
and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation.
|
88 |
+
|
89 |
+
</details>
|
90 |
+
|
91 |
+
<details>
|
92 |
+
<summary>
|
93 |
+
My installation failed with `CUDA_HOME environment variable is not set`
|
94 |
+
</summary>
|
95 |
+
<br/>
|
96 |
+
|
97 |
+
This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
|
98 |
+
```
|
99 |
+
export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
|
100 |
+
```
|
101 |
+
and rerun the installation.
|
102 |
+
|
103 |
+
Also, you should make sure
|
104 |
+
```
|
105 |
+
python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
|
106 |
+
```
|
107 |
+
print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
|
108 |
+
|
109 |
+
If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command:
|
110 |
+
```
|
111 |
+
pip install --no-build-isolation -e .
|
112 |
+
```
|
113 |
+
|
114 |
+
</details>
|
115 |
+
|
116 |
+
<details>
|
117 |
+
<summary>
|
118 |
+
I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
|
119 |
+
</summary>
|
120 |
+
<br/>
|
121 |
+
|
122 |
+
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
|
123 |
+
|
124 |
+
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
|
125 |
+
|
126 |
+
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
|
127 |
+
</details>
|
128 |
+
|
129 |
+
<details>
|
130 |
+
<summary>
|
131 |
+
I got `CUDA error: no kernel image is available for execution on the device`
|
132 |
+
</summary>
|
133 |
+
<br/>
|
134 |
+
|
135 |
+
A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).
|
136 |
+
|
137 |
+
You can try pulling the latest code from the SAM 2 repo and running the following
|
138 |
+
```
|
139 |
+
export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
|
140 |
+
```
|
141 |
+
to manually specify the CUDA capability in the compilation target that matches your GPU.
|
142 |
+
</details>
|
143 |
+
|
144 |
+
<details>
|
145 |
+
<summary>
|
146 |
+
I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
|
147 |
+
</summary>
|
148 |
+
<br/>
|
149 |
+
|
150 |
+
This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
|
151 |
+
```python
|
152 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
153 |
+
```
|
154 |
+
in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
|
155 |
+
```python
|
156 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
|
157 |
+
```
|
158 |
+
to relax the attention kernel setting and use other kernels than Flash Attention.
|
159 |
+
</details>
|
160 |
+
|
161 |
+
<details>
|
162 |
+
<summary>
|
163 |
+
I got `Error compiling objects for extension`
|
164 |
+
</summary>
|
165 |
+
<br/>
|
166 |
+
|
167 |
+
You may see error log of:
|
168 |
+
> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
|
169 |
+
|
170 |
+
This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).<br>
|
171 |
+
You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py). <br>
|
172 |
+
After adding the argument, `get_extension()` will look like this:
|
173 |
+
```python
|
174 |
+
def get_extensions():
|
175 |
+
srcs = ["sam2/csrc/connected_components.cu"]
|
176 |
+
compile_args = {
|
177 |
+
"cxx": [],
|
178 |
+
"nvcc": [
|
179 |
+
"-DCUDA_HAS_FP16=1",
|
180 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
181 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
182 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
183 |
+
"-allow-unsupported-compiler" # Add this argument
|
184 |
+
],
|
185 |
+
}
|
186 |
+
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
187 |
+
return ext_modules
|
188 |
+
```
|
189 |
+
</details>
|
clone-IDEA-Research/Grounded-SAM-2/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 - present, IDEA Research.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
clone-IDEA-Research/Grounded-SAM-2/LICENSE_cctorch
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
3. Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
clone-IDEA-Research/Grounded-SAM-2/LICENSE_groundingdino
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 - present, IDEA Research.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
clone-IDEA-Research/Grounded-SAM-2/LICENSE_sam2
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
clone-IDEA-Research/Grounded-SAM-2/MANIFEST.in
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
recursive-include sam2 *.yaml #include all config files
|
clone-IDEA-Research/Grounded-SAM-2/Makefile
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Get version of CUDA and enable it for compilation if CUDA > 11.0
|
2 |
+
# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
|
3 |
+
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
|
4 |
+
# when running in Docker
|
5 |
+
# Check if nvcc is installed
|
6 |
+
NVCC := $(shell which nvcc)
|
7 |
+
ifeq ($(NVCC),)
|
8 |
+
# NVCC not found
|
9 |
+
USE_CUDA := 0
|
10 |
+
NVCC_VERSION := "not installed"
|
11 |
+
else
|
12 |
+
NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
|
13 |
+
USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
|
14 |
+
endif
|
15 |
+
|
16 |
+
# Add the list of supported ARCHs
|
17 |
+
ifeq ($(USE_CUDA), 1)
|
18 |
+
TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX"
|
19 |
+
BUILD_MESSAGE := "I will try to build the image with CUDA support"
|
20 |
+
else
|
21 |
+
TORCH_CUDA_ARCH_LIST :=
|
22 |
+
BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
|
23 |
+
endif
|
24 |
+
|
25 |
+
|
26 |
+
build-image:
|
27 |
+
@echo $(BUILD_MESSAGE)
|
28 |
+
docker build --build-arg USE_CUDA=$(USE_CUDA) \
|
29 |
+
--build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
|
30 |
+
-t grounded_sam2:1.0 .
|
31 |
+
run:
|
32 |
+
docker run --gpus all -it --rm --net=host --privileged \
|
33 |
+
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
34 |
+
-v "${PWD}":/home/appuser/Grounded-SAM-2 \
|
35 |
+
-e DISPLAY=$DISPLAY \
|
36 |
+
--name=gsa \
|
37 |
+
--ipc=host -it grounded_sam2:1.0
|
clone-IDEA-Research/Grounded-SAM-2/README.md
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Grounded SAM 2: Ground and Track Anything in Videos
|
2 |
+
|
3 |
+
**[IDEA-Research](https://github.com/idea-research)**
|
4 |
+
|
5 |
+
[Tianhe Ren](https://rentainhe.github.io/), [Shuo Shen](https://github.com/ShuoShenDe)
|
6 |
+
|
7 |
+
[[`SAM 2 Paper`](https://arxiv.org/abs/2408.00714)] [[`Grounding DINO Paper`](https://arxiv.org/abs/2303.05499)] [[`Grounding DINO 1.5 Paper`](https://arxiv.org/abs/2405.10300)] [[`DINO-X Paper`](https://arxiv.org/abs/2411.14347)] [[`BibTeX`](#citation)]
|
8 |
+
|
9 |
+
[](https://github.com/user-attachments/assets/f0fb0022-779a-49fb-8f46-3a18a8b4e893)
|
10 |
+
|
11 |
+
## Highlights
|
12 |
+
|
13 |
+
Grounded SAM 2 is a foundation model pipeline towards grounding and track anything in Videos with [Grounding DINO](https://arxiv.org/abs/2303.05499), [Grounding DINO 1.5](https://arxiv.org/abs/2405.10300), [Florence-2](https://arxiv.org/abs/2311.06242), [DINO-X](https://arxiv.org/abs/2411.14347) and [SAM 2](https://arxiv.org/abs/2408.00714).
|
14 |
+
|
15 |
+
In this repo, we've supported the following demo with **simple implementations**:
|
16 |
+
- **Ground and Segment Anything** with Grounding DINO, Grounding DINO 1.5 & 1.6, DINO-X and SAM 2
|
17 |
+
- **Ground and Track Anything** with Grounding DINO, Grounding DINO 1.5 & 1.6, DINO-X and SAM 2
|
18 |
+
- **Detect, Segment and Track Visualization** based on the powerful [supervision](https://github.com/roboflow/supervision) library.
|
19 |
+
|
20 |
+
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
|
21 |
+
|
22 |
+
## Latest updates
|
23 |
+
|
24 |
+
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
|
25 |
+
|
26 |
+
- **2024.10.24**: Support [SAHI (Slicing Aided Hyper Inference)](https://docs.ultralytics.com/guides/sahi-tiled-inference/) on Grounded SAM 2 (with Grounding DINO 1.5) which may be helpful for inferencing high resolution image with dense small objects (e.g. **4K** images).
|
27 |
+
|
28 |
+
- **2024.10.10**: Support `SAM-2.1` models, if you want to use `SAM 2.1` model, you need to update to the latest code and reinstall SAM 2 follow [SAM 2.1 Installation](https://github.com/facebookresearch/sam2?tab=readme-ov-file#latest-updates).
|
29 |
+
|
30 |
+
- **2024.08.31**: Support `dump json results` in Grounded SAM 2 Image Demos (with Grounding DINO).
|
31 |
+
|
32 |
+
- **2024.08.20**: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`.
|
33 |
+
|
34 |
+
- **2024.08.09**: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
|
35 |
+
|
36 |
+
- **2024.08.07**: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos.
|
37 |
+
|
38 |
+
## Contents
|
39 |
+
- [Installation](#installation)
|
40 |
+
- [Grounded SAM 2 Demos](#grounded-sam-2-demos)
|
41 |
+
- [Grounded SAM 2 Image Demo](#grounded-sam-2-image-demo-with-grounding-dino)
|
42 |
+
- [Grounded SAM 2 Image Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-image-demo-with-grounding-dino-15--16)
|
43 |
+
- [Grounded SAM 2 Image Demo (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x)
|
44 |
+
- [Grounded SAM 2 with SAHI for High Resolution Image Inference](#sahi-slicing-aided-hyper-inference-with-grounding-dino-15-and-sam-2)
|
45 |
+
- [Automatically Saving Grounding and Segmentation Results](#automatically-saving-grounding-results-image-demo)
|
46 |
+
- [Grounded SAM 2 Video Object Tracking Demo](#grounded-sam-2-video-object-tracking-demo)
|
47 |
+
- [Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-grounding-dino-15--16)
|
48 |
+
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino)
|
49 |
+
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
|
50 |
+
- [Grounded SAM 2 Video Object Tracking Demo (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x)
|
51 |
+
- [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
|
52 |
+
- [Grounded SAM 2 Florence-2 Demos](#grounded-sam-2-florence-2-demos)
|
53 |
+
- [Grounded SAM 2 Florence-2 Image Demo](#grounded-sam-2-florence-2-image-demo)
|
54 |
+
- [Grounded SAM 2 Florence-2 Image Auto-Labeling Demo](#grounded-sam-2-florence-2-image-auto-labeling-demo)
|
55 |
+
- [Citation](#citation)
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
## Installation
|
60 |
+
|
61 |
+
Download the pretrained `SAM 2` checkpoints:
|
62 |
+
|
63 |
+
```bash
|
64 |
+
cd checkpoints
|
65 |
+
bash download_ckpts.sh
|
66 |
+
```
|
67 |
+
|
68 |
+
Download the pretrained `Grounding DINO` checkpoints:
|
69 |
+
|
70 |
+
```bash
|
71 |
+
cd gdino_checkpoints
|
72 |
+
bash download_ckpts.sh
|
73 |
+
```
|
74 |
+
|
75 |
+
### Installation without docker
|
76 |
+
|
77 |
+
Install PyTorch environment first. We use `python=3.10`, as well as `torch >= 2.3.1`, `torchvision>=0.18.1` and `cuda-12.1` in our environment to run this demo. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. You can easily install the latest version of PyTorch as follows:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
pip3 install torch torchvision torchaudio
|
81 |
+
```
|
82 |
+
|
83 |
+
Since we need the CUDA compilation environment to compile the `Deformable Attention` operator used in Grounding DINO, we need to check whether the CUDA environment variables have been set correctly (which you can refer to [Grounding DINO Installation](https://github.com/IDEA-Research/GroundingDINO?tab=readme-ov-file#hammer_and_wrench-install) for more details). You can set the environment variable manually as follows if you want to build a local GPU environment for Grounding DINO to run Grounded SAM 2:
|
84 |
+
|
85 |
+
```bash
|
86 |
+
export CUDA_HOME=/path/to/cuda-12.1/
|
87 |
+
```
|
88 |
+
|
89 |
+
Install `Segment Anything 2`:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
pip install -e .
|
93 |
+
```
|
94 |
+
|
95 |
+
Install `Grounding DINO`:
|
96 |
+
|
97 |
+
```bash
|
98 |
+
pip install --no-build-isolation -e grounding_dino
|
99 |
+
```
|
100 |
+
|
101 |
+
### Installation with docker
|
102 |
+
Build the Docker image and Run the Docker container:
|
103 |
+
|
104 |
+
```
|
105 |
+
cd Grounded-SAM-2
|
106 |
+
make build-image
|
107 |
+
make run
|
108 |
+
```
|
109 |
+
After executing these commands, you will be inside the Docker environment. The working directory within the container is set to: `/home/appuser/Grounded-SAM-2`
|
110 |
+
|
111 |
+
Once inside the Docker environment, you can start the demo by running:
|
112 |
+
```
|
113 |
+
python grounded_sam2_tracking_demo.py
|
114 |
+
```
|
115 |
+
|
116 |
+
## Grounded SAM 2 Demos
|
117 |
+
### Grounded SAM 2 Image Demo (with Grounding DINO)
|
118 |
+
Note that `Grounding DINO` has already been supported in [Huggingface](https://huggingface.co/IDEA-Research/grounding-dino-tiny), so we provide two choices for running `Grounded SAM 2` model:
|
119 |
+
- Use huggingface API to inference Grounding DINO (which is simple and clear)
|
120 |
+
|
121 |
+
```bash
|
122 |
+
python grounded_sam2_hf_model_demo.py
|
123 |
+
```
|
124 |
+
|
125 |
+
> [!NOTE]
|
126 |
+
> 🚨 If you encounter network issues while using the `HuggingFace` model, you can resolve them by setting the appropriate mirror source as `export HF_ENDPOINT=https://hf-mirror.com`
|
127 |
+
|
128 |
+
- Load local pretrained Grounding DINO checkpoint and inference with Grounding DINO original API (make sure you've already downloaded the pretrained checkpoint)
|
129 |
+
|
130 |
+
```bash
|
131 |
+
python grounded_sam2_local_demo.py
|
132 |
+
```
|
133 |
+
|
134 |
+
|
135 |
+
### Grounded SAM 2 Image Demo (with Grounding DINO 1.5 & 1.6)
|
136 |
+
|
137 |
+
We've already released our most capable open-set detection model [Grounding DINO 1.5 & 1.6](https://github.com/IDEA-Research/Grounding-DINO-1.5-API), which can be combined with SAM 2 for stronger open-set detection and segmentation capability. You can apply the API token first and run Grounded SAM 2 with Grounding DINO 1.5 as follows:
|
138 |
+
|
139 |
+
Install the latest DDS cloudapi:
|
140 |
+
|
141 |
+
```bash
|
142 |
+
pip install dds-cloudapi-sdk --upgrade
|
143 |
+
```
|
144 |
+
|
145 |
+
Apply your API token from our official website here: [request API token](https://deepdataspace.com/request_api).
|
146 |
+
|
147 |
+
```bash
|
148 |
+
python grounded_sam2_gd1.5_demo.py
|
149 |
+
```
|
150 |
+
|
151 |
+
### SAHI (Slicing Aided Hyper Inference) with Grounding DINO 1.5 and SAM 2
|
152 |
+
|
153 |
+
If your images are high resolution with dense objects, directly using Grounding DINO 1.5 for inference on the original image may not be the best choice. We support [SAHI (Slicing Aided Hyper Inference)](https://docs.ultralytics.com/guides/sahi-tiled-inference/), which works by first dividing the original image into smaller overlapping patches. Inference is then performed separately on each patch, and the final detection results are merged. This method is highly effective and accuracy for dense and small objects detection in high resolution images.
|
154 |
+
|
155 |
+
You can run SAHI inference by setting the following param in [grounded_sam2_gd1.5_demo.py](./grounded_sam2_gd1.5_demo.py):
|
156 |
+
|
157 |
+
```python
|
158 |
+
WITH_SLICE_INFERENCE = True
|
159 |
+
```
|
160 |
+
|
161 |
+
The visualization is shown as follows:
|
162 |
+
|
163 |
+
| Text Prompt | Input Image | Grounded SAM 2 | Grounded SAM 2 with SAHI |
|
164 |
+
|:----:|:----:|:----:|:----:|
|
165 |
+
| `Person` |  |  |  |
|
166 |
+
|
167 |
+
- **Notes:** We only support SAHI on Grounding DINO 1.5 because it works better with stronger grounding model which may produce less hallucination results.
|
168 |
+
|
169 |
+
### Grounded SAM 2 Image Demo (with DINO-X)
|
170 |
+
|
171 |
+
We've implemented Grounded SAM 2 with the strongest open-world perception model [DINO-X](https://github.com/IDEA-Research/DINO-X-API) for better open-set detection and segmentation performance. You can apply the API token first and run Grounded SAM 2 with DINO-X as follows:
|
172 |
+
|
173 |
+
Install the latest DDS cloudapi:
|
174 |
+
|
175 |
+
```bash
|
176 |
+
pip install dds-cloudapi-sdk --upgrade
|
177 |
+
```
|
178 |
+
|
179 |
+
Apply your API token from our official website here: [request API token](https://deepdataspace.com/request_api).
|
180 |
+
|
181 |
+
```bash
|
182 |
+
python grounded_sam2_dinox_demo.py
|
183 |
+
```
|
184 |
+
|
185 |
+
### Automatically Saving Grounding Results (Image Demo)
|
186 |
+
|
187 |
+
After setting `DUMP_JSON_RESULTS=True` in the following Grounded SAM 2 Image Demos:
|
188 |
+
- [grounded_sam2_local_demo.py](./grounded_sam2_local_demo.py)
|
189 |
+
- [grounded_sam2_hf_model_demo.py](./grounded_sam2_hf_model_demo.py)
|
190 |
+
- [grounded_sam2_gd1.5_demo.py](./grounded_sam2_gd1.5_demo.py)
|
191 |
+
- [grounded_sam2_dinox_demo.py](./grounded_sam2_dinox_demo.py)
|
192 |
+
|
193 |
+
The `grounding` and `segmentation` results will be automatically saved in the `outputs` dir with the following format:
|
194 |
+
|
195 |
+
```python
|
196 |
+
{
|
197 |
+
"image_path": "path/to/image.jpg",
|
198 |
+
"annotations": [
|
199 |
+
{
|
200 |
+
"class_name": "class_name",
|
201 |
+
"bbox": [x1, y1, x2, y2],
|
202 |
+
"segmentation": {
|
203 |
+
"size": [h, w],
|
204 |
+
"counts": "rle_encoded_mask"
|
205 |
+
},
|
206 |
+
"score": confidence score
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"box_format": "xyxy",
|
210 |
+
"img_width": w,
|
211 |
+
"img_height": h
|
212 |
+
}
|
213 |
+
```
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
### Grounded SAM 2 Video Object Tracking Demo
|
218 |
+
|
219 |
+
Based on the strong tracking capability of SAM 2, we can combined it with Grounding DINO for open-set object segmentation and tracking. You can run the following scripts to get the tracking results with Grounded SAM 2:
|
220 |
+
|
221 |
+
```bash
|
222 |
+
python grounded_sam2_tracking_demo.py
|
223 |
+
```
|
224 |
+
|
225 |
+
- The tracking results of each frame will be saved in `./tracking_results`
|
226 |
+
- The video will be save as `children_tracking_demo_video.mp4`
|
227 |
+
- You can refine this file with different text prompt and video clips yourself to get more tracking results.
|
228 |
+
- We only prompt the first video frame with Grounding DINO here for simple usage.
|
229 |
+
|
230 |
+
#### Support Various Prompt Type for Tracking
|
231 |
+
|
232 |
+
We've supported different types of prompt for Grounded SAM 2 tracking demo:
|
233 |
+
|
234 |
+
- **Point Prompt**: In order to **get a stable segmentation results**, we re-use the SAM 2 image predictor to get the prediction mask from each object based on Grounding DINO box outputs, then we **uniformly sample points from the prediction mask** as point prompts for SAM 2 video predictor
|
235 |
+
- **Box Prompt**: We directly use the box outputs from Grounding DINO as box prompts for SAM 2 video predictor
|
236 |
+
- **Mask Prompt**: We use the SAM 2 mask prediction results based on Grounding DINO box outputs as mask prompt for SAM 2 video predictor.
|
237 |
+
|
238 |
+

|
239 |
+
|
240 |
+
|
241 |
+
### Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)
|
242 |
+
|
243 |
+
We've also support video object tracking demo based on our stronger `Grounding DINO 1.5` model and `SAM 2`, you can try the following demo after applying the API keys for running `Grounding DINO 1.5`:
|
244 |
+
|
245 |
+
```bash
|
246 |
+
python grounded_sam2_tracking_demo_with_gd1.5.py
|
247 |
+
```
|
248 |
+
|
249 |
+
### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO)
|
250 |
+
|
251 |
+
Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO and SAM 2 by using the following scripts:
|
252 |
+
|
253 |
+
```bash
|
254 |
+
python grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py
|
255 |
+
```
|
256 |
+
|
257 |
+
If you are not convenient to use huggingface demo, you can also run tracking demo with local grounding dino model with the following scripts:
|
258 |
+
|
259 |
+
```bash
|
260 |
+
python grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
|
261 |
+
```
|
262 |
+
|
263 |
+
### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO 1.5 & 1.6)
|
264 |
+
|
265 |
+
Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO 1.5 and SAM 2 by using the following scripts:
|
266 |
+
|
267 |
+
```bash
|
268 |
+
python grounded_sam2_tracking_demo_custom_video_input_gd1.5.py
|
269 |
+
```
|
270 |
+
|
271 |
+
You can specify the params in this file:
|
272 |
+
|
273 |
+
```python
|
274 |
+
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
275 |
+
TEXT_PROMPT = "hippopotamus."
|
276 |
+
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
277 |
+
API_TOKEN_FOR_GD1_5 = "Your API token" # api token for G-DINO 1.5
|
278 |
+
PROMPT_TYPE_FOR_VIDEO = "mask" # using SAM 2 mask prediction as prompt for video predictor
|
279 |
+
```
|
280 |
+
|
281 |
+
After running our demo code, you can get the tracking results as follows:
|
282 |
+
|
283 |
+
[](https://github.com/user-attachments/assets/1fbdc6f4-3e50-4221-9600-98c397beecdf)
|
284 |
+
|
285 |
+
And we will automatically save the tracking visualization results in `OUTPUT_VIDEO_PATH`.
|
286 |
+
|
287 |
+
> [!WARNING]
|
288 |
+
> We initialize the box prompts on the first frame of the input video. If you want to start from different frame, you can refine `ann_frame_idx` by yourself in our code.
|
289 |
+
|
290 |
+
### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with DINO-X)
|
291 |
+
|
292 |
+
Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with DINO-X and SAM 2 by using the following scripts:
|
293 |
+
|
294 |
+
```bash
|
295 |
+
python grounded_sam2_tracking_demo_custom_video_input_dinox.py
|
296 |
+
```
|
297 |
+
|
298 |
+
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
|
299 |
+
|
300 |
+
In above demos, we only prompt Grounded SAM 2 in specific frame, which may not be friendly to find new object during the whole video. In this demo, we try to **find new objects** and assign them with new ID across the whole video, this function is **still under develop**. it's not that stable now.
|
301 |
+
|
302 |
+
Users can upload their own video files and specify custom text prompts for grounding and tracking using the Grounding DINO and SAM 2 frameworks. To do this, execute the script:
|
303 |
+
|
304 |
+
|
305 |
+
```bash
|
306 |
+
python grounded_sam2_tracking_demo_with_continuous_id.py
|
307 |
+
```
|
308 |
+
|
309 |
+
You can customize various parameters including:
|
310 |
+
|
311 |
+
- `text`: The grounding text prompt.
|
312 |
+
- `video_dir`: Directory containing the video files.
|
313 |
+
- `output_dir`: Directory to save the processed output.
|
314 |
+
- `output_video_path`: Path for the output video.
|
315 |
+
- `step`: Frame stepping for processing.
|
316 |
+
- `box_threshold`: box threshold for groundingdino model
|
317 |
+
- `text_threshold`: text threshold for groundingdino model
|
318 |
+
Note: This method supports only the mask type of text prompt.
|
319 |
+
|
320 |
+
After running our demo code, you can get the tracking results as follows:
|
321 |
+
|
322 |
+
[](https://github.com/user-attachments/assets/d3f91ad0-3d32-43c4-a0dc-0bed661415f4)
|
323 |
+
|
324 |
+
If you want to try `Grounding DINO 1.5` model, you can run the following scripts after setting your API token:
|
325 |
+
|
326 |
+
```bash
|
327 |
+
python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
|
328 |
+
```
|
329 |
+
|
330 |
+
### Grounded-SAM-2 Video Object Tracking with Continuous ID plus Reverse Tracking(with Grounding DINO)
|
331 |
+
This method could simply cover the whole lifetime of the object
|
332 |
+
```bash
|
333 |
+
python grounded_sam2_tracking_demo_with_continuous_id_plus.py
|
334 |
+
|
335 |
+
```
|
336 |
+
|
337 |
+
## Grounded SAM 2 Florence-2 Demos
|
338 |
+
### Grounded SAM 2 Florence-2 Image Demo
|
339 |
+
|
340 |
+
In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
|
341 |
+
|
342 |
+
[Florence-2](https://arxiv.org/abs/2311.06242) is a powerful vision foundation model by Microsoft which supports a series of vision tasks by prompting with special `task_prompt` includes but not limited to:
|
343 |
+
|
344 |
+
| Task | Task Prompt | Text Input | Task Introduction |
|
345 |
+
|:---:|:---:|:---:|:---:|
|
346 |
+
| Object Detection | `<OD>` | ✘ | Detect main objects with single category name |
|
347 |
+
| Dense Region Caption | `<DENSE_REGION_CAPTION>` | ✘ | Detect main objects with short description |
|
348 |
+
| Region Proposal | `<REGION_PROPOSAL>` | ✘ | Generate proposals without category name |
|
349 |
+
| Phrase Grounding | `<CAPTION_TO_PHRASE_GROUNDING>` | ✔ | Ground main objects in image mentioned in caption |
|
350 |
+
| Referring Expression Segmentation | `<REFERRING_EXPRESSION_SEGMENTATION>` | ✔ | Ground the object which is most related to the text input |
|
351 |
+
| Open Vocabulary Detection and Segmentation | `<OPEN_VOCABULARY_DETECTION>` | ✔ | Ground any object with text input |
|
352 |
+
|
353 |
+
|
354 |
+
Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo:
|
355 |
+
|
356 |
+
> [!NOTE]
|
357 |
+
> 🚨 If you encounter network issues while using the `HuggingFace` model, you can resolve them by setting the appropriate mirror source as `export HF_ENDPOINT=https://hf-mirror.com`
|
358 |
+
|
359 |
+
**Object Detection and Segmentation**
|
360 |
+
```bash
|
361 |
+
python grounded_sam2_florence2_image_demo.py \
|
362 |
+
--pipeline object_detection_segmentation \
|
363 |
+
--image_path ./notebooks/images/cars.jpg
|
364 |
+
```
|
365 |
+
|
366 |
+
**Dense Region Caption and Segmentation**
|
367 |
+
```bash
|
368 |
+
python grounded_sam2_florence2_image_demo.py \
|
369 |
+
--pipeline dense_region_caption_segmentation \
|
370 |
+
--image_path ./notebooks/images/cars.jpg
|
371 |
+
```
|
372 |
+
|
373 |
+
**Region Proposal and Segmentation**
|
374 |
+
```bash
|
375 |
+
python grounded_sam2_florence2_image_demo.py \
|
376 |
+
--pipeline region_proposal_segmentation \
|
377 |
+
--image_path ./notebooks/images/cars.jpg
|
378 |
+
```
|
379 |
+
|
380 |
+
**Phrase Grounding and Segmentation**
|
381 |
+
```bash
|
382 |
+
python grounded_sam2_florence2_image_demo.py \
|
383 |
+
--pipeline phrase_grounding_segmentation \
|
384 |
+
--image_path ./notebooks/images/cars.jpg \
|
385 |
+
--text_input "The image shows two vintage Chevrolet cars parked side by side, with one being a red convertible and the other a pink sedan, \
|
386 |
+
set against the backdrop of an urban area with a multi-story building and trees. \
|
387 |
+
The cars have Cuban license plates, indicating a location likely in Cuba."
|
388 |
+
```
|
389 |
+
|
390 |
+
**Referring Expression Segmentation**
|
391 |
+
```bash
|
392 |
+
python grounded_sam2_florence2_image_demo.py \
|
393 |
+
--pipeline referring_expression_segmentation \
|
394 |
+
--image_path ./notebooks/images/cars.jpg \
|
395 |
+
--text_input "The left red car."
|
396 |
+
```
|
397 |
+
|
398 |
+
**Open-Vocabulary Detection and Segmentation**
|
399 |
+
```bash
|
400 |
+
python grounded_sam2_florence2_image_demo.py \
|
401 |
+
--pipeline open_vocabulary_detection_segmentation \
|
402 |
+
--image_path ./notebooks/images/cars.jpg \
|
403 |
+
--text_input "car <and> building"
|
404 |
+
```
|
405 |
+
- Note that if you want to **detect multiple classes** you should split them with `<and>` in your input text.
|
406 |
+
|
407 |
+
|
408 |
+
### Grounded SAM 2 Florence-2 Image Auto-Labeling Demo
|
409 |
+
`Florence-2` can be used as a auto image annotator by cascading its caption capability with its grounding capability.
|
410 |
+
|
411 |
+
| Task | Task Prompt | Text Input |
|
412 |
+
|:---:|:---:|:---:|
|
413 |
+
| Caption + Phrase Grounding | `<CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
414 |
+
| Detailed Caption + Phrase Grounding | `<DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
415 |
+
| More Detailed Caption + Phrase Grounding | `<MORE_DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
416 |
+
|
417 |
+
You can try the following scripts to run these demo:
|
418 |
+
|
419 |
+
**Caption to Phrase Grounding**
|
420 |
+
```bash
|
421 |
+
python grounded_sam2_florence2_autolabel_pipeline.py \
|
422 |
+
--image_path ./notebooks/images/groceries.jpg \
|
423 |
+
--pipeline caption_to_phrase_grounding \
|
424 |
+
--caption_type caption
|
425 |
+
```
|
426 |
+
|
427 |
+
- You can specify `caption_type` to control the granularity of the caption, if you want a more detailed caption, you can try `--caption_type detailed_caption` or `--caption_type more_detailed_caption`.
|
428 |
+
|
429 |
+
### Citation
|
430 |
+
|
431 |
+
If you find this project helpful for your research, please consider citing the following BibTeX entry.
|
432 |
+
|
433 |
+
```BibTex
|
434 |
+
@misc{ravi2024sam2segmentimages,
|
435 |
+
title={SAM 2: Segment Anything in Images and Videos},
|
436 |
+
author={Nikhila Ravi and Valentin Gabeur and Yuan-Ting Hu and Ronghang Hu and Chaitanya Ryali and Tengyu Ma and Haitham Khedr and Roman Rädle and Chloe Rolland and Laura Gustafson and Eric Mintun and Junting Pan and Kalyan Vasudev Alwala and Nicolas Carion and Chao-Yuan Wu and Ross Girshick and Piotr Dollár and Christoph Feichtenhofer},
|
437 |
+
year={2024},
|
438 |
+
eprint={2408.00714},
|
439 |
+
archivePrefix={arXiv},
|
440 |
+
primaryClass={cs.CV},
|
441 |
+
url={https://arxiv.org/abs/2408.00714},
|
442 |
+
}
|
443 |
+
|
444 |
+
@article{liu2023grounding,
|
445 |
+
title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
|
446 |
+
author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
|
447 |
+
journal={arXiv preprint arXiv:2303.05499},
|
448 |
+
year={2023}
|
449 |
+
}
|
450 |
+
|
451 |
+
@misc{ren2024grounding,
|
452 |
+
title={Grounding DINO 1.5: Advance the "Edge" of Open-Set Object Detection},
|
453 |
+
author={Tianhe Ren and Qing Jiang and Shilong Liu and Zhaoyang Zeng and Wenlong Liu and Han Gao and Hongjie Huang and Zhengyu Ma and Xiaoke Jiang and Yihao Chen and Yuda Xiong and Hao Zhang and Feng Li and Peijun Tang and Kent Yu and Lei Zhang},
|
454 |
+
year={2024},
|
455 |
+
eprint={2405.10300},
|
456 |
+
archivePrefix={arXiv},
|
457 |
+
primaryClass={cs.CV}
|
458 |
+
}
|
459 |
+
|
460 |
+
@misc{ren2024grounded,
|
461 |
+
title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
|
462 |
+
author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},
|
463 |
+
year={2024},
|
464 |
+
eprint={2401.14159},
|
465 |
+
archivePrefix={arXiv},
|
466 |
+
primaryClass={cs.CV}
|
467 |
+
}
|
468 |
+
|
469 |
+
@article{kirillov2023segany,
|
470 |
+
title={Segment Anything},
|
471 |
+
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
|
472 |
+
journal={arXiv:2304.02643},
|
473 |
+
year={2023}
|
474 |
+
}
|
475 |
+
|
476 |
+
@misc{jiang2024trex2,
|
477 |
+
title={T-Rex2: Towards Generic Object Detection via Text-Visual Prompt Synergy},
|
478 |
+
author={Qing Jiang and Feng Li and Zhaoyang Zeng and Tianhe Ren and Shilong Liu and Lei Zhang},
|
479 |
+
year={2024},
|
480 |
+
eprint={2403.14610},
|
481 |
+
archivePrefix={arXiv},
|
482 |
+
primaryClass={cs.CV}
|
483 |
+
}
|
484 |
+
```
|
clone-IDEA-Research/Grounded-SAM-2/SAM2_README.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SAM 2: Segment Anything in Images and Videos
|
2 |
+
|
3 |
+
**[AI at Meta, FAIR](https://ai.meta.com/research/)**
|
4 |
+
|
5 |
+
[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
|
6 |
+
|
7 |
+
[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
|
8 |
+
|
9 |
+

|
10 |
+
|
11 |
+
**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
## Installation
|
16 |
+
|
17 |
+
Please install SAM 2 on a GPU machine using:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
git clone https://github.com/facebookresearch/segment-anything-2.git
|
21 |
+
|
22 |
+
cd segment-anything-2; pip install -e .
|
23 |
+
```
|
24 |
+
|
25 |
+
To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
pip install -e ".[demo]"
|
29 |
+
```
|
30 |
+
|
31 |
+
## Getting Started
|
32 |
+
|
33 |
+
### Download Checkpoints
|
34 |
+
|
35 |
+
First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
|
36 |
+
|
37 |
+
```bash
|
38 |
+
cd checkpoints
|
39 |
+
./download_ckpts.sh
|
40 |
+
```
|
41 |
+
|
42 |
+
or individually from:
|
43 |
+
|
44 |
+
- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
|
45 |
+
- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
|
46 |
+
- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)
|
47 |
+
- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)
|
48 |
+
|
49 |
+
Then SAM 2 can be used in a few lines as follows for image and video prediction.
|
50 |
+
|
51 |
+
### Image prediction
|
52 |
+
|
53 |
+
SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
|
54 |
+
|
55 |
+
```python
|
56 |
+
import torch
|
57 |
+
from sam2.build_sam import build_sam2
|
58 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
59 |
+
|
60 |
+
checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
61 |
+
model_cfg = "sam2_hiera_l.yaml"
|
62 |
+
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
|
63 |
+
|
64 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
65 |
+
predictor.set_image(<your_image>)
|
66 |
+
masks, _, _ = predictor.predict(<input_prompts>)
|
67 |
+
```
|
68 |
+
|
69 |
+
Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases.
|
70 |
+
|
71 |
+
SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images.
|
72 |
+
|
73 |
+
### Video prediction
|
74 |
+
|
75 |
+
For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
|
76 |
+
|
77 |
+
```python
|
78 |
+
import torch
|
79 |
+
from sam2.build_sam import build_sam2_video_predictor
|
80 |
+
|
81 |
+
checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
82 |
+
model_cfg = "sam2_hiera_l.yaml"
|
83 |
+
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
|
84 |
+
|
85 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
86 |
+
state = predictor.init_state(<your_video>)
|
87 |
+
|
88 |
+
# add new prompts and instantly get the output on the same frame
|
89 |
+
frame_idx, object_ids, masks = predictor.add_new_points(state, <your prompts>):
|
90 |
+
|
91 |
+
# propagate the prompts to get masklets throughout the video
|
92 |
+
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
93 |
+
...
|
94 |
+
```
|
95 |
+
|
96 |
+
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
|
97 |
+
|
98 |
+
## Model Description
|
99 |
+
|
100 |
+
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
101 |
+
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
102 |
+
| sam2_hiera_tiny | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
|
103 |
+
| sam2_hiera_small | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
|
104 |
+
| sam2_hiera_base_plus | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
|
105 |
+
| sam2_hiera_large | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
|
106 |
+
|
107 |
+
\* Compile the model by setting `compile_image_encoder: True` in the config.
|
108 |
+
|
109 |
+
## Segment Anything Video Dataset
|
110 |
+
|
111 |
+
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
112 |
+
|
113 |
+
## License
|
114 |
+
|
115 |
+
The models are licensed under the [Apache 2.0 license](./LICENSE). Please refer to our research paper for more details on the models.
|
116 |
+
|
117 |
+
## Contributing
|
118 |
+
|
119 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
120 |
+
|
121 |
+
## Contributors
|
122 |
+
|
123 |
+
The SAM 2 project was made possible with the help of many contributors (alphabetical):
|
124 |
+
|
125 |
+
Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
|
126 |
+
|
127 |
+
Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
|
128 |
+
|
129 |
+
## Citing SAM 2
|
130 |
+
|
131 |
+
If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
|
132 |
+
|
133 |
+
```bibtex
|
134 |
+
@article{ravi2024sam2,
|
135 |
+
title={SAM 2: Segment Anything in Images and Videos},
|
136 |
+
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
|
137 |
+
journal={arXiv preprint},
|
138 |
+
year={2024}
|
139 |
+
}
|
140 |
+
```
|
clone-IDEA-Research/Grounded-SAM-2/backend.Dockerfile
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
|
2 |
+
ARG MODEL_SIZE=base_plus
|
3 |
+
|
4 |
+
FROM ${BASE_IMAGE}
|
5 |
+
|
6 |
+
# Gunicorn environment variables
|
7 |
+
ENV GUNICORN_WORKERS=1
|
8 |
+
ENV GUNICORN_THREADS=2
|
9 |
+
ENV GUNICORN_PORT=5000
|
10 |
+
|
11 |
+
# SAM 2 environment variables
|
12 |
+
ENV APP_ROOT=/opt/sam2
|
13 |
+
ENV PYTHONUNBUFFERED=1
|
14 |
+
ENV SAM2_BUILD_CUDA=0
|
15 |
+
ENV MODEL_SIZE=${MODEL_SIZE}
|
16 |
+
|
17 |
+
# Install system requirements
|
18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
19 |
+
ffmpeg \
|
20 |
+
libavutil-dev \
|
21 |
+
libavcodec-dev \
|
22 |
+
libavformat-dev \
|
23 |
+
libswscale-dev \
|
24 |
+
pkg-config \
|
25 |
+
build-essential \
|
26 |
+
libffi-dev
|
27 |
+
|
28 |
+
COPY setup.py .
|
29 |
+
COPY README.md .
|
30 |
+
|
31 |
+
RUN pip install --upgrade pip setuptools
|
32 |
+
RUN pip install -e ".[interactive-demo]"
|
33 |
+
|
34 |
+
# https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707
|
35 |
+
RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg
|
36 |
+
|
37 |
+
# Make app directory. This directory will host all files required for the
|
38 |
+
# backend and SAM 2 inference files.
|
39 |
+
RUN mkdir ${APP_ROOT}
|
40 |
+
|
41 |
+
# Copy backend server files
|
42 |
+
COPY demo/backend/server ${APP_ROOT}/server
|
43 |
+
|
44 |
+
# Copy SAM 2 inference files
|
45 |
+
COPY sam2 ${APP_ROOT}/server/sam2
|
46 |
+
|
47 |
+
# Download SAM 2.1 checkpoints
|
48 |
+
ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt
|
49 |
+
ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt
|
50 |
+
ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt
|
51 |
+
ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt
|
52 |
+
|
53 |
+
WORKDIR ${APP_ROOT}/server
|
54 |
+
|
55 |
+
# https://pythonspeed.com/articles/gunicorn-in-docker/
|
56 |
+
CMD gunicorn --worker-tmp-dir /dev/shm \
|
57 |
+
--worker-class gthread app:app \
|
58 |
+
--log-level info \
|
59 |
+
--access-logfile /dev/stdout \
|
60 |
+
--log-file /dev/stderr \
|
61 |
+
--workers ${GUNICORN_WORKERS} \
|
62 |
+
--threads ${GUNICORN_THREADS} \
|
63 |
+
--bind 0.0.0.0:${GUNICORN_PORT} \
|
64 |
+
--timeout 60
|
clone-IDEA-Research/Grounded-SAM-2/docker-compose.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
frontend:
|
3 |
+
image: sam2/frontend
|
4 |
+
build:
|
5 |
+
context: ./demo/frontend
|
6 |
+
dockerfile: frontend.Dockerfile
|
7 |
+
ports:
|
8 |
+
- 7262:80
|
9 |
+
|
10 |
+
backend:
|
11 |
+
image: sam2/backend
|
12 |
+
build:
|
13 |
+
context: .
|
14 |
+
dockerfile: backend.Dockerfile
|
15 |
+
ports:
|
16 |
+
- 7263:5000
|
17 |
+
volumes:
|
18 |
+
- ./demo/data/:/data/:rw
|
19 |
+
environment:
|
20 |
+
- SERVER_ENVIRONMENT=DEV
|
21 |
+
- GUNICORN_WORKERS=1
|
22 |
+
# Inference API needs to have at least 2 threads to handle an incoming
|
23 |
+
# parallel cancel propagation request
|
24 |
+
- GUNICORN_THREADS=2
|
25 |
+
- GUNICORN_PORT=5000
|
26 |
+
- API_URL=http://localhost:7263
|
27 |
+
- DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4
|
28 |
+
# # ffmpeg/video encode settings
|
29 |
+
- FFMPEG_NUM_THREADS=1
|
30 |
+
- VIDEO_ENCODE_CODEC=libx264
|
31 |
+
- VIDEO_ENCODE_CRF=23
|
32 |
+
- VIDEO_ENCODE_FPS=24
|
33 |
+
- VIDEO_ENCODE_MAX_WIDTH=1280
|
34 |
+
- VIDEO_ENCODE_MAX_HEIGHT=720
|
35 |
+
- VIDEO_ENCODE_VERBOSE=False
|
36 |
+
deploy:
|
37 |
+
resources:
|
38 |
+
reservations:
|
39 |
+
devices:
|
40 |
+
- driver: nvidia
|
41 |
+
count: 1
|
42 |
+
capabilities: [gpu]
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_dinox_demo.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
5 |
+
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
6 |
+
from dds_cloudapi_sdk import TextPrompt
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import json
|
11 |
+
import torch
|
12 |
+
import tempfile
|
13 |
+
import numpy as np
|
14 |
+
import supervision as sv
|
15 |
+
import pycocotools.mask as mask_util
|
16 |
+
from pathlib import Path
|
17 |
+
from PIL import Image
|
18 |
+
from sam2.build_sam import build_sam2
|
19 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
20 |
+
|
21 |
+
"""
|
22 |
+
Hyper parameters
|
23 |
+
"""
|
24 |
+
API_TOKEN = "Your API token"
|
25 |
+
TEXT_PROMPT = "car . building ."
|
26 |
+
IMG_PATH = "notebooks/images/cars.jpg"
|
27 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
28 |
+
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
29 |
+
BOX_THRESHOLD = 0.2
|
30 |
+
WITH_SLICE_INFERENCE = False
|
31 |
+
SLICE_WH = (480, 480)
|
32 |
+
OVERLAP_RATIO = (0.2, 0.2)
|
33 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
35 |
+
DUMP_JSON_RESULTS = True
|
36 |
+
|
37 |
+
# create output directory
|
38 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
39 |
+
|
40 |
+
"""
|
41 |
+
Prompt DINO-X with Text for Box Prompt Generation with Cloud API
|
42 |
+
"""
|
43 |
+
# Step 1: initialize the config
|
44 |
+
token = API_TOKEN
|
45 |
+
config = Config(token)
|
46 |
+
|
47 |
+
# Step 2: initialize the client
|
48 |
+
client = Client(config)
|
49 |
+
|
50 |
+
# Step 3: run the task by DetectionTask class
|
51 |
+
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
52 |
+
# if you are processing local image file, upload them to DDS server to get the image url
|
53 |
+
|
54 |
+
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
55 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
56 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
57 |
+
|
58 |
+
if WITH_SLICE_INFERENCE:
|
59 |
+
def callback(image_slice: np.ndarray) -> sv.Detections:
|
60 |
+
print("Inference on image slice")
|
61 |
+
# save the img as temp img file for GD-1.5 API usage
|
62 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
63 |
+
temp_filename = tmpfile.name
|
64 |
+
cv2.imwrite(temp_filename, image_slice)
|
65 |
+
image_url = client.upload_file(temp_filename)
|
66 |
+
task = DinoxTask(
|
67 |
+
image_url=image_url,
|
68 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
69 |
+
bbox_threshold=0.25,
|
70 |
+
targets=[DetectionTarget.BBox],
|
71 |
+
)
|
72 |
+
client.run_task(task)
|
73 |
+
result = task.result
|
74 |
+
# detele the tempfile
|
75 |
+
os.remove(temp_filename)
|
76 |
+
|
77 |
+
input_boxes = []
|
78 |
+
confidences = []
|
79 |
+
class_ids = []
|
80 |
+
objects = result.objects
|
81 |
+
for idx, obj in enumerate(objects):
|
82 |
+
input_boxes.append(obj.bbox)
|
83 |
+
confidences.append(obj.score)
|
84 |
+
cls_name = obj.category.lower().strip()
|
85 |
+
class_ids.append(class_name_to_id[cls_name])
|
86 |
+
# ensure input_boxes with shape (_, 4)
|
87 |
+
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
88 |
+
class_ids = np.array(class_ids)
|
89 |
+
confidences = np.array(confidences)
|
90 |
+
return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
|
91 |
+
|
92 |
+
slicer = sv.InferenceSlicer(
|
93 |
+
callback=callback,
|
94 |
+
slice_wh=SLICE_WH,
|
95 |
+
overlap_ratio_wh=OVERLAP_RATIO,
|
96 |
+
iou_threshold=0.5,
|
97 |
+
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
98 |
+
)
|
99 |
+
detections = slicer(cv2.imread(IMG_PATH))
|
100 |
+
class_names = [class_id_to_name[id] for id in detections.class_id]
|
101 |
+
confidences = detections.confidence
|
102 |
+
class_ids = detections.class_id
|
103 |
+
input_boxes = detections.xyxy
|
104 |
+
else:
|
105 |
+
image_url = client.upload_file(IMG_PATH)
|
106 |
+
|
107 |
+
task = DinoxTask(
|
108 |
+
image_url=image_url,
|
109 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
110 |
+
bbox_threshold=0.25,
|
111 |
+
targets=[DetectionTarget.BBox],
|
112 |
+
)
|
113 |
+
|
114 |
+
client.run_task(task)
|
115 |
+
result = task.result
|
116 |
+
|
117 |
+
objects = result.objects # the list of detected objects
|
118 |
+
|
119 |
+
|
120 |
+
input_boxes = []
|
121 |
+
confidences = []
|
122 |
+
class_names = []
|
123 |
+
class_ids = []
|
124 |
+
|
125 |
+
for idx, obj in enumerate(objects):
|
126 |
+
input_boxes.append(obj.bbox)
|
127 |
+
confidences.append(obj.score)
|
128 |
+
cls_name = obj.category.lower().strip()
|
129 |
+
class_names.append(cls_name)
|
130 |
+
class_ids.append(class_name_to_id[cls_name])
|
131 |
+
|
132 |
+
input_boxes = np.array(input_boxes)
|
133 |
+
class_ids = np.array(class_ids)
|
134 |
+
|
135 |
+
"""
|
136 |
+
Init SAM 2 Model and Predict Mask with Box Prompt
|
137 |
+
"""
|
138 |
+
|
139 |
+
# environment settings
|
140 |
+
# use bfloat16
|
141 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
142 |
+
|
143 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
144 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
145 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
146 |
+
torch.backends.cudnn.allow_tf32 = True
|
147 |
+
|
148 |
+
# build SAM2 image predictor
|
149 |
+
sam2_checkpoint = SAM2_CHECKPOINT
|
150 |
+
model_cfg = SAM2_MODEL_CONFIG
|
151 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
152 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
153 |
+
|
154 |
+
image = Image.open(IMG_PATH)
|
155 |
+
|
156 |
+
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
157 |
+
|
158 |
+
masks, scores, logits = sam2_predictor.predict(
|
159 |
+
point_coords=None,
|
160 |
+
point_labels=None,
|
161 |
+
box=input_boxes,
|
162 |
+
multimask_output=False,
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
"""
|
167 |
+
Post-process the output of the model to get the masks, scores, and logits for visualization
|
168 |
+
"""
|
169 |
+
# convert the shape to (n, H, W)
|
170 |
+
if masks.ndim == 4:
|
171 |
+
masks = masks.squeeze(1)
|
172 |
+
|
173 |
+
|
174 |
+
"""
|
175 |
+
Visualization the Predict Results
|
176 |
+
"""
|
177 |
+
|
178 |
+
labels = [
|
179 |
+
f"{class_name} {confidence:.2f}"
|
180 |
+
for class_name, confidence
|
181 |
+
in zip(class_names, confidences)
|
182 |
+
]
|
183 |
+
|
184 |
+
"""
|
185 |
+
Visualize image with supervision useful API
|
186 |
+
"""
|
187 |
+
img = cv2.imread(IMG_PATH)
|
188 |
+
detections = sv.Detections(
|
189 |
+
xyxy=input_boxes, # (n, 4)
|
190 |
+
mask=masks.astype(bool), # (n, h, w)
|
191 |
+
class_id=class_ids
|
192 |
+
)
|
193 |
+
|
194 |
+
box_annotator = sv.BoxAnnotator()
|
195 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
196 |
+
|
197 |
+
label_annotator = sv.LabelAnnotator()
|
198 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
199 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "dinox_annotated_image.jpg"), annotated_frame)
|
200 |
+
|
201 |
+
mask_annotator = sv.MaskAnnotator()
|
202 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
203 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "dinox_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
204 |
+
|
205 |
+
print(f'Annotated image has already been saved as to "{OUTPUT_DIR}"')
|
206 |
+
|
207 |
+
"""
|
208 |
+
Dump the results in standard format and save as json files
|
209 |
+
"""
|
210 |
+
|
211 |
+
def single_mask_to_rle(mask):
|
212 |
+
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
213 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
214 |
+
return rle
|
215 |
+
|
216 |
+
if DUMP_JSON_RESULTS:
|
217 |
+
print("Start dumping the annotation...")
|
218 |
+
# convert mask into rle format
|
219 |
+
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
220 |
+
|
221 |
+
input_boxes = input_boxes.tolist()
|
222 |
+
scores = scores.tolist()
|
223 |
+
# FIXME: class_names should be a list of strings without spaces
|
224 |
+
class_names = [class_name.strip() for class_name in class_names]
|
225 |
+
# save the results in standard format
|
226 |
+
results = {
|
227 |
+
"image_path": IMG_PATH,
|
228 |
+
"annotations" : [
|
229 |
+
{
|
230 |
+
"class_name": class_name,
|
231 |
+
"bbox": box,
|
232 |
+
"segmentation": mask_rle,
|
233 |
+
"score": score,
|
234 |
+
}
|
235 |
+
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
236 |
+
],
|
237 |
+
"box_format": "xyxy",
|
238 |
+
"img_width": image.width,
|
239 |
+
"img_height": image.height,
|
240 |
+
}
|
241 |
+
|
242 |
+
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_dinox_image_demo_results.json"), "w") as f:
|
243 |
+
json.dump(results, f, indent=4)
|
244 |
+
|
245 |
+
print(f'Annotation has already been saved to "{OUTPUT_DIR}"')
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_autolabel_pipeline.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import supervision as sv
|
7 |
+
from PIL import Image
|
8 |
+
from sam2.build_sam import build_sam2
|
9 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
11 |
+
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
12 |
+
|
13 |
+
"""
|
14 |
+
Define Some Hyperparam
|
15 |
+
"""
|
16 |
+
|
17 |
+
TASK_PROMPT = {
|
18 |
+
"caption": "<CAPTION>",
|
19 |
+
"detailed_caption": "<DETAILED_CAPTION>",
|
20 |
+
"more_detailed_caption": "<MORE_DETAILED_CAPTION>",
|
21 |
+
"object_detection": "<OD>",
|
22 |
+
"dense_region_caption": "<DENSE_REGION_CAPTION>",
|
23 |
+
"region_proposal": "<REGION_PROPOSAL>",
|
24 |
+
"phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
25 |
+
"referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
|
26 |
+
"region_to_segmentation": "<REGION_TO_SEGMENTATION>",
|
27 |
+
"open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
|
28 |
+
"region_to_category": "<REGION_TO_CATEGORY>",
|
29 |
+
"region_to_description": "<REGION_TO_DESCRIPTION>",
|
30 |
+
"ocr": "<OCR>",
|
31 |
+
"ocr_with_region": "<OCR_WITH_REGION>",
|
32 |
+
}
|
33 |
+
|
34 |
+
OUTPUT_DIR = "./outputs"
|
35 |
+
|
36 |
+
if not os.path.exists(OUTPUT_DIR):
|
37 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
38 |
+
|
39 |
+
"""
|
40 |
+
Init Florence-2 and SAM 2 Model
|
41 |
+
"""
|
42 |
+
|
43 |
+
FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
|
44 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
45 |
+
SAM2_CONFIG = "sam2_hiera_l.yaml"
|
46 |
+
|
47 |
+
# environment settings
|
48 |
+
# use bfloat16
|
49 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
50 |
+
|
51 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
52 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
53 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
54 |
+
torch.backends.cudnn.allow_tf32 = True
|
55 |
+
|
56 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
57 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
58 |
+
|
59 |
+
# build florence-2
|
60 |
+
florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
|
61 |
+
florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
|
62 |
+
|
63 |
+
# build sam 2
|
64 |
+
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
|
65 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
66 |
+
|
67 |
+
def run_florence2(task_prompt, text_input, model, processor, image):
|
68 |
+
assert model is not None, "You should pass the init florence-2 model here"
|
69 |
+
assert processor is not None, "You should set florence-2 processor here"
|
70 |
+
|
71 |
+
device = model.device
|
72 |
+
|
73 |
+
if text_input is None:
|
74 |
+
prompt = task_prompt
|
75 |
+
else:
|
76 |
+
prompt = task_prompt + text_input
|
77 |
+
|
78 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
|
79 |
+
generated_ids = model.generate(
|
80 |
+
input_ids=inputs["input_ids"].to(device),
|
81 |
+
pixel_values=inputs["pixel_values"].to(device),
|
82 |
+
max_new_tokens=1024,
|
83 |
+
early_stopping=False,
|
84 |
+
do_sample=False,
|
85 |
+
num_beams=3,
|
86 |
+
)
|
87 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
88 |
+
parsed_answer = processor.post_process_generation(
|
89 |
+
generated_text,
|
90 |
+
task=task_prompt,
|
91 |
+
image_size=(image.width, image.height)
|
92 |
+
)
|
93 |
+
return parsed_answer
|
94 |
+
|
95 |
+
|
96 |
+
"""
|
97 |
+
We try to support a series of cascaded auto-labelling pipelines with Florence-2 and SAM 2
|
98 |
+
"""
|
99 |
+
|
100 |
+
"""
|
101 |
+
Auto-Labelling Pipeline 1: Caption/Detailed Caption/More Detailed Caption + Phrase Grounding + Segmentation
|
102 |
+
"""
|
103 |
+
def caption_phrase_grounding_and_segmentation(
|
104 |
+
florence2_model,
|
105 |
+
florence2_processor,
|
106 |
+
sam2_predictor,
|
107 |
+
image_path,
|
108 |
+
caption_task_prompt='<CAPTION>',
|
109 |
+
output_dir=OUTPUT_DIR
|
110 |
+
):
|
111 |
+
assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"]
|
112 |
+
image = Image.open(image_path).convert("RGB")
|
113 |
+
|
114 |
+
# image caption
|
115 |
+
caption_results = run_florence2(caption_task_prompt, None, florence2_model, florence2_processor, image)
|
116 |
+
text_input = caption_results[caption_task_prompt]
|
117 |
+
print(f'Image caption for "{image_path}": ', text_input)
|
118 |
+
|
119 |
+
# phrase grounding
|
120 |
+
grounding_results = run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', text_input, florence2_model, florence2_processor, image)
|
121 |
+
grounding_results = grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']
|
122 |
+
|
123 |
+
# parse florence-2 detection results
|
124 |
+
input_boxes = np.array(grounding_results["bboxes"])
|
125 |
+
class_names = grounding_results["labels"]
|
126 |
+
class_ids = np.array(list(range(len(class_names))))
|
127 |
+
|
128 |
+
# predict mask with SAM 2
|
129 |
+
sam2_predictor.set_image(np.array(image))
|
130 |
+
masks, scores, logits = sam2_predictor.predict(
|
131 |
+
point_coords=None,
|
132 |
+
point_labels=None,
|
133 |
+
box=input_boxes,
|
134 |
+
multimask_output=False,
|
135 |
+
)
|
136 |
+
|
137 |
+
if masks.ndim == 4:
|
138 |
+
masks = masks.squeeze(1)
|
139 |
+
|
140 |
+
# specify labels
|
141 |
+
labels = [
|
142 |
+
f"{class_name}" for class_name in class_names
|
143 |
+
]
|
144 |
+
|
145 |
+
# visualization results
|
146 |
+
img = cv2.imread(image_path)
|
147 |
+
detections = sv.Detections(
|
148 |
+
xyxy=input_boxes,
|
149 |
+
mask=masks.astype(bool),
|
150 |
+
class_id=class_ids
|
151 |
+
)
|
152 |
+
|
153 |
+
box_annotator = sv.BoxAnnotator()
|
154 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
155 |
+
|
156 |
+
label_annotator = sv.LabelAnnotator()
|
157 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
158 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling.jpg"), annotated_frame)
|
159 |
+
|
160 |
+
mask_annotator = sv.MaskAnnotator()
|
161 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
162 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling_with_mask.jpg"), annotated_frame)
|
163 |
+
|
164 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
|
169 |
+
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
170 |
+
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
|
171 |
+
parser.add_argument("--pipeline", type=str, default="caption_to_phrase_grounding", required=True, help="pipeline to use")
|
172 |
+
parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption")
|
173 |
+
args = parser.parse_args()
|
174 |
+
|
175 |
+
CAPTION_TO_TASK_PROMPT = {
|
176 |
+
"caption": "<CAPTION>",
|
177 |
+
"detailed_caption": "<DETAILED_CAPTION>",
|
178 |
+
"more_detailed_caption": "<MORE_DETAILED_CAPTION>"
|
179 |
+
}
|
180 |
+
|
181 |
+
IMAGE_PATH = args.image_path
|
182 |
+
PIPELINE = args.pipeline
|
183 |
+
CAPTION_TYPE = args.caption_type
|
184 |
+
assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption"]
|
185 |
+
|
186 |
+
print(f"Running pipeline: {PIPELINE} now.")
|
187 |
+
|
188 |
+
if PIPELINE == "caption_to_phrase_grounding":
|
189 |
+
# pipeline-1: caption + phrase grounding + segmentation
|
190 |
+
caption_phrase_grounding_and_segmentation(
|
191 |
+
florence2_model=florence2_model,
|
192 |
+
florence2_processor=florence2_processor,
|
193 |
+
sam2_predictor=sam2_predictor,
|
194 |
+
caption_task_prompt=CAPTION_TO_TASK_PROMPT[CAPTION_TYPE],
|
195 |
+
image_path=IMAGE_PATH
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_florence2_image_demo.py
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import supervision as sv
|
7 |
+
from PIL import Image
|
8 |
+
from sam2.build_sam import build_sam2
|
9 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
11 |
+
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
12 |
+
|
13 |
+
"""
|
14 |
+
Define Some Hyperparam
|
15 |
+
"""
|
16 |
+
|
17 |
+
TASK_PROMPT = {
|
18 |
+
"caption": "<CAPTION>",
|
19 |
+
"detailed_caption": "<DETAILED_CAPTION>",
|
20 |
+
"more_detailed_caption": "<MORE_DETAILED_CAPTION",
|
21 |
+
"object_detection": "<OD>",
|
22 |
+
"dense_region_caption": "<DENSE_REGION_CAPTION>",
|
23 |
+
"region_proposal": "<REGION_PROPOSAL>",
|
24 |
+
"phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
25 |
+
"referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
|
26 |
+
"region_to_segmentation": "<REGION_TO_SEGMENTATION>",
|
27 |
+
"open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
|
28 |
+
"region_to_category": "<REGION_TO_CATEGORY>",
|
29 |
+
"region_to_description": "<REGION_TO_DESCRIPTION>",
|
30 |
+
"ocr": "<OCR>",
|
31 |
+
"ocr_with_region": "<OCR_WITH_REGION>",
|
32 |
+
}
|
33 |
+
|
34 |
+
OUTPUT_DIR = "./outputs"
|
35 |
+
|
36 |
+
if not os.path.exists(OUTPUT_DIR):
|
37 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
38 |
+
|
39 |
+
"""
|
40 |
+
Init Florence-2 and SAM 2 Model
|
41 |
+
"""
|
42 |
+
|
43 |
+
FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
|
44 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
45 |
+
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
46 |
+
|
47 |
+
# environment settings
|
48 |
+
# use bfloat16
|
49 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
50 |
+
|
51 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
52 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
53 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
54 |
+
torch.backends.cudnn.allow_tf32 = True
|
55 |
+
|
56 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
57 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
58 |
+
|
59 |
+
# build florence-2
|
60 |
+
florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
|
61 |
+
florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
|
62 |
+
|
63 |
+
# build sam 2
|
64 |
+
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
|
65 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
66 |
+
|
67 |
+
def run_florence2(task_prompt, text_input, model, processor, image):
|
68 |
+
assert model is not None, "You should pass the init florence-2 model here"
|
69 |
+
assert processor is not None, "You should set florence-2 processor here"
|
70 |
+
|
71 |
+
device = model.device
|
72 |
+
|
73 |
+
if text_input is None:
|
74 |
+
prompt = task_prompt
|
75 |
+
else:
|
76 |
+
prompt = task_prompt + text_input
|
77 |
+
|
78 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
|
79 |
+
generated_ids = model.generate(
|
80 |
+
input_ids=inputs["input_ids"].to(device),
|
81 |
+
pixel_values=inputs["pixel_values"].to(device),
|
82 |
+
max_new_tokens=1024,
|
83 |
+
early_stopping=False,
|
84 |
+
do_sample=False,
|
85 |
+
num_beams=3,
|
86 |
+
)
|
87 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
88 |
+
parsed_answer = processor.post_process_generation(
|
89 |
+
generated_text,
|
90 |
+
task=task_prompt,
|
91 |
+
image_size=(image.width, image.height)
|
92 |
+
)
|
93 |
+
return parsed_answer
|
94 |
+
|
95 |
+
|
96 |
+
"""
|
97 |
+
We support a set of pipelines built by Florence-2 + SAM 2
|
98 |
+
"""
|
99 |
+
|
100 |
+
"""
|
101 |
+
Pipeline-1: Object Detection + Segmentation
|
102 |
+
"""
|
103 |
+
def object_detection_and_segmentation(
|
104 |
+
florence2_model,
|
105 |
+
florence2_processor,
|
106 |
+
sam2_predictor,
|
107 |
+
image_path,
|
108 |
+
task_prompt="<OD>",
|
109 |
+
text_input=None,
|
110 |
+
output_dir=OUTPUT_DIR
|
111 |
+
):
|
112 |
+
assert text_input is None, "Text input should be None when calling object detection pipeline."
|
113 |
+
# run florence-2 object detection in demo
|
114 |
+
image = Image.open(image_path).convert("RGB")
|
115 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
116 |
+
|
117 |
+
""" Florence-2 Object Detection Output Format
|
118 |
+
{'<OD>':
|
119 |
+
{
|
120 |
+
'bboxes':
|
121 |
+
[
|
122 |
+
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
|
123 |
+
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
|
124 |
+
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
|
125 |
+
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
|
126 |
+
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
|
127 |
+
],
|
128 |
+
'labels': ['car', 'door', 'door', 'wheel', 'wheel']
|
129 |
+
}
|
130 |
+
}
|
131 |
+
"""
|
132 |
+
results = results[task_prompt]
|
133 |
+
# parse florence-2 detection results
|
134 |
+
input_boxes = np.array(results["bboxes"])
|
135 |
+
class_names = results["labels"]
|
136 |
+
class_ids = np.array(list(range(len(class_names))))
|
137 |
+
|
138 |
+
# predict mask with SAM 2
|
139 |
+
sam2_predictor.set_image(np.array(image))
|
140 |
+
masks, scores, logits = sam2_predictor.predict(
|
141 |
+
point_coords=None,
|
142 |
+
point_labels=None,
|
143 |
+
box=input_boxes,
|
144 |
+
multimask_output=False,
|
145 |
+
)
|
146 |
+
|
147 |
+
if masks.ndim == 4:
|
148 |
+
masks = masks.squeeze(1)
|
149 |
+
|
150 |
+
# specify labels
|
151 |
+
labels = [
|
152 |
+
f"{class_name}" for class_name in class_names
|
153 |
+
]
|
154 |
+
|
155 |
+
# visualization results
|
156 |
+
img = cv2.imread(image_path)
|
157 |
+
detections = sv.Detections(
|
158 |
+
xyxy=input_boxes,
|
159 |
+
mask=masks.astype(bool),
|
160 |
+
class_id=class_ids
|
161 |
+
)
|
162 |
+
|
163 |
+
box_annotator = sv.BoxAnnotator()
|
164 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
165 |
+
|
166 |
+
label_annotator = sv.LabelAnnotator()
|
167 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
168 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
|
169 |
+
|
170 |
+
mask_annotator = sv.MaskAnnotator()
|
171 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
172 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
|
173 |
+
|
174 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
175 |
+
|
176 |
+
"""
|
177 |
+
Pipeline 2: Dense Region Caption + Segmentation
|
178 |
+
"""
|
179 |
+
def dense_region_caption_and_segmentation(
|
180 |
+
florence2_model,
|
181 |
+
florence2_processor,
|
182 |
+
sam2_predictor,
|
183 |
+
image_path,
|
184 |
+
task_prompt="<DENSE_REGION_CAPTION>",
|
185 |
+
text_input=None,
|
186 |
+
output_dir=OUTPUT_DIR
|
187 |
+
):
|
188 |
+
assert text_input is None, "Text input should be None when calling dense region caption pipeline."
|
189 |
+
# run florence-2 object detection in demo
|
190 |
+
image = Image.open(image_path).convert("RGB")
|
191 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
192 |
+
|
193 |
+
""" Florence-2 Object Detection Output Format
|
194 |
+
{'<DENSE_REGION_CAPTION>':
|
195 |
+
{
|
196 |
+
'bboxes':
|
197 |
+
[
|
198 |
+
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
|
199 |
+
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
|
200 |
+
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
|
201 |
+
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
|
202 |
+
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
|
203 |
+
],
|
204 |
+
'labels': ['turquoise Volkswagen Beetle', 'wooden double doors with metal handles', 'wheel', 'wheel', 'door']
|
205 |
+
}
|
206 |
+
}
|
207 |
+
"""
|
208 |
+
results = results[task_prompt]
|
209 |
+
# parse florence-2 detection results
|
210 |
+
input_boxes = np.array(results["bboxes"])
|
211 |
+
class_names = results["labels"]
|
212 |
+
class_ids = np.array(list(range(len(class_names))))
|
213 |
+
|
214 |
+
# predict mask with SAM 2
|
215 |
+
sam2_predictor.set_image(np.array(image))
|
216 |
+
masks, scores, logits = sam2_predictor.predict(
|
217 |
+
point_coords=None,
|
218 |
+
point_labels=None,
|
219 |
+
box=input_boxes,
|
220 |
+
multimask_output=False,
|
221 |
+
)
|
222 |
+
|
223 |
+
if masks.ndim == 4:
|
224 |
+
masks = masks.squeeze(1)
|
225 |
+
|
226 |
+
# specify labels
|
227 |
+
labels = [
|
228 |
+
f"{class_name}" for class_name in class_names
|
229 |
+
]
|
230 |
+
|
231 |
+
# visualization results
|
232 |
+
img = cv2.imread(image_path)
|
233 |
+
detections = sv.Detections(
|
234 |
+
xyxy=input_boxes,
|
235 |
+
mask=masks.astype(bool),
|
236 |
+
class_id=class_ids
|
237 |
+
)
|
238 |
+
|
239 |
+
box_annotator = sv.BoxAnnotator()
|
240 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
241 |
+
|
242 |
+
label_annotator = sv.LabelAnnotator()
|
243 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
244 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_annotated_image.jpg"), annotated_frame)
|
245 |
+
|
246 |
+
mask_annotator = sv.MaskAnnotator()
|
247 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
248 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_image_with_mask.jpg"), annotated_frame)
|
249 |
+
|
250 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
251 |
+
|
252 |
+
|
253 |
+
"""
|
254 |
+
Pipeline 3: Region Proposal + Segmentation
|
255 |
+
"""
|
256 |
+
def region_proposal_and_segmentation(
|
257 |
+
florence2_model,
|
258 |
+
florence2_processor,
|
259 |
+
sam2_predictor,
|
260 |
+
image_path,
|
261 |
+
task_prompt="<REGION_PROPOSAL>",
|
262 |
+
text_input=None,
|
263 |
+
output_dir=OUTPUT_DIR
|
264 |
+
):
|
265 |
+
assert text_input is None, "Text input should be None when calling region proposal pipeline."
|
266 |
+
# run florence-2 object detection in demo
|
267 |
+
image = Image.open(image_path).convert("RGB")
|
268 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
269 |
+
|
270 |
+
""" Florence-2 Object Detection Output Format
|
271 |
+
{'<REGION_PROPOSAL>':
|
272 |
+
{
|
273 |
+
'bboxes':
|
274 |
+
[
|
275 |
+
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
|
276 |
+
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
|
277 |
+
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
|
278 |
+
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
|
279 |
+
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
|
280 |
+
],
|
281 |
+
'labels': ['', '', '', '', '', '', '']
|
282 |
+
}
|
283 |
+
}
|
284 |
+
"""
|
285 |
+
results = results[task_prompt]
|
286 |
+
# parse florence-2 detection results
|
287 |
+
input_boxes = np.array(results["bboxes"])
|
288 |
+
class_names = results["labels"]
|
289 |
+
class_ids = np.array(list(range(len(class_names))))
|
290 |
+
|
291 |
+
# predict mask with SAM 2
|
292 |
+
sam2_predictor.set_image(np.array(image))
|
293 |
+
masks, scores, logits = sam2_predictor.predict(
|
294 |
+
point_coords=None,
|
295 |
+
point_labels=None,
|
296 |
+
box=input_boxes,
|
297 |
+
multimask_output=False,
|
298 |
+
)
|
299 |
+
|
300 |
+
if masks.ndim == 4:
|
301 |
+
masks = masks.squeeze(1)
|
302 |
+
|
303 |
+
# specify labels
|
304 |
+
labels = [
|
305 |
+
f"region_{idx}" for idx, class_name in enumerate(class_names)
|
306 |
+
]
|
307 |
+
|
308 |
+
# visualization results
|
309 |
+
img = cv2.imread(image_path)
|
310 |
+
detections = sv.Detections(
|
311 |
+
xyxy=input_boxes,
|
312 |
+
mask=masks.astype(bool),
|
313 |
+
class_id=class_ids
|
314 |
+
)
|
315 |
+
|
316 |
+
box_annotator = sv.BoxAnnotator()
|
317 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
318 |
+
|
319 |
+
label_annotator = sv.LabelAnnotator()
|
320 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
321 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal.jpg"), annotated_frame)
|
322 |
+
|
323 |
+
mask_annotator = sv.MaskAnnotator()
|
324 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
325 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal_with_mask.jpg"), annotated_frame)
|
326 |
+
|
327 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
328 |
+
|
329 |
+
|
330 |
+
"""
|
331 |
+
Pipeline 4: Phrase Grounding + Segmentation
|
332 |
+
"""
|
333 |
+
def phrase_grounding_and_segmentation(
|
334 |
+
florence2_model,
|
335 |
+
florence2_processor,
|
336 |
+
sam2_predictor,
|
337 |
+
image_path,
|
338 |
+
task_prompt="<CAPTION_TO_PHRASE_GROUNDING>",
|
339 |
+
text_input=None,
|
340 |
+
output_dir=OUTPUT_DIR
|
341 |
+
):
|
342 |
+
# run florence-2 object detection in demo
|
343 |
+
image = Image.open(image_path).convert("RGB")
|
344 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
345 |
+
|
346 |
+
""" Florence-2 Object Detection Output Format
|
347 |
+
{'<CAPTION_TO_PHRASE_GROUNDING>':
|
348 |
+
{
|
349 |
+
'bboxes':
|
350 |
+
[
|
351 |
+
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594],
|
352 |
+
[1.5999999046325684, 4.079999923706055, 639.0399780273438, 305.03997802734375]
|
353 |
+
],
|
354 |
+
'labels': ['A green car', 'a yellow building']
|
355 |
+
}
|
356 |
+
}
|
357 |
+
"""
|
358 |
+
assert text_input is not None, "Text input should not be None when calling phrase grounding pipeline."
|
359 |
+
results = results[task_prompt]
|
360 |
+
# parse florence-2 detection results
|
361 |
+
input_boxes = np.array(results["bboxes"])
|
362 |
+
class_names = results["labels"]
|
363 |
+
class_ids = np.array(list(range(len(class_names))))
|
364 |
+
|
365 |
+
# predict mask with SAM 2
|
366 |
+
sam2_predictor.set_image(np.array(image))
|
367 |
+
masks, scores, logits = sam2_predictor.predict(
|
368 |
+
point_coords=None,
|
369 |
+
point_labels=None,
|
370 |
+
box=input_boxes,
|
371 |
+
multimask_output=False,
|
372 |
+
)
|
373 |
+
|
374 |
+
if masks.ndim == 4:
|
375 |
+
masks = masks.squeeze(1)
|
376 |
+
|
377 |
+
# specify labels
|
378 |
+
labels = [
|
379 |
+
f"{class_name}" for class_name in class_names
|
380 |
+
]
|
381 |
+
|
382 |
+
# visualization results
|
383 |
+
img = cv2.imread(image_path)
|
384 |
+
detections = sv.Detections(
|
385 |
+
xyxy=input_boxes,
|
386 |
+
mask=masks.astype(bool),
|
387 |
+
class_id=class_ids
|
388 |
+
)
|
389 |
+
|
390 |
+
box_annotator = sv.BoxAnnotator()
|
391 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
392 |
+
|
393 |
+
label_annotator = sv.LabelAnnotator()
|
394 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
395 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding.jpg"), annotated_frame)
|
396 |
+
|
397 |
+
mask_annotator = sv.MaskAnnotator()
|
398 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
399 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding_with_mask.jpg"), annotated_frame)
|
400 |
+
|
401 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
402 |
+
|
403 |
+
|
404 |
+
"""
|
405 |
+
Pipeline 5: Referring Expression Segmentation
|
406 |
+
|
407 |
+
Note that Florence-2 directly support referring segmentation with polygon output format, which may be not that accurate,
|
408 |
+
therefore we try to decode box from polygon and use SAM 2 for mask prediction
|
409 |
+
"""
|
410 |
+
def referring_expression_segmentation(
|
411 |
+
florence2_model,
|
412 |
+
florence2_processor,
|
413 |
+
sam2_predictor,
|
414 |
+
image_path,
|
415 |
+
task_prompt="<REFERRING_EXPRESSION_SEGMENTATION>",
|
416 |
+
text_input=None,
|
417 |
+
output_dir=OUTPUT_DIR
|
418 |
+
):
|
419 |
+
# run florence-2 object detection in demo
|
420 |
+
image = Image.open(image_path).convert("RGB")
|
421 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
422 |
+
|
423 |
+
""" Florence-2 Object Detection Output Format
|
424 |
+
{'<REFERRING_EXPRESSION_SEGMENTATION>':
|
425 |
+
{
|
426 |
+
'polygons': [[[...]]]
|
427 |
+
'labels': ['']
|
428 |
+
}
|
429 |
+
}
|
430 |
+
"""
|
431 |
+
assert text_input is not None, "Text input should not be None when calling referring segmentation pipeline."
|
432 |
+
results = results[task_prompt]
|
433 |
+
# parse florence-2 detection results
|
434 |
+
polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
|
435 |
+
class_names = [text_input]
|
436 |
+
class_ids = np.array(list(range(len(class_names))))
|
437 |
+
|
438 |
+
# parse polygon format to mask
|
439 |
+
img_width, img_height = image.size[0], image.size[1]
|
440 |
+
florence2_mask = np.zeros((img_height, img_width), dtype=np.uint8)
|
441 |
+
if len(polygon_points) < 3:
|
442 |
+
print("Invalid polygon:", polygon_points)
|
443 |
+
exit()
|
444 |
+
cv2.fillPoly(florence2_mask, [polygon_points], 1)
|
445 |
+
if florence2_mask.ndim == 2:
|
446 |
+
florence2_mask = florence2_mask[None]
|
447 |
+
|
448 |
+
# compute bounding box based on polygon points
|
449 |
+
x_min = np.min(polygon_points[:, 0])
|
450 |
+
y_min = np.min(polygon_points[:, 1])
|
451 |
+
x_max = np.max(polygon_points[:, 0])
|
452 |
+
y_max = np.max(polygon_points[:, 1])
|
453 |
+
|
454 |
+
input_boxes = np.array([[x_min, y_min, x_max, y_max]])
|
455 |
+
|
456 |
+
# predict mask with SAM 2
|
457 |
+
sam2_predictor.set_image(np.array(image))
|
458 |
+
sam2_masks, scores, logits = sam2_predictor.predict(
|
459 |
+
point_coords=None,
|
460 |
+
point_labels=None,
|
461 |
+
box=input_boxes,
|
462 |
+
multimask_output=False,
|
463 |
+
)
|
464 |
+
|
465 |
+
if sam2_masks.ndim == 4:
|
466 |
+
sam2_masks = sam2_masks.squeeze(1)
|
467 |
+
|
468 |
+
# specify labels
|
469 |
+
labels = [
|
470 |
+
f"{class_name}" for class_name in class_names
|
471 |
+
]
|
472 |
+
|
473 |
+
# visualization florence2 mask
|
474 |
+
img = cv2.imread(image_path)
|
475 |
+
detections = sv.Detections(
|
476 |
+
xyxy=input_boxes,
|
477 |
+
mask=florence2_mask.astype(bool),
|
478 |
+
class_id=class_ids
|
479 |
+
)
|
480 |
+
|
481 |
+
box_annotator = sv.BoxAnnotator()
|
482 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
483 |
+
|
484 |
+
label_annotator = sv.LabelAnnotator()
|
485 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
486 |
+
cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box.jpg"), annotated_frame)
|
487 |
+
|
488 |
+
mask_annotator = sv.MaskAnnotator()
|
489 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
490 |
+
cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box_with_mask.jpg"), annotated_frame)
|
491 |
+
|
492 |
+
print(f'Successfully save florence-2 annotated image to "{output_dir}"')
|
493 |
+
|
494 |
+
# visualize sam2 mask
|
495 |
+
img = cv2.imread(image_path)
|
496 |
+
detections = sv.Detections(
|
497 |
+
xyxy=input_boxes,
|
498 |
+
mask=sam2_masks.astype(bool),
|
499 |
+
class_id=class_ids
|
500 |
+
)
|
501 |
+
|
502 |
+
box_annotator = sv.BoxAnnotator()
|
503 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
504 |
+
|
505 |
+
label_annotator = sv.LabelAnnotator()
|
506 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
507 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box.jpg"), annotated_frame)
|
508 |
+
|
509 |
+
mask_annotator = sv.MaskAnnotator()
|
510 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
511 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box_with_sam2_mask.jpg"), annotated_frame)
|
512 |
+
|
513 |
+
print(f'Successfully save sam2 annotated image to "{output_dir}"')
|
514 |
+
|
515 |
+
|
516 |
+
"""
|
517 |
+
Pipeline 6: Open-Vocabulary Detection + Segmentation
|
518 |
+
"""
|
519 |
+
def open_vocabulary_detection_and_segmentation(
|
520 |
+
florence2_model,
|
521 |
+
florence2_processor,
|
522 |
+
sam2_predictor,
|
523 |
+
image_path,
|
524 |
+
task_prompt="<OPEN_VOCABULARY_DETECTION>",
|
525 |
+
text_input=None,
|
526 |
+
output_dir=OUTPUT_DIR
|
527 |
+
):
|
528 |
+
# run florence-2 object detection in demo
|
529 |
+
image = Image.open(image_path).convert("RGB")
|
530 |
+
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
531 |
+
|
532 |
+
""" Florence-2 Open-Vocabulary Detection Output Format
|
533 |
+
{'<OPEN_VOCABULARY_DETECTION>':
|
534 |
+
{
|
535 |
+
'bboxes':
|
536 |
+
[
|
537 |
+
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594]
|
538 |
+
],
|
539 |
+
'bboxes_labels': ['A green car'],
|
540 |
+
'polygons': [],
|
541 |
+
'polygons_labels': []
|
542 |
+
}
|
543 |
+
}
|
544 |
+
"""
|
545 |
+
assert text_input is not None, "Text input should not be None when calling open-vocabulary detection pipeline."
|
546 |
+
results = results[task_prompt]
|
547 |
+
# parse florence-2 detection results
|
548 |
+
input_boxes = np.array(results["bboxes"])
|
549 |
+
print(results)
|
550 |
+
class_names = results["bboxes_labels"]
|
551 |
+
class_ids = np.array(list(range(len(class_names))))
|
552 |
+
|
553 |
+
# predict mask with SAM 2
|
554 |
+
sam2_predictor.set_image(np.array(image))
|
555 |
+
masks, scores, logits = sam2_predictor.predict(
|
556 |
+
point_coords=None,
|
557 |
+
point_labels=None,
|
558 |
+
box=input_boxes,
|
559 |
+
multimask_output=False,
|
560 |
+
)
|
561 |
+
|
562 |
+
if masks.ndim == 4:
|
563 |
+
masks = masks.squeeze(1)
|
564 |
+
|
565 |
+
# specify labels
|
566 |
+
labels = [
|
567 |
+
f"{class_name}" for class_name in class_names
|
568 |
+
]
|
569 |
+
|
570 |
+
# visualization results
|
571 |
+
img = cv2.imread(image_path)
|
572 |
+
detections = sv.Detections(
|
573 |
+
xyxy=input_boxes,
|
574 |
+
mask=masks.astype(bool),
|
575 |
+
class_id=class_ids
|
576 |
+
)
|
577 |
+
|
578 |
+
box_annotator = sv.BoxAnnotator()
|
579 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
580 |
+
|
581 |
+
label_annotator = sv.LabelAnnotator()
|
582 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
583 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame)
|
584 |
+
|
585 |
+
mask_annotator = sv.MaskAnnotator()
|
586 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
587 |
+
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame)
|
588 |
+
|
589 |
+
print(f'Successfully save annotated image to "{output_dir}"')
|
590 |
+
|
591 |
+
if __name__ == "__main__":
|
592 |
+
|
593 |
+
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
594 |
+
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
|
595 |
+
parser.add_argument("--pipeline", type=str, default="object_detection_segmentation", required=True, help="path to image file")
|
596 |
+
parser.add_argument("--text_input", type=str, default=None, required=False, help="path to image file")
|
597 |
+
args = parser.parse_args()
|
598 |
+
|
599 |
+
IMAGE_PATH = args.image_path
|
600 |
+
PIPELINE = args.pipeline
|
601 |
+
INPUT_TEXT = args.text_input
|
602 |
+
|
603 |
+
print(f"Running pipeline: {PIPELINE} now.")
|
604 |
+
|
605 |
+
if PIPELINE == "object_detection_segmentation":
|
606 |
+
# pipeline-1: detection + segmentation
|
607 |
+
object_detection_and_segmentation(
|
608 |
+
florence2_model=florence2_model,
|
609 |
+
florence2_processor=florence2_processor,
|
610 |
+
sam2_predictor=sam2_predictor,
|
611 |
+
image_path=IMAGE_PATH
|
612 |
+
)
|
613 |
+
elif PIPELINE == "dense_region_caption_segmentation":
|
614 |
+
# pipeline-2: dense region caption + segmentation
|
615 |
+
dense_region_caption_and_segmentation(
|
616 |
+
florence2_model=florence2_model,
|
617 |
+
florence2_processor=florence2_processor,
|
618 |
+
sam2_predictor=sam2_predictor,
|
619 |
+
image_path=IMAGE_PATH
|
620 |
+
)
|
621 |
+
elif PIPELINE == "region_proposal_segmentation":
|
622 |
+
# pipeline-3: dense region caption + segmentation
|
623 |
+
region_proposal_and_segmentation(
|
624 |
+
florence2_model=florence2_model,
|
625 |
+
florence2_processor=florence2_processor,
|
626 |
+
sam2_predictor=sam2_predictor,
|
627 |
+
image_path=IMAGE_PATH
|
628 |
+
)
|
629 |
+
elif PIPELINE == "phrase_grounding_segmentation":
|
630 |
+
# pipeline-4: phrase grounding + segmentation
|
631 |
+
phrase_grounding_and_segmentation(
|
632 |
+
florence2_model=florence2_model,
|
633 |
+
florence2_processor=florence2_processor,
|
634 |
+
sam2_predictor=sam2_predictor,
|
635 |
+
image_path=IMAGE_PATH,
|
636 |
+
text_input=INPUT_TEXT
|
637 |
+
)
|
638 |
+
elif PIPELINE == "referring_expression_segmentation":
|
639 |
+
# pipeline-5: referring segmentation + segmentation
|
640 |
+
referring_expression_segmentation(
|
641 |
+
florence2_model=florence2_model,
|
642 |
+
florence2_processor=florence2_processor,
|
643 |
+
sam2_predictor=sam2_predictor,
|
644 |
+
image_path=IMAGE_PATH,
|
645 |
+
text_input=INPUT_TEXT
|
646 |
+
)
|
647 |
+
elif PIPELINE == "open_vocabulary_detection_segmentation":
|
648 |
+
# pipeline-6: open-vocabulary detection + segmentation
|
649 |
+
open_vocabulary_detection_and_segmentation(
|
650 |
+
florence2_model=florence2_model,
|
651 |
+
florence2_processor=florence2_processor,
|
652 |
+
sam2_predictor=sam2_predictor,
|
653 |
+
image_path=IMAGE_PATH,
|
654 |
+
text_input=INPUT_TEXT
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_gd1.5_demo.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk import DetectionTask
|
5 |
+
from dds_cloudapi_sdk import TextPrompt
|
6 |
+
from dds_cloudapi_sdk import DetectionModel
|
7 |
+
from dds_cloudapi_sdk import DetectionTarget
|
8 |
+
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
import tempfile
|
14 |
+
import numpy as np
|
15 |
+
import supervision as sv
|
16 |
+
import pycocotools.mask as mask_util
|
17 |
+
from pathlib import Path
|
18 |
+
from PIL import Image
|
19 |
+
from sam2.build_sam import build_sam2
|
20 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
21 |
+
|
22 |
+
"""
|
23 |
+
Hyper parameters
|
24 |
+
"""
|
25 |
+
API_TOKEN = "Your API token"
|
26 |
+
TEXT_PROMPT = "car . building ."
|
27 |
+
IMG_PATH = "notebooks/images/cars.jpg"
|
28 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
29 |
+
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
30 |
+
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
31 |
+
BOX_THRESHOLD = 0.2
|
32 |
+
WITH_SLICE_INFERENCE = False
|
33 |
+
SLICE_WH = (480, 480)
|
34 |
+
OVERLAP_RATIO = (0.2, 0.2)
|
35 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
+
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
|
37 |
+
DUMP_JSON_RESULTS = True
|
38 |
+
|
39 |
+
# create output directory
|
40 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
41 |
+
|
42 |
+
"""
|
43 |
+
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
|
44 |
+
"""
|
45 |
+
# Step 1: initialize the config
|
46 |
+
token = API_TOKEN
|
47 |
+
config = Config(token)
|
48 |
+
|
49 |
+
# Step 2: initialize the client
|
50 |
+
client = Client(config)
|
51 |
+
|
52 |
+
# Step 3: run the task by DetectionTask class
|
53 |
+
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
54 |
+
# if you are processing local image file, upload them to DDS server to get the image url
|
55 |
+
|
56 |
+
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
57 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
58 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
59 |
+
|
60 |
+
if WITH_SLICE_INFERENCE:
|
61 |
+
def callback(image_slice: np.ndarray) -> sv.Detections:
|
62 |
+
print("Inference on image slice")
|
63 |
+
# save the img as temp img file for GD-1.5 API usage
|
64 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
65 |
+
temp_filename = tmpfile.name
|
66 |
+
cv2.imwrite(temp_filename, image_slice)
|
67 |
+
image_url = client.upload_file(temp_filename)
|
68 |
+
task = DetectionTask(
|
69 |
+
image_url=image_url,
|
70 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
71 |
+
targets=[DetectionTarget.BBox], # detect bbox
|
72 |
+
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
73 |
+
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
74 |
+
)
|
75 |
+
client.run_task(task)
|
76 |
+
result = task.result
|
77 |
+
# detele the tempfile
|
78 |
+
os.remove(temp_filename)
|
79 |
+
|
80 |
+
input_boxes = []
|
81 |
+
confidences = []
|
82 |
+
class_ids = []
|
83 |
+
objects = result.objects
|
84 |
+
for idx, obj in enumerate(objects):
|
85 |
+
input_boxes.append(obj.bbox)
|
86 |
+
confidences.append(obj.score)
|
87 |
+
cls_name = obj.category.lower().strip()
|
88 |
+
class_ids.append(class_name_to_id[cls_name])
|
89 |
+
# ensure input_boxes with shape (_, 4)
|
90 |
+
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
91 |
+
class_ids = np.array(class_ids)
|
92 |
+
confidences = np.array(confidences)
|
93 |
+
return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
|
94 |
+
|
95 |
+
slicer = sv.InferenceSlicer(
|
96 |
+
callback=callback,
|
97 |
+
slice_wh=SLICE_WH,
|
98 |
+
overlap_ratio_wh=OVERLAP_RATIO,
|
99 |
+
iou_threshold=0.5,
|
100 |
+
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
101 |
+
)
|
102 |
+
detections = slicer(cv2.imread(IMG_PATH))
|
103 |
+
class_names = [class_id_to_name[id] for id in detections.class_id]
|
104 |
+
confidences = detections.confidence
|
105 |
+
class_ids = detections.class_id
|
106 |
+
input_boxes = detections.xyxy
|
107 |
+
else:
|
108 |
+
image_url = client.upload_file(IMG_PATH)
|
109 |
+
|
110 |
+
task = DetectionTask(
|
111 |
+
image_url=image_url,
|
112 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
113 |
+
targets=[DetectionTarget.BBox], # detect bbox
|
114 |
+
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
115 |
+
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
116 |
+
)
|
117 |
+
|
118 |
+
client.run_task(task)
|
119 |
+
result = task.result
|
120 |
+
|
121 |
+
objects = result.objects # the list of detected objects
|
122 |
+
|
123 |
+
|
124 |
+
input_boxes = []
|
125 |
+
confidences = []
|
126 |
+
class_names = []
|
127 |
+
class_ids = []
|
128 |
+
|
129 |
+
for idx, obj in enumerate(objects):
|
130 |
+
input_boxes.append(obj.bbox)
|
131 |
+
confidences.append(obj.score)
|
132 |
+
cls_name = obj.category.lower().strip()
|
133 |
+
class_names.append(cls_name)
|
134 |
+
class_ids.append(class_name_to_id[cls_name])
|
135 |
+
|
136 |
+
input_boxes = np.array(input_boxes)
|
137 |
+
class_ids = np.array(class_ids)
|
138 |
+
|
139 |
+
"""
|
140 |
+
Init SAM 2 Model and Predict Mask with Box Prompt
|
141 |
+
"""
|
142 |
+
|
143 |
+
# environment settings
|
144 |
+
# use bfloat16
|
145 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
146 |
+
|
147 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
148 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
149 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
150 |
+
torch.backends.cudnn.allow_tf32 = True
|
151 |
+
|
152 |
+
# build SAM2 image predictor
|
153 |
+
sam2_checkpoint = SAM2_CHECKPOINT
|
154 |
+
model_cfg = SAM2_MODEL_CONFIG
|
155 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
156 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
157 |
+
|
158 |
+
image = Image.open(IMG_PATH)
|
159 |
+
|
160 |
+
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
161 |
+
|
162 |
+
masks, scores, logits = sam2_predictor.predict(
|
163 |
+
point_coords=None,
|
164 |
+
point_labels=None,
|
165 |
+
box=input_boxes,
|
166 |
+
multimask_output=False,
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
"""
|
171 |
+
Post-process the output of the model to get the masks, scores, and logits for visualization
|
172 |
+
"""
|
173 |
+
# convert the shape to (n, H, W)
|
174 |
+
if masks.ndim == 4:
|
175 |
+
masks = masks.squeeze(1)
|
176 |
+
|
177 |
+
|
178 |
+
"""
|
179 |
+
Visualization the Predict Results
|
180 |
+
"""
|
181 |
+
|
182 |
+
labels = [
|
183 |
+
f"{class_name} {confidence:.2f}"
|
184 |
+
for class_name, confidence
|
185 |
+
in zip(class_names, confidences)
|
186 |
+
]
|
187 |
+
|
188 |
+
"""
|
189 |
+
Visualize image with supervision useful API
|
190 |
+
"""
|
191 |
+
img = cv2.imread(IMG_PATH)
|
192 |
+
detections = sv.Detections(
|
193 |
+
xyxy=input_boxes, # (n, 4)
|
194 |
+
mask=masks.astype(bool), # (n, h, w)
|
195 |
+
class_id=class_ids
|
196 |
+
)
|
197 |
+
|
198 |
+
box_annotator = sv.BoxAnnotator()
|
199 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
200 |
+
|
201 |
+
label_annotator = sv.LabelAnnotator()
|
202 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
203 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
204 |
+
|
205 |
+
mask_annotator = sv.MaskAnnotator()
|
206 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
207 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
208 |
+
|
209 |
+
print(f'Annotated image has already been saved as to "{OUTPUT_DIR}"')
|
210 |
+
|
211 |
+
"""
|
212 |
+
Dump the results in standard format and save as json files
|
213 |
+
"""
|
214 |
+
|
215 |
+
def single_mask_to_rle(mask):
|
216 |
+
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
217 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
218 |
+
return rle
|
219 |
+
|
220 |
+
if DUMP_JSON_RESULTS:
|
221 |
+
print("Start dumping the annotation...")
|
222 |
+
# convert mask into rle format
|
223 |
+
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
224 |
+
|
225 |
+
input_boxes = input_boxes.tolist()
|
226 |
+
scores = scores.tolist()
|
227 |
+
# FIXME: class_names should be a list of strings without spaces
|
228 |
+
class_names = [class_name.strip() for class_name in class_names]
|
229 |
+
# save the results in standard format
|
230 |
+
results = {
|
231 |
+
"image_path": IMG_PATH,
|
232 |
+
"annotations" : [
|
233 |
+
{
|
234 |
+
"class_name": class_name,
|
235 |
+
"bbox": box,
|
236 |
+
"segmentation": mask_rle,
|
237 |
+
"score": score,
|
238 |
+
}
|
239 |
+
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
240 |
+
],
|
241 |
+
"box_format": "xyxy",
|
242 |
+
"img_width": image.width,
|
243 |
+
"img_height": image.height,
|
244 |
+
}
|
245 |
+
|
246 |
+
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
|
247 |
+
json.dump(results, f, indent=4)
|
248 |
+
|
249 |
+
print(f'Annotation has already been saved to "{OUTPUT_DIR}"')
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_hf_model_demo.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import supervision as sv
|
8 |
+
import pycocotools.mask as mask_util
|
9 |
+
from pathlib import Path
|
10 |
+
from supervision.draw.color import ColorPalette
|
11 |
+
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
12 |
+
from PIL import Image
|
13 |
+
from sam2.build_sam import build_sam2
|
14 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
15 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
16 |
+
|
17 |
+
"""
|
18 |
+
Hyper parameters
|
19 |
+
"""
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--grounding-model', default="IDEA-Research/grounding-dino-tiny")
|
22 |
+
parser.add_argument("--text-prompt", default="car. tire.")
|
23 |
+
parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
|
24 |
+
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
|
25 |
+
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
|
26 |
+
parser.add_argument("--output-dir", default="outputs/test_sam2.1")
|
27 |
+
parser.add_argument("--no-dump-json", action="store_true")
|
28 |
+
parser.add_argument("--force-cpu", action="store_true")
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
GROUNDING_MODEL = args.grounding_model
|
32 |
+
TEXT_PROMPT = args.text_prompt
|
33 |
+
IMG_PATH = args.img_path
|
34 |
+
SAM2_CHECKPOINT = args.sam2_checkpoint
|
35 |
+
SAM2_MODEL_CONFIG = args.sam2_model_config
|
36 |
+
DEVICE = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
37 |
+
OUTPUT_DIR = Path(args.output_dir)
|
38 |
+
DUMP_JSON_RESULTS = not args.no_dump_json
|
39 |
+
|
40 |
+
# create output directory
|
41 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
# environment settings
|
44 |
+
# use bfloat16
|
45 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
46 |
+
|
47 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
48 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
49 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
50 |
+
torch.backends.cudnn.allow_tf32 = True
|
51 |
+
|
52 |
+
# build SAM2 image predictor
|
53 |
+
sam2_checkpoint = SAM2_CHECKPOINT
|
54 |
+
model_cfg = SAM2_MODEL_CONFIG
|
55 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
56 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
57 |
+
|
58 |
+
# build grounding dino from huggingface
|
59 |
+
model_id = GROUNDING_MODEL
|
60 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
61 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
|
62 |
+
|
63 |
+
|
64 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
65 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
66 |
+
text = TEXT_PROMPT
|
67 |
+
img_path = IMG_PATH
|
68 |
+
|
69 |
+
image = Image.open(img_path)
|
70 |
+
|
71 |
+
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
72 |
+
|
73 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
|
74 |
+
with torch.no_grad():
|
75 |
+
outputs = grounding_model(**inputs)
|
76 |
+
|
77 |
+
results = processor.post_process_grounded_object_detection(
|
78 |
+
outputs,
|
79 |
+
inputs.input_ids,
|
80 |
+
box_threshold=0.4,
|
81 |
+
text_threshold=0.3,
|
82 |
+
target_sizes=[image.size[::-1]]
|
83 |
+
)
|
84 |
+
|
85 |
+
"""
|
86 |
+
Results is a list of dict with the following structure:
|
87 |
+
[
|
88 |
+
{
|
89 |
+
'scores': tensor([0.7969, 0.6469, 0.6002, 0.4220], device='cuda:0'),
|
90 |
+
'labels': ['car', 'tire', 'tire', 'tire'],
|
91 |
+
'boxes': tensor([[ 89.3244, 278.6940, 1710.3505, 851.5143],
|
92 |
+
[1392.4701, 554.4064, 1628.6133, 777.5872],
|
93 |
+
[ 436.1182, 621.8940, 676.5255, 851.6897],
|
94 |
+
[1236.0990, 688.3547, 1400.2427, 753.1256]], device='cuda:0')
|
95 |
+
}
|
96 |
+
]
|
97 |
+
"""
|
98 |
+
|
99 |
+
# get the box prompt for SAM 2
|
100 |
+
input_boxes = results[0]["boxes"].cpu().numpy()
|
101 |
+
|
102 |
+
masks, scores, logits = sam2_predictor.predict(
|
103 |
+
point_coords=None,
|
104 |
+
point_labels=None,
|
105 |
+
box=input_boxes,
|
106 |
+
multimask_output=False,
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
"""
|
111 |
+
Post-process the output of the model to get the masks, scores, and logits for visualization
|
112 |
+
"""
|
113 |
+
# convert the shape to (n, H, W)
|
114 |
+
if masks.ndim == 4:
|
115 |
+
masks = masks.squeeze(1)
|
116 |
+
|
117 |
+
|
118 |
+
confidences = results[0]["scores"].cpu().numpy().tolist()
|
119 |
+
class_names = results[0]["labels"]
|
120 |
+
class_ids = np.array(list(range(len(class_names))))
|
121 |
+
|
122 |
+
labels = [
|
123 |
+
f"{class_name} {confidence:.2f}"
|
124 |
+
for class_name, confidence
|
125 |
+
in zip(class_names, confidences)
|
126 |
+
]
|
127 |
+
|
128 |
+
"""
|
129 |
+
Visualize image with supervision useful API
|
130 |
+
"""
|
131 |
+
img = cv2.imread(img_path)
|
132 |
+
detections = sv.Detections(
|
133 |
+
xyxy=input_boxes, # (n, 4)
|
134 |
+
mask=masks.astype(bool), # (n, h, w)
|
135 |
+
class_id=class_ids
|
136 |
+
)
|
137 |
+
|
138 |
+
"""
|
139 |
+
Note that if you want to use default color map,
|
140 |
+
you can set color=ColorPalette.DEFAULT
|
141 |
+
"""
|
142 |
+
box_annotator = sv.BoxAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
143 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
144 |
+
|
145 |
+
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
146 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
147 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
148 |
+
|
149 |
+
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
150 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
151 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
152 |
+
|
153 |
+
|
154 |
+
"""
|
155 |
+
Dump the results in standard format and save as json files
|
156 |
+
"""
|
157 |
+
|
158 |
+
def single_mask_to_rle(mask):
|
159 |
+
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
160 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
161 |
+
return rle
|
162 |
+
|
163 |
+
if DUMP_JSON_RESULTS:
|
164 |
+
# convert mask into rle format
|
165 |
+
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
166 |
+
|
167 |
+
input_boxes = input_boxes.tolist()
|
168 |
+
scores = scores.tolist()
|
169 |
+
# save the results in standard format
|
170 |
+
results = {
|
171 |
+
"image_path": img_path,
|
172 |
+
"annotations" : [
|
173 |
+
{
|
174 |
+
"class_name": class_name,
|
175 |
+
"bbox": box,
|
176 |
+
"segmentation": mask_rle,
|
177 |
+
"score": score,
|
178 |
+
}
|
179 |
+
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
180 |
+
],
|
181 |
+
"box_format": "xyxy",
|
182 |
+
"img_width": image.width,
|
183 |
+
"img_height": image.height,
|
184 |
+
}
|
185 |
+
|
186 |
+
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_hf_model_demo_results.json"), "w") as f:
|
187 |
+
json.dump(results, f, indent=4)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_local_demo.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import supervision as sv
|
7 |
+
import pycocotools.mask as mask_util
|
8 |
+
from pathlib import Path
|
9 |
+
from torchvision.ops import box_convert
|
10 |
+
from sam2.build_sam import build_sam2
|
11 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
+
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
13 |
+
|
14 |
+
"""
|
15 |
+
Hyper parameters
|
16 |
+
"""
|
17 |
+
TEXT_PROMPT = "car. tire."
|
18 |
+
IMG_PATH = "notebooks/images/truck.jpg"
|
19 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
20 |
+
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
21 |
+
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
22 |
+
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
23 |
+
BOX_THRESHOLD = 0.35
|
24 |
+
TEXT_THRESHOLD = 0.25
|
25 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
|
27 |
+
DUMP_JSON_RESULTS = True
|
28 |
+
|
29 |
+
# create output directory
|
30 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
31 |
+
|
32 |
+
# environment settings
|
33 |
+
# use bfloat16
|
34 |
+
|
35 |
+
# build SAM2 image predictor
|
36 |
+
sam2_checkpoint = SAM2_CHECKPOINT
|
37 |
+
model_cfg = SAM2_MODEL_CONFIG
|
38 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
39 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
40 |
+
|
41 |
+
# build grounding dino model
|
42 |
+
grounding_model = load_model(
|
43 |
+
model_config_path=GROUNDING_DINO_CONFIG,
|
44 |
+
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
|
45 |
+
device=DEVICE
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
50 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
51 |
+
text = TEXT_PROMPT
|
52 |
+
img_path = IMG_PATH
|
53 |
+
|
54 |
+
image_source, image = load_image(img_path)
|
55 |
+
|
56 |
+
sam2_predictor.set_image(image_source)
|
57 |
+
|
58 |
+
boxes, confidences, labels = predict(
|
59 |
+
model=grounding_model,
|
60 |
+
image=image,
|
61 |
+
caption=text,
|
62 |
+
box_threshold=BOX_THRESHOLD,
|
63 |
+
text_threshold=TEXT_THRESHOLD,
|
64 |
+
)
|
65 |
+
|
66 |
+
# process the box prompt for SAM 2
|
67 |
+
h, w, _ = image_source.shape
|
68 |
+
boxes = boxes * torch.Tensor([w, h, w, h])
|
69 |
+
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
70 |
+
|
71 |
+
|
72 |
+
# FIXME: figure how does this influence the G-DINO model
|
73 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
74 |
+
|
75 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
76 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
77 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
78 |
+
torch.backends.cudnn.allow_tf32 = True
|
79 |
+
|
80 |
+
masks, scores, logits = sam2_predictor.predict(
|
81 |
+
point_coords=None,
|
82 |
+
point_labels=None,
|
83 |
+
box=input_boxes,
|
84 |
+
multimask_output=False,
|
85 |
+
)
|
86 |
+
|
87 |
+
"""
|
88 |
+
Post-process the output of the model to get the masks, scores, and logits for visualization
|
89 |
+
"""
|
90 |
+
# convert the shape to (n, H, W)
|
91 |
+
if masks.ndim == 4:
|
92 |
+
masks = masks.squeeze(1)
|
93 |
+
|
94 |
+
|
95 |
+
confidences = confidences.numpy().tolist()
|
96 |
+
class_names = labels
|
97 |
+
|
98 |
+
class_ids = np.array(list(range(len(class_names))))
|
99 |
+
|
100 |
+
labels = [
|
101 |
+
f"{class_name} {confidence:.2f}"
|
102 |
+
for class_name, confidence
|
103 |
+
in zip(class_names, confidences)
|
104 |
+
]
|
105 |
+
|
106 |
+
"""
|
107 |
+
Visualize image with supervision useful API
|
108 |
+
"""
|
109 |
+
img = cv2.imread(img_path)
|
110 |
+
detections = sv.Detections(
|
111 |
+
xyxy=input_boxes, # (n, 4)
|
112 |
+
mask=masks.astype(bool), # (n, h, w)
|
113 |
+
class_id=class_ids
|
114 |
+
)
|
115 |
+
|
116 |
+
box_annotator = sv.BoxAnnotator()
|
117 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
118 |
+
|
119 |
+
label_annotator = sv.LabelAnnotator()
|
120 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
121 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
122 |
+
|
123 |
+
mask_annotator = sv.MaskAnnotator()
|
124 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
125 |
+
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
126 |
+
|
127 |
+
"""
|
128 |
+
Dump the results in standard format and save as json files
|
129 |
+
"""
|
130 |
+
|
131 |
+
def single_mask_to_rle(mask):
|
132 |
+
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
133 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
134 |
+
return rle
|
135 |
+
|
136 |
+
if DUMP_JSON_RESULTS:
|
137 |
+
# convert mask into rle format
|
138 |
+
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
139 |
+
|
140 |
+
input_boxes = input_boxes.tolist()
|
141 |
+
scores = scores.tolist()
|
142 |
+
# save the results in standard format
|
143 |
+
results = {
|
144 |
+
"image_path": img_path,
|
145 |
+
"annotations" : [
|
146 |
+
{
|
147 |
+
"class_name": class_name,
|
148 |
+
"bbox": box,
|
149 |
+
"segmentation": mask_rle,
|
150 |
+
"score": score,
|
151 |
+
}
|
152 |
+
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
153 |
+
],
|
154 |
+
"box_format": "xyxy",
|
155 |
+
"img_width": w,
|
156 |
+
"img_height": h,
|
157 |
+
}
|
158 |
+
|
159 |
+
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
|
160 |
+
json.dump(results, f, indent=4)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
from PIL import Image
|
7 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
10 |
+
from utils.track_utils import sample_points_from_masks
|
11 |
+
from utils.video_utils import create_video_from_images
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
Step 1: Environment settings and model initialization
|
16 |
+
"""
|
17 |
+
# use bfloat16 for the entire notebook
|
18 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
19 |
+
|
20 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
21 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
22 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
23 |
+
torch.backends.cudnn.allow_tf32 = True
|
24 |
+
|
25 |
+
# init sam image predictor and video predictor model
|
26 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
27 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
28 |
+
|
29 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
30 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
31 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
32 |
+
|
33 |
+
|
34 |
+
# init grounding dino model from huggingface
|
35 |
+
model_id = "IDEA-Research/grounding-dino-tiny"
|
36 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
38 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
39 |
+
|
40 |
+
|
41 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
42 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
43 |
+
text = "car."
|
44 |
+
|
45 |
+
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
46 |
+
|
47 |
+
video_dir = "notebooks/videos/car"
|
48 |
+
|
49 |
+
# scan all the JPEG frame names in this directory
|
50 |
+
frame_names = [
|
51 |
+
p for p in os.listdir(video_dir)
|
52 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
53 |
+
]
|
54 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
55 |
+
|
56 |
+
# init video predictor state
|
57 |
+
inference_state = video_predictor.init_state(video_path=video_dir)
|
58 |
+
|
59 |
+
ann_frame_idx = 0 # the frame index we interact with
|
60 |
+
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
61 |
+
|
62 |
+
|
63 |
+
"""
|
64 |
+
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
|
65 |
+
"""
|
66 |
+
|
67 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
68 |
+
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
69 |
+
image = Image.open(img_path)
|
70 |
+
|
71 |
+
# run Grounding DINO on the image
|
72 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
73 |
+
with torch.no_grad():
|
74 |
+
outputs = grounding_model(**inputs)
|
75 |
+
|
76 |
+
results = processor.post_process_grounded_object_detection(
|
77 |
+
outputs,
|
78 |
+
inputs.input_ids,
|
79 |
+
box_threshold=0.25,
|
80 |
+
text_threshold=0.3,
|
81 |
+
target_sizes=[image.size[::-1]]
|
82 |
+
)
|
83 |
+
|
84 |
+
# prompt SAM image predictor to get the mask for the object
|
85 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
86 |
+
|
87 |
+
# process the detection results
|
88 |
+
input_boxes = results[0]["boxes"].cpu().numpy()
|
89 |
+
OBJECTS = results[0]["labels"]
|
90 |
+
|
91 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
92 |
+
masks, scores, logits = image_predictor.predict(
|
93 |
+
point_coords=None,
|
94 |
+
point_labels=None,
|
95 |
+
box=input_boxes,
|
96 |
+
multimask_output=False,
|
97 |
+
)
|
98 |
+
|
99 |
+
# convert the mask shape to (n, H, W)
|
100 |
+
if masks.ndim == 3:
|
101 |
+
masks = masks[None]
|
102 |
+
scores = scores[None]
|
103 |
+
logits = logits[None]
|
104 |
+
elif masks.ndim == 4:
|
105 |
+
masks = masks.squeeze(1)
|
106 |
+
|
107 |
+
"""
|
108 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
109 |
+
"""
|
110 |
+
|
111 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
|
112 |
+
|
113 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
114 |
+
|
115 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
116 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
117 |
+
# sample the positive points from mask for each objects
|
118 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
119 |
+
|
120 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
121 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
122 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
123 |
+
inference_state=inference_state,
|
124 |
+
frame_idx=ann_frame_idx,
|
125 |
+
obj_id=object_id,
|
126 |
+
points=points,
|
127 |
+
labels=labels,
|
128 |
+
)
|
129 |
+
# Using box prompt
|
130 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
131 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
132 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
133 |
+
inference_state=inference_state,
|
134 |
+
frame_idx=ann_frame_idx,
|
135 |
+
obj_id=object_id,
|
136 |
+
box=box,
|
137 |
+
)
|
138 |
+
# Using mask prompt is a more straightforward way
|
139 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
140 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
141 |
+
labels = np.ones((1), dtype=np.int32)
|
142 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
143 |
+
inference_state=inference_state,
|
144 |
+
frame_idx=ann_frame_idx,
|
145 |
+
obj_id=object_id,
|
146 |
+
mask=mask
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
150 |
+
|
151 |
+
|
152 |
+
"""
|
153 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
154 |
+
"""
|
155 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
156 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
157 |
+
video_segments[out_frame_idx] = {
|
158 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
159 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
160 |
+
}
|
161 |
+
|
162 |
+
"""
|
163 |
+
Step 5: Visualize the segment results across the video and save them
|
164 |
+
"""
|
165 |
+
|
166 |
+
save_dir = "./tracking_results"
|
167 |
+
|
168 |
+
if not os.path.exists(save_dir):
|
169 |
+
os.makedirs(save_dir)
|
170 |
+
|
171 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
172 |
+
for frame_idx, segments in video_segments.items():
|
173 |
+
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
|
174 |
+
|
175 |
+
object_ids = list(segments.keys())
|
176 |
+
masks = list(segments.values())
|
177 |
+
masks = np.concatenate(masks, axis=0)
|
178 |
+
|
179 |
+
detections = sv.Detections(
|
180 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
181 |
+
mask=masks, # (n, h, w)
|
182 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
183 |
+
)
|
184 |
+
box_annotator = sv.BoxAnnotator()
|
185 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
186 |
+
label_annotator = sv.LabelAnnotator()
|
187 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
188 |
+
mask_annotator = sv.MaskAnnotator()
|
189 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
190 |
+
cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
191 |
+
|
192 |
+
|
193 |
+
"""
|
194 |
+
Step 6: Convert the annotated frames to video
|
195 |
+
"""
|
196 |
+
|
197 |
+
output_video_path = "./children_tracking_demo_video.mp4"
|
198 |
+
create_video_from_images(save_dir, output_video_path)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_dinox.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
5 |
+
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
6 |
+
from dds_cloudapi_sdk import TextPrompt
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import supervision as sv
|
13 |
+
|
14 |
+
from pathlib import Path
|
15 |
+
from tqdm import tqdm
|
16 |
+
from PIL import Image
|
17 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
18 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
19 |
+
from utils.track_utils import sample_points_from_masks
|
20 |
+
from utils.video_utils import create_video_from_images
|
21 |
+
|
22 |
+
"""
|
23 |
+
Hyperparam for Ground and Tracking
|
24 |
+
"""
|
25 |
+
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
26 |
+
TEXT_PROMPT = "hippopotamus."
|
27 |
+
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
28 |
+
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
29 |
+
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
30 |
+
API_TOKEN_FOR_DINOX = "Your API token"
|
31 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
32 |
+
BOX_THRESHOLD = 0.2
|
33 |
+
|
34 |
+
"""
|
35 |
+
Step 1: Environment settings and model initialization for SAM 2
|
36 |
+
"""
|
37 |
+
# use bfloat16 for the entire notebook
|
38 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
39 |
+
|
40 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
41 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
42 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
43 |
+
torch.backends.cudnn.allow_tf32 = True
|
44 |
+
|
45 |
+
# init sam image predictor and video predictor model
|
46 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
47 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
48 |
+
|
49 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
50 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
51 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
52 |
+
|
53 |
+
|
54 |
+
# # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
55 |
+
# video_dir = "notebooks/videos/bedroom"
|
56 |
+
|
57 |
+
"""
|
58 |
+
Custom video input directly using video files
|
59 |
+
"""
|
60 |
+
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
|
61 |
+
print(video_info)
|
62 |
+
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
|
63 |
+
|
64 |
+
# saving video to frames
|
65 |
+
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
|
66 |
+
source_frames.mkdir(parents=True, exist_ok=True)
|
67 |
+
|
68 |
+
with sv.ImageSink(
|
69 |
+
target_dir_path=source_frames,
|
70 |
+
overwrite=True,
|
71 |
+
image_name_pattern="{:05d}.jpg"
|
72 |
+
) as sink:
|
73 |
+
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
|
74 |
+
sink.save_image(frame)
|
75 |
+
|
76 |
+
# scan all the JPEG frame names in this directory
|
77 |
+
frame_names = [
|
78 |
+
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
|
79 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
80 |
+
]
|
81 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
82 |
+
|
83 |
+
# init video predictor state
|
84 |
+
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
|
85 |
+
|
86 |
+
ann_frame_idx = 0 # the frame index we interact with
|
87 |
+
"""
|
88 |
+
Step 2: Prompt DINO-X with Cloud API for box coordinates
|
89 |
+
"""
|
90 |
+
|
91 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
92 |
+
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
|
93 |
+
image = Image.open(img_path)
|
94 |
+
|
95 |
+
# Step 1: initialize the config
|
96 |
+
config = Config(API_TOKEN_FOR_DINOX)
|
97 |
+
|
98 |
+
# Step 2: initialize the client
|
99 |
+
client = Client(config)
|
100 |
+
|
101 |
+
# Step 3: run the task by DetectionTask class
|
102 |
+
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
103 |
+
# if you are processing local image file, upload them to DDS server to get the image url
|
104 |
+
image_url = client.upload_file(img_path)
|
105 |
+
|
106 |
+
task = DinoxTask(
|
107 |
+
image_url=image_url,
|
108 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
109 |
+
bbox_threshold=0.25,
|
110 |
+
targets=[DetectionTarget.BBox],
|
111 |
+
)
|
112 |
+
|
113 |
+
client.run_task(task)
|
114 |
+
result = task.result
|
115 |
+
|
116 |
+
objects = result.objects # the list of detected objects
|
117 |
+
|
118 |
+
|
119 |
+
input_boxes = []
|
120 |
+
confidences = []
|
121 |
+
class_names = []
|
122 |
+
|
123 |
+
for idx, obj in enumerate(objects):
|
124 |
+
input_boxes.append(obj.bbox)
|
125 |
+
confidences.append(obj.score)
|
126 |
+
class_names.append(obj.category)
|
127 |
+
|
128 |
+
input_boxes = np.array(input_boxes)
|
129 |
+
|
130 |
+
print(input_boxes)
|
131 |
+
|
132 |
+
# prompt SAM image predictor to get the mask for the object
|
133 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
134 |
+
|
135 |
+
# process the detection results
|
136 |
+
OBJECTS = class_names
|
137 |
+
|
138 |
+
print(OBJECTS)
|
139 |
+
|
140 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
141 |
+
masks, scores, logits = image_predictor.predict(
|
142 |
+
point_coords=None,
|
143 |
+
point_labels=None,
|
144 |
+
box=input_boxes,
|
145 |
+
multimask_output=False,
|
146 |
+
)
|
147 |
+
# convert the mask shape to (n, H, W)
|
148 |
+
if masks.ndim == 4:
|
149 |
+
masks = masks.squeeze(1)
|
150 |
+
|
151 |
+
"""
|
152 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
153 |
+
"""
|
154 |
+
|
155 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
156 |
+
|
157 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
158 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
159 |
+
# sample the positive points from mask for each objects
|
160 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
161 |
+
|
162 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
163 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
164 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
165 |
+
inference_state=inference_state,
|
166 |
+
frame_idx=ann_frame_idx,
|
167 |
+
obj_id=object_id,
|
168 |
+
points=points,
|
169 |
+
labels=labels,
|
170 |
+
)
|
171 |
+
# Using box prompt
|
172 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
173 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
174 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
175 |
+
inference_state=inference_state,
|
176 |
+
frame_idx=ann_frame_idx,
|
177 |
+
obj_id=object_id,
|
178 |
+
box=box,
|
179 |
+
)
|
180 |
+
# Using mask prompt is a more straightforward way
|
181 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
182 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
183 |
+
labels = np.ones((1), dtype=np.int32)
|
184 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
185 |
+
inference_state=inference_state,
|
186 |
+
frame_idx=ann_frame_idx,
|
187 |
+
obj_id=object_id,
|
188 |
+
mask=mask
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
192 |
+
|
193 |
+
"""
|
194 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
195 |
+
"""
|
196 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
197 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
198 |
+
video_segments[out_frame_idx] = {
|
199 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
200 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
201 |
+
}
|
202 |
+
|
203 |
+
"""
|
204 |
+
Step 5: Visualize the segment results across the video and save them
|
205 |
+
"""
|
206 |
+
|
207 |
+
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
|
208 |
+
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
|
209 |
+
|
210 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
211 |
+
|
212 |
+
for frame_idx, segments in video_segments.items():
|
213 |
+
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
|
214 |
+
|
215 |
+
object_ids = list(segments.keys())
|
216 |
+
masks = list(segments.values())
|
217 |
+
masks = np.concatenate(masks, axis=0)
|
218 |
+
|
219 |
+
detections = sv.Detections(
|
220 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
221 |
+
mask=masks, # (n, h, w)
|
222 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
223 |
+
)
|
224 |
+
box_annotator = sv.BoxAnnotator()
|
225 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
226 |
+
label_annotator = sv.LabelAnnotator()
|
227 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
228 |
+
mask_annotator = sv.MaskAnnotator()
|
229 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
230 |
+
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
231 |
+
|
232 |
+
|
233 |
+
"""
|
234 |
+
Step 6: Convert the annotated frames to video
|
235 |
+
"""
|
236 |
+
|
237 |
+
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
11 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
13 |
+
from utils.track_utils import sample_points_from_masks
|
14 |
+
from utils.video_utils import create_video_from_images
|
15 |
+
|
16 |
+
"""
|
17 |
+
Hyperparam for Ground and Tracking
|
18 |
+
"""
|
19 |
+
MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
20 |
+
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
21 |
+
TEXT_PROMPT = "hippopotamus."
|
22 |
+
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
23 |
+
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
24 |
+
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
25 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
26 |
+
|
27 |
+
"""
|
28 |
+
Step 1: Environment settings and model initialization for SAM 2
|
29 |
+
"""
|
30 |
+
# use bfloat16 for the entire notebook
|
31 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
32 |
+
|
33 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
34 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
35 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
36 |
+
torch.backends.cudnn.allow_tf32 = True
|
37 |
+
|
38 |
+
# init sam image predictor and video predictor model
|
39 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
40 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
41 |
+
|
42 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
43 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
44 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
45 |
+
|
46 |
+
# build grounding dino from huggingface
|
47 |
+
model_id = MODEL_ID
|
48 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
50 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
51 |
+
|
52 |
+
|
53 |
+
"""
|
54 |
+
Custom video input directly using video files
|
55 |
+
"""
|
56 |
+
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
|
57 |
+
print(video_info)
|
58 |
+
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
|
59 |
+
|
60 |
+
# saving video to frames
|
61 |
+
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
|
62 |
+
source_frames.mkdir(parents=True, exist_ok=True)
|
63 |
+
|
64 |
+
with sv.ImageSink(
|
65 |
+
target_dir_path=source_frames,
|
66 |
+
overwrite=True,
|
67 |
+
image_name_pattern="{:05d}.jpg"
|
68 |
+
) as sink:
|
69 |
+
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
|
70 |
+
sink.save_image(frame)
|
71 |
+
|
72 |
+
# scan all the JPEG frame names in this directory
|
73 |
+
frame_names = [
|
74 |
+
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
|
75 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
76 |
+
]
|
77 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
78 |
+
|
79 |
+
# init video predictor state
|
80 |
+
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
|
81 |
+
|
82 |
+
ann_frame_idx = 0 # the frame index we interact with
|
83 |
+
"""
|
84 |
+
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
85 |
+
"""
|
86 |
+
|
87 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
88 |
+
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
|
89 |
+
image = Image.open(img_path)
|
90 |
+
inputs = processor(images=image, text=TEXT_PROMPT, return_tensors="pt").to(device)
|
91 |
+
with torch.no_grad():
|
92 |
+
outputs = grounding_model(**inputs)
|
93 |
+
|
94 |
+
results = processor.post_process_grounded_object_detection(
|
95 |
+
outputs,
|
96 |
+
inputs.input_ids,
|
97 |
+
box_threshold=0.4,
|
98 |
+
text_threshold=0.3,
|
99 |
+
target_sizes=[image.size[::-1]]
|
100 |
+
)
|
101 |
+
|
102 |
+
input_boxes = results[0]["boxes"].cpu().numpy()
|
103 |
+
confidences = results[0]["scores"].cpu().numpy().tolist()
|
104 |
+
class_names = results[0]["labels"]
|
105 |
+
|
106 |
+
print(input_boxes)
|
107 |
+
|
108 |
+
# prompt SAM image predictor to get the mask for the object
|
109 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
110 |
+
|
111 |
+
# process the detection results
|
112 |
+
OBJECTS = class_names
|
113 |
+
|
114 |
+
print(OBJECTS)
|
115 |
+
|
116 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
117 |
+
masks, scores, logits = image_predictor.predict(
|
118 |
+
point_coords=None,
|
119 |
+
point_labels=None,
|
120 |
+
box=input_boxes,
|
121 |
+
multimask_output=False,
|
122 |
+
)
|
123 |
+
# convert the mask shape to (n, H, W)
|
124 |
+
if masks.ndim == 4:
|
125 |
+
masks = masks.squeeze(1)
|
126 |
+
|
127 |
+
"""
|
128 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
129 |
+
"""
|
130 |
+
|
131 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
132 |
+
|
133 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
134 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
135 |
+
# sample the positive points from mask for each objects
|
136 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
137 |
+
|
138 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
139 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
140 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
141 |
+
inference_state=inference_state,
|
142 |
+
frame_idx=ann_frame_idx,
|
143 |
+
obj_id=object_id,
|
144 |
+
points=points,
|
145 |
+
labels=labels,
|
146 |
+
)
|
147 |
+
# Using box prompt
|
148 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
149 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
150 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
151 |
+
inference_state=inference_state,
|
152 |
+
frame_idx=ann_frame_idx,
|
153 |
+
obj_id=object_id,
|
154 |
+
box=box,
|
155 |
+
)
|
156 |
+
# Using mask prompt is a more straightforward way
|
157 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
158 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
159 |
+
labels = np.ones((1), dtype=np.int32)
|
160 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
161 |
+
inference_state=inference_state,
|
162 |
+
frame_idx=ann_frame_idx,
|
163 |
+
obj_id=object_id,
|
164 |
+
mask=mask
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
168 |
+
|
169 |
+
|
170 |
+
"""
|
171 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
172 |
+
"""
|
173 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
174 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
175 |
+
video_segments[out_frame_idx] = {
|
176 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
177 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
178 |
+
}
|
179 |
+
|
180 |
+
"""
|
181 |
+
Step 5: Visualize the segment results across the video and save them
|
182 |
+
"""
|
183 |
+
|
184 |
+
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
|
185 |
+
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
|
186 |
+
|
187 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
188 |
+
|
189 |
+
for frame_idx, segments in video_segments.items():
|
190 |
+
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
|
191 |
+
|
192 |
+
object_ids = list(segments.keys())
|
193 |
+
masks = list(segments.values())
|
194 |
+
masks = np.concatenate(masks, axis=0)
|
195 |
+
|
196 |
+
detections = sv.Detections(
|
197 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
198 |
+
mask=masks, # (n, h, w)
|
199 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
200 |
+
)
|
201 |
+
box_annotator = sv.BoxAnnotator()
|
202 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
203 |
+
label_annotator = sv.LabelAnnotator()
|
204 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
205 |
+
mask_annotator = sv.MaskAnnotator()
|
206 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
207 |
+
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
208 |
+
|
209 |
+
|
210 |
+
"""
|
211 |
+
Step 6: Convert the annotated frames to video
|
212 |
+
"""
|
213 |
+
|
214 |
+
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
from torchvision.ops import box_convert
|
7 |
+
from pathlib import Path
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
11 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
+
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
13 |
+
from utils.track_utils import sample_points_from_masks
|
14 |
+
from utils.video_utils import create_video_from_images
|
15 |
+
|
16 |
+
"""
|
17 |
+
Hyperparam for Ground and Tracking
|
18 |
+
"""
|
19 |
+
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
20 |
+
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
21 |
+
BOX_THRESHOLD = 0.35
|
22 |
+
TEXT_THRESHOLD = 0.25
|
23 |
+
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
24 |
+
TEXT_PROMPT = "hippopotamus."
|
25 |
+
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
26 |
+
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
27 |
+
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
28 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
29 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
|
31 |
+
"""
|
32 |
+
Step 1: Environment settings and model initialization for Grounding DINO and SAM 2
|
33 |
+
"""
|
34 |
+
# build grounding dino model from local path
|
35 |
+
grounding_model = load_model(
|
36 |
+
model_config_path=GROUNDING_DINO_CONFIG,
|
37 |
+
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
|
38 |
+
device=DEVICE
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
# init sam image predictor and video predictor model
|
43 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
44 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
45 |
+
|
46 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
47 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
48 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
49 |
+
|
50 |
+
|
51 |
+
"""
|
52 |
+
Custom video input directly using video files
|
53 |
+
"""
|
54 |
+
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
|
55 |
+
print(video_info)
|
56 |
+
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
|
57 |
+
|
58 |
+
# saving video to frames
|
59 |
+
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
|
60 |
+
source_frames.mkdir(parents=True, exist_ok=True)
|
61 |
+
|
62 |
+
with sv.ImageSink(
|
63 |
+
target_dir_path=source_frames,
|
64 |
+
overwrite=True,
|
65 |
+
image_name_pattern="{:05d}.jpg"
|
66 |
+
) as sink:
|
67 |
+
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
|
68 |
+
sink.save_image(frame)
|
69 |
+
|
70 |
+
# scan all the JPEG frame names in this directory
|
71 |
+
frame_names = [
|
72 |
+
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
|
73 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
74 |
+
]
|
75 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
76 |
+
|
77 |
+
# init video predictor state
|
78 |
+
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
|
79 |
+
|
80 |
+
ann_frame_idx = 0 # the frame index we interact with
|
81 |
+
"""
|
82 |
+
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
83 |
+
"""
|
84 |
+
|
85 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
86 |
+
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
|
87 |
+
image_source, image = load_image(img_path)
|
88 |
+
|
89 |
+
boxes, confidences, labels = predict(
|
90 |
+
model=grounding_model,
|
91 |
+
image=image,
|
92 |
+
caption=TEXT_PROMPT,
|
93 |
+
box_threshold=BOX_THRESHOLD,
|
94 |
+
text_threshold=TEXT_THRESHOLD,
|
95 |
+
)
|
96 |
+
|
97 |
+
# process the box prompt for SAM 2
|
98 |
+
h, w, _ = image_source.shape
|
99 |
+
boxes = boxes * torch.Tensor([w, h, w, h])
|
100 |
+
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
101 |
+
confidences = confidences.numpy().tolist()
|
102 |
+
class_names = labels
|
103 |
+
|
104 |
+
print(input_boxes)
|
105 |
+
|
106 |
+
# prompt SAM image predictor to get the mask for the object
|
107 |
+
image_predictor.set_image(image_source)
|
108 |
+
|
109 |
+
# process the detection results
|
110 |
+
OBJECTS = class_names
|
111 |
+
|
112 |
+
print(OBJECTS)
|
113 |
+
|
114 |
+
# FIXME: figure how does this influence the G-DINO model
|
115 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
116 |
+
|
117 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
118 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
119 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
120 |
+
torch.backends.cudnn.allow_tf32 = True
|
121 |
+
|
122 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
123 |
+
masks, scores, logits = image_predictor.predict(
|
124 |
+
point_coords=None,
|
125 |
+
point_labels=None,
|
126 |
+
box=input_boxes,
|
127 |
+
multimask_output=False,
|
128 |
+
)
|
129 |
+
# convert the mask shape to (n, H, W)
|
130 |
+
if masks.ndim == 4:
|
131 |
+
masks = masks.squeeze(1)
|
132 |
+
|
133 |
+
"""
|
134 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
135 |
+
"""
|
136 |
+
|
137 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
138 |
+
|
139 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
140 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
141 |
+
# sample the positive points from mask for each objects
|
142 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
143 |
+
|
144 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
145 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
146 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
147 |
+
inference_state=inference_state,
|
148 |
+
frame_idx=ann_frame_idx,
|
149 |
+
obj_id=object_id,
|
150 |
+
points=points,
|
151 |
+
labels=labels,
|
152 |
+
)
|
153 |
+
# Using box prompt
|
154 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
155 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
156 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
157 |
+
inference_state=inference_state,
|
158 |
+
frame_idx=ann_frame_idx,
|
159 |
+
obj_id=object_id,
|
160 |
+
box=box,
|
161 |
+
)
|
162 |
+
# Using mask prompt is a more straightforward way
|
163 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
164 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
165 |
+
labels = np.ones((1), dtype=np.int32)
|
166 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
167 |
+
inference_state=inference_state,
|
168 |
+
frame_idx=ann_frame_idx,
|
169 |
+
obj_id=object_id,
|
170 |
+
mask=mask
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
174 |
+
|
175 |
+
|
176 |
+
"""
|
177 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
178 |
+
"""
|
179 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
180 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
181 |
+
video_segments[out_frame_idx] = {
|
182 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
183 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
184 |
+
}
|
185 |
+
|
186 |
+
"""
|
187 |
+
Step 5: Visualize the segment results across the video and save them
|
188 |
+
"""
|
189 |
+
|
190 |
+
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
|
191 |
+
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
|
192 |
+
|
193 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
194 |
+
|
195 |
+
for frame_idx, segments in video_segments.items():
|
196 |
+
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
|
197 |
+
|
198 |
+
object_ids = list(segments.keys())
|
199 |
+
masks = list(segments.values())
|
200 |
+
masks = np.concatenate(masks, axis=0)
|
201 |
+
|
202 |
+
detections = sv.Detections(
|
203 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
204 |
+
mask=masks, # (n, h, w)
|
205 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
206 |
+
)
|
207 |
+
box_annotator = sv.BoxAnnotator()
|
208 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
209 |
+
label_annotator = sv.LabelAnnotator()
|
210 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
211 |
+
mask_annotator = sv.MaskAnnotator()
|
212 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
213 |
+
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
214 |
+
|
215 |
+
|
216 |
+
"""
|
217 |
+
Step 6: Convert the annotated frames to video
|
218 |
+
"""
|
219 |
+
|
220 |
+
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk import DetectionTask
|
5 |
+
from dds_cloudapi_sdk import TextPrompt
|
6 |
+
from dds_cloudapi_sdk import DetectionModel
|
7 |
+
from dds_cloudapi_sdk import DetectionTarget
|
8 |
+
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import supervision as sv
|
14 |
+
|
15 |
+
from pathlib import Path
|
16 |
+
from tqdm import tqdm
|
17 |
+
from PIL import Image
|
18 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
19 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
20 |
+
from utils.track_utils import sample_points_from_masks
|
21 |
+
from utils.video_utils import create_video_from_images
|
22 |
+
|
23 |
+
"""
|
24 |
+
Hyperparam for Ground and Tracking
|
25 |
+
"""
|
26 |
+
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
27 |
+
TEXT_PROMPT = "hippopotamus."
|
28 |
+
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
29 |
+
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
30 |
+
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
31 |
+
API_TOKEN_FOR_GD1_5 = "Your API token"
|
32 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
33 |
+
BOX_THRESHOLD = 0.2
|
34 |
+
|
35 |
+
"""
|
36 |
+
Step 1: Environment settings and model initialization for SAM 2
|
37 |
+
"""
|
38 |
+
# use bfloat16 for the entire notebook
|
39 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
40 |
+
|
41 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
42 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
43 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
44 |
+
torch.backends.cudnn.allow_tf32 = True
|
45 |
+
|
46 |
+
# init sam image predictor and video predictor model
|
47 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
48 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
49 |
+
|
50 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
51 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
52 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
53 |
+
|
54 |
+
|
55 |
+
# # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
56 |
+
# video_dir = "notebooks/videos/bedroom"
|
57 |
+
|
58 |
+
"""
|
59 |
+
Custom video input directly using video files
|
60 |
+
"""
|
61 |
+
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
|
62 |
+
print(video_info)
|
63 |
+
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
|
64 |
+
|
65 |
+
# saving video to frames
|
66 |
+
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
|
67 |
+
source_frames.mkdir(parents=True, exist_ok=True)
|
68 |
+
|
69 |
+
with sv.ImageSink(
|
70 |
+
target_dir_path=source_frames,
|
71 |
+
overwrite=True,
|
72 |
+
image_name_pattern="{:05d}.jpg"
|
73 |
+
) as sink:
|
74 |
+
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
|
75 |
+
sink.save_image(frame)
|
76 |
+
|
77 |
+
# scan all the JPEG frame names in this directory
|
78 |
+
frame_names = [
|
79 |
+
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
|
80 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
81 |
+
]
|
82 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
83 |
+
|
84 |
+
# init video predictor state
|
85 |
+
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
|
86 |
+
|
87 |
+
ann_frame_idx = 0 # the frame index we interact with
|
88 |
+
"""
|
89 |
+
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
90 |
+
"""
|
91 |
+
|
92 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
93 |
+
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
|
94 |
+
image = Image.open(img_path)
|
95 |
+
|
96 |
+
# Step 1: initialize the config
|
97 |
+
config = Config(API_TOKEN_FOR_GD1_5)
|
98 |
+
|
99 |
+
# Step 2: initialize the client
|
100 |
+
client = Client(config)
|
101 |
+
|
102 |
+
# Step 3: run the task by DetectionTask class
|
103 |
+
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
104 |
+
# if you are processing local image file, upload them to DDS server to get the image url
|
105 |
+
image_url = client.upload_file(img_path)
|
106 |
+
|
107 |
+
task = DetectionTask(
|
108 |
+
image_url=image_url,
|
109 |
+
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
110 |
+
targets=[DetectionTarget.BBox], # detect bbox
|
111 |
+
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
112 |
+
bbox_threshold=BOX_THRESHOLD,
|
113 |
+
)
|
114 |
+
|
115 |
+
client.run_task(task)
|
116 |
+
result = task.result
|
117 |
+
|
118 |
+
objects = result.objects # the list of detected objects
|
119 |
+
|
120 |
+
|
121 |
+
input_boxes = []
|
122 |
+
confidences = []
|
123 |
+
class_names = []
|
124 |
+
|
125 |
+
for idx, obj in enumerate(objects):
|
126 |
+
input_boxes.append(obj.bbox)
|
127 |
+
confidences.append(obj.score)
|
128 |
+
class_names.append(obj.category)
|
129 |
+
|
130 |
+
input_boxes = np.array(input_boxes)
|
131 |
+
|
132 |
+
print(input_boxes)
|
133 |
+
|
134 |
+
# prompt SAM image predictor to get the mask for the object
|
135 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
136 |
+
|
137 |
+
# process the detection results
|
138 |
+
OBJECTS = class_names
|
139 |
+
|
140 |
+
print(OBJECTS)
|
141 |
+
|
142 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
143 |
+
masks, scores, logits = image_predictor.predict(
|
144 |
+
point_coords=None,
|
145 |
+
point_labels=None,
|
146 |
+
box=input_boxes,
|
147 |
+
multimask_output=False,
|
148 |
+
)
|
149 |
+
# convert the mask shape to (n, H, W)
|
150 |
+
if masks.ndim == 4:
|
151 |
+
masks = masks.squeeze(1)
|
152 |
+
|
153 |
+
"""
|
154 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
155 |
+
"""
|
156 |
+
|
157 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
158 |
+
|
159 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
160 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
161 |
+
# sample the positive points from mask for each objects
|
162 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
163 |
+
|
164 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
165 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
166 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
167 |
+
inference_state=inference_state,
|
168 |
+
frame_idx=ann_frame_idx,
|
169 |
+
obj_id=object_id,
|
170 |
+
points=points,
|
171 |
+
labels=labels,
|
172 |
+
)
|
173 |
+
# Using box prompt
|
174 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
175 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
176 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
177 |
+
inference_state=inference_state,
|
178 |
+
frame_idx=ann_frame_idx,
|
179 |
+
obj_id=object_id,
|
180 |
+
box=box,
|
181 |
+
)
|
182 |
+
# Using mask prompt is a more straightforward way
|
183 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
184 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
185 |
+
labels = np.ones((1), dtype=np.int32)
|
186 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
187 |
+
inference_state=inference_state,
|
188 |
+
frame_idx=ann_frame_idx,
|
189 |
+
obj_id=object_id,
|
190 |
+
mask=mask
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
194 |
+
|
195 |
+
"""
|
196 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
197 |
+
"""
|
198 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
199 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
200 |
+
video_segments[out_frame_idx] = {
|
201 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
202 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
203 |
+
}
|
204 |
+
|
205 |
+
"""
|
206 |
+
Step 5: Visualize the segment results across the video and save them
|
207 |
+
"""
|
208 |
+
|
209 |
+
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
|
210 |
+
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
|
211 |
+
|
212 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
213 |
+
|
214 |
+
for frame_idx, segments in video_segments.items():
|
215 |
+
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
|
216 |
+
|
217 |
+
object_ids = list(segments.keys())
|
218 |
+
masks = list(segments.values())
|
219 |
+
masks = np.concatenate(masks, axis=0)
|
220 |
+
|
221 |
+
detections = sv.Detections(
|
222 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
223 |
+
mask=masks, # (n, h, w)
|
224 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
225 |
+
)
|
226 |
+
box_annotator = sv.BoxAnnotator()
|
227 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
228 |
+
label_annotator = sv.LabelAnnotator()
|
229 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
230 |
+
mask_annotator = sv.MaskAnnotator()
|
231 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
232 |
+
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
233 |
+
|
234 |
+
|
235 |
+
"""
|
236 |
+
Step 6: Convert the annotated frames to video
|
237 |
+
"""
|
238 |
+
|
239 |
+
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
from PIL import Image
|
7 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
10 |
+
from utils.track_utils import sample_points_from_masks
|
11 |
+
from utils.video_utils import create_video_from_images
|
12 |
+
from utils.common_utils import CommonUtils
|
13 |
+
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
14 |
+
import json
|
15 |
+
import copy
|
16 |
+
|
17 |
+
"""
|
18 |
+
Step 1: Environment settings and model initialization
|
19 |
+
"""
|
20 |
+
# use bfloat16 for the entire notebook
|
21 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
22 |
+
|
23 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
24 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
+
torch.backends.cudnn.allow_tf32 = True
|
27 |
+
|
28 |
+
# init sam image predictor and video predictor model
|
29 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
30 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
print("device", device)
|
33 |
+
|
34 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
35 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
36 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
37 |
+
|
38 |
+
|
39 |
+
# init grounding dino model from huggingface
|
40 |
+
model_id = "IDEA-Research/grounding-dino-tiny"
|
41 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
42 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
43 |
+
|
44 |
+
|
45 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
46 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
47 |
+
text = "car."
|
48 |
+
|
49 |
+
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
50 |
+
video_dir = "notebooks/videos/car"
|
51 |
+
# 'output_dir' is the directory to save the annotated frames
|
52 |
+
output_dir = "./outputs"
|
53 |
+
# 'output_video_path' is the path to save the final video
|
54 |
+
output_video_path = "./outputs/output.mp4"
|
55 |
+
# create the output directory
|
56 |
+
CommonUtils.creat_dirs(output_dir)
|
57 |
+
mask_data_dir = os.path.join(output_dir, "mask_data")
|
58 |
+
json_data_dir = os.path.join(output_dir, "json_data")
|
59 |
+
result_dir = os.path.join(output_dir, "result")
|
60 |
+
CommonUtils.creat_dirs(mask_data_dir)
|
61 |
+
CommonUtils.creat_dirs(json_data_dir)
|
62 |
+
# scan all the JPEG frame names in this directory
|
63 |
+
frame_names = [
|
64 |
+
p for p in os.listdir(video_dir)
|
65 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
66 |
+
]
|
67 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
68 |
+
|
69 |
+
# init video predictor state
|
70 |
+
inference_state = video_predictor.init_state(video_path=video_dir, offload_video_to_cpu=True, async_loading_frames=True)
|
71 |
+
step = 20 # the step to sample frames for Grounding DINO predictor
|
72 |
+
|
73 |
+
sam2_masks = MaskDictionaryModel()
|
74 |
+
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
75 |
+
objects_count = 0
|
76 |
+
|
77 |
+
"""
|
78 |
+
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
79 |
+
"""
|
80 |
+
print("Total frames:", len(frame_names))
|
81 |
+
for start_frame_idx in range(0, len(frame_names), step):
|
82 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
83 |
+
print("start_frame_idx", start_frame_idx)
|
84 |
+
# continue
|
85 |
+
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
86 |
+
image = Image.open(img_path)
|
87 |
+
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
88 |
+
mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
89 |
+
|
90 |
+
# run Grounding DINO on the image
|
91 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
92 |
+
with torch.no_grad():
|
93 |
+
outputs = grounding_model(**inputs)
|
94 |
+
|
95 |
+
results = processor.post_process_grounded_object_detection(
|
96 |
+
outputs,
|
97 |
+
inputs.input_ids,
|
98 |
+
box_threshold=0.25,
|
99 |
+
text_threshold=0.25,
|
100 |
+
target_sizes=[image.size[::-1]]
|
101 |
+
)
|
102 |
+
|
103 |
+
# prompt SAM image predictor to get the mask for the object
|
104 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
105 |
+
|
106 |
+
# process the detection results
|
107 |
+
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
108 |
+
# print("results[0]",results[0])
|
109 |
+
OBJECTS = results[0]["labels"]
|
110 |
+
if input_boxes.shape[0] != 0:
|
111 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
112 |
+
masks, scores, logits = image_predictor.predict(
|
113 |
+
point_coords=None,
|
114 |
+
point_labels=None,
|
115 |
+
box=input_boxes,
|
116 |
+
multimask_output=False,
|
117 |
+
)
|
118 |
+
# convert the mask shape to (n, H, W)
|
119 |
+
if masks.ndim == 2:
|
120 |
+
masks = masks[None]
|
121 |
+
scores = scores[None]
|
122 |
+
logits = logits[None]
|
123 |
+
elif masks.ndim == 4:
|
124 |
+
masks = masks.squeeze(1)
|
125 |
+
|
126 |
+
"""
|
127 |
+
Step 3: Register each object's positive points to video predictor
|
128 |
+
"""
|
129 |
+
|
130 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
131 |
+
if mask_dict.promote_type == "mask":
|
132 |
+
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
133 |
+
else:
|
134 |
+
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
135 |
+
|
136 |
+
|
137 |
+
"""
|
138 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
139 |
+
"""
|
140 |
+
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
141 |
+
print("objects_count", objects_count)
|
142 |
+
else:
|
143 |
+
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
144 |
+
mask_dict = sam2_masks
|
145 |
+
|
146 |
+
|
147 |
+
if len(mask_dict.labels) == 0:
|
148 |
+
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
149 |
+
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
150 |
+
continue
|
151 |
+
else:
|
152 |
+
video_predictor.reset_state(inference_state)
|
153 |
+
|
154 |
+
for object_id, object_info in mask_dict.labels.items():
|
155 |
+
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
156 |
+
inference_state,
|
157 |
+
start_frame_idx,
|
158 |
+
object_id,
|
159 |
+
object_info.mask,
|
160 |
+
)
|
161 |
+
|
162 |
+
video_segments = {} # output the following {step} frames tracking masks
|
163 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
164 |
+
frame_masks = MaskDictionaryModel()
|
165 |
+
|
166 |
+
for i, out_obj_id in enumerate(out_obj_ids):
|
167 |
+
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
168 |
+
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
169 |
+
object_info.update_box()
|
170 |
+
frame_masks.labels[out_obj_id] = object_info
|
171 |
+
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
172 |
+
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
173 |
+
frame_masks.mask_height = out_mask.shape[-2]
|
174 |
+
frame_masks.mask_width = out_mask.shape[-1]
|
175 |
+
|
176 |
+
video_segments[out_frame_idx] = frame_masks
|
177 |
+
sam2_masks = copy.deepcopy(frame_masks)
|
178 |
+
|
179 |
+
print("video_segments:", len(video_segments))
|
180 |
+
"""
|
181 |
+
Step 5: save the tracking masks and json files
|
182 |
+
"""
|
183 |
+
for frame_idx, frame_masks_info in video_segments.items():
|
184 |
+
mask = frame_masks_info.labels
|
185 |
+
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
186 |
+
for obj_id, obj_info in mask.items():
|
187 |
+
mask_img[obj_info.mask == True] = obj_id
|
188 |
+
|
189 |
+
mask_img = mask_img.numpy().astype(np.uint16)
|
190 |
+
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
191 |
+
|
192 |
+
json_data = frame_masks_info.to_dict()
|
193 |
+
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
194 |
+
with open(json_data_path, "w") as f:
|
195 |
+
json.dump(json_data, f)
|
196 |
+
|
197 |
+
|
198 |
+
"""
|
199 |
+
Step 6: Draw the results and save the video
|
200 |
+
"""
|
201 |
+
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
202 |
+
|
203 |
+
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk import DetectionTask
|
5 |
+
from dds_cloudapi_sdk import TextPrompt
|
6 |
+
from dds_cloudapi_sdk import DetectionModel
|
7 |
+
from dds_cloudapi_sdk import DetectionTarget
|
8 |
+
|
9 |
+
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
16 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
17 |
+
from utils.video_utils import create_video_from_images
|
18 |
+
from utils.common_utils import CommonUtils
|
19 |
+
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
20 |
+
import json
|
21 |
+
import copy
|
22 |
+
|
23 |
+
"""
|
24 |
+
Step 1: Environment settings and model initialization
|
25 |
+
"""
|
26 |
+
# use bfloat16 for the entire notebook
|
27 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
28 |
+
|
29 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
30 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
31 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
32 |
+
torch.backends.cudnn.allow_tf32 = True
|
33 |
+
|
34 |
+
# init sam image predictor and video predictor model
|
35 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
36 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
+
print("device", device)
|
39 |
+
|
40 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
41 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
42 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
43 |
+
|
44 |
+
|
45 |
+
# init grounding dino model from huggingface
|
46 |
+
model_id = "IDEA-Research/grounding-dino-tiny"
|
47 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
48 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
49 |
+
|
50 |
+
|
51 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
52 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
53 |
+
text = "car."
|
54 |
+
|
55 |
+
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
56 |
+
video_dir = "notebooks/videos/car"
|
57 |
+
# 'output_dir' is the directory to save the annotated frames
|
58 |
+
output_dir = "./outputs"
|
59 |
+
# 'output_video_path' is the path to save the final video
|
60 |
+
output_video_path = "./outputs/output.mp4"
|
61 |
+
# create the output directory
|
62 |
+
CommonUtils.creat_dirs(output_dir)
|
63 |
+
mask_data_dir = os.path.join(output_dir, "mask_data")
|
64 |
+
json_data_dir = os.path.join(output_dir, "json_data")
|
65 |
+
result_dir = os.path.join(output_dir, "result")
|
66 |
+
CommonUtils.creat_dirs(mask_data_dir)
|
67 |
+
CommonUtils.creat_dirs(json_data_dir)
|
68 |
+
# scan all the JPEG frame names in this directory
|
69 |
+
frame_names = [
|
70 |
+
p for p in os.listdir(video_dir)
|
71 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
72 |
+
]
|
73 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
74 |
+
|
75 |
+
# init video predictor state
|
76 |
+
inference_state = video_predictor.init_state(video_path=video_dir)
|
77 |
+
step = 10 # the step to sample frames for Grounding DINO predictor
|
78 |
+
|
79 |
+
sam2_masks = MaskDictionaryModel()
|
80 |
+
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
81 |
+
objects_count = 0
|
82 |
+
|
83 |
+
"""
|
84 |
+
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
85 |
+
"""
|
86 |
+
print("Total frames:", len(frame_names))
|
87 |
+
for start_frame_idx in range(0, len(frame_names), step):
|
88 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
89 |
+
print("start_frame_idx", start_frame_idx)
|
90 |
+
# continue
|
91 |
+
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
92 |
+
image = Image.open(img_path)
|
93 |
+
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
94 |
+
mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
95 |
+
|
96 |
+
# run Grounding DINO 1.5 on the image
|
97 |
+
|
98 |
+
API_TOKEN_FOR_GD1_5 = "Your API token"
|
99 |
+
|
100 |
+
config = Config(API_TOKEN_FOR_GD1_5)
|
101 |
+
# Step 2: initialize the client
|
102 |
+
client = Client(config)
|
103 |
+
|
104 |
+
image_url = client.upload_file(img_path)
|
105 |
+
task = DetectionTask(
|
106 |
+
image_url=image_url,
|
107 |
+
prompts=[TextPrompt(text=text)],
|
108 |
+
targets=[DetectionTarget.BBox], # detect bbox
|
109 |
+
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
110 |
+
)
|
111 |
+
client.run_task(task)
|
112 |
+
result = task.result
|
113 |
+
|
114 |
+
objects = result.objects # the list of detected objects
|
115 |
+
input_boxes = []
|
116 |
+
confidences = []
|
117 |
+
class_names = []
|
118 |
+
|
119 |
+
for idx, obj in enumerate(objects):
|
120 |
+
input_boxes.append(obj.bbox)
|
121 |
+
confidences.append(obj.score)
|
122 |
+
class_names.append(obj.category)
|
123 |
+
|
124 |
+
input_boxes = np.array(input_boxes)
|
125 |
+
OBJECTS = class_names
|
126 |
+
if input_boxes.shape[0] != 0:
|
127 |
+
# prompt SAM image predictor to get the mask for the object
|
128 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
129 |
+
|
130 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
131 |
+
masks, scores, logits = image_predictor.predict(
|
132 |
+
point_coords=None,
|
133 |
+
point_labels=None,
|
134 |
+
box=input_boxes,
|
135 |
+
multimask_output=False,
|
136 |
+
)
|
137 |
+
# convert the mask shape to (n, H, W)
|
138 |
+
if masks.ndim == 2:
|
139 |
+
masks = masks[None]
|
140 |
+
scores = scores[None]
|
141 |
+
logits = logits[None]
|
142 |
+
elif masks.ndim == 4:
|
143 |
+
masks = masks.squeeze(1)
|
144 |
+
|
145 |
+
"""
|
146 |
+
Step 3: Register each object's positive points to video predictor
|
147 |
+
"""
|
148 |
+
|
149 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
150 |
+
if mask_dict.promote_type == "mask":
|
151 |
+
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
152 |
+
else:
|
153 |
+
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
158 |
+
print("objects_count", objects_count)
|
159 |
+
|
160 |
+
else:
|
161 |
+
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
162 |
+
mask_dict = sam2_masks
|
163 |
+
|
164 |
+
|
165 |
+
"""
|
166 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
167 |
+
"""
|
168 |
+
if len(mask_dict.labels) == 0:
|
169 |
+
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
170 |
+
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
171 |
+
continue
|
172 |
+
else:
|
173 |
+
video_predictor.reset_state(inference_state)
|
174 |
+
|
175 |
+
for object_id, object_info in mask_dict.labels.items():
|
176 |
+
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
177 |
+
inference_state,
|
178 |
+
start_frame_idx,
|
179 |
+
object_id,
|
180 |
+
object_info.mask,
|
181 |
+
)
|
182 |
+
|
183 |
+
video_segments = {} # output the following {step} frames tracking masks
|
184 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
185 |
+
frame_masks = MaskDictionaryModel()
|
186 |
+
|
187 |
+
for i, out_obj_id in enumerate(out_obj_ids):
|
188 |
+
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
189 |
+
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
190 |
+
object_info.update_box()
|
191 |
+
frame_masks.labels[out_obj_id] = object_info
|
192 |
+
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
193 |
+
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
194 |
+
frame_masks.mask_height = out_mask.shape[-2]
|
195 |
+
frame_masks.mask_width = out_mask.shape[-1]
|
196 |
+
|
197 |
+
video_segments[out_frame_idx] = frame_masks
|
198 |
+
sam2_masks = copy.deepcopy(frame_masks)
|
199 |
+
|
200 |
+
print("video_segments:", len(video_segments))
|
201 |
+
"""
|
202 |
+
Step 5: save the tracking masks and json files
|
203 |
+
"""
|
204 |
+
for frame_idx, frame_masks_info in video_segments.items():
|
205 |
+
mask = frame_masks_info.labels
|
206 |
+
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
207 |
+
for obj_id, obj_info in mask.items():
|
208 |
+
mask_img[obj_info.mask == True] = obj_id
|
209 |
+
|
210 |
+
mask_img = mask_img.numpy().astype(np.uint16)
|
211 |
+
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
212 |
+
|
213 |
+
json_data = frame_masks_info.to_dict()
|
214 |
+
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
215 |
+
with open(json_data_path, "w") as f:
|
216 |
+
json.dump(json_data, f)
|
217 |
+
|
218 |
+
|
219 |
+
"""
|
220 |
+
Step 6: Draw the results and save the video
|
221 |
+
"""
|
222 |
+
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
223 |
+
|
224 |
+
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id_plus.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
from PIL import Image
|
7 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
10 |
+
from utils.track_utils import sample_points_from_masks
|
11 |
+
from utils.video_utils import create_video_from_images
|
12 |
+
from utils.common_utils import CommonUtils
|
13 |
+
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
14 |
+
import json
|
15 |
+
import copy
|
16 |
+
|
17 |
+
# This demo shows the continuous object tracking plus reverse tracking with Grounding DINO and SAM 2
|
18 |
+
"""
|
19 |
+
Step 1: Environment settings and model initialization
|
20 |
+
"""
|
21 |
+
# use bfloat16 for the entire notebook
|
22 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
23 |
+
|
24 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
25 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
torch.backends.cudnn.allow_tf32 = True
|
28 |
+
|
29 |
+
# init sam image predictor and video predictor model
|
30 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
31 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
print("device", device)
|
34 |
+
|
35 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
36 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
37 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
38 |
+
|
39 |
+
|
40 |
+
# init grounding dino model from huggingface
|
41 |
+
model_id = "IDEA-Research/grounding-dino-tiny"
|
42 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
43 |
+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
44 |
+
|
45 |
+
|
46 |
+
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
47 |
+
# VERY important: text queries need to be lowercased + end with a dot
|
48 |
+
text = "car."
|
49 |
+
|
50 |
+
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
51 |
+
video_dir = "notebooks/videos/car"
|
52 |
+
# 'output_dir' is the directory to save the annotated frames
|
53 |
+
output_dir = "outputs"
|
54 |
+
# 'output_video_path' is the path to save the final video
|
55 |
+
output_video_path = "./outputs/output.mp4"
|
56 |
+
# create the output directory
|
57 |
+
mask_data_dir = os.path.join(output_dir, "mask_data")
|
58 |
+
json_data_dir = os.path.join(output_dir, "json_data")
|
59 |
+
result_dir = os.path.join(output_dir, "result")
|
60 |
+
CommonUtils.creat_dirs(mask_data_dir)
|
61 |
+
CommonUtils.creat_dirs(json_data_dir)
|
62 |
+
# scan all the JPEG frame names in this directory
|
63 |
+
frame_names = [
|
64 |
+
p for p in os.listdir(video_dir)
|
65 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
66 |
+
]
|
67 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
68 |
+
|
69 |
+
# init video predictor state
|
70 |
+
inference_state = video_predictor.init_state(video_path=video_dir)
|
71 |
+
step = 20 # the step to sample frames for Grounding DINO predictor
|
72 |
+
|
73 |
+
sam2_masks = MaskDictionaryModel()
|
74 |
+
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
75 |
+
objects_count = 0
|
76 |
+
frame_object_count = {}
|
77 |
+
"""
|
78 |
+
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
79 |
+
"""
|
80 |
+
print("Total frames:", len(frame_names))
|
81 |
+
for start_frame_idx in range(0, len(frame_names), step):
|
82 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
83 |
+
print("start_frame_idx", start_frame_idx)
|
84 |
+
# continue
|
85 |
+
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
86 |
+
image = Image.open(img_path).convert("RGB")
|
87 |
+
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
88 |
+
mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
89 |
+
|
90 |
+
# run Grounding DINO on the image
|
91 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
92 |
+
with torch.no_grad():
|
93 |
+
outputs = grounding_model(**inputs)
|
94 |
+
|
95 |
+
results = processor.post_process_grounded_object_detection(
|
96 |
+
outputs,
|
97 |
+
inputs.input_ids,
|
98 |
+
box_threshold=0.25,
|
99 |
+
text_threshold=0.25,
|
100 |
+
target_sizes=[image.size[::-1]]
|
101 |
+
)
|
102 |
+
|
103 |
+
# prompt SAM image predictor to get the mask for the object
|
104 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
105 |
+
|
106 |
+
# process the detection results
|
107 |
+
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
108 |
+
# print("results[0]",results[0])
|
109 |
+
OBJECTS = results[0]["labels"]
|
110 |
+
if input_boxes.shape[0] != 0:
|
111 |
+
|
112 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
113 |
+
masks, scores, logits = image_predictor.predict(
|
114 |
+
point_coords=None,
|
115 |
+
point_labels=None,
|
116 |
+
box=input_boxes,
|
117 |
+
multimask_output=False,
|
118 |
+
)
|
119 |
+
# convert the mask shape to (n, H, W)
|
120 |
+
if masks.ndim == 2:
|
121 |
+
masks = masks[None]
|
122 |
+
scores = scores[None]
|
123 |
+
logits = logits[None]
|
124 |
+
elif masks.ndim == 4:
|
125 |
+
masks = masks.squeeze(1)
|
126 |
+
"""
|
127 |
+
Step 3: Register each object's positive points to video predictor
|
128 |
+
"""
|
129 |
+
|
130 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
131 |
+
if mask_dict.promote_type == "mask":
|
132 |
+
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
133 |
+
else:
|
134 |
+
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
135 |
+
else:
|
136 |
+
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
137 |
+
mask_dict = sam2_masks
|
138 |
+
|
139 |
+
"""
|
140 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
141 |
+
"""
|
142 |
+
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
143 |
+
frame_object_count[start_frame_idx] = objects_count
|
144 |
+
print("objects_count", objects_count)
|
145 |
+
|
146 |
+
if len(mask_dict.labels) == 0:
|
147 |
+
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
148 |
+
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
149 |
+
continue
|
150 |
+
else:
|
151 |
+
video_predictor.reset_state(inference_state)
|
152 |
+
|
153 |
+
for object_id, object_info in mask_dict.labels.items():
|
154 |
+
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
155 |
+
inference_state,
|
156 |
+
start_frame_idx,
|
157 |
+
object_id,
|
158 |
+
object_info.mask,
|
159 |
+
)
|
160 |
+
|
161 |
+
video_segments = {} # output the following {step} frames tracking masks
|
162 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
163 |
+
frame_masks = MaskDictionaryModel()
|
164 |
+
|
165 |
+
for i, out_obj_id in enumerate(out_obj_ids):
|
166 |
+
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
167 |
+
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))
|
168 |
+
object_info.update_box()
|
169 |
+
frame_masks.labels[out_obj_id] = object_info
|
170 |
+
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
171 |
+
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
172 |
+
frame_masks.mask_height = out_mask.shape[-2]
|
173 |
+
frame_masks.mask_width = out_mask.shape[-1]
|
174 |
+
|
175 |
+
video_segments[out_frame_idx] = frame_masks
|
176 |
+
sam2_masks = copy.deepcopy(frame_masks)
|
177 |
+
|
178 |
+
print("video_segments:", len(video_segments))
|
179 |
+
"""
|
180 |
+
Step 5: save the tracking masks and json files
|
181 |
+
"""
|
182 |
+
for frame_idx, frame_masks_info in video_segments.items():
|
183 |
+
mask = frame_masks_info.labels
|
184 |
+
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
185 |
+
for obj_id, obj_info in mask.items():
|
186 |
+
mask_img[obj_info.mask == True] = obj_id
|
187 |
+
|
188 |
+
mask_img = mask_img.numpy().astype(np.uint16)
|
189 |
+
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
190 |
+
|
191 |
+
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
192 |
+
frame_masks_info.to_json(json_data_path)
|
193 |
+
|
194 |
+
|
195 |
+
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
196 |
+
|
197 |
+
print("try reverse tracking")
|
198 |
+
start_object_id = 0
|
199 |
+
object_info_dict = {}
|
200 |
+
for frame_idx, current_object_count in frame_object_count.items():
|
201 |
+
print("reverse tracking frame", frame_idx, frame_names[frame_idx])
|
202 |
+
if frame_idx != 0:
|
203 |
+
video_predictor.reset_state(inference_state)
|
204 |
+
image_base_name = frame_names[frame_idx].split(".")[0]
|
205 |
+
json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
|
206 |
+
json_data = MaskDictionaryModel().from_json(json_data_path)
|
207 |
+
mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
|
208 |
+
mask_array = np.load(mask_data_path)
|
209 |
+
for object_id in range(start_object_id+1, current_object_count+1):
|
210 |
+
print("reverse tracking object", object_id)
|
211 |
+
object_info_dict[object_id] = json_data.labels[object_id]
|
212 |
+
video_predictor.add_new_mask(inference_state, frame_idx, object_id, mask_array == object_id)
|
213 |
+
start_object_id = current_object_count
|
214 |
+
|
215 |
+
|
216 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step*2, start_frame_idx=frame_idx, reverse=True):
|
217 |
+
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
218 |
+
json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
|
219 |
+
json_data = MaskDictionaryModel().from_json(json_data_path)
|
220 |
+
mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
|
221 |
+
mask_array = np.load(mask_data_path)
|
222 |
+
# merge the reverse tracking masks with the original masks
|
223 |
+
for i, out_obj_id in enumerate(out_obj_ids):
|
224 |
+
out_mask = (out_mask_logits[i] > 0.0).cpu()
|
225 |
+
if out_mask.sum() == 0:
|
226 |
+
print("no mask for object", out_obj_id, "at frame", out_frame_idx)
|
227 |
+
continue
|
228 |
+
object_info = object_info_dict[out_obj_id]
|
229 |
+
object_info.mask = out_mask[0]
|
230 |
+
object_info.update_box()
|
231 |
+
json_data.labels[out_obj_id] = object_info
|
232 |
+
mask_array = np.where(mask_array != out_obj_id, mask_array, 0)
|
233 |
+
mask_array[object_info.mask] = out_obj_id
|
234 |
+
|
235 |
+
np.save(mask_data_path, mask_array)
|
236 |
+
json_data.to_json(json_data_path)
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
"""
|
243 |
+
Step 6: Draw the results and save the video
|
244 |
+
"""
|
245 |
+
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir+"_reverse")
|
246 |
+
|
247 |
+
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
clone-IDEA-Research/Grounded-SAM-2/grounded_sam2_tracking_demo_with_gd1.5.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dds cloudapi for Grounding DINO 1.5
|
2 |
+
from dds_cloudapi_sdk import Config
|
3 |
+
from dds_cloudapi_sdk import Client
|
4 |
+
from dds_cloudapi_sdk import DetectionTask
|
5 |
+
from dds_cloudapi_sdk import TextPrompt
|
6 |
+
from dds_cloudapi_sdk import DetectionModel
|
7 |
+
from dds_cloudapi_sdk import DetectionTarget
|
8 |
+
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import supervision as sv
|
14 |
+
from PIL import Image
|
15 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
16 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
17 |
+
from utils.track_utils import sample_points_from_masks
|
18 |
+
from utils.video_utils import create_video_from_images
|
19 |
+
|
20 |
+
|
21 |
+
"""
|
22 |
+
Step 1: Environment settings and model initialization for SAM 2
|
23 |
+
"""
|
24 |
+
# use bfloat16 for the entire notebook
|
25 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
26 |
+
|
27 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
28 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
29 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
30 |
+
torch.backends.cudnn.allow_tf32 = True
|
31 |
+
|
32 |
+
# init sam image predictor and video predictor model
|
33 |
+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
34 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
35 |
+
|
36 |
+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
37 |
+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
38 |
+
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
39 |
+
|
40 |
+
|
41 |
+
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
42 |
+
video_dir = "notebooks/videos/bedroom"
|
43 |
+
|
44 |
+
# scan all the JPEG frame names in this directory
|
45 |
+
frame_names = [
|
46 |
+
p for p in os.listdir(video_dir)
|
47 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
48 |
+
]
|
49 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
50 |
+
|
51 |
+
# init video predictor state
|
52 |
+
inference_state = video_predictor.init_state(video_path=video_dir)
|
53 |
+
|
54 |
+
ann_frame_idx = 0 # the frame index we interact with
|
55 |
+
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
56 |
+
|
57 |
+
|
58 |
+
"""
|
59 |
+
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
60 |
+
"""
|
61 |
+
|
62 |
+
# prompt grounding dino to get the box coordinates on specific frame
|
63 |
+
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
64 |
+
image = Image.open(img_path)
|
65 |
+
|
66 |
+
# Step 1: initialize the config
|
67 |
+
token = "Your API token"
|
68 |
+
config = Config(token)
|
69 |
+
|
70 |
+
# Step 2: initialize the client
|
71 |
+
client = Client(config)
|
72 |
+
|
73 |
+
# Step 3: run the task by DetectionTask class
|
74 |
+
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
75 |
+
# if you are processing local image file, upload them to DDS server to get the image url
|
76 |
+
image_url = client.upload_file(img_path)
|
77 |
+
|
78 |
+
task = DetectionTask(
|
79 |
+
image_url=image_url,
|
80 |
+
prompts=[TextPrompt(text="children. pillow")],
|
81 |
+
targets=[DetectionTarget.BBox], # detect bbox
|
82 |
+
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
|
83 |
+
bbox_threshold=0.2,
|
84 |
+
)
|
85 |
+
|
86 |
+
client.run_task(task)
|
87 |
+
result = task.result
|
88 |
+
|
89 |
+
objects = result.objects # the list of detected objects
|
90 |
+
|
91 |
+
|
92 |
+
input_boxes = []
|
93 |
+
confidences = []
|
94 |
+
class_names = []
|
95 |
+
|
96 |
+
for idx, obj in enumerate(objects):
|
97 |
+
input_boxes.append(obj.bbox)
|
98 |
+
confidences.append(obj.score)
|
99 |
+
class_names.append(obj.category)
|
100 |
+
|
101 |
+
input_boxes = np.array(input_boxes)
|
102 |
+
|
103 |
+
print(input_boxes)
|
104 |
+
|
105 |
+
# prompt SAM image predictor to get the mask for the object
|
106 |
+
image_predictor.set_image(np.array(image.convert("RGB")))
|
107 |
+
|
108 |
+
# process the detection results
|
109 |
+
OBJECTS = class_names
|
110 |
+
|
111 |
+
print(OBJECTS)
|
112 |
+
|
113 |
+
# prompt SAM 2 image predictor to get the mask for the object
|
114 |
+
masks, scores, logits = image_predictor.predict(
|
115 |
+
point_coords=None,
|
116 |
+
point_labels=None,
|
117 |
+
box=input_boxes,
|
118 |
+
multimask_output=False,
|
119 |
+
)
|
120 |
+
|
121 |
+
# convert the mask shape to (n, H, W)
|
122 |
+
if masks.ndim == 3:
|
123 |
+
masks = masks[None]
|
124 |
+
scores = scores[None]
|
125 |
+
logits = logits[None]
|
126 |
+
elif masks.ndim == 4:
|
127 |
+
masks = masks.squeeze(1)
|
128 |
+
|
129 |
+
"""
|
130 |
+
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
131 |
+
"""
|
132 |
+
|
133 |
+
PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
|
134 |
+
|
135 |
+
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
136 |
+
|
137 |
+
# If you are using point prompts, we uniformly sample positive points based on the mask
|
138 |
+
if PROMPT_TYPE_FOR_VIDEO == "point":
|
139 |
+
# sample the positive points from mask for each objects
|
140 |
+
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
141 |
+
|
142 |
+
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
143 |
+
labels = np.ones((points.shape[0]), dtype=np.int32)
|
144 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
145 |
+
inference_state=inference_state,
|
146 |
+
frame_idx=ann_frame_idx,
|
147 |
+
obj_id=object_id,
|
148 |
+
points=points,
|
149 |
+
labels=labels,
|
150 |
+
)
|
151 |
+
# Using box prompt
|
152 |
+
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
153 |
+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
154 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
155 |
+
inference_state=inference_state,
|
156 |
+
frame_idx=ann_frame_idx,
|
157 |
+
obj_id=object_id,
|
158 |
+
box=box,
|
159 |
+
)
|
160 |
+
# Using mask prompt is a more straightforward way
|
161 |
+
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
162 |
+
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
163 |
+
labels = np.ones((1), dtype=np.int32)
|
164 |
+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
165 |
+
inference_state=inference_state,
|
166 |
+
frame_idx=ann_frame_idx,
|
167 |
+
obj_id=object_id,
|
168 |
+
mask=mask
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
"""
|
176 |
+
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
177 |
+
"""
|
178 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
179 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
180 |
+
video_segments[out_frame_idx] = {
|
181 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
182 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
183 |
+
}
|
184 |
+
|
185 |
+
"""
|
186 |
+
Step 5: Visualize the segment results across the video and save them
|
187 |
+
"""
|
188 |
+
|
189 |
+
save_dir = "./tracking_results"
|
190 |
+
|
191 |
+
if not os.path.exists(save_dir):
|
192 |
+
os.makedirs(save_dir)
|
193 |
+
|
194 |
+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
195 |
+
for frame_idx, segments in video_segments.items():
|
196 |
+
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
|
197 |
+
|
198 |
+
object_ids = list(segments.keys())
|
199 |
+
masks = list(segments.values())
|
200 |
+
masks = np.concatenate(masks, axis=0)
|
201 |
+
|
202 |
+
detections = sv.Detections(
|
203 |
+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
204 |
+
mask=masks, # (n, h, w)
|
205 |
+
class_id=np.array(object_ids, dtype=np.int32),
|
206 |
+
)
|
207 |
+
box_annotator = sv.BoxAnnotator()
|
208 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
209 |
+
label_annotator = sv.LabelAnnotator()
|
210 |
+
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
211 |
+
mask_annotator = sv.MaskAnnotator()
|
212 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
213 |
+
cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
214 |
+
|
215 |
+
|
216 |
+
"""
|
217 |
+
Step 6: Convert the annotated frames to video
|
218 |
+
"""
|
219 |
+
|
220 |
+
output_video_path = "./children_tracking_demo_video.mp4"
|
221 |
+
create_video_from_images(save_dir, output_video_path)
|
clone-IDEA-Research/Grounded-SAM-2/pyproject.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = [
|
3 |
+
"setuptools>=61.0",
|
4 |
+
"torch>=2.3.1",
|
5 |
+
]
|
6 |
+
build-backend = "setuptools.build_meta"
|
clone-IDEA-Research/Grounded-SAM-2/sam2/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from hydra import initialize_config_module
|
8 |
+
from hydra.core.global_hydra import GlobalHydra
|
9 |
+
|
10 |
+
if not GlobalHydra.instance().is_initialized():
|
11 |
+
initialize_config_module("sam2", version_base="1.2")
|
clone-IDEA-Research/Grounded-SAM-2/sam2/automatic_mask_generator.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
16 |
+
from sam2.utils.amg import (
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
MaskData,
|
28 |
+
remove_small_regions,
|
29 |
+
rle_to_mask,
|
30 |
+
uncrop_boxes_xyxy,
|
31 |
+
uncrop_masks,
|
32 |
+
uncrop_points,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class SAM2AutomaticMaskGenerator:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
model: SAM2Base,
|
40 |
+
points_per_side: Optional[int] = 32,
|
41 |
+
points_per_batch: int = 64,
|
42 |
+
pred_iou_thresh: float = 0.8,
|
43 |
+
stability_score_thresh: float = 0.95,
|
44 |
+
stability_score_offset: float = 1.0,
|
45 |
+
mask_threshold: float = 0.0,
|
46 |
+
box_nms_thresh: float = 0.7,
|
47 |
+
crop_n_layers: int = 0,
|
48 |
+
crop_nms_thresh: float = 0.7,
|
49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
50 |
+
crop_n_points_downscale_factor: int = 1,
|
51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
52 |
+
min_mask_region_area: int = 0,
|
53 |
+
output_mode: str = "binary_mask",
|
54 |
+
use_m2m: bool = False,
|
55 |
+
multimask_output: bool = True,
|
56 |
+
**kwargs,
|
57 |
+
) -> None:
|
58 |
+
"""
|
59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
60 |
+
Generates a grid of point prompts over the image, then filters
|
61 |
+
low quality and duplicate masks. The default settings are chosen
|
62 |
+
for SAM 2 with a HieraL backbone.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
66 |
+
points_per_side (int or None): The number of points to be sampled
|
67 |
+
along one side of the image. The total number of points is
|
68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
69 |
+
point sampling.
|
70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
73 |
+
model's predicted mask quality.
|
74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
76 |
+
the model's mask predictions.
|
77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
78 |
+
calculated the stability score.
|
79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks.
|
82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
83 |
+
crops of the image. Sets the number of layers to run, where each
|
84 |
+
layer has 2**i_layer number of image crops.
|
85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
86 |
+
suppression to filter duplicate masks between different crops.
|
87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
88 |
+
In the first crop layer, crops will overlap by this fraction of
|
89 |
+
the image length. Later layers with more crops scale down this overlap.
|
90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
96 |
+
to remove disconnected regions and holes in masks with area smaller
|
97 |
+
than min_mask_region_area. Requires opencv.
|
98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
101 |
+
memory.
|
102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
104 |
+
"""
|
105 |
+
|
106 |
+
assert (points_per_side is None) != (
|
107 |
+
point_grids is None
|
108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
109 |
+
if points_per_side is not None:
|
110 |
+
self.point_grids = build_all_layer_point_grids(
|
111 |
+
points_per_side,
|
112 |
+
crop_n_layers,
|
113 |
+
crop_n_points_downscale_factor,
|
114 |
+
)
|
115 |
+
elif point_grids is not None:
|
116 |
+
self.point_grids = point_grids
|
117 |
+
else:
|
118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
119 |
+
|
120 |
+
assert output_mode in [
|
121 |
+
"binary_mask",
|
122 |
+
"uncompressed_rle",
|
123 |
+
"coco_rle",
|
124 |
+
], f"Unknown output_mode {output_mode}."
|
125 |
+
if output_mode == "coco_rle":
|
126 |
+
try:
|
127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
128 |
+
except ImportError as e:
|
129 |
+
print("Please install pycocotools")
|
130 |
+
raise e
|
131 |
+
|
132 |
+
self.predictor = SAM2ImagePredictor(
|
133 |
+
model,
|
134 |
+
max_hole_area=min_mask_region_area,
|
135 |
+
max_sprinkle_area=min_mask_region_area,
|
136 |
+
)
|
137 |
+
self.points_per_batch = points_per_batch
|
138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
139 |
+
self.stability_score_thresh = stability_score_thresh
|
140 |
+
self.stability_score_offset = stability_score_offset
|
141 |
+
self.mask_threshold = mask_threshold
|
142 |
+
self.box_nms_thresh = box_nms_thresh
|
143 |
+
self.crop_n_layers = crop_n_layers
|
144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
147 |
+
self.min_mask_region_area = min_mask_region_area
|
148 |
+
self.output_mode = output_mode
|
149 |
+
self.use_m2m = use_m2m
|
150 |
+
self.multimask_output = multimask_output
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
154 |
+
"""
|
155 |
+
Load a pretrained model from the Hugging Face hub.
|
156 |
+
|
157 |
+
Arguments:
|
158 |
+
model_id (str): The Hugging Face repository ID.
|
159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
163 |
+
"""
|
164 |
+
from sam2.build_sam import build_sam2_hf
|
165 |
+
|
166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
167 |
+
return cls(sam_model, **kwargs)
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
171 |
+
"""
|
172 |
+
Generates masks for the given image.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
179 |
+
a dict containing the following keys:
|
180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
182 |
+
is a dictionary containing the RLE.
|
183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
184 |
+
area (int): The area in pixels of the mask.
|
185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
187 |
+
point_coords (list(list(float))): The point coordinates input
|
188 |
+
to the model to generate this mask.
|
189 |
+
stability_score (float): A measure of the mask's quality. This
|
190 |
+
is filtered on using the stability_score_thresh parameter.
|
191 |
+
crop_box (list(float)): The crop of the image used to generate
|
192 |
+
the mask, given in XYWH format.
|
193 |
+
"""
|
194 |
+
|
195 |
+
# Generate masks
|
196 |
+
mask_data = self._generate_masks(image)
|
197 |
+
|
198 |
+
# Encode masks
|
199 |
+
if self.output_mode == "coco_rle":
|
200 |
+
mask_data["segmentations"] = [
|
201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
202 |
+
]
|
203 |
+
elif self.output_mode == "binary_mask":
|
204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
205 |
+
else:
|
206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
207 |
+
|
208 |
+
# Write mask records
|
209 |
+
curr_anns = []
|
210 |
+
for idx in range(len(mask_data["segmentations"])):
|
211 |
+
ann = {
|
212 |
+
"segmentation": mask_data["segmentations"][idx],
|
213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
219 |
+
}
|
220 |
+
curr_anns.append(ann)
|
221 |
+
|
222 |
+
return curr_anns
|
223 |
+
|
224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
225 |
+
orig_size = image.shape[:2]
|
226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
228 |
+
)
|
229 |
+
|
230 |
+
# Iterate over image crops
|
231 |
+
data = MaskData()
|
232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
234 |
+
data.cat(crop_data)
|
235 |
+
|
236 |
+
# Remove duplicate masks between crops
|
237 |
+
if len(crop_boxes) > 1:
|
238 |
+
# Prefer masks from smaller crops
|
239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
240 |
+
scores = scores.to(data["boxes"].device)
|
241 |
+
keep_by_nms = batched_nms(
|
242 |
+
data["boxes"].float(),
|
243 |
+
scores,
|
244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
245 |
+
iou_threshold=self.crop_nms_thresh,
|
246 |
+
)
|
247 |
+
data.filter(keep_by_nms)
|
248 |
+
data.to_numpy()
|
249 |
+
return data
|
250 |
+
|
251 |
+
def _process_crop(
|
252 |
+
self,
|
253 |
+
image: np.ndarray,
|
254 |
+
crop_box: List[int],
|
255 |
+
crop_layer_idx: int,
|
256 |
+
orig_size: Tuple[int, ...],
|
257 |
+
) -> MaskData:
|
258 |
+
# Crop the image and calculate embeddings
|
259 |
+
x0, y0, x1, y1 = crop_box
|
260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
261 |
+
cropped_im_size = cropped_im.shape[:2]
|
262 |
+
self.predictor.set_image(cropped_im)
|
263 |
+
|
264 |
+
# Get points for this crop
|
265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
267 |
+
|
268 |
+
# Generate masks for this crop in batches
|
269 |
+
data = MaskData()
|
270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
271 |
+
batch_data = self._process_batch(
|
272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
273 |
+
)
|
274 |
+
data.cat(batch_data)
|
275 |
+
del batch_data
|
276 |
+
self.predictor.reset_predictor()
|
277 |
+
|
278 |
+
# Remove duplicates within this crop.
|
279 |
+
keep_by_nms = batched_nms(
|
280 |
+
data["boxes"].float(),
|
281 |
+
data["iou_preds"],
|
282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
283 |
+
iou_threshold=self.box_nms_thresh,
|
284 |
+
)
|
285 |
+
data.filter(keep_by_nms)
|
286 |
+
|
287 |
+
# Return to the original image frame
|
288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
291 |
+
|
292 |
+
return data
|
293 |
+
|
294 |
+
def _process_batch(
|
295 |
+
self,
|
296 |
+
points: np.ndarray,
|
297 |
+
im_size: Tuple[int, ...],
|
298 |
+
crop_box: List[int],
|
299 |
+
orig_size: Tuple[int, ...],
|
300 |
+
normalize=False,
|
301 |
+
) -> MaskData:
|
302 |
+
orig_h, orig_w = orig_size
|
303 |
+
|
304 |
+
# Run model on this batch
|
305 |
+
points = torch.as_tensor(
|
306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
307 |
+
)
|
308 |
+
in_points = self.predictor._transforms.transform_coords(
|
309 |
+
points, normalize=normalize, orig_hw=im_size
|
310 |
+
)
|
311 |
+
in_labels = torch.ones(
|
312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
313 |
+
)
|
314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
315 |
+
in_points[:, None, :],
|
316 |
+
in_labels[:, None],
|
317 |
+
multimask_output=self.multimask_output,
|
318 |
+
return_logits=True,
|
319 |
+
)
|
320 |
+
|
321 |
+
# Serialize predictions and store in MaskData
|
322 |
+
data = MaskData(
|
323 |
+
masks=masks.flatten(0, 1),
|
324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
327 |
+
)
|
328 |
+
del masks
|
329 |
+
|
330 |
+
if not self.use_m2m:
|
331 |
+
# Filter by predicted IoU
|
332 |
+
if self.pred_iou_thresh > 0.0:
|
333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
334 |
+
data.filter(keep_mask)
|
335 |
+
|
336 |
+
# Calculate and filter by stability score
|
337 |
+
data["stability_score"] = calculate_stability_score(
|
338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
339 |
+
)
|
340 |
+
if self.stability_score_thresh > 0.0:
|
341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
342 |
+
data.filter(keep_mask)
|
343 |
+
else:
|
344 |
+
# One step refinement using previous mask predictions
|
345 |
+
in_points = self.predictor._transforms.transform_coords(
|
346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
347 |
+
)
|
348 |
+
labels = torch.ones(
|
349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
350 |
+
)
|
351 |
+
masks, ious = self.refine_with_m2m(
|
352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
353 |
+
)
|
354 |
+
data["masks"] = masks.squeeze(1)
|
355 |
+
data["iou_preds"] = ious.squeeze(1)
|
356 |
+
|
357 |
+
if self.pred_iou_thresh > 0.0:
|
358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
359 |
+
data.filter(keep_mask)
|
360 |
+
|
361 |
+
data["stability_score"] = calculate_stability_score(
|
362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
363 |
+
)
|
364 |
+
if self.stability_score_thresh > 0.0:
|
365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
366 |
+
data.filter(keep_mask)
|
367 |
+
|
368 |
+
# Threshold masks and calculate boxes
|
369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
371 |
+
|
372 |
+
# Filter boxes that touch crop boundaries
|
373 |
+
keep_mask = ~is_box_near_crop_edge(
|
374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
375 |
+
)
|
376 |
+
if not torch.all(keep_mask):
|
377 |
+
data.filter(keep_mask)
|
378 |
+
|
379 |
+
# Compress to RLE
|
380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
382 |
+
del data["masks"]
|
383 |
+
|
384 |
+
return data
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def postprocess_small_regions(
|
388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
389 |
+
) -> MaskData:
|
390 |
+
"""
|
391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
392 |
+
box NMS to remove any new duplicates.
|
393 |
+
|
394 |
+
Edits mask_data in place.
|
395 |
+
|
396 |
+
Requires open-cv as a dependency.
|
397 |
+
"""
|
398 |
+
if len(mask_data["rles"]) == 0:
|
399 |
+
return mask_data
|
400 |
+
|
401 |
+
# Filter small disconnected regions and holes
|
402 |
+
new_masks = []
|
403 |
+
scores = []
|
404 |
+
for rle in mask_data["rles"]:
|
405 |
+
mask = rle_to_mask(rle)
|
406 |
+
|
407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
408 |
+
unchanged = not changed
|
409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
410 |
+
unchanged = unchanged and not changed
|
411 |
+
|
412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
415 |
+
scores.append(float(unchanged))
|
416 |
+
|
417 |
+
# Recalculate boxes and remove any new duplicates
|
418 |
+
masks = torch.cat(new_masks, dim=0)
|
419 |
+
boxes = batched_mask_to_box(masks)
|
420 |
+
keep_by_nms = batched_nms(
|
421 |
+
boxes.float(),
|
422 |
+
torch.as_tensor(scores),
|
423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
424 |
+
iou_threshold=nms_thresh,
|
425 |
+
)
|
426 |
+
|
427 |
+
# Only recalculate RLEs for masks that have changed
|
428 |
+
for i_mask in keep_by_nms:
|
429 |
+
if scores[i_mask] == 0.0:
|
430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
433 |
+
mask_data.filter(keep_by_nms)
|
434 |
+
|
435 |
+
return mask_data
|
436 |
+
|
437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
438 |
+
new_masks = []
|
439 |
+
new_iou_preds = []
|
440 |
+
|
441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
442 |
+
points_per_batch, points, point_labels, low_res_masks
|
443 |
+
):
|
444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
445 |
+
cur_points[:, None, :],
|
446 |
+
cur_point_labels[:, None],
|
447 |
+
mask_input=low_res_mask[:, None, :],
|
448 |
+
multimask_output=False,
|
449 |
+
return_logits=True,
|
450 |
+
)
|
451 |
+
new_masks.append(best_masks)
|
452 |
+
new_iou_preds.append(best_iou_preds)
|
453 |
+
masks = torch.cat(new_masks, dim=0)
|
454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
clone-IDEA-Research/Grounded-SAM-2/sam2/build_sam.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from hydra import compose
|
12 |
+
from hydra.utils import instantiate
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
import sam2
|
16 |
+
|
17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
19 |
+
# it could shadow the sam2 package and cause issues.
|
20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
23 |
+
# This typically happens because the user is running Python from the parent directory
|
24 |
+
# that contains the sam2 repo they cloned.
|
25 |
+
raise RuntimeError(
|
26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
36 |
+
"facebook/sam2-hiera-tiny": (
|
37 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
38 |
+
"sam2_hiera_tiny.pt",
|
39 |
+
),
|
40 |
+
"facebook/sam2-hiera-small": (
|
41 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
42 |
+
"sam2_hiera_small.pt",
|
43 |
+
),
|
44 |
+
"facebook/sam2-hiera-base-plus": (
|
45 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
46 |
+
"sam2_hiera_base_plus.pt",
|
47 |
+
),
|
48 |
+
"facebook/sam2-hiera-large": (
|
49 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
50 |
+
"sam2_hiera_large.pt",
|
51 |
+
),
|
52 |
+
"facebook/sam2.1-hiera-tiny": (
|
53 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
54 |
+
"sam2.1_hiera_tiny.pt",
|
55 |
+
),
|
56 |
+
"facebook/sam2.1-hiera-small": (
|
57 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
58 |
+
"sam2.1_hiera_small.pt",
|
59 |
+
),
|
60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
61 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
62 |
+
"sam2.1_hiera_base_plus.pt",
|
63 |
+
),
|
64 |
+
"facebook/sam2.1-hiera-large": (
|
65 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
66 |
+
"sam2.1_hiera_large.pt",
|
67 |
+
),
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def build_sam2(
|
72 |
+
config_file,
|
73 |
+
ckpt_path=None,
|
74 |
+
device="cuda",
|
75 |
+
mode="eval",
|
76 |
+
hydra_overrides_extra=[],
|
77 |
+
apply_postprocessing=True,
|
78 |
+
**kwargs,
|
79 |
+
):
|
80 |
+
|
81 |
+
if apply_postprocessing:
|
82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
83 |
+
hydra_overrides_extra += [
|
84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
88 |
+
]
|
89 |
+
# Read config and init model
|
90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
91 |
+
OmegaConf.resolve(cfg)
|
92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
93 |
+
_load_checkpoint(model, ckpt_path)
|
94 |
+
model = model.to(device)
|
95 |
+
if mode == "eval":
|
96 |
+
model.eval()
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
def build_sam2_video_predictor(
|
101 |
+
config_file,
|
102 |
+
ckpt_path=None,
|
103 |
+
device="cuda",
|
104 |
+
mode="eval",
|
105 |
+
hydra_overrides_extra=[],
|
106 |
+
apply_postprocessing=True,
|
107 |
+
**kwargs,
|
108 |
+
):
|
109 |
+
hydra_overrides = [
|
110 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
111 |
+
]
|
112 |
+
if apply_postprocessing:
|
113 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
114 |
+
hydra_overrides_extra += [
|
115 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
116 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
117 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
118 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
119 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
120 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
121 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
122 |
+
"++model.fill_hole_area=8",
|
123 |
+
]
|
124 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
125 |
+
|
126 |
+
# Read config and init model
|
127 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
128 |
+
OmegaConf.resolve(cfg)
|
129 |
+
model = instantiate(cfg.model, _recursive_=True)
|
130 |
+
_load_checkpoint(model, ckpt_path)
|
131 |
+
model = model.to(device)
|
132 |
+
if mode == "eval":
|
133 |
+
model.eval()
|
134 |
+
return model
|
135 |
+
|
136 |
+
|
137 |
+
def _hf_download(model_id):
|
138 |
+
from huggingface_hub import hf_hub_download
|
139 |
+
|
140 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
141 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
142 |
+
return config_name, ckpt_path
|
143 |
+
|
144 |
+
|
145 |
+
def build_sam2_hf(model_id, **kwargs):
|
146 |
+
config_name, ckpt_path = _hf_download(model_id)
|
147 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
148 |
+
|
149 |
+
|
150 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
151 |
+
config_name, ckpt_path = _hf_download(model_id)
|
152 |
+
return build_sam2_video_predictor(
|
153 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def _load_checkpoint(model, ckpt_path):
|
158 |
+
if ckpt_path is not None:
|
159 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
160 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
161 |
+
if missing_keys:
|
162 |
+
logging.error(missing_keys)
|
163 |
+
raise RuntimeError()
|
164 |
+
if unexpected_keys:
|
165 |
+
logging.error(unexpected_keys)
|
166 |
+
raise RuntimeError()
|
167 |
+
logging.info("Loaded checkpoint sucessfully")
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
# use high-resolution feature map in the SAM mask decoder
|
93 |
+
use_high_res_features_in_sam: true
|
94 |
+
# output 3 masks on the first click on initial conditioning frames
|
95 |
+
multimask_output_in_sam: true
|
96 |
+
# SAM heads
|
97 |
+
iou_prediction_use_sigmoid: True
|
98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
+
use_obj_ptrs_in_encoder: true
|
100 |
+
add_tpos_enc_to_obj_ptrs: false
|
101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
+
# object occlusion prediction
|
103 |
+
pred_obj_scores: true
|
104 |
+
pred_obj_scores_mlp: true
|
105 |
+
fixed_no_obj_ptr: true
|
106 |
+
# multimask tracking settings
|
107 |
+
multimask_output_for_tracking: true
|
108 |
+
use_multimask_token_for_obj_ptr: true
|
109 |
+
multimask_min_pt_num: 0
|
110 |
+
multimask_max_pt_num: 1
|
111 |
+
use_mlp_for_obj_ptr_proj: true
|
112 |
+
# Compilation flag
|
113 |
+
compile_image_encoder: False
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_l.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [32, 32]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
compile_image_encoder: False
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_s.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
# use high-resolution feature map in the SAM mask decoder
|
96 |
+
use_high_res_features_in_sam: true
|
97 |
+
# output 3 masks on the first click on initial conditioning frames
|
98 |
+
multimask_output_in_sam: true
|
99 |
+
# SAM heads
|
100 |
+
iou_prediction_use_sigmoid: True
|
101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
102 |
+
use_obj_ptrs_in_encoder: true
|
103 |
+
add_tpos_enc_to_obj_ptrs: false
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_hiera_t.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
# HieraT does not currently support compilation, should always be set to False
|
118 |
+
compile_image_encoder: False
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_image_predictor.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from typing import List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL.Image import Image
|
14 |
+
|
15 |
+
from sam2.modeling.sam2_base import SAM2Base
|
16 |
+
|
17 |
+
from sam2.utils.transforms import SAM2Transforms
|
18 |
+
|
19 |
+
|
20 |
+
class SAM2ImagePredictor:
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
sam_model: SAM2Base,
|
24 |
+
mask_threshold=0.0,
|
25 |
+
max_hole_area=0.0,
|
26 |
+
max_sprinkle_area=0.0,
|
27 |
+
**kwargs,
|
28 |
+
) -> None:
|
29 |
+
"""
|
30 |
+
Uses SAM-2 to calculate the image embedding for an image, and then
|
31 |
+
allow repeated, efficient mask prediction given prompts.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
sam_model (Sam-2): The model to use for mask prediction.
|
35 |
+
mask_threshold (float): The threshold to use when converting mask logits
|
36 |
+
to binary masks. Masks are thresholded at 0 by default.
|
37 |
+
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
38 |
+
the maximum area of max_hole_area in low_res_masks.
|
39 |
+
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
40 |
+
the maximum area of max_sprinkle_area in low_res_masks.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.model = sam_model
|
44 |
+
self._transforms = SAM2Transforms(
|
45 |
+
resolution=self.model.image_size,
|
46 |
+
mask_threshold=mask_threshold,
|
47 |
+
max_hole_area=max_hole_area,
|
48 |
+
max_sprinkle_area=max_sprinkle_area,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Predictor state
|
52 |
+
self._is_image_set = False
|
53 |
+
self._features = None
|
54 |
+
self._orig_hw = None
|
55 |
+
# Whether the predictor is set for single image or a batch of images
|
56 |
+
self._is_batch = False
|
57 |
+
|
58 |
+
# Predictor config
|
59 |
+
self.mask_threshold = mask_threshold
|
60 |
+
|
61 |
+
# Spatial dim for backbone feature maps
|
62 |
+
self._bb_feat_sizes = [
|
63 |
+
(256, 256),
|
64 |
+
(128, 128),
|
65 |
+
(64, 64),
|
66 |
+
]
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
70 |
+
"""
|
71 |
+
Load a pretrained model from the Hugging Face hub.
|
72 |
+
|
73 |
+
Arguments:
|
74 |
+
model_id (str): The Hugging Face repository ID.
|
75 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
(SAM2ImagePredictor): The loaded model.
|
79 |
+
"""
|
80 |
+
from sam2.build_sam import build_sam2_hf
|
81 |
+
|
82 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
83 |
+
return cls(sam_model, **kwargs)
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def set_image(
|
87 |
+
self,
|
88 |
+
image: Union[np.ndarray, Image],
|
89 |
+
) -> None:
|
90 |
+
"""
|
91 |
+
Calculates the image embeddings for the provided image, allowing
|
92 |
+
masks to be predicted with the 'predict' method.
|
93 |
+
|
94 |
+
Arguments:
|
95 |
+
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
|
96 |
+
with pixel values in [0, 255].
|
97 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
98 |
+
"""
|
99 |
+
self.reset_predictor()
|
100 |
+
# Transform the image to the form expected by the model
|
101 |
+
if isinstance(image, np.ndarray):
|
102 |
+
logging.info("For numpy array image, we assume (HxWxC) format")
|
103 |
+
self._orig_hw = [image.shape[:2]]
|
104 |
+
elif isinstance(image, Image):
|
105 |
+
w, h = image.size
|
106 |
+
self._orig_hw = [(h, w)]
|
107 |
+
else:
|
108 |
+
raise NotImplementedError("Image format not supported")
|
109 |
+
|
110 |
+
input_image = self._transforms(image)
|
111 |
+
input_image = input_image[None, ...].to(self.device)
|
112 |
+
|
113 |
+
assert (
|
114 |
+
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
115 |
+
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
116 |
+
logging.info("Computing image embeddings for the provided image...")
|
117 |
+
backbone_out = self.model.forward_image(input_image)
|
118 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
119 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
120 |
+
if self.model.directly_add_no_mem_embed:
|
121 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
122 |
+
|
123 |
+
feats = [
|
124 |
+
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
125 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
126 |
+
][::-1]
|
127 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
128 |
+
self._is_image_set = True
|
129 |
+
logging.info("Image embeddings computed.")
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def set_image_batch(
|
133 |
+
self,
|
134 |
+
image_list: List[Union[np.ndarray]],
|
135 |
+
) -> None:
|
136 |
+
"""
|
137 |
+
Calculates the image embeddings for the provided image batch, allowing
|
138 |
+
masks to be predicted with the 'predict_batch' method.
|
139 |
+
|
140 |
+
Arguments:
|
141 |
+
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
|
142 |
+
with pixel values in [0, 255].
|
143 |
+
"""
|
144 |
+
self.reset_predictor()
|
145 |
+
assert isinstance(image_list, list)
|
146 |
+
self._orig_hw = []
|
147 |
+
for image in image_list:
|
148 |
+
assert isinstance(
|
149 |
+
image, np.ndarray
|
150 |
+
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
151 |
+
self._orig_hw.append(image.shape[:2])
|
152 |
+
# Transform the image to the form expected by the model
|
153 |
+
img_batch = self._transforms.forward_batch(image_list)
|
154 |
+
img_batch = img_batch.to(self.device)
|
155 |
+
batch_size = img_batch.shape[0]
|
156 |
+
assert (
|
157 |
+
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
158 |
+
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
159 |
+
logging.info("Computing image embeddings for the provided images...")
|
160 |
+
backbone_out = self.model.forward_image(img_batch)
|
161 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
162 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
163 |
+
if self.model.directly_add_no_mem_embed:
|
164 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
165 |
+
|
166 |
+
feats = [
|
167 |
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
168 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
169 |
+
][::-1]
|
170 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
171 |
+
self._is_image_set = True
|
172 |
+
self._is_batch = True
|
173 |
+
logging.info("Image embeddings computed.")
|
174 |
+
|
175 |
+
def predict_batch(
|
176 |
+
self,
|
177 |
+
point_coords_batch: List[np.ndarray] = None,
|
178 |
+
point_labels_batch: List[np.ndarray] = None,
|
179 |
+
box_batch: List[np.ndarray] = None,
|
180 |
+
mask_input_batch: List[np.ndarray] = None,
|
181 |
+
multimask_output: bool = True,
|
182 |
+
return_logits: bool = False,
|
183 |
+
normalize_coords=True,
|
184 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
185 |
+
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
186 |
+
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
187 |
+
"""
|
188 |
+
assert self._is_batch, "This function should only be used when in batched mode"
|
189 |
+
if not self._is_image_set:
|
190 |
+
raise RuntimeError(
|
191 |
+
"An image must be set with .set_image_batch(...) before mask prediction."
|
192 |
+
)
|
193 |
+
num_images = len(self._features["image_embed"])
|
194 |
+
all_masks = []
|
195 |
+
all_ious = []
|
196 |
+
all_low_res_masks = []
|
197 |
+
for img_idx in range(num_images):
|
198 |
+
# Transform input prompts
|
199 |
+
point_coords = (
|
200 |
+
point_coords_batch[img_idx] if point_coords_batch is not None else None
|
201 |
+
)
|
202 |
+
point_labels = (
|
203 |
+
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
204 |
+
)
|
205 |
+
box = box_batch[img_idx] if box_batch is not None else None
|
206 |
+
mask_input = (
|
207 |
+
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
208 |
+
)
|
209 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
210 |
+
point_coords,
|
211 |
+
point_labels,
|
212 |
+
box,
|
213 |
+
mask_input,
|
214 |
+
normalize_coords,
|
215 |
+
img_idx=img_idx,
|
216 |
+
)
|
217 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
218 |
+
unnorm_coords,
|
219 |
+
labels,
|
220 |
+
unnorm_box,
|
221 |
+
mask_input,
|
222 |
+
multimask_output,
|
223 |
+
return_logits=return_logits,
|
224 |
+
img_idx=img_idx,
|
225 |
+
)
|
226 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
227 |
+
iou_predictions_np = (
|
228 |
+
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
229 |
+
)
|
230 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
231 |
+
all_masks.append(masks_np)
|
232 |
+
all_ious.append(iou_predictions_np)
|
233 |
+
all_low_res_masks.append(low_res_masks_np)
|
234 |
+
|
235 |
+
return all_masks, all_ious, all_low_res_masks
|
236 |
+
|
237 |
+
def predict(
|
238 |
+
self,
|
239 |
+
point_coords: Optional[np.ndarray] = None,
|
240 |
+
point_labels: Optional[np.ndarray] = None,
|
241 |
+
box: Optional[np.ndarray] = None,
|
242 |
+
mask_input: Optional[np.ndarray] = None,
|
243 |
+
multimask_output: bool = True,
|
244 |
+
return_logits: bool = False,
|
245 |
+
normalize_coords=True,
|
246 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
247 |
+
"""
|
248 |
+
Predict masks for the given input prompts, using the currently set image.
|
249 |
+
|
250 |
+
Arguments:
|
251 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
252 |
+
model. Each point is in (X,Y) in pixels.
|
253 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
254 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
255 |
+
background point.
|
256 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
257 |
+
model, in XYXY format.
|
258 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
259 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
260 |
+
for SAM, H=W=256.
|
261 |
+
multimask_output (bool): If true, the model will return three masks.
|
262 |
+
For ambiguous input prompts (such as a single click), this will often
|
263 |
+
produce better masks than a single prediction. If only a single
|
264 |
+
mask is needed, the model's predicted quality score can be used
|
265 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
266 |
+
input prompts, multimask_output=False can give better results.
|
267 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
268 |
+
instead of a binary mask.
|
269 |
+
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
273 |
+
number of masks, and (H, W) is the original image size.
|
274 |
+
(np.ndarray): An array of length C containing the model's
|
275 |
+
predictions for the quality of each mask.
|
276 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
277 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
278 |
+
a subsequent iteration as mask input.
|
279 |
+
"""
|
280 |
+
if not self._is_image_set:
|
281 |
+
raise RuntimeError(
|
282 |
+
"An image must be set with .set_image(...) before mask prediction."
|
283 |
+
)
|
284 |
+
|
285 |
+
# Transform input prompts
|
286 |
+
|
287 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
288 |
+
point_coords, point_labels, box, mask_input, normalize_coords
|
289 |
+
)
|
290 |
+
|
291 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
292 |
+
unnorm_coords,
|
293 |
+
labels,
|
294 |
+
unnorm_box,
|
295 |
+
mask_input,
|
296 |
+
multimask_output,
|
297 |
+
return_logits=return_logits,
|
298 |
+
)
|
299 |
+
|
300 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
301 |
+
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
302 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
303 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
304 |
+
|
305 |
+
def _prep_prompts(
|
306 |
+
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
307 |
+
):
|
308 |
+
|
309 |
+
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
310 |
+
if point_coords is not None:
|
311 |
+
assert (
|
312 |
+
point_labels is not None
|
313 |
+
), "point_labels must be supplied if point_coords is supplied."
|
314 |
+
point_coords = torch.as_tensor(
|
315 |
+
point_coords, dtype=torch.float, device=self.device
|
316 |
+
)
|
317 |
+
unnorm_coords = self._transforms.transform_coords(
|
318 |
+
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
319 |
+
)
|
320 |
+
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
321 |
+
if len(unnorm_coords.shape) == 2:
|
322 |
+
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
323 |
+
if box is not None:
|
324 |
+
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
325 |
+
unnorm_box = self._transforms.transform_boxes(
|
326 |
+
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
327 |
+
) # Bx2x2
|
328 |
+
if mask_logits is not None:
|
329 |
+
mask_input = torch.as_tensor(
|
330 |
+
mask_logits, dtype=torch.float, device=self.device
|
331 |
+
)
|
332 |
+
if len(mask_input.shape) == 3:
|
333 |
+
mask_input = mask_input[None, :, :, :]
|
334 |
+
return mask_input, unnorm_coords, labels, unnorm_box
|
335 |
+
|
336 |
+
@torch.no_grad()
|
337 |
+
def _predict(
|
338 |
+
self,
|
339 |
+
point_coords: Optional[torch.Tensor],
|
340 |
+
point_labels: Optional[torch.Tensor],
|
341 |
+
boxes: Optional[torch.Tensor] = None,
|
342 |
+
mask_input: Optional[torch.Tensor] = None,
|
343 |
+
multimask_output: bool = True,
|
344 |
+
return_logits: bool = False,
|
345 |
+
img_idx: int = -1,
|
346 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
347 |
+
"""
|
348 |
+
Predict masks for the given input prompts, using the currently set image.
|
349 |
+
Input prompts are batched torch tensors and are expected to already be
|
350 |
+
transformed to the input frame using SAM2Transforms.
|
351 |
+
|
352 |
+
Arguments:
|
353 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
354 |
+
model. Each point is in (X,Y) in pixels.
|
355 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
356 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
357 |
+
background point.
|
358 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
359 |
+
model, in XYXY format.
|
360 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
361 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
362 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
363 |
+
predict method do not need further transformation.
|
364 |
+
multimask_output (bool): If true, the model will return three masks.
|
365 |
+
For ambiguous input prompts (such as a single click), this will often
|
366 |
+
produce better masks than a single prediction. If only a single
|
367 |
+
mask is needed, the model's predicted quality score can be used
|
368 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
369 |
+
input prompts, multimask_output=False can give better results.
|
370 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
371 |
+
instead of a binary mask.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
375 |
+
number of masks, and (H, W) is the original image size.
|
376 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
377 |
+
predictions for the quality of each mask.
|
378 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
379 |
+
of masks and H=W=256. These low res logits can be passed to
|
380 |
+
a subsequent iteration as mask input.
|
381 |
+
"""
|
382 |
+
if not self._is_image_set:
|
383 |
+
raise RuntimeError(
|
384 |
+
"An image must be set with .set_image(...) before mask prediction."
|
385 |
+
)
|
386 |
+
|
387 |
+
if point_coords is not None:
|
388 |
+
concat_points = (point_coords, point_labels)
|
389 |
+
else:
|
390 |
+
concat_points = None
|
391 |
+
|
392 |
+
# Embed prompts
|
393 |
+
if boxes is not None:
|
394 |
+
box_coords = boxes.reshape(-1, 2, 2)
|
395 |
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
396 |
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
397 |
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
398 |
+
# boxes are added at the beginning) to sam_prompt_encoder
|
399 |
+
if concat_points is not None:
|
400 |
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
401 |
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
402 |
+
concat_points = (concat_coords, concat_labels)
|
403 |
+
else:
|
404 |
+
concat_points = (box_coords, box_labels)
|
405 |
+
|
406 |
+
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
407 |
+
points=concat_points,
|
408 |
+
boxes=None,
|
409 |
+
masks=mask_input,
|
410 |
+
)
|
411 |
+
|
412 |
+
# Predict masks
|
413 |
+
batched_mode = (
|
414 |
+
concat_points is not None and concat_points[0].shape[0] > 1
|
415 |
+
) # multi object prediction
|
416 |
+
high_res_features = [
|
417 |
+
feat_level[img_idx].unsqueeze(0)
|
418 |
+
for feat_level in self._features["high_res_feats"]
|
419 |
+
]
|
420 |
+
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
421 |
+
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
422 |
+
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
423 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
424 |
+
dense_prompt_embeddings=dense_embeddings,
|
425 |
+
multimask_output=multimask_output,
|
426 |
+
repeat_image=batched_mode,
|
427 |
+
high_res_features=high_res_features,
|
428 |
+
)
|
429 |
+
|
430 |
+
# Upscale the masks to the original image resolution
|
431 |
+
masks = self._transforms.postprocess_masks(
|
432 |
+
low_res_masks, self._orig_hw[img_idx]
|
433 |
+
)
|
434 |
+
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
435 |
+
if not return_logits:
|
436 |
+
masks = masks > self.mask_threshold
|
437 |
+
|
438 |
+
return masks, iou_predictions, low_res_masks
|
439 |
+
|
440 |
+
def get_image_embedding(self) -> torch.Tensor:
|
441 |
+
"""
|
442 |
+
Returns the image embeddings for the currently set image, with
|
443 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
444 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
445 |
+
"""
|
446 |
+
if not self._is_image_set:
|
447 |
+
raise RuntimeError(
|
448 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
449 |
+
)
|
450 |
+
assert (
|
451 |
+
self._features is not None
|
452 |
+
), "Features must exist if an image has been set."
|
453 |
+
return self._features["image_embed"]
|
454 |
+
|
455 |
+
@property
|
456 |
+
def device(self) -> torch.device:
|
457 |
+
return self.model.device
|
458 |
+
|
459 |
+
def reset_predictor(self) -> None:
|
460 |
+
"""
|
461 |
+
Resets the image embeddings and other state variables.
|
462 |
+
"""
|
463 |
+
self._is_image_set = False
|
464 |
+
self._features = None
|
465 |
+
self._orig_hw = None
|
466 |
+
self._is_batch = False
|
clone-IDEA-Research/Grounded-SAM-2/sam2/sam2_video_predictor.py
ADDED
@@ -0,0 +1,1172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
15 |
+
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
16 |
+
|
17 |
+
|
18 |
+
class SAM2VideoPredictor(SAM2Base):
|
19 |
+
"""The predictor class to handle user interactions and manage inference states."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
fill_hole_area=0,
|
24 |
+
# whether to apply non-overlapping constraints on the output object masks
|
25 |
+
non_overlap_masks=False,
|
26 |
+
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
27 |
+
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
28 |
+
clear_non_cond_mem_around_input=False,
|
29 |
+
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
30 |
+
clear_non_cond_mem_for_multi_obj=False,
|
31 |
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
32 |
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
33 |
+
add_all_frames_to_correct_as_cond=False,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.fill_hole_area = fill_hole_area
|
38 |
+
self.non_overlap_masks = non_overlap_masks
|
39 |
+
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
40 |
+
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
41 |
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
42 |
+
|
43 |
+
@torch.inference_mode()
|
44 |
+
def init_state(
|
45 |
+
self,
|
46 |
+
video_path,
|
47 |
+
offload_video_to_cpu=False,
|
48 |
+
offload_state_to_cpu=False,
|
49 |
+
async_loading_frames=False,
|
50 |
+
):
|
51 |
+
"""Initialize an inference state."""
|
52 |
+
compute_device = self.device # device of the model
|
53 |
+
images, video_height, video_width = load_video_frames(
|
54 |
+
video_path=video_path,
|
55 |
+
image_size=self.image_size,
|
56 |
+
offload_video_to_cpu=offload_video_to_cpu,
|
57 |
+
async_loading_frames=async_loading_frames,
|
58 |
+
compute_device=compute_device,
|
59 |
+
)
|
60 |
+
inference_state = {}
|
61 |
+
inference_state["images"] = images
|
62 |
+
inference_state["num_frames"] = len(images)
|
63 |
+
# whether to offload the video frames to CPU memory
|
64 |
+
# turning on this option saves the GPU memory with only a very small overhead
|
65 |
+
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
66 |
+
# whether to offload the inference state to CPU memory
|
67 |
+
# turning on this option saves the GPU memory at the cost of a lower tracking fps
|
68 |
+
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
|
69 |
+
# and from 24 to 21 when tracking two objects)
|
70 |
+
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
|
71 |
+
# the original video height and width, used for resizing final output scores
|
72 |
+
inference_state["video_height"] = video_height
|
73 |
+
inference_state["video_width"] = video_width
|
74 |
+
inference_state["device"] = compute_device
|
75 |
+
if offload_state_to_cpu:
|
76 |
+
inference_state["storage_device"] = torch.device("cpu")
|
77 |
+
else:
|
78 |
+
inference_state["storage_device"] = compute_device
|
79 |
+
# inputs on each frame
|
80 |
+
inference_state["point_inputs_per_obj"] = {}
|
81 |
+
inference_state["mask_inputs_per_obj"] = {}
|
82 |
+
# visual features on a small number of recently visited frames for quick interactions
|
83 |
+
inference_state["cached_features"] = {}
|
84 |
+
# values that don't change across frames (so we only need to hold one copy of them)
|
85 |
+
inference_state["constants"] = {}
|
86 |
+
# mapping between client-side object id and model-side object index
|
87 |
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
88 |
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
89 |
+
inference_state["obj_ids"] = []
|
90 |
+
# A storage to hold the model's tracking results and states on each frame
|
91 |
+
inference_state["output_dict"] = {
|
92 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
93 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
94 |
+
}
|
95 |
+
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
96 |
+
inference_state["output_dict_per_obj"] = {}
|
97 |
+
# A temporary storage to hold new outputs when user interact with a frame
|
98 |
+
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
99 |
+
inference_state["temp_output_dict_per_obj"] = {}
|
100 |
+
# Frames that already holds consolidated outputs from click or mask inputs
|
101 |
+
# (we directly use their consolidated outputs during tracking)
|
102 |
+
inference_state["consolidated_frame_inds"] = {
|
103 |
+
"cond_frame_outputs": set(), # set containing frame indices
|
104 |
+
"non_cond_frame_outputs": set(), # set containing frame indices
|
105 |
+
}
|
106 |
+
# metadata for each tracking frame (e.g. which direction it's tracked)
|
107 |
+
inference_state["tracking_has_started"] = False
|
108 |
+
inference_state["frames_already_tracked"] = {}
|
109 |
+
# Warm up the visual backbone and cache the image feature on frame 0
|
110 |
+
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
111 |
+
return inference_state
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
115 |
+
"""
|
116 |
+
Load a pretrained model from the Hugging Face hub.
|
117 |
+
|
118 |
+
Arguments:
|
119 |
+
model_id (str): The Hugging Face repository ID.
|
120 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
(SAM2VideoPredictor): The loaded model.
|
124 |
+
"""
|
125 |
+
from sam2.build_sam import build_sam2_video_predictor_hf
|
126 |
+
|
127 |
+
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
128 |
+
return sam_model
|
129 |
+
|
130 |
+
def _obj_id_to_idx(self, inference_state, obj_id):
|
131 |
+
"""Map client-side object id to model-side object index."""
|
132 |
+
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
133 |
+
if obj_idx is not None:
|
134 |
+
return obj_idx
|
135 |
+
|
136 |
+
# This is a new object id not sent to the server before. We only allow adding
|
137 |
+
# new objects *before* the tracking starts.
|
138 |
+
allow_new_object = not inference_state["tracking_has_started"]
|
139 |
+
if allow_new_object:
|
140 |
+
# get the next object slot
|
141 |
+
obj_idx = len(inference_state["obj_id_to_idx"])
|
142 |
+
inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
143 |
+
inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
144 |
+
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
145 |
+
# set up input and output structures for this object
|
146 |
+
inference_state["point_inputs_per_obj"][obj_idx] = {}
|
147 |
+
inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
148 |
+
inference_state["output_dict_per_obj"][obj_idx] = {
|
149 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
150 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
151 |
+
}
|
152 |
+
inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
153 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
154 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
155 |
+
}
|
156 |
+
return obj_idx
|
157 |
+
else:
|
158 |
+
raise RuntimeError(
|
159 |
+
f"Cannot add new object id {obj_id} after tracking starts. "
|
160 |
+
f"All existing object ids: {inference_state['obj_ids']}. "
|
161 |
+
f"Please call 'reset_state' to restart from scratch."
|
162 |
+
)
|
163 |
+
|
164 |
+
def _obj_idx_to_id(self, inference_state, obj_idx):
|
165 |
+
"""Map model-side object index to client-side object id."""
|
166 |
+
return inference_state["obj_idx_to_id"][obj_idx]
|
167 |
+
|
168 |
+
def _get_obj_num(self, inference_state):
|
169 |
+
"""Get the total number of unique object ids received so far in this session."""
|
170 |
+
return len(inference_state["obj_idx_to_id"])
|
171 |
+
|
172 |
+
@torch.inference_mode()
|
173 |
+
def add_new_points_or_box(
|
174 |
+
self,
|
175 |
+
inference_state,
|
176 |
+
frame_idx,
|
177 |
+
obj_id,
|
178 |
+
points=None,
|
179 |
+
labels=None,
|
180 |
+
clear_old_points=True,
|
181 |
+
normalize_coords=True,
|
182 |
+
box=None,
|
183 |
+
):
|
184 |
+
"""Add new points to a frame."""
|
185 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
186 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
187 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
188 |
+
|
189 |
+
if (points is not None) != (labels is not None):
|
190 |
+
raise ValueError("points and labels must be provided together")
|
191 |
+
if points is None and box is None:
|
192 |
+
raise ValueError("at least one of points or box must be provided as input")
|
193 |
+
|
194 |
+
if points is None:
|
195 |
+
points = torch.zeros(0, 2, dtype=torch.float32)
|
196 |
+
elif not isinstance(points, torch.Tensor):
|
197 |
+
points = torch.tensor(points, dtype=torch.float32)
|
198 |
+
if labels is None:
|
199 |
+
labels = torch.zeros(0, dtype=torch.int32)
|
200 |
+
elif not isinstance(labels, torch.Tensor):
|
201 |
+
labels = torch.tensor(labels, dtype=torch.int32)
|
202 |
+
if points.dim() == 2:
|
203 |
+
points = points.unsqueeze(0) # add batch dimension
|
204 |
+
if labels.dim() == 1:
|
205 |
+
labels = labels.unsqueeze(0) # add batch dimension
|
206 |
+
|
207 |
+
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
208 |
+
# along with the user-provided points (consistent with how SAM 2 is trained).
|
209 |
+
if box is not None:
|
210 |
+
if not clear_old_points:
|
211 |
+
raise ValueError(
|
212 |
+
"cannot add box without clearing old points, since "
|
213 |
+
"box prompt must be provided before any point prompt "
|
214 |
+
"(please use clear_old_points=True instead)"
|
215 |
+
)
|
216 |
+
if inference_state["tracking_has_started"]:
|
217 |
+
warnings.warn(
|
218 |
+
"You are adding a box after tracking starts. SAM 2 may not always be "
|
219 |
+
"able to incorporate a box prompt for *refinement*. If you intend to "
|
220 |
+
"use box prompt as an *initial* input before tracking, please call "
|
221 |
+
"'reset_state' on the inference state to restart from scratch.",
|
222 |
+
category=UserWarning,
|
223 |
+
stacklevel=2,
|
224 |
+
)
|
225 |
+
if not isinstance(box, torch.Tensor):
|
226 |
+
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
227 |
+
box_coords = box.reshape(1, 2, 2)
|
228 |
+
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
229 |
+
box_labels = box_labels.reshape(1, 2)
|
230 |
+
points = torch.cat([box_coords, points], dim=1)
|
231 |
+
labels = torch.cat([box_labels, labels], dim=1)
|
232 |
+
|
233 |
+
if normalize_coords:
|
234 |
+
video_H = inference_state["video_height"]
|
235 |
+
video_W = inference_state["video_width"]
|
236 |
+
points = points / torch.tensor([video_W, video_H]).to(points.device)
|
237 |
+
# scale the (normalized) coordinates by the model's internal image size
|
238 |
+
points = points * self.image_size
|
239 |
+
points = points.to(inference_state["device"])
|
240 |
+
labels = labels.to(inference_state["device"])
|
241 |
+
|
242 |
+
if not clear_old_points:
|
243 |
+
point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
244 |
+
else:
|
245 |
+
point_inputs = None
|
246 |
+
point_inputs = concat_points(point_inputs, points, labels)
|
247 |
+
|
248 |
+
point_inputs_per_frame[frame_idx] = point_inputs
|
249 |
+
mask_inputs_per_frame.pop(frame_idx, None)
|
250 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
251 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
252 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
253 |
+
# the input points will be used to correct the already tracked masks.
|
254 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
255 |
+
# whether to track in reverse time order
|
256 |
+
if is_init_cond_frame:
|
257 |
+
reverse = False
|
258 |
+
else:
|
259 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
260 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
261 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
262 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
263 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
264 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
265 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
266 |
+
|
267 |
+
# Get any previously predicted mask logits on this object and feed it along with
|
268 |
+
# the new clicks into the SAM mask decoder.
|
269 |
+
prev_sam_mask_logits = None
|
270 |
+
# lookup temporary output dict first, which contains the most recent output
|
271 |
+
# (if not found, then lookup conditioning and non-conditioning frame output)
|
272 |
+
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
273 |
+
if prev_out is None:
|
274 |
+
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
275 |
+
if prev_out is None:
|
276 |
+
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
277 |
+
|
278 |
+
if prev_out is not None and prev_out["pred_masks"] is not None:
|
279 |
+
device = inference_state["device"]
|
280 |
+
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
281 |
+
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
282 |
+
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
283 |
+
current_out, _ = self._run_single_frame_inference(
|
284 |
+
inference_state=inference_state,
|
285 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
286 |
+
frame_idx=frame_idx,
|
287 |
+
batch_size=1, # run on the slice of a single object
|
288 |
+
is_init_cond_frame=is_init_cond_frame,
|
289 |
+
point_inputs=point_inputs,
|
290 |
+
mask_inputs=None,
|
291 |
+
reverse=reverse,
|
292 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
293 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
294 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
295 |
+
# them into memory.
|
296 |
+
run_mem_encoder=False,
|
297 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
298 |
+
)
|
299 |
+
# Add the output to the output dict (to be used as future memory)
|
300 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
301 |
+
|
302 |
+
# Resize the output mask to the original video resolution
|
303 |
+
obj_ids = inference_state["obj_ids"]
|
304 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
305 |
+
inference_state,
|
306 |
+
frame_idx,
|
307 |
+
is_cond=is_cond,
|
308 |
+
run_mem_encoder=False,
|
309 |
+
consolidate_at_video_res=True,
|
310 |
+
)
|
311 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
312 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
313 |
+
)
|
314 |
+
return frame_idx, obj_ids, video_res_masks
|
315 |
+
|
316 |
+
def add_new_points(self, *args, **kwargs):
|
317 |
+
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
318 |
+
return self.add_new_points_or_box(*args, **kwargs)
|
319 |
+
|
320 |
+
@torch.inference_mode()
|
321 |
+
def add_new_mask(
|
322 |
+
self,
|
323 |
+
inference_state,
|
324 |
+
frame_idx,
|
325 |
+
obj_id,
|
326 |
+
mask,
|
327 |
+
):
|
328 |
+
"""Add new mask to a frame."""
|
329 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
330 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
331 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
332 |
+
|
333 |
+
if not isinstance(mask, torch.Tensor):
|
334 |
+
mask = torch.tensor(mask, dtype=torch.bool)
|
335 |
+
assert mask.dim() == 2
|
336 |
+
mask_H, mask_W = mask.shape
|
337 |
+
mask_inputs_orig = mask[None, None] # add batch and channel dimension
|
338 |
+
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
339 |
+
|
340 |
+
# resize the mask if it doesn't match the model's image size
|
341 |
+
if mask_H != self.image_size or mask_W != self.image_size:
|
342 |
+
mask_inputs = torch.nn.functional.interpolate(
|
343 |
+
mask_inputs_orig,
|
344 |
+
size=(self.image_size, self.image_size),
|
345 |
+
align_corners=False,
|
346 |
+
mode="bilinear",
|
347 |
+
antialias=True, # use antialias for downsampling
|
348 |
+
)
|
349 |
+
mask_inputs = (mask_inputs >= 0.5).float()
|
350 |
+
else:
|
351 |
+
mask_inputs = mask_inputs_orig
|
352 |
+
|
353 |
+
mask_inputs_per_frame[frame_idx] = mask_inputs
|
354 |
+
point_inputs_per_frame.pop(frame_idx, None)
|
355 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
356 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
357 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
358 |
+
# the input points will be used to correct the already tracked masks.
|
359 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
360 |
+
# whether to track in reverse time order
|
361 |
+
if is_init_cond_frame:
|
362 |
+
reverse = False
|
363 |
+
else:
|
364 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
365 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
366 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
367 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
368 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
369 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
370 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
371 |
+
|
372 |
+
current_out, _ = self._run_single_frame_inference(
|
373 |
+
inference_state=inference_state,
|
374 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
375 |
+
frame_idx=frame_idx,
|
376 |
+
batch_size=1, # run on the slice of a single object
|
377 |
+
is_init_cond_frame=is_init_cond_frame,
|
378 |
+
point_inputs=None,
|
379 |
+
mask_inputs=mask_inputs,
|
380 |
+
reverse=reverse,
|
381 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
382 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
383 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
384 |
+
# them into memory.
|
385 |
+
run_mem_encoder=False,
|
386 |
+
)
|
387 |
+
# Add the output to the output dict (to be used as future memory)
|
388 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
389 |
+
|
390 |
+
# Resize the output mask to the original video resolution
|
391 |
+
obj_ids = inference_state["obj_ids"]
|
392 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
393 |
+
inference_state,
|
394 |
+
frame_idx,
|
395 |
+
is_cond=is_cond,
|
396 |
+
run_mem_encoder=False,
|
397 |
+
consolidate_at_video_res=True,
|
398 |
+
)
|
399 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
400 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
401 |
+
)
|
402 |
+
return frame_idx, obj_ids, video_res_masks
|
403 |
+
|
404 |
+
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
405 |
+
"""
|
406 |
+
Resize the object scores to the original video resolution (video_res_masks)
|
407 |
+
and apply non-overlapping constraints for final output.
|
408 |
+
"""
|
409 |
+
device = inference_state["device"]
|
410 |
+
video_H = inference_state["video_height"]
|
411 |
+
video_W = inference_state["video_width"]
|
412 |
+
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
413 |
+
if any_res_masks.shape[-2:] == (video_H, video_W):
|
414 |
+
video_res_masks = any_res_masks
|
415 |
+
else:
|
416 |
+
video_res_masks = torch.nn.functional.interpolate(
|
417 |
+
any_res_masks,
|
418 |
+
size=(video_H, video_W),
|
419 |
+
mode="bilinear",
|
420 |
+
align_corners=False,
|
421 |
+
)
|
422 |
+
if self.non_overlap_masks:
|
423 |
+
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
424 |
+
return any_res_masks, video_res_masks
|
425 |
+
|
426 |
+
def _consolidate_temp_output_across_obj(
|
427 |
+
self,
|
428 |
+
inference_state,
|
429 |
+
frame_idx,
|
430 |
+
is_cond,
|
431 |
+
run_mem_encoder,
|
432 |
+
consolidate_at_video_res=False,
|
433 |
+
):
|
434 |
+
"""
|
435 |
+
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
436 |
+
a frame into a single output for all objects, including
|
437 |
+
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
438 |
+
`output_dict_per_obj` for this frame) or leave them as placeholder values
|
439 |
+
(if they don't exist in `output_dict_per_obj` for this frame);
|
440 |
+
2) if specified, rerun memory encoder after apply non-overlapping constraints
|
441 |
+
on the object scores.
|
442 |
+
"""
|
443 |
+
batch_size = self._get_obj_num(inference_state)
|
444 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
445 |
+
# Optionally, we allow consolidating the temporary outputs at the original
|
446 |
+
# video resolution (to provide a better editing experience for mask prompts).
|
447 |
+
if consolidate_at_video_res:
|
448 |
+
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
449 |
+
consolidated_H = inference_state["video_height"]
|
450 |
+
consolidated_W = inference_state["video_width"]
|
451 |
+
consolidated_mask_key = "pred_masks_video_res"
|
452 |
+
else:
|
453 |
+
consolidated_H = consolidated_W = self.image_size // 4
|
454 |
+
consolidated_mask_key = "pred_masks"
|
455 |
+
|
456 |
+
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
457 |
+
# will be added when rerunning the memory encoder after applying non-overlapping
|
458 |
+
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
459 |
+
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
460 |
+
consolidated_out = {
|
461 |
+
"maskmem_features": None,
|
462 |
+
"maskmem_pos_enc": None,
|
463 |
+
consolidated_mask_key: torch.full(
|
464 |
+
size=(batch_size, 1, consolidated_H, consolidated_W),
|
465 |
+
fill_value=NO_OBJ_SCORE,
|
466 |
+
dtype=torch.float32,
|
467 |
+
device=inference_state["storage_device"],
|
468 |
+
),
|
469 |
+
"obj_ptr": torch.full(
|
470 |
+
size=(batch_size, self.hidden_dim),
|
471 |
+
fill_value=NO_OBJ_SCORE,
|
472 |
+
dtype=torch.float32,
|
473 |
+
device=inference_state["device"],
|
474 |
+
),
|
475 |
+
"object_score_logits": torch.full(
|
476 |
+
size=(batch_size, 1),
|
477 |
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
478 |
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
+
fill_value=10.0,
|
480 |
+
dtype=torch.float32,
|
481 |
+
device=inference_state["device"],
|
482 |
+
),
|
483 |
+
}
|
484 |
+
empty_mask_ptr = None
|
485 |
+
for obj_idx in range(batch_size):
|
486 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
487 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
488 |
+
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
489 |
+
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
490 |
+
# we fall back and look up its previous output in "output_dict_per_obj".
|
491 |
+
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
492 |
+
# "output_dict_per_obj" to find a previous output for this object.
|
493 |
+
if out is None:
|
494 |
+
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
495 |
+
if out is None:
|
496 |
+
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
497 |
+
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
498 |
+
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
499 |
+
# placeholder above) and set its object pointer to be a dummy pointer.
|
500 |
+
if out is None:
|
501 |
+
# Fill in dummy object pointers for those objects without any inputs or
|
502 |
+
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
503 |
+
# i.e. when we need to build the memory for tracking).
|
504 |
+
if run_mem_encoder:
|
505 |
+
if empty_mask_ptr is None:
|
506 |
+
empty_mask_ptr = self._get_empty_mask_ptr(
|
507 |
+
inference_state, frame_idx
|
508 |
+
)
|
509 |
+
# fill object pointer with a dummy pointer (based on an empty mask)
|
510 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
511 |
+
continue
|
512 |
+
# Add the temporary object output mask to consolidated output mask
|
513 |
+
obj_mask = out["pred_masks"]
|
514 |
+
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
515 |
+
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
516 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
517 |
+
else:
|
518 |
+
# Resize first if temporary object mask has a different resolution
|
519 |
+
resized_obj_mask = torch.nn.functional.interpolate(
|
520 |
+
obj_mask,
|
521 |
+
size=consolidated_pred_masks.shape[-2:],
|
522 |
+
mode="bilinear",
|
523 |
+
align_corners=False,
|
524 |
+
)
|
525 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
526 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
527 |
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
528 |
+
"object_score_logits"
|
529 |
+
]
|
530 |
+
|
531 |
+
# Optionally, apply non-overlapping constraints on the consolidated scores
|
532 |
+
# and rerun the memory encoder
|
533 |
+
if run_mem_encoder:
|
534 |
+
device = inference_state["device"]
|
535 |
+
high_res_masks = torch.nn.functional.interpolate(
|
536 |
+
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
537 |
+
size=(self.image_size, self.image_size),
|
538 |
+
mode="bilinear",
|
539 |
+
align_corners=False,
|
540 |
+
)
|
541 |
+
if self.non_overlap_masks_for_mem_enc:
|
542 |
+
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
543 |
+
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
544 |
+
inference_state=inference_state,
|
545 |
+
frame_idx=frame_idx,
|
546 |
+
batch_size=batch_size,
|
547 |
+
high_res_masks=high_res_masks,
|
548 |
+
object_score_logits=consolidated_out["object_score_logits"],
|
549 |
+
is_mask_from_pts=True, # these frames are what the user interacted with
|
550 |
+
)
|
551 |
+
consolidated_out["maskmem_features"] = maskmem_features
|
552 |
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
553 |
+
|
554 |
+
return consolidated_out
|
555 |
+
|
556 |
+
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
557 |
+
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
558 |
+
# A dummy (empty) mask with a single object
|
559 |
+
batch_size = 1
|
560 |
+
mask_inputs = torch.zeros(
|
561 |
+
(batch_size, 1, self.image_size, self.image_size),
|
562 |
+
dtype=torch.float32,
|
563 |
+
device=inference_state["device"],
|
564 |
+
)
|
565 |
+
|
566 |
+
# Retrieve correct image features
|
567 |
+
(
|
568 |
+
_,
|
569 |
+
_,
|
570 |
+
current_vision_feats,
|
571 |
+
current_vision_pos_embeds,
|
572 |
+
feat_sizes,
|
573 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
574 |
+
|
575 |
+
# Feed the empty mask and image feature above to get a dummy object pointer
|
576 |
+
current_out = self.track_step(
|
577 |
+
frame_idx=frame_idx,
|
578 |
+
is_init_cond_frame=True,
|
579 |
+
current_vision_feats=current_vision_feats,
|
580 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
581 |
+
feat_sizes=feat_sizes,
|
582 |
+
point_inputs=None,
|
583 |
+
mask_inputs=mask_inputs,
|
584 |
+
output_dict={},
|
585 |
+
num_frames=inference_state["num_frames"],
|
586 |
+
track_in_reverse=False,
|
587 |
+
run_mem_encoder=False,
|
588 |
+
prev_sam_mask_logits=None,
|
589 |
+
)
|
590 |
+
return current_out["obj_ptr"]
|
591 |
+
|
592 |
+
@torch.inference_mode()
|
593 |
+
def propagate_in_video_preflight(self, inference_state):
|
594 |
+
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
595 |
+
# Tracking has started and we don't allow adding new objects until session is reset.
|
596 |
+
inference_state["tracking_has_started"] = True
|
597 |
+
batch_size = self._get_obj_num(inference_state)
|
598 |
+
|
599 |
+
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
600 |
+
# add them into "output_dict".
|
601 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
602 |
+
output_dict = inference_state["output_dict"]
|
603 |
+
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
604 |
+
# temporary outputs have been added (either in this call or any previous calls
|
605 |
+
# to `propagate_in_video_preflight`).
|
606 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
607 |
+
for is_cond in [False, True]:
|
608 |
+
# Separately consolidate conditioning and non-conditioning temp outputs
|
609 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
610 |
+
# Find all the frames that contain temporary outputs for any objects
|
611 |
+
# (these should be the frames that have just received clicks for mask inputs
|
612 |
+
# via `add_new_points_or_box` or `add_new_mask`)
|
613 |
+
temp_frame_inds = set()
|
614 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
615 |
+
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
616 |
+
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
617 |
+
# consolidate the temporary output across all objects on this frame
|
618 |
+
for frame_idx in temp_frame_inds:
|
619 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
620 |
+
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
621 |
+
)
|
622 |
+
# merge them into "output_dict" and also create per-object slices
|
623 |
+
output_dict[storage_key][frame_idx] = consolidated_out
|
624 |
+
self._add_output_per_object(
|
625 |
+
inference_state, frame_idx, consolidated_out, storage_key
|
626 |
+
)
|
627 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
628 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
629 |
+
)
|
630 |
+
if clear_non_cond_mem:
|
631 |
+
# clear non-conditioning memory of the surrounding frames
|
632 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
633 |
+
|
634 |
+
# clear temporary outputs in `temp_output_dict_per_obj`
|
635 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
636 |
+
obj_temp_output_dict[storage_key].clear()
|
637 |
+
|
638 |
+
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
639 |
+
# output on the same frame in "non_cond_frame_outputs"
|
640 |
+
for frame_idx in output_dict["cond_frame_outputs"]:
|
641 |
+
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
642 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
643 |
+
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
644 |
+
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
645 |
+
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
646 |
+
assert frame_idx in output_dict["cond_frame_outputs"]
|
647 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
648 |
+
|
649 |
+
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
650 |
+
# with either points or mask inputs (which should be true under a correct workflow).
|
651 |
+
all_consolidated_frame_inds = (
|
652 |
+
consolidated_frame_inds["cond_frame_outputs"]
|
653 |
+
| consolidated_frame_inds["non_cond_frame_outputs"]
|
654 |
+
)
|
655 |
+
input_frames_inds = set()
|
656 |
+
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
657 |
+
input_frames_inds.update(point_inputs_per_frame.keys())
|
658 |
+
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
659 |
+
input_frames_inds.update(mask_inputs_per_frame.keys())
|
660 |
+
assert all_consolidated_frame_inds == input_frames_inds
|
661 |
+
|
662 |
+
@torch.inference_mode()
|
663 |
+
def propagate_in_video(
|
664 |
+
self,
|
665 |
+
inference_state,
|
666 |
+
start_frame_idx=None,
|
667 |
+
max_frame_num_to_track=None,
|
668 |
+
reverse=False,
|
669 |
+
):
|
670 |
+
"""Propagate the input points across frames to track in the entire video."""
|
671 |
+
self.propagate_in_video_preflight(inference_state)
|
672 |
+
|
673 |
+
output_dict = inference_state["output_dict"]
|
674 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
675 |
+
obj_ids = inference_state["obj_ids"]
|
676 |
+
num_frames = inference_state["num_frames"]
|
677 |
+
batch_size = self._get_obj_num(inference_state)
|
678 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
679 |
+
raise RuntimeError("No points are provided; please add points first")
|
680 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
681 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
682 |
+
)
|
683 |
+
|
684 |
+
# set start index, end index, and processing order
|
685 |
+
if start_frame_idx is None:
|
686 |
+
# default: start from the earliest frame with input points
|
687 |
+
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
688 |
+
if max_frame_num_to_track is None:
|
689 |
+
# default: track all the frames in the video
|
690 |
+
max_frame_num_to_track = num_frames
|
691 |
+
if reverse:
|
692 |
+
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
693 |
+
if start_frame_idx > 0:
|
694 |
+
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
695 |
+
else:
|
696 |
+
processing_order = [] # skip reverse tracking if starting from frame 0
|
697 |
+
else:
|
698 |
+
end_frame_idx = min(
|
699 |
+
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
700 |
+
)
|
701 |
+
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
702 |
+
|
703 |
+
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
704 |
+
# We skip those frames already in consolidated outputs (these are frames
|
705 |
+
# that received input clicks or mask). Note that we cannot directly run
|
706 |
+
# batched forward on them via `_run_single_frame_inference` because the
|
707 |
+
# number of clicks on each object might be different.
|
708 |
+
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
709 |
+
storage_key = "cond_frame_outputs"
|
710 |
+
current_out = output_dict[storage_key][frame_idx]
|
711 |
+
pred_masks = current_out["pred_masks"]
|
712 |
+
if clear_non_cond_mem:
|
713 |
+
# clear non-conditioning memory of the surrounding frames
|
714 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
715 |
+
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
716 |
+
storage_key = "non_cond_frame_outputs"
|
717 |
+
current_out = output_dict[storage_key][frame_idx]
|
718 |
+
pred_masks = current_out["pred_masks"]
|
719 |
+
else:
|
720 |
+
storage_key = "non_cond_frame_outputs"
|
721 |
+
current_out, pred_masks = self._run_single_frame_inference(
|
722 |
+
inference_state=inference_state,
|
723 |
+
output_dict=output_dict,
|
724 |
+
frame_idx=frame_idx,
|
725 |
+
batch_size=batch_size,
|
726 |
+
is_init_cond_frame=False,
|
727 |
+
point_inputs=None,
|
728 |
+
mask_inputs=None,
|
729 |
+
reverse=reverse,
|
730 |
+
run_mem_encoder=True,
|
731 |
+
)
|
732 |
+
output_dict[storage_key][frame_idx] = current_out
|
733 |
+
# Create slices of per-object outputs for subsequent interaction with each
|
734 |
+
# individual object after tracking.
|
735 |
+
self._add_output_per_object(
|
736 |
+
inference_state, frame_idx, current_out, storage_key
|
737 |
+
)
|
738 |
+
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
739 |
+
|
740 |
+
# Resize the output mask to the original video resolution (we directly use
|
741 |
+
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
742 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
743 |
+
inference_state, pred_masks
|
744 |
+
)
|
745 |
+
yield frame_idx, obj_ids, video_res_masks
|
746 |
+
|
747 |
+
def _add_output_per_object(
|
748 |
+
self, inference_state, frame_idx, current_out, storage_key
|
749 |
+
):
|
750 |
+
"""
|
751 |
+
Split a multi-object output into per-object output slices and add them into
|
752 |
+
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
753 |
+
"""
|
754 |
+
maskmem_features = current_out["maskmem_features"]
|
755 |
+
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
756 |
+
|
757 |
+
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
758 |
+
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
759 |
+
|
760 |
+
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
761 |
+
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
762 |
+
obj_slice = slice(obj_idx, obj_idx + 1)
|
763 |
+
obj_out = {
|
764 |
+
"maskmem_features": None,
|
765 |
+
"maskmem_pos_enc": None,
|
766 |
+
"pred_masks": current_out["pred_masks"][obj_slice],
|
767 |
+
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
768 |
+
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
769 |
+
}
|
770 |
+
if maskmem_features is not None:
|
771 |
+
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
772 |
+
if maskmem_pos_enc is not None:
|
773 |
+
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
774 |
+
obj_output_dict[storage_key][frame_idx] = obj_out
|
775 |
+
|
776 |
+
@torch.inference_mode()
|
777 |
+
def clear_all_prompts_in_frame(
|
778 |
+
self, inference_state, frame_idx, obj_id, need_output=True
|
779 |
+
):
|
780 |
+
"""Remove all input points or mask in a specific frame for a given object."""
|
781 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
782 |
+
|
783 |
+
# Clear the conditioning information on the given frame
|
784 |
+
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
785 |
+
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
786 |
+
|
787 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
788 |
+
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
789 |
+
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
790 |
+
|
791 |
+
# Check and see if there are still any inputs left on this frame
|
792 |
+
batch_size = self._get_obj_num(inference_state)
|
793 |
+
frame_has_input = False
|
794 |
+
for obj_idx2 in range(batch_size):
|
795 |
+
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
796 |
+
frame_has_input = True
|
797 |
+
break
|
798 |
+
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
799 |
+
frame_has_input = True
|
800 |
+
break
|
801 |
+
|
802 |
+
# If this frame has no remaining inputs for any objects, we further clear its
|
803 |
+
# conditioning frame status
|
804 |
+
if not frame_has_input:
|
805 |
+
output_dict = inference_state["output_dict"]
|
806 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
807 |
+
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
808 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
809 |
+
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
810 |
+
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
811 |
+
if out is not None:
|
812 |
+
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
813 |
+
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
814 |
+
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
815 |
+
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
816 |
+
# Similarly, do it for the sliced output on each object.
|
817 |
+
for obj_idx2 in range(batch_size):
|
818 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
819 |
+
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
820 |
+
if obj_out is not None:
|
821 |
+
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
822 |
+
|
823 |
+
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
824 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
825 |
+
self._reset_tracking_results(inference_state)
|
826 |
+
|
827 |
+
if not need_output:
|
828 |
+
return
|
829 |
+
# Finally, output updated masks per object (after removing the inputs above)
|
830 |
+
obj_ids = inference_state["obj_ids"]
|
831 |
+
is_cond = any(
|
832 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
833 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
834 |
+
)
|
835 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
836 |
+
inference_state,
|
837 |
+
frame_idx,
|
838 |
+
is_cond=is_cond,
|
839 |
+
run_mem_encoder=False,
|
840 |
+
consolidate_at_video_res=True,
|
841 |
+
)
|
842 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
843 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
844 |
+
)
|
845 |
+
return frame_idx, obj_ids, video_res_masks
|
846 |
+
|
847 |
+
@torch.inference_mode()
|
848 |
+
def reset_state(self, inference_state):
|
849 |
+
"""Remove all input points or mask in all frames throughout the video."""
|
850 |
+
self._reset_tracking_results(inference_state)
|
851 |
+
# Remove all object ids
|
852 |
+
inference_state["obj_id_to_idx"].clear()
|
853 |
+
inference_state["obj_idx_to_id"].clear()
|
854 |
+
inference_state["obj_ids"].clear()
|
855 |
+
inference_state["point_inputs_per_obj"].clear()
|
856 |
+
inference_state["mask_inputs_per_obj"].clear()
|
857 |
+
inference_state["output_dict_per_obj"].clear()
|
858 |
+
inference_state["temp_output_dict_per_obj"].clear()
|
859 |
+
|
860 |
+
def _reset_tracking_results(self, inference_state):
|
861 |
+
"""Reset all tracking inputs and results across the videos."""
|
862 |
+
for v in inference_state["point_inputs_per_obj"].values():
|
863 |
+
v.clear()
|
864 |
+
for v in inference_state["mask_inputs_per_obj"].values():
|
865 |
+
v.clear()
|
866 |
+
for v in inference_state["output_dict_per_obj"].values():
|
867 |
+
v["cond_frame_outputs"].clear()
|
868 |
+
v["non_cond_frame_outputs"].clear()
|
869 |
+
for v in inference_state["temp_output_dict_per_obj"].values():
|
870 |
+
v["cond_frame_outputs"].clear()
|
871 |
+
v["non_cond_frame_outputs"].clear()
|
872 |
+
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
873 |
+
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
874 |
+
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
875 |
+
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
876 |
+
inference_state["tracking_has_started"] = False
|
877 |
+
inference_state["frames_already_tracked"].clear()
|
878 |
+
|
879 |
+
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
880 |
+
"""Compute the image features on a given frame."""
|
881 |
+
# Look up in the cache first
|
882 |
+
image, backbone_out = inference_state["cached_features"].get(
|
883 |
+
frame_idx, (None, None)
|
884 |
+
)
|
885 |
+
if backbone_out is None:
|
886 |
+
# Cache miss -- we will run inference on a single image
|
887 |
+
device = inference_state["device"]
|
888 |
+
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
889 |
+
backbone_out = self.forward_image(image)
|
890 |
+
# Cache the most recent frame's feature (for repeated interactions with
|
891 |
+
# a frame; we can use an LRU cache for more frames in the future).
|
892 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
893 |
+
|
894 |
+
# expand the features to have the same dimension as the number of objects
|
895 |
+
expanded_image = image.expand(batch_size, -1, -1, -1)
|
896 |
+
expanded_backbone_out = {
|
897 |
+
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
898 |
+
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
899 |
+
}
|
900 |
+
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
901 |
+
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
902 |
+
batch_size, -1, -1, -1
|
903 |
+
)
|
904 |
+
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
905 |
+
pos = pos.expand(batch_size, -1, -1, -1)
|
906 |
+
expanded_backbone_out["vision_pos_enc"][i] = pos
|
907 |
+
|
908 |
+
features = self._prepare_backbone_features(expanded_backbone_out)
|
909 |
+
features = (expanded_image,) + features
|
910 |
+
return features
|
911 |
+
|
912 |
+
def _run_single_frame_inference(
|
913 |
+
self,
|
914 |
+
inference_state,
|
915 |
+
output_dict,
|
916 |
+
frame_idx,
|
917 |
+
batch_size,
|
918 |
+
is_init_cond_frame,
|
919 |
+
point_inputs,
|
920 |
+
mask_inputs,
|
921 |
+
reverse,
|
922 |
+
run_mem_encoder,
|
923 |
+
prev_sam_mask_logits=None,
|
924 |
+
):
|
925 |
+
"""Run tracking on a single frame based on current inputs and previous memory."""
|
926 |
+
# Retrieve correct image features
|
927 |
+
(
|
928 |
+
_,
|
929 |
+
_,
|
930 |
+
current_vision_feats,
|
931 |
+
current_vision_pos_embeds,
|
932 |
+
feat_sizes,
|
933 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
934 |
+
|
935 |
+
# point and mask should not appear as input simultaneously on the same frame
|
936 |
+
assert point_inputs is None or mask_inputs is None
|
937 |
+
current_out = self.track_step(
|
938 |
+
frame_idx=frame_idx,
|
939 |
+
is_init_cond_frame=is_init_cond_frame,
|
940 |
+
current_vision_feats=current_vision_feats,
|
941 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
942 |
+
feat_sizes=feat_sizes,
|
943 |
+
point_inputs=point_inputs,
|
944 |
+
mask_inputs=mask_inputs,
|
945 |
+
output_dict=output_dict,
|
946 |
+
num_frames=inference_state["num_frames"],
|
947 |
+
track_in_reverse=reverse,
|
948 |
+
run_mem_encoder=run_mem_encoder,
|
949 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
950 |
+
)
|
951 |
+
|
952 |
+
# optionally offload the output to CPU memory to save GPU space
|
953 |
+
storage_device = inference_state["storage_device"]
|
954 |
+
maskmem_features = current_out["maskmem_features"]
|
955 |
+
if maskmem_features is not None:
|
956 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
957 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
958 |
+
pred_masks_gpu = current_out["pred_masks"]
|
959 |
+
# potentially fill holes in the predicted masks
|
960 |
+
if self.fill_hole_area > 0:
|
961 |
+
pred_masks_gpu = fill_holes_in_mask_scores(
|
962 |
+
pred_masks_gpu, self.fill_hole_area
|
963 |
+
)
|
964 |
+
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
965 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
966 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
967 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
968 |
+
obj_ptr = current_out["obj_ptr"]
|
969 |
+
object_score_logits = current_out["object_score_logits"]
|
970 |
+
# make a compact version of this frame's output to reduce the state size
|
971 |
+
compact_current_out = {
|
972 |
+
"maskmem_features": maskmem_features,
|
973 |
+
"maskmem_pos_enc": maskmem_pos_enc,
|
974 |
+
"pred_masks": pred_masks,
|
975 |
+
"obj_ptr": obj_ptr,
|
976 |
+
"object_score_logits": object_score_logits,
|
977 |
+
}
|
978 |
+
return compact_current_out, pred_masks_gpu
|
979 |
+
|
980 |
+
def _run_memory_encoder(
|
981 |
+
self,
|
982 |
+
inference_state,
|
983 |
+
frame_idx,
|
984 |
+
batch_size,
|
985 |
+
high_res_masks,
|
986 |
+
object_score_logits,
|
987 |
+
is_mask_from_pts,
|
988 |
+
):
|
989 |
+
"""
|
990 |
+
Run the memory encoder on `high_res_masks`. This is usually after applying
|
991 |
+
non-overlapping constraints to object scores. Since their scores changed, their
|
992 |
+
memory also need to be computed again with the memory encoder.
|
993 |
+
"""
|
994 |
+
# Retrieve correct image features
|
995 |
+
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
996 |
+
inference_state, frame_idx, batch_size
|
997 |
+
)
|
998 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
999 |
+
current_vision_feats=current_vision_feats,
|
1000 |
+
feat_sizes=feat_sizes,
|
1001 |
+
pred_masks_high_res=high_res_masks,
|
1002 |
+
object_score_logits=object_score_logits,
|
1003 |
+
is_mask_from_pts=is_mask_from_pts,
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
# optionally offload the output to CPU memory to save GPU space
|
1007 |
+
storage_device = inference_state["storage_device"]
|
1008 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
1009 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
1010 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
1011 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(
|
1012 |
+
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
1013 |
+
)
|
1014 |
+
return maskmem_features, maskmem_pos_enc
|
1015 |
+
|
1016 |
+
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
1017 |
+
"""
|
1018 |
+
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1019 |
+
a constant in the inference session to reduce session storage size.
|
1020 |
+
"""
|
1021 |
+
model_constants = inference_state["constants"]
|
1022 |
+
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1023 |
+
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
1024 |
+
if out_maskmem_pos_enc is not None:
|
1025 |
+
if "maskmem_pos_enc" not in model_constants:
|
1026 |
+
assert isinstance(out_maskmem_pos_enc, list)
|
1027 |
+
# only take the slice for one object, since it's same across objects
|
1028 |
+
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1029 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
1030 |
+
else:
|
1031 |
+
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1032 |
+
# expand the cached maskmem_pos_enc to the actual batch size
|
1033 |
+
batch_size = out_maskmem_pos_enc[0].size(0)
|
1034 |
+
expanded_maskmem_pos_enc = [
|
1035 |
+
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
1036 |
+
]
|
1037 |
+
else:
|
1038 |
+
expanded_maskmem_pos_enc = None
|
1039 |
+
return expanded_maskmem_pos_enc
|
1040 |
+
|
1041 |
+
@torch.inference_mode()
|
1042 |
+
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
1043 |
+
"""
|
1044 |
+
Remove an object id from the tracking state. If strict is True, we check whether
|
1045 |
+
the object id actually exists and raise an error if it doesn't exist.
|
1046 |
+
"""
|
1047 |
+
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
1048 |
+
updated_frames = []
|
1049 |
+
# Check whether this object_id to remove actually exists and possibly raise an error.
|
1050 |
+
if old_obj_idx_to_rm is None:
|
1051 |
+
if not strict:
|
1052 |
+
return inference_state["obj_ids"], updated_frames
|
1053 |
+
raise RuntimeError(
|
1054 |
+
f"Cannot remove object id {obj_id} as it doesn't exist. "
|
1055 |
+
f"All existing object ids: {inference_state['obj_ids']}."
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
# If this is the only remaining object id, we simply reset the state.
|
1059 |
+
if len(inference_state["obj_id_to_idx"]) == 1:
|
1060 |
+
self.reset_state(inference_state)
|
1061 |
+
return inference_state["obj_ids"], updated_frames
|
1062 |
+
|
1063 |
+
# There are still remaining objects after removing this object id. In this case,
|
1064 |
+
# we need to delete the object storage from inference state tensors.
|
1065 |
+
# Step 0: clear the input on those frames where this object id has point or mask input
|
1066 |
+
# (note that this step is required as it might downgrade conditioning frames to
|
1067 |
+
# non-conditioning ones)
|
1068 |
+
obj_input_frames_inds = set()
|
1069 |
+
obj_input_frames_inds.update(
|
1070 |
+
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
1071 |
+
)
|
1072 |
+
obj_input_frames_inds.update(
|
1073 |
+
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
1074 |
+
)
|
1075 |
+
for frame_idx in obj_input_frames_inds:
|
1076 |
+
self.clear_all_prompts_in_frame(
|
1077 |
+
inference_state, frame_idx, obj_id, need_output=False
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
1081 |
+
# since Step 0 still requires the old object id mappings in inference_state)
|
1082 |
+
old_obj_ids = inference_state["obj_ids"]
|
1083 |
+
old_obj_inds = list(range(len(old_obj_ids)))
|
1084 |
+
remain_old_obj_inds = old_obj_inds.copy()
|
1085 |
+
remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
1086 |
+
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
1087 |
+
new_obj_inds = list(range(len(new_obj_ids)))
|
1088 |
+
# build new mappings
|
1089 |
+
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
1090 |
+
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
1091 |
+
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
1092 |
+
inference_state["obj_ids"] = new_obj_ids
|
1093 |
+
|
1094 |
+
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
1095 |
+
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
1096 |
+
# it's already handled in Step 0)
|
1097 |
+
def _map_keys(container):
|
1098 |
+
new_kvs = []
|
1099 |
+
for k in old_obj_inds:
|
1100 |
+
v = container.pop(k)
|
1101 |
+
if k in old_idx_to_new_idx:
|
1102 |
+
new_kvs.append((old_idx_to_new_idx[k], v))
|
1103 |
+
container.update(new_kvs)
|
1104 |
+
|
1105 |
+
_map_keys(inference_state["point_inputs_per_obj"])
|
1106 |
+
_map_keys(inference_state["mask_inputs_per_obj"])
|
1107 |
+
_map_keys(inference_state["output_dict_per_obj"])
|
1108 |
+
_map_keys(inference_state["temp_output_dict_per_obj"])
|
1109 |
+
|
1110 |
+
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
1111 |
+
def _slice_state(output_dict, storage_key):
|
1112 |
+
for frame_idx, out in output_dict[storage_key].items():
|
1113 |
+
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
1114 |
+
out["maskmem_pos_enc"] = [
|
1115 |
+
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
1116 |
+
]
|
1117 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
1118 |
+
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
1119 |
+
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
1120 |
+
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
1121 |
+
out["object_score_logits"] = out["object_score_logits"][
|
1122 |
+
remain_old_obj_inds
|
1123 |
+
]
|
1124 |
+
# also update the per-object slices
|
1125 |
+
self._add_output_per_object(
|
1126 |
+
inference_state, frame_idx, out, storage_key
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
1130 |
+
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
1131 |
+
|
1132 |
+
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
1133 |
+
# could show an updated mask for objects previously occluded by the object being removed
|
1134 |
+
if need_output:
|
1135 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
1136 |
+
for frame_idx in obj_input_frames_inds:
|
1137 |
+
is_cond = any(
|
1138 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
1139 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
1140 |
+
)
|
1141 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
1142 |
+
inference_state,
|
1143 |
+
frame_idx,
|
1144 |
+
is_cond=is_cond,
|
1145 |
+
run_mem_encoder=False,
|
1146 |
+
consolidate_at_video_res=True,
|
1147 |
+
)
|
1148 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
1149 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
1150 |
+
)
|
1151 |
+
updated_frames.append((frame_idx, video_res_masks))
|
1152 |
+
|
1153 |
+
return inference_state["obj_ids"], updated_frames
|
1154 |
+
|
1155 |
+
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
1156 |
+
"""
|
1157 |
+
Remove the non-conditioning memory around the input frame. When users provide
|
1158 |
+
correction clicks, the surrounding frames' non-conditioning memories can still
|
1159 |
+
contain outdated object appearance information and could confuse the model.
|
1160 |
+
|
1161 |
+
This method clears those non-conditioning memories surrounding the interacted
|
1162 |
+
frame to avoid giving the model both old and new information about the object.
|
1163 |
+
"""
|
1164 |
+
r = self.memory_temporal_stride_for_eval
|
1165 |
+
frame_idx_begin = frame_idx - r * self.num_maskmem
|
1166 |
+
frame_idx_end = frame_idx + r * self.num_maskmem
|
1167 |
+
output_dict = inference_state["output_dict"]
|
1168 |
+
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
1169 |
+
for t in range(frame_idx_begin, frame_idx_end + 1):
|
1170 |
+
non_cond_frame_outputs.pop(t, None)
|
1171 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
1172 |
+
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
clone-IDEA-Research/Grounded-SAM-2/setup.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import os
|
7 |
+
|
8 |
+
from setuptools import find_packages, setup
|
9 |
+
|
10 |
+
# Package metadata
|
11 |
+
NAME = "SAM-2"
|
12 |
+
VERSION = "1.0"
|
13 |
+
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
|
14 |
+
URL = "https://github.com/facebookresearch/sam2"
|
15 |
+
AUTHOR = "Meta AI"
|
16 |
+
AUTHOR_EMAIL = "[email protected]"
|
17 |
+
LICENSE = "Apache 2.0"
|
18 |
+
|
19 |
+
# Read the contents of README file
|
20 |
+
with open("README.md", "r", encoding="utf-8") as f:
|
21 |
+
LONG_DESCRIPTION = f.read()
|
22 |
+
|
23 |
+
# Required dependencies
|
24 |
+
REQUIRED_PACKAGES = [
|
25 |
+
"torch>=2.3.1",
|
26 |
+
"torchvision>=0.18.1",
|
27 |
+
"numpy>=1.24.4",
|
28 |
+
"tqdm>=4.66.1",
|
29 |
+
"hydra-core>=1.3.2",
|
30 |
+
"iopath>=0.1.10",
|
31 |
+
"pillow>=9.4.0",
|
32 |
+
]
|
33 |
+
|
34 |
+
EXTRA_PACKAGES = {
|
35 |
+
"notebooks": [
|
36 |
+
"matplotlib>=3.9.1",
|
37 |
+
"jupyter>=1.0.0",
|
38 |
+
"opencv-python>=4.7.0",
|
39 |
+
"eva-decord>=0.6.1",
|
40 |
+
],
|
41 |
+
"interactive-demo": [
|
42 |
+
"Flask>=3.0.3",
|
43 |
+
"Flask-Cors>=5.0.0",
|
44 |
+
"av>=13.0.0",
|
45 |
+
"dataclasses-json>=0.6.7",
|
46 |
+
"eva-decord>=0.6.1",
|
47 |
+
"gunicorn>=23.0.0",
|
48 |
+
"imagesize>=1.4.1",
|
49 |
+
"pycocotools>=2.0.8",
|
50 |
+
"strawberry-graphql>=0.243.0",
|
51 |
+
],
|
52 |
+
"dev": [
|
53 |
+
"black==24.2.0",
|
54 |
+
"usort==1.0.2",
|
55 |
+
"ufmt==2.0.0b2",
|
56 |
+
"fvcore>=0.1.5.post20221221",
|
57 |
+
"pandas>=2.2.2",
|
58 |
+
"scikit-image>=0.24.0",
|
59 |
+
"tensorboard>=2.17.0",
|
60 |
+
"pycocotools>=2.0.8",
|
61 |
+
"tensordict>=0.5.0",
|
62 |
+
"opencv-python>=4.7.0",
|
63 |
+
"submitit>=1.5.1",
|
64 |
+
],
|
65 |
+
}
|
66 |
+
|
67 |
+
# By default, we also build the SAM 2 CUDA extension.
|
68 |
+
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
|
69 |
+
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
|
70 |
+
# By default, we allow SAM 2 installation to proceed even with build errors.
|
71 |
+
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
|
72 |
+
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
73 |
+
|
74 |
+
# Catch and skip errors during extension building and print a warning message
|
75 |
+
# (note that this message only shows up under verbose build mode
|
76 |
+
# "pip install -v -e ." or "python setup.py build_ext -v")
|
77 |
+
CUDA_ERROR_MSG = (
|
78 |
+
"{}\n\n"
|
79 |
+
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
80 |
+
"You can still use SAM 2 and it's OK to ignore the error above, although some "
|
81 |
+
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
|
82 |
+
"(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def get_extensions():
|
87 |
+
if not BUILD_CUDA:
|
88 |
+
return []
|
89 |
+
|
90 |
+
try:
|
91 |
+
from torch.utils.cpp_extension import CUDAExtension
|
92 |
+
|
93 |
+
srcs = ["sam2/csrc/connected_components.cu"]
|
94 |
+
compile_args = {
|
95 |
+
"cxx": [],
|
96 |
+
"nvcc": [
|
97 |
+
"-DCUDA_HAS_FP16=1",
|
98 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
99 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
100 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
101 |
+
],
|
102 |
+
}
|
103 |
+
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
104 |
+
except Exception as e:
|
105 |
+
if BUILD_ALLOW_ERRORS:
|
106 |
+
print(CUDA_ERROR_MSG.format(e))
|
107 |
+
ext_modules = []
|
108 |
+
else:
|
109 |
+
raise e
|
110 |
+
|
111 |
+
return ext_modules
|
112 |
+
|
113 |
+
|
114 |
+
try:
|
115 |
+
from torch.utils.cpp_extension import BuildExtension
|
116 |
+
|
117 |
+
class BuildExtensionIgnoreErrors(BuildExtension):
|
118 |
+
|
119 |
+
def finalize_options(self):
|
120 |
+
try:
|
121 |
+
super().finalize_options()
|
122 |
+
except Exception as e:
|
123 |
+
print(CUDA_ERROR_MSG.format(e))
|
124 |
+
self.extensions = []
|
125 |
+
|
126 |
+
def build_extensions(self):
|
127 |
+
try:
|
128 |
+
super().build_extensions()
|
129 |
+
except Exception as e:
|
130 |
+
print(CUDA_ERROR_MSG.format(e))
|
131 |
+
self.extensions = []
|
132 |
+
|
133 |
+
def get_ext_filename(self, ext_name):
|
134 |
+
try:
|
135 |
+
return super().get_ext_filename(ext_name)
|
136 |
+
except Exception as e:
|
137 |
+
print(CUDA_ERROR_MSG.format(e))
|
138 |
+
self.extensions = []
|
139 |
+
return "_C.so"
|
140 |
+
|
141 |
+
cmdclass = {
|
142 |
+
"build_ext": (
|
143 |
+
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
144 |
+
if BUILD_ALLOW_ERRORS
|
145 |
+
else BuildExtension.with_options(no_python_abi_suffix=True)
|
146 |
+
)
|
147 |
+
}
|
148 |
+
except Exception as e:
|
149 |
+
cmdclass = {}
|
150 |
+
if BUILD_ALLOW_ERRORS:
|
151 |
+
print(CUDA_ERROR_MSG.format(e))
|
152 |
+
else:
|
153 |
+
raise e
|
154 |
+
|
155 |
+
|
156 |
+
# Setup configuration
|
157 |
+
setup(
|
158 |
+
name=NAME,
|
159 |
+
version=VERSION,
|
160 |
+
description=DESCRIPTION,
|
161 |
+
long_description=LONG_DESCRIPTION,
|
162 |
+
long_description_content_type="text/markdown",
|
163 |
+
url=URL,
|
164 |
+
author=AUTHOR,
|
165 |
+
author_email=AUTHOR_EMAIL,
|
166 |
+
license=LICENSE,
|
167 |
+
packages=find_packages(exclude="notebooks"),
|
168 |
+
include_package_data=True,
|
169 |
+
install_requires=REQUIRED_PACKAGES,
|
170 |
+
extras_require=EXTRA_PACKAGES,
|
171 |
+
python_requires=">=3.10.0",
|
172 |
+
ext_modules=get_extensions(),
|
173 |
+
cmdclass=cmdclass,
|
174 |
+
)
|
clone-IDEA-Research/Grounded-Segment-Anything/.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# checkpoint
|
132 |
+
*.pth
|
133 |
+
outputs/
|
134 |
+
|
135 |
+
.idea/
|
clone-IDEA-Research/Grounded-Segment-Anything/.gitmodules
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[submodule "grounded-sam-osx"]
|
3 |
+
path = grounded-sam-osx
|
4 |
+
url = https://github.com/linjing7/grounded-sam-osx.git
|
5 |
+
[submodule "VISAM"]
|
6 |
+
path = VISAM
|
7 |
+
url = https://github.com/BingfengYan/VISAM
|
clone-IDEA-Research/Grounded-Segment-Anything/CITATION.cff
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you use this software, please cite it as below."
|
3 |
+
authors:
|
4 |
+
- name: "Grounded-SAM Contributors"
|
5 |
+
title: "Grounded-Segment-Anything"
|
6 |
+
date-released: 2023-04-06
|
7 |
+
url: "https://github.com/IDEA-Research/Grounded-Segment-Anything"
|
8 |
+
license: Apache-2.0
|
clone-IDEA-Research/Grounded-Segment-Anything/Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
|
2 |
+
|
3 |
+
# Arguments to build Docker Image using CUDA
|
4 |
+
ARG USE_CUDA=0
|
5 |
+
ARG TORCH_ARCH=
|
6 |
+
|
7 |
+
ENV AM_I_DOCKER True
|
8 |
+
ENV BUILD_WITH_CUDA "${USE_CUDA}"
|
9 |
+
ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}"
|
10 |
+
ENV CUDA_HOME /usr/local/cuda-11.6/
|
11 |
+
|
12 |
+
RUN mkdir -p /home/appuser/Grounded-Segment-Anything
|
13 |
+
COPY . /home/appuser/Grounded-Segment-Anything/
|
14 |
+
|
15 |
+
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
|
16 |
+
libsm6=2:* libxext6=2:* git=1:* nano=2.* \
|
17 |
+
vim=2:* -y \
|
18 |
+
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
|
19 |
+
|
20 |
+
WORKDIR /home/appuser/Grounded-Segment-Anything
|
21 |
+
RUN python -m pip install --no-cache-dir -e segment_anything
|
22 |
+
|
23 |
+
# When using build isolation, PyTorch with newer CUDA is installed and can't compile GroundingDINO
|
24 |
+
RUN python -m pip install --no-cache-dir wheel
|
25 |
+
RUN python -m pip install --no-cache-dir --no-build-isolation -e GroundingDINO
|
26 |
+
|
27 |
+
WORKDIR /home/appuser
|
28 |
+
RUN pip install --no-cache-dir diffusers[torch]==0.15.1 opencv-python==4.7.0.72 \
|
29 |
+
pycocotools==2.0.6 matplotlib==3.5.3 \
|
30 |
+
onnxruntime==1.14.1 onnx==1.13.1 ipykernel==6.16.2 scipy gradio openai
|
clone-IDEA-Research/Grounded-Segment-Anything/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2020 - present, IDEA, Inc
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
clone-IDEA-Research/Grounded-Segment-Anything/Makefile
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Get version of CUDA and enable it for compilation if CUDA > 11.0
|
2 |
+
# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
|
3 |
+
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
|
4 |
+
# when running in Docker
|
5 |
+
# Check if nvcc is installed
|
6 |
+
NVCC := $(shell which nvcc)
|
7 |
+
ifeq ($(NVCC),)
|
8 |
+
# NVCC not found
|
9 |
+
USE_CUDA := 0
|
10 |
+
NVCC_VERSION := "not installed"
|
11 |
+
else
|
12 |
+
NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
|
13 |
+
USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
|
14 |
+
endif
|
15 |
+
|
16 |
+
# Add the list of supported ARCHs
|
17 |
+
ifeq ($(USE_CUDA), 1)
|
18 |
+
TORCH_CUDA_ARCH_LIST := "3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
|
19 |
+
BUILD_MESSAGE := "I will try to build the image with CUDA support"
|
20 |
+
else
|
21 |
+
TORCH_CUDA_ARCH_LIST :=
|
22 |
+
BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
|
23 |
+
endif
|
24 |
+
|
25 |
+
|
26 |
+
build-image:
|
27 |
+
@echo $(BUILD_MESSAGE)
|
28 |
+
docker build --build-arg USE_CUDA=$(USE_CUDA) \
|
29 |
+
--build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
|
30 |
+
-t gsa:v0 .
|
31 |
+
run:
|
32 |
+
ifeq (,$(wildcard ./sam_vit_h_4b8939.pth))
|
33 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
34 |
+
endif
|
35 |
+
ifeq (,$(wildcard ./groundingdino_swint_ogc.pth))
|
36 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
37 |
+
endif
|
38 |
+
docker run --gpus all -it --rm --net=host --privileged \
|
39 |
+
-v /tmp/.X11-unix:/tmp/.X11-unix \
|
40 |
+
-v "${PWD}":/home/appuser/Grounded-Segment-Anything \
|
41 |
+
-e DISPLAY=$DISPLAY \
|
42 |
+
--name=gsa \
|
43 |
+
--ipc=host -it gsa:v0
|
clone-IDEA-Research/Grounded-Segment-Anything/README.md
ADDED
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
# Grounded-Segment-Anything
|
4 |
+
[](https://youtu.be/oEQYStnF2l8) [](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/automated-dataset-annotation-and-evaluation-with-grounding-dino-and-sam.ipynb) [](https://github.com/camenduru/grounded-segment-anything-colab) [](https://huggingface.co/spaces/IDEA-Research/Grounded-SAM) [](https://replicate.com/cjwbw/grounded-recognize-anything) [](https://modelscope.cn/studios/tuofeilunhifi/Grounded-Segment-Anything/summary) [](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) [](https://github.com/continue-revolution/sd-webui-segment-anything) [](./grounded_sam.ipynb) [](https://arxiv.org/abs/2303.05499) [](https://arxiv.org/abs/2304.02643) [](https://arxiv.org/abs/2401.14159)
|
5 |
+
|
6 |
+
|
7 |
+
We plan to create a very interesting demo by combining [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment Anything](https://github.com/facebookresearch/segment-anything) which aims to detect and segment anything with text inputs! And we will continue to improve it and create more interesting demos based on this foundation. And we have already released an overall technical report about our project on arXiv, please check [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159) for more details.
|
8 |
+
|
9 |
+
- 🔥 **[Grounded SAM 2](https://github.com/IDEA-Research/Grounded-SAM-2)** is released now, which combines Grounding DINO with [SAM 2](https://github.com/facebookresearch/segment-anything-2) for any object tracking in open-world scenarios.
|
10 |
+
- 🔥 **[Grounding DINO 1.5](https://github.com/IDEA-Research/Grounding-DINO-1.5-API)** is released now, which is IDEA Research's **Most Capable** Open-World Object Detection Model!
|
11 |
+
- 🔥 **[Grounding DINO](https://arxiv.org/abs/2303.05499)** and **[Grounded SAM](https://arxiv.org/abs/2401.14159)** are now supported in Huggingface. For more convenient use, you can refer to [this documentation](https://huggingface.co/docs/transformers/model_doc/grounding-dino)
|
12 |
+
|
13 |
+
We are very willing to **help everyone share and promote new projects** based on Segment-Anything, Please check out here for more amazing demos and works in the community: [Highlight Extension Projects](#highlighted-projects). You can submit a new issue (with `project` tag) or a new pull request to add new project's links.
|
14 |
+
|
15 |
+

|
16 |
+
|
17 |
+

|
18 |
+
|
19 |
+
**🍄 Why Building this Project?**
|
20 |
+
|
21 |
+
The **core idea** behind this project is to **combine the strengths of different models in order to build a very powerful pipeline for solving complex problems**. And it's worth mentioning that this is a workflow for combining strong expert models, where **all parts can be used separately or in combination, and can be replaced with any similar but different models (like replacing Grounding DINO with GLIP or other detectors / replacing Stable-Diffusion with ControlNet or GLIGEN/ Combining with ChatGPT)**.
|
22 |
+
|
23 |
+
**🍇 Updates**
|
24 |
+
- **`2024/01/26`** We have released a comprehensive technical report about our project on arXiv, please check [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159) for more details. And we are profoundly grateful for the contributions of all the contributors in this project.
|
25 |
+
- **`2023/12/17`** Support [Grounded-RepViT-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-repvit-sam-demo) demo, thanks a lot for their great work!
|
26 |
+
- **`2023/12/16`** Support [Grounded-Edge-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-edge-sam-demo) demo, thanks a lot for their great work!
|
27 |
+
- **`2023/12/10`** Support [Grounded-Efficient-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-efficient-sam-demo) demo, thanks a lot for their great work!
|
28 |
+
- **`2023/11/24`** Release [RAM++](https://arxiv.org/abs/2310.15200), which is the next generation of RAM. RAM++ can recognize any category with high accuracy, including both predefined common categories and diverse open-set categories.
|
29 |
+
- **`2023/11/23`** Release our newly proposed visual prompt counting model [T-Rex](https://github.com/IDEA-Research/T-Rex). The introduction [Video](https://www.youtube.com/watch?v=engIEhZogAQ) and [Demo](https://deepdataspace.com/playground/ivp) is available in [DDS](https://github.com/IDEA-Research/deepdataspace) now.
|
30 |
+
- **`2023/07/25`** Support [Light-HQ-SAM](https://github.com/SysCV/sam-hq) in [EfficientSAM](./EfficientSAM/), credits to [Mingqiao Ye](https://github.com/ymq2017) and [Lei Ke](https://github.com/lkeab), thanks a lot for their great work!
|
31 |
+
- **`2023/07/14`** Combining **Grounding-DINO-B** with [SAM-HQ](https://github.com/SysCV/sam-hq) achieves **49.6 mean AP** in [Segmentation in the Wild](https://eval.ai/web/challenges/challenge-page/1931/overview) competition zero-shot track, surpassing Grounded-SAM by **3.6 mean AP**, thanks for their great work!
|
32 |
+
- **`2023/06/28`** Combining Grounding-DINO with Efficient SAM variants including [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) and [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) in [EfficientSAM](./EfficientSAM/) for faster annotating, thanks a lot for their great work!
|
33 |
+
- **`2023/06/20`** By combining **Grounding-DINO-L** with **SAM-ViT-H**, Grounded-SAM achieves 46.0 mean AP in [Segmentation in the Wild](https://eval.ai/web/challenges/challenge-page/1931/overview) competition zero-shot track on [CVPR 2023 workshop](https://computer-vision-in-the-wild.github.io/cvpr-2023/), surpassing [UNINEXT (CVPR 2023)](https://github.com/MasterBin-IIAU/UNINEXT) by about **4 mean AP**.
|
34 |
+
- **`2023/06/16`** Release [RAM-Grounded-SAM Replicate Online Demo](https://replicate.com/cjwbw/ram-grounded-sam). Thanks a lot to [Chenxi](https://chenxwh.github.io/) for providing this nice demo 🌹.
|
35 |
+
- **`2023/06/14`** Support [RAM-Grounded-SAM & SAM-HQ](./automatic_label_ram_demo.py) and update [Simple Automatic Label Demo](./automatic_label_ram_demo.py) to support [RAM](https://github.com/OPPOMKLab/recognize-anything), setting up a strong automatic annotation pipeline.
|
36 |
+
- **`2023/06/13`** Checkout the [Autodistill: Train YOLOv8 with ZERO Annotations](https://youtu.be/gKTYMfwPo4M) tutorial to learn how to use Grounded-SAM + [Autodistill](https://github.com/autodistill/autodistill) for automated data labeling and real-time model training.
|
37 |
+
- **`2023/06/13`** Support [SAM-HQ](https://github.com/SysCV/sam-hq) in [Grounded-SAM Demo](#running_man-grounded-sam-detect-and-segment-everything-with-text-prompt) for higher quality prediction.
|
38 |
+
- **`2023/06/12`** Support [RAM-Grounded-SAM](#label-grounded-sam-with-ram-or-tag2text-for-automatic-labeling) for strong automatic labeling pipeline! Thanks for [Recognize-Anything](https://github.com/OPPOMKLab/recognize-anything).
|
39 |
+
- **`2023/06/01`** Our Grounded-SAM has been accepted to present a **demo** at [ICCV 2023](https://iccv2023.thecvf.com/)! See you in Paris!
|
40 |
+
- **`2023/05/23`**: Support `Image-Referring-Segment`, `Audio-Referring-Segment` and `Text-Referring-Segment` in [ImageBind-SAM](./playground/ImageBind_SAM/).
|
41 |
+
- **`2023/05/03`**: Checkout the [Automated Dataset Annotation and Evaluation with GroundingDINO and SAM](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/automated-dataset-annotation-and-evaluation-with-grounding-dino-and-sam.ipynb) which is an amazing tutorial on automatic labeling! Thanks a lot for [Piotr Skalski](https://github.com/SkalskiP) and [Roboflow](https://github.com/roboflow/notebooks)!
|
42 |
+
|
43 |
+
|
44 |
+
## Table of Contents
|
45 |
+
- [Grounded-Segment-Anything](#grounded-segment-anything)
|
46 |
+
- [Preliminary Works](#preliminary-works)
|
47 |
+
- [Highlighted Projects](#highlighted-projects)
|
48 |
+
- [Installation](#installation)
|
49 |
+
- [Install with Docker](#install-with-docker)
|
50 |
+
- [Install locally](#install-without-docker)
|
51 |
+
- [Grounded-SAM Playground](#grounded-sam-playground)
|
52 |
+
- [Step-by-Step Notebook Demo](#open_book-step-by-step-notebook-demo)
|
53 |
+
- [GroundingDINO: Detect Everything with Text Prompt](#running_man-groundingdino-detect-everything-with-text-prompt)
|
54 |
+
- [Grounded-SAM: Detect and Segment Everything with Text Prompt](#running_man-grounded-sam-detect-and-segment-everything-with-text-prompt)
|
55 |
+
- [Grounded-SAM with Inpainting: Detect, Segment and Generate Everything with Text Prompt](#skier-grounded-sam-with-inpainting-detect-segment-and-generate-everything-with-text-prompt)
|
56 |
+
- [Grounded-SAM and Inpaint Gradio APP](#golfing-grounded-sam-and-inpaint-gradio-app)
|
57 |
+
- [Grounded-SAM with RAM or Tag2Text for Automatic Labeling](#label-grounded-sam-with-ram-or-tag2text-for-automatic-labeling)
|
58 |
+
- [Grounded-SAM with BLIP & ChatGPT for Automatic Labeling](#robot-grounded-sam-with-blip-for-automatic-labeling)
|
59 |
+
- [Grounded-SAM with Whisper: Detect and Segment Anything with Audio](#open_mouth-grounded-sam-with-whisper-detect-and-segment-anything-with-audio)
|
60 |
+
- [Grounded-SAM ChatBot with Visual ChatGPT](#speech_balloon-grounded-sam-chatbot-demo)
|
61 |
+
- [Grounded-SAM with OSX for 3D Whole-Body Mesh Recovery](#man_dancing-run-grounded-segment-anything--osx-demo)
|
62 |
+
- [Grounded-SAM with VISAM for Tracking and Segment Anything](#man_dancing-run-grounded-segment-anything--visam-demo)
|
63 |
+
- [Interactive Fashion-Edit Playground: Click for Segmentation And Editing](#dancers-interactive-editing)
|
64 |
+
- [Interactive Human-face Editing Playground: Click And Editing Human Face](#dancers-interactive-editing)
|
65 |
+
- [3D Box Via Segment Anything](#camera-3d-box-via-segment-anything)
|
66 |
+
- [Playground: More Interesting and Imaginative Demos with Grounded-SAM](./playground/)
|
67 |
+
- [DeepFloyd: Image Generation with Text Prompt](./playground/DeepFloyd/)
|
68 |
+
- [PaintByExample: Exemplar-based Image Editing with Diffusion Models](./playground/PaintByExample/)
|
69 |
+
- [LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions](./playground/LaMa/)
|
70 |
+
- [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](./playground/RePaint/)
|
71 |
+
- [ImageBind with SAM: Segment with Different Modalities](./playground/ImageBind_SAM/)
|
72 |
+
- [Efficient SAM Series for Faster Annotation](./EfficientSAM/)
|
73 |
+
- [Grounded-FastSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-fastsam-demo)
|
74 |
+
- [Grounded-MobileSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-mobilesam-demo)
|
75 |
+
- [Grounded-Light-HQSAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-light-hqsam-demo)
|
76 |
+
- [Grounded-Efficient-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-efficient-sam-demo)
|
77 |
+
- [Grounded-Edge-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-edge-sam-demo)
|
78 |
+
- [Grounded-RepViT-SAM Demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM#run-grounded-repvit-sam-demo)
|
79 |
+
- [Citation](#citation)
|
80 |
+
|
81 |
+
## Preliminary Works
|
82 |
+
|
83 |
+
Here we provide some background knowledge that you may need to know before trying the demos.
|
84 |
+
|
85 |
+
<div align="center">
|
86 |
+
|
87 |
+
| Title | Intro | Description | Links |
|
88 |
+
|:----:|:----:|:----:|:----:|
|
89 |
+
| [Segment-Anything](https://arxiv.org/abs/2304.02643) |  | A strong foundation model aims to segment everything in an image, which needs prompts (as boxes/points/text) to generate masks | [[Github](https://github.com/facebookresearch/segment-anything)] <br> [[Page](https://segment-anything.com/)] <br> [[Demo](https://segment-anything.com/demo)] |
|
90 |
+
| [Grounding DINO](https://arxiv.org/abs/2303.05499) |  | A strong zero-shot detector which is capable of to generate high quality boxes and labels with free-form text. | [[Github](https://github.com/IDEA-Research/GroundingDINO)] <br> [[Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] |
|
91 |
+
| [OSX](http://arxiv.org/abs/2303.16160) |  | A strong and efficient one-stage motion capture method to generate high quality 3D human mesh from monucular image. OSX also releases a large-scale upper-body dataset UBody for a more accurate reconstrution in the upper-body scene. | [[Github](https://github.com/IDEA-Research/OSX)] <br> [[Page](https://osx-ubody.github.io/)] <br> [[Video](https://osx-ubody.github.io/)] <br> [[Data](https://docs.google.com/forms/d/e/1FAIpQLSehgBP7wdn_XznGAM2AiJPiPLTqXXHw5uX9l7qeQ1Dh9HoO_A/viewform)] |
|
92 |
+
| [Stable-Diffusion](https://arxiv.org/abs/2112.10752) |  | A super powerful open-source latent text-to-image diffusion model | [[Github](https://github.com/CompVis/stable-diffusion)] <br> [[Page](https://ommer-lab.com/research/latent-diffusion-models/)] |
|
93 |
+
| [RAM++](https://arxiv.org/abs/2310.15200) |  | RAM++ is the next generation of RAM, which can recognize any category with high accuracy. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] |
|
94 |
+
| [RAM](https://recognize-anything.github.io/) |  | RAM is an image tagging model, which can recognize any common category with high accuracy. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] <br> [[Demo](https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text)] |
|
95 |
+
| [BLIP](https://arxiv.org/abs/2201.12086) |  | A wonderful language-vision model for image understanding. | [[GitHub](https://github.com/salesforce/LAVIS)] |
|
96 |
+
| [Visual ChatGPT](https://arxiv.org/abs/2303.04671) |  | A wonderful tool that connects ChatGPT and a series of Visual Foundation Models to enable sending and receiving images during chatting. | [[Github](https://github.com/microsoft/TaskMatrix)] <br> [[Demo](https://huggingface.co/spaces/microsoft/visual_chatgpt)] |
|
97 |
+
| [Tag2Text](https://tag2text.github.io/) |  | An efficient and controllable vision-language model which can simultaneously output superior image captioning and image tagging. | [[Github](https://github.com/OPPOMKLab/recognize-anything)] <br> [[Demo](https://huggingface.co/spaces/xinyu1205/Tag2Text)] |
|
98 |
+
| [VoxelNeXt](https://arxiv.org/abs/2303.11301) |  | A clean, simple, and fully-sparse 3D object detector, which predicts objects directly upon sparse voxel features. | [[Github](https://github.com/dvlab-research/VoxelNeXt)]
|
99 |
+
|
100 |
+
</div>
|
101 |
+
|
102 |
+
## Highlighted Projects
|
103 |
+
|
104 |
+
Here we provide some impressive works you may find interesting:
|
105 |
+
|
106 |
+
<div align="center">
|
107 |
+
|
108 |
+
| Title | Description | Links |
|
109 |
+
|:---:|:---:|:---:|
|
110 |
+
| [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM) | A universal image segmentation model to enable segment and recognize anything at any desired granularity | [[Github](https://github.com/UX-Decoder/Semantic-SAM)] <br> [[Demo](https://github.com/UX-Decoder/Semantic-SAM)] |
|
111 |
+
| [SEEM: Segment Everything Everywhere All at Once](https://arxiv.org/pdf/2304.06718.pdf) | A powerful promptable segmentation model supports segmenting with various types of prompts (text, point, scribble, referring image, etc.) and any combination of prompts. | [[Github](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)] <br> [[Demo](https://huggingface.co/spaces/xdecoder/SEEM)] |
|
112 |
+
| [OpenSeeD](https://arxiv.org/pdf/2303.08131.pdf) | A simple framework for open-vocabulary segmentation and detection which supports interactive segmentation with box input to generate mask | [[Github](https://github.com/IDEA-Research/OpenSeeD)] |
|
113 |
+
| [LLaVA](https://arxiv.org/abs/2304.08485) | Visual instruction tuning with GPT-4 | [[Github](https://github.com/haotian-liu/LLaVA)] <br> [[Page](https://llava-vl.github.io/)] <br> [[Demo](https://llava.hliu.cc/)] <br> [[Data](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)] <br> [[Model](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)] |
|
114 |
+
| [GenSAM](https://arxiv.org/abs/2312.07374) | Relaxing the instance-specific manual prompt requirement in SAM through training-free test-time adaptation | [[Github](https://github.com/jyLin8100/GenSAM)] <br> [[Page](https://lwpyh.github.io/GenSAM/)] |
|
115 |
+
|
116 |
+
</div>
|
117 |
+
|
118 |
+
We also list some awesome segment-anything extension projects here you may find interesting:
|
119 |
+
- [Computer Vision in the Wild (CVinW) Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set tasks in computer vision.
|
120 |
+
- [Zero-Shot Anomaly Detection](https://github.com/caoyunkang/GroundedSAM-zero-shot-anomaly-detection) by Yunkang Cao
|
121 |
+
- [EditAnything: ControlNet + StableDiffusion based on the SAM segmentation mask](https://github.com/sail-sg/EditAnything) by Shanghua Gao and Pan Zhou
|
122 |
+
- [IEA: Image Editing Anything](https://github.com/feizc/IEA) by Zhengcong Fei
|
123 |
+
- [SAM-MMRorate: Combining Rotated Object Detector and SAM](https://github.com/Li-Qingyun/sam-mmrotate) by Qingyun Li and Xue Yang
|
124 |
+
- [Awesome-Anything](https://github.com/VainF/Awesome-Anything) by Gongfan Fang
|
125 |
+
- [Prompt-Segment-Anything](https://github.com/RockeyCoss/Prompt-Segment-Anything) by Rockey
|
126 |
+
- [WebUI for Segment-Anything and Grounded-SAM](https://github.com/continue-revolution/sd-webui-segment-anything) by Chengsong Zhang
|
127 |
+
- [Inpainting Anything: Inpaint Anything with SAM + Inpainting models](https://github.com/geekyutao/Inpaint-Anything) by Tao Yu
|
128 |
+
- [Grounded Segment Anything From Objects to Parts: Combining Segment-Anything with VLPart & GLIP & Visual ChatGPT](https://github.com/Cheems-Seminar/segment-anything-and-name-it) by Peize Sun and Shoufa Chen
|
129 |
+
- [Narapi-SAM: Integration of Segment Anything into Narapi (A nice viewer for SAM)](https://github.com/MIC-DKFZ/napari-sam) by MIC-DKFZ
|
130 |
+
- [Grounded Segment Anything Colab](https://github.com/camenduru/grounded-segment-anything-colab) by camenduru
|
131 |
+
- [Optical Character Recognition with Segment Anything](https://github.com/yeungchenwa/OCR-SAM) by Zhenhua Yang
|
132 |
+
- [Transform Image into Unique Paragraph with ChatGPT, BLIP2, OFA, GRIT, Segment Anything, ControlNet](https://github.com/showlab/Image2Paragraph) by showlab
|
133 |
+
- [Lang-Segment-Anything: Another awesome demo for combining GroundingDINO with Segment-Anything](https://github.com/luca-medeiros/lang-segment-anything) by Luca Medeiros
|
134 |
+
- [🥳 🚀 **Playground: Integrate SAM and OpenMMLab!**](https://github.com/open-mmlab/playground)
|
135 |
+
- [3D-object via Segment Anything](https://github.com/dvlab-research/3D-Box-Segment-Anything) by Yukang Chen
|
136 |
+
- [Image2Paragraph: Transform Image Into Unique Paragraph](https://github.com/showlab/Image2Paragraph) by Show Lab
|
137 |
+
- [Zero-shot Scene Graph Generate with Grounded-SAM](https://github.com/showlab/Image2Paragraph) by JackWhite-rwx
|
138 |
+
- [CLIP Surgery for Better Explainability with Enhancement in Open-Vocabulary Tasks](https://github.com/xmed-lab/CLIP_Surgery) by Eli-YiLi
|
139 |
+
- [Panoptic-Segment-Anything: Zero-shot panoptic segmentation using SAM](https://github.com/segments-ai/panoptic-segment-anything) by segments-ai
|
140 |
+
- [Caption-Anything: Generates Descriptive Captions for Any Object within an Image](https://github.com/ttengwang/Caption-Anything) by Teng Wang
|
141 |
+
- [Segment-Anything-3D: Transferring Segmentation Information of 2D Images to 3D Space](https://github.com/Pointcept/SegmentAnything3D) by Yunhan Yang
|
142 |
+
- [Expediting SAM without Fine-tuning](https://github.com/Expedit-LargeScale-Vision-Transformer/Expedit-SAM) by Weicong Liang and Yuhui Yuan
|
143 |
+
- [Semantic Segment Anything: Providing Rich Semantic Category Annotations for SAM](https://github.com/fudan-zvg/Semantic-Segment-Anything) by Jiaqi Chen and Zeyu Yang and Li Zhang
|
144 |
+
- [Enhance Everything: Combining SAM with Image Restoration and Enhancement Tasks](https://github.com/lixinustc/Enhance-Anything) by Xin Li
|
145 |
+
- [DragGAN](https://github.com/Zeqiang-Lai/DragGAN) by Shanghai AI Lab.
|
146 |
+
- [Tabletop HandyBot: Robotic arm assistant that performs tabletop tasks using Grounded-SAM](https://github.com/ycheng517/tabletop-handybot) by Yifei Cheng
|
147 |
+
|
148 |
+
## Installation
|
149 |
+
The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
|
150 |
+
|
151 |
+
### Install with Docker
|
152 |
+
|
153 |
+
Open one terminal:
|
154 |
+
|
155 |
+
```
|
156 |
+
make build-image
|
157 |
+
```
|
158 |
+
|
159 |
+
```
|
160 |
+
make run
|
161 |
+
```
|
162 |
+
|
163 |
+
That's it.
|
164 |
+
|
165 |
+
If you would like to allow visualization across docker container, open another terminal and type:
|
166 |
+
|
167 |
+
```
|
168 |
+
xhost +
|
169 |
+
```
|
170 |
+
|
171 |
+
|
172 |
+
### Install without Docker
|
173 |
+
You should set the environment variable manually as follows if you want to build a local GPU environment for Grounded-SAM:
|
174 |
+
```bash
|
175 |
+
export AM_I_DOCKER=False
|
176 |
+
export BUILD_WITH_CUDA=True
|
177 |
+
export CUDA_HOME=/path/to/cuda-11.3/
|
178 |
+
```
|
179 |
+
|
180 |
+
Install Segment Anything:
|
181 |
+
|
182 |
+
```bash
|
183 |
+
python -m pip install -e segment_anything
|
184 |
+
```
|
185 |
+
|
186 |
+
Install Grounding DINO:
|
187 |
+
|
188 |
+
```bash
|
189 |
+
pip install --no-build-isolation -e GroundingDINO
|
190 |
+
```
|
191 |
+
|
192 |
+
|
193 |
+
Install diffusers:
|
194 |
+
|
195 |
+
```bash
|
196 |
+
pip install --upgrade diffusers[torch]
|
197 |
+
```
|
198 |
+
|
199 |
+
Install osx:
|
200 |
+
|
201 |
+
```bash
|
202 |
+
git submodule update --init --recursive
|
203 |
+
cd grounded-sam-osx && bash install.sh
|
204 |
+
```
|
205 |
+
|
206 |
+
Install RAM & Tag2Text:
|
207 |
+
|
208 |
+
```bash
|
209 |
+
git clone https://github.com/xinyu1205/recognize-anything.git
|
210 |
+
pip install -r ./recognize-anything/requirements.txt
|
211 |
+
pip install -e ./recognize-anything/
|
212 |
+
```
|
213 |
+
|
214 |
+
The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks.
|
215 |
+
|
216 |
+
```
|
217 |
+
pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel
|
218 |
+
```
|
219 |
+
|
220 |
+
More details can be found in [install segment anything](https://github.com/facebookresearch/segment-anything#installation) and [install GroundingDINO](https://github.com/IDEA-Research/GroundingDINO#install) and [install OSX](https://github.com/IDEA-Research/OSX)
|
221 |
+
|
222 |
+
|
223 |
+
## Grounded-SAM Playground
|
224 |
+
Let's start exploring our Grounding-SAM Playground and we will release more interesting demos in the future, stay tuned!
|
225 |
+
|
226 |
+
## :open_book: Step-by-Step Notebook Demo
|
227 |
+
Here we list some notebook demo provided in this project:
|
228 |
+
- [grounded_sam.ipynb](grounded_sam.ipynb)
|
229 |
+
- [grounded_sam_colab_demo.ipynb](grounded_sam_colab_demo.ipynb)
|
230 |
+
- [grounded_sam_3d_box.ipynb](grounded_sam_3d_box)
|
231 |
+
|
232 |
+
|
233 |
+
### :running_man: GroundingDINO: Detect Everything with Text Prompt
|
234 |
+
|
235 |
+
:grapes: [[arXiv Paper](https://arxiv.org/abs/2303.05499)] :rose:[[Try the Colab Demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)] :sunflower: [[Try Huggingface Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] :mushroom: [[Automated Dataset Annotation and Evaluation](https://youtu.be/C4NqaRBz_Kw)]
|
236 |
+
|
237 |
+
Here's the step-by-step tutorial on running `GroundingDINO` demo:
|
238 |
+
|
239 |
+
**Step 1: Download the pretrained weights**
|
240 |
+
|
241 |
+
```bash
|
242 |
+
cd Grounded-Segment-Anything
|
243 |
+
|
244 |
+
# download the pretrained groundingdino-swin-tiny model
|
245 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
246 |
+
```
|
247 |
+
|
248 |
+
**Step 2: Running the demo**
|
249 |
+
|
250 |
+
```bash
|
251 |
+
python grounding_dino_demo.py
|
252 |
+
```
|
253 |
+
|
254 |
+
<details>
|
255 |
+
<summary> <b> Running with Python (same as demo but you can run it anywhere after installing GroundingDINO) </b> </summary>
|
256 |
+
|
257 |
+
```python
|
258 |
+
from groundingdino.util.inference import load_model, load_image, predict, annotate
|
259 |
+
import cv2
|
260 |
+
|
261 |
+
model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./groundingdino_swint_ogc.pth")
|
262 |
+
IMAGE_PATH = "assets/demo1.jpg"
|
263 |
+
TEXT_PROMPT = "bear."
|
264 |
+
BOX_THRESHOLD = 0.35
|
265 |
+
TEXT_THRESHOLD = 0.25
|
266 |
+
|
267 |
+
image_source, image = load_image(IMAGE_PATH)
|
268 |
+
|
269 |
+
boxes, logits, phrases = predict(
|
270 |
+
model=model,
|
271 |
+
image=image,
|
272 |
+
caption=TEXT_PROMPT,
|
273 |
+
box_threshold=BOX_THRESHOLD,
|
274 |
+
text_threshold=TEXT_THRESHOLD
|
275 |
+
)
|
276 |
+
|
277 |
+
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
|
278 |
+
cv2.imwrite("annotated_image.jpg", annotated_frame)
|
279 |
+
```
|
280 |
+
|
281 |
+
</details>
|
282 |
+
<br>
|
283 |
+
|
284 |
+
**Tips**
|
285 |
+
- If you want to detect multiple objects in one sentence with [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO), we suggest separating each name with `.` . An example: `cat . dog . chair .`
|
286 |
+
|
287 |
+
**Step 3: Check the annotated image**
|
288 |
+
|
289 |
+
The annotated image will be saved as `./annotated_image.jpg`.
|
290 |
+
|
291 |
+
<div align="center">
|
292 |
+
|
293 |
+
| Text Prompt | Demo Image | Annotated Image |
|
294 |
+
|:----:|:----:|:----:|
|
295 |
+
| `Bear.` |  |  |
|
296 |
+
| `Horse. Clouds. Grasses. Sky. Hill` |  | 
|
297 |
+
|
298 |
+
</div>
|
299 |
+
|
300 |
+
|
301 |
+
### :running_man: Grounded-SAM: Detect and Segment Everything with Text Prompt
|
302 |
+
|
303 |
+
Here's the step-by-step tutorial on running `Grounded-SAM` demo:
|
304 |
+
|
305 |
+
**Step 1: Download the pretrained weights**
|
306 |
+
|
307 |
+
```bash
|
308 |
+
cd Grounded-Segment-Anything
|
309 |
+
|
310 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
311 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
312 |
+
```
|
313 |
+
|
314 |
+
We provide two versions of Grounded-SAM demo here:
|
315 |
+
- [grounded_sam_demo.py](./grounded_sam_demo.py): our original implementation for Grounded-SAM.
|
316 |
+
- [grounded_sam_simple_demo.py](./grounded_sam_simple_demo.py) our updated more elegant version for Grounded-SAM.
|
317 |
+
|
318 |
+
**Step 2: Running original grounded-sam demo**
|
319 |
+
```bash
|
320 |
+
# depends on your device
|
321 |
+
export CUDA_VISIBLE_DEVICES=0
|
322 |
+
```
|
323 |
+
|
324 |
+
```python
|
325 |
+
|
326 |
+
python grounded_sam_demo.py \
|
327 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
328 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
329 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
330 |
+
--input_image assets/demo1.jpg \
|
331 |
+
--output_dir "outputs" \
|
332 |
+
--box_threshold 0.3 \
|
333 |
+
--text_threshold 0.25 \
|
334 |
+
--text_prompt "bear" \
|
335 |
+
--device "cuda"
|
336 |
+
```
|
337 |
+
|
338 |
+
The annotated results will be saved in `./outputs` as follows
|
339 |
+
|
340 |
+
<div align="center">
|
341 |
+
|
342 |
+
| Input Image | Annotated Image | Generated Mask |
|
343 |
+
|:----:|:----:|:----:|
|
344 |
+
|  |  |  |
|
345 |
+
|
346 |
+
</div>
|
347 |
+
|
348 |
+
**Step 3: Running grounded-sam demo with sam-hq**
|
349 |
+
- Download the demo image
|
350 |
+
```bash
|
351 |
+
wget https://github.com/IDEA-Research/detrex-storage/releases/download/grounded-sam-storage/sam_hq_demo_image.png
|
352 |
+
```
|
353 |
+
|
354 |
+
- Download SAM-HQ checkpoint [here](https://github.com/SysCV/sam-hq#model-checkpoints)
|
355 |
+
|
356 |
+
- Running grounded-sam-hq demo as follows:
|
357 |
+
```python
|
358 |
+
export CUDA_VISIBLE_DEVICES=0
|
359 |
+
python grounded_sam_demo.py \
|
360 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
361 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
362 |
+
--sam_hq_checkpoint ./sam_hq_vit_h.pth \ # path to sam-hq checkpoint
|
363 |
+
--use_sam_hq \ # set to use sam-hq model
|
364 |
+
--input_image sam_hq_demo_image.png \
|
365 |
+
--output_dir "outputs" \
|
366 |
+
--box_threshold 0.3 \
|
367 |
+
--text_threshold 0.25 \
|
368 |
+
--text_prompt "chair." \
|
369 |
+
--device "cuda"
|
370 |
+
```
|
371 |
+
|
372 |
+
The annotated results will be saved in `./outputs` as follows
|
373 |
+
|
374 |
+
<div align="center">
|
375 |
+
|
376 |
+
| Input Image | SAM Output | SAM-HQ Output |
|
377 |
+
|:----:|:----:|:----:|
|
378 |
+
|  |  |  |
|
379 |
+
|
380 |
+
</div>
|
381 |
+
|
382 |
+
**Step 4: Running the updated grounded-sam demo (optional)**
|
383 |
+
|
384 |
+
Note that this demo is almost same as the original demo, but **with more elegant code**.
|
385 |
+
|
386 |
+
```python
|
387 |
+
python grounded_sam_simple_demo.py
|
388 |
+
```
|
389 |
+
|
390 |
+
The annotated results will be saved as `./groundingdino_annotated_image.jpg` and `./grounded_sam_annotated_image.jpg`
|
391 |
+
|
392 |
+
<div align="center">
|
393 |
+
|
394 |
+
| Text Prompt | Input Image | GroundingDINO Annotated Image | Grounded-SAM Annotated Image |
|
395 |
+
|:----:|:----:|:----:|:----:|
|
396 |
+
| `The running dog` |  |  |  |
|
397 |
+
| `Horse. Clouds. Grasses. Sky. Hill` |  |  |  |
|
398 |
+
|
399 |
+
</div>
|
400 |
+
|
401 |
+
**Step 5: Running the Sam model with multi-gpu**
|
402 |
+
```bash
|
403 |
+
export CUDA_VISIBLE_DEVICES=0,1
|
404 |
+
```
|
405 |
+
```python
|
406 |
+
|
407 |
+
python grounded_sam_multi_gpu_demo.py \
|
408 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
409 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
410 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
411 |
+
--input_path assets/car \
|
412 |
+
--output_dir "outputs" \
|
413 |
+
--box_threshold 0.3 \
|
414 |
+
--text_threshold 0.25 \
|
415 |
+
--text_prompt "car" \
|
416 |
+
--device "cuda"
|
417 |
+
```
|
418 |
+
You will see that the model is loaded once per GPU 
|
419 |
+
|
420 |
+
### :skier: Grounded-SAM with Inpainting: Detect, Segment and Generate Everything with Text Prompt
|
421 |
+
|
422 |
+
**Step 1: Download the pretrained weights**
|
423 |
+
|
424 |
+
```bash
|
425 |
+
cd Grounded-Segment-Anything
|
426 |
+
|
427 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
428 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
429 |
+
```
|
430 |
+
|
431 |
+
**Step 2: Running grounded-sam inpainting demo**
|
432 |
+
|
433 |
+
```bash
|
434 |
+
CUDA_VISIBLE_DEVICES=0
|
435 |
+
python grounded_sam_inpainting_demo.py \
|
436 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
437 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
438 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
439 |
+
--input_image assets/inpaint_demo.jpg \
|
440 |
+
--output_dir "outputs" \
|
441 |
+
--box_threshold 0.3 \
|
442 |
+
--text_threshold 0.25 \
|
443 |
+
--det_prompt "bench" \
|
444 |
+
--inpaint_prompt "A sofa, high quality, detailed" \
|
445 |
+
--device "cuda"
|
446 |
+
```
|
447 |
+
|
448 |
+
The annotated and inpaint image will be saved in `./outputs`
|
449 |
+
|
450 |
+
**Step 3: Check the results**
|
451 |
+
|
452 |
+
|
453 |
+
<div align="center">
|
454 |
+
|
455 |
+
| Input Image | Det Prompt | Annotated Image | Inpaint Prompt | Inpaint Image |
|
456 |
+
|:---:|:---:|:---:|:---:|:---:|
|
457 |
+
| | `Bench` |  | `A sofa, high quality, detailed` |  |
|
458 |
+
|
459 |
+
</div>
|
460 |
+
|
461 |
+
### :golfing: Grounded-SAM and Inpaint Gradio APP
|
462 |
+
|
463 |
+
We support 6 tasks in the local Gradio APP:
|
464 |
+
|
465 |
+
1. **scribble**: Segmentation is achieved through Segment Anything and mouse click interaction (you need to click on the object with the mouse, no need to specify the prompt).
|
466 |
+
2. **automask**: Segment the entire image at once through Segment Anything (no need to specify a prompt).
|
467 |
+
3. **det**: Realize detection through Grounding DINO and text interaction (text prompt needs to be specified).
|
468 |
+
4. **seg**: Realize text interaction by combining Grounding DINO and Segment Anything to realize detection + segmentation (need to specify text prompt).
|
469 |
+
5. **inpainting**: By combining Grounding DINO + Segment Anything + Stable Diffusion to achieve text exchange and replace the target object (need to specify text prompt and inpaint prompt) .
|
470 |
+
6. **automatic**: By combining BLIP + Grounding DINO + Segment Anything to achieve non-interactive detection + segmentation (no need to specify prompt).
|
471 |
+
|
472 |
+
```bash
|
473 |
+
python gradio_app.py
|
474 |
+
```
|
475 |
+
|
476 |
+
- The gradio_app visualization as follows:
|
477 |
+
|
478 |
+

|
479 |
+
|
480 |
+
|
481 |
+
### :label: Grounded-SAM with RAM or Tag2Text for Automatic Labeling
|
482 |
+
[**The Recognize Anything Models**](https://github.com/OPPOMKLab/recognize-anything) are a series of open-source and strong fundamental image recognition models, including [RAM++](https://arxiv.org/abs/2310.15200), [RAM](https://arxiv.org/abs/2306.03514) and [Tag2text](https://arxiv.org/abs/2303.05657).
|
483 |
+
|
484 |
+
|
485 |
+
It is seamlessly linked to generate pseudo labels automatically as follows:
|
486 |
+
1. Use RAM/Tag2Text to generate tags.
|
487 |
+
2. Use Grounded-Segment-Anything to generate the boxes and masks.
|
488 |
+
|
489 |
+
|
490 |
+
**Step 1: Init submodule and download the pretrained checkpoint**
|
491 |
+
|
492 |
+
- Init submodule:
|
493 |
+
|
494 |
+
```bash
|
495 |
+
cd Grounded-Segment-Anything
|
496 |
+
git submodule init
|
497 |
+
git submodule update
|
498 |
+
```
|
499 |
+
|
500 |
+
- Download pretrained weights for `GroundingDINO`, `SAM` and `RAM/Tag2Text`:
|
501 |
+
|
502 |
+
```bash
|
503 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
504 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
505 |
+
|
506 |
+
|
507 |
+
wget https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/ram_swin_large_14m.pth
|
508 |
+
wget https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/tag2text_swin_14m.pth
|
509 |
+
```
|
510 |
+
|
511 |
+
**Step 2: Running the demo with RAM**
|
512 |
+
```bash
|
513 |
+
export CUDA_VISIBLE_DEVICES=0
|
514 |
+
python automatic_label_ram_demo.py \
|
515 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
516 |
+
--ram_checkpoint ram_swin_large_14m.pth \
|
517 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
518 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
519 |
+
--input_image assets/demo9.jpg \
|
520 |
+
--output_dir "outputs" \
|
521 |
+
--box_threshold 0.25 \
|
522 |
+
--text_threshold 0.2 \
|
523 |
+
--iou_threshold 0.5 \
|
524 |
+
--device "cuda"
|
525 |
+
```
|
526 |
+
|
527 |
+
|
528 |
+
**Step 2: Or Running the demo with Tag2Text**
|
529 |
+
```bash
|
530 |
+
export CUDA_VISIBLE_DEVICES=0
|
531 |
+
python automatic_label_tag2text_demo.py \
|
532 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
533 |
+
--tag2text_checkpoint tag2text_swin_14m.pth \
|
534 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
535 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
536 |
+
--input_image assets/demo9.jpg \
|
537 |
+
--output_dir "outputs" \
|
538 |
+
--box_threshold 0.25 \
|
539 |
+
--text_threshold 0.2 \
|
540 |
+
--iou_threshold 0.5 \
|
541 |
+
--device "cuda"
|
542 |
+
```
|
543 |
+
|
544 |
+
- RAM++ significantly improves the open-set capability of RAM, for [RAM++ inference on unseen categoreis](https://github.com/xinyu1205/recognize-anything#ram-inference-on-unseen-categories-open-set).
|
545 |
+
- Tag2Text also provides powerful captioning capabilities, and the process with captions can refer to [BLIP](#robot-run-grounded-segment-anything--blip-demo).
|
546 |
+
- The pseudo labels and model prediction visualization will be saved in `output_dir` as follows (right figure):
|
547 |
+
|
548 |
+

|
549 |
+
|
550 |
+
|
551 |
+
### :robot: Grounded-SAM with BLIP for Automatic Labeling
|
552 |
+
It is easy to generate pseudo labels automatically as follows:
|
553 |
+
1. Use BLIP (or other caption models) to generate a caption.
|
554 |
+
2. Extract tags from the caption. We use ChatGPT to handle the potential complicated sentences.
|
555 |
+
3. Use Grounded-Segment-Anything to generate the boxes and masks.
|
556 |
+
|
557 |
+
- Run Demo
|
558 |
+
```bash
|
559 |
+
export OPENAI_API_KEY=your_openai_key
|
560 |
+
export OPENAI_API_BASE=https://closeai.deno.dev/v1
|
561 |
+
export CUDA_VISIBLE_DEVICES=0
|
562 |
+
python automatic_label_demo.py \
|
563 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
564 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
565 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
566 |
+
--input_image assets/demo3.jpg \
|
567 |
+
--output_dir "outputs" \
|
568 |
+
--openai_key $OPENAI_API_KEY \
|
569 |
+
--box_threshold 0.25 \
|
570 |
+
--text_threshold 0.2 \
|
571 |
+
--iou_threshold 0.5 \
|
572 |
+
--device "cuda"
|
573 |
+
```
|
574 |
+
|
575 |
+
- When you don't have a paid Account for ChatGPT is also possible to use NLTK instead. Just don't include the ```openai_key``` Parameter when starting the Demo.
|
576 |
+
- The Script will automatically download the necessary NLTK Data.
|
577 |
+
- The pseudo labels and model prediction visualization will be saved in `output_dir` as follows:
|
578 |
+
|
579 |
+

|
580 |
+
|
581 |
+
|
582 |
+
### :open_mouth: Grounded-SAM with Whisper: Detect and Segment Anything with Audio
|
583 |
+
Detect and segment anything with speech!
|
584 |
+
|
585 |
+

|
586 |
+
|
587 |
+
**Install Whisper**
|
588 |
+
```bash
|
589 |
+
pip install -U openai-whisper
|
590 |
+
```
|
591 |
+
See the [whisper official page](https://github.com/openai/whisper#setup) if you have other questions for the installation.
|
592 |
+
|
593 |
+
**Run Voice-to-Label Demo**
|
594 |
+
|
595 |
+
Optional: Download the demo audio file
|
596 |
+
|
597 |
+
```bash
|
598 |
+
wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/demo_audio.mp3
|
599 |
+
```
|
600 |
+
|
601 |
+
|
602 |
+
```bash
|
603 |
+
export CUDA_VISIBLE_DEVICES=0
|
604 |
+
python grounded_sam_whisper_demo.py \
|
605 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
606 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
607 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
608 |
+
--input_image assets/demo4.jpg \
|
609 |
+
--output_dir "outputs" \
|
610 |
+
--box_threshold 0.3 \
|
611 |
+
--text_threshold 0.25 \
|
612 |
+
--speech_file "demo_audio.mp3" \
|
613 |
+
--device "cuda"
|
614 |
+
```
|
615 |
+
|
616 |
+

|
617 |
+
|
618 |
+
**Run Voice-to-inpaint Demo**
|
619 |
+
|
620 |
+
You can enable chatgpt to help you automatically detect the object and inpainting order with `--enable_chatgpt`.
|
621 |
+
|
622 |
+
Or you can specify the object you want to inpaint [stored in `args.det_speech_file`] and the text you want to inpaint with [stored in `args.inpaint_speech_file`].
|
623 |
+
|
624 |
+
```bash
|
625 |
+
export OPENAI_API_KEY=your_openai_key
|
626 |
+
export OPENAI_API_BASE=https://closeai.deno.dev/v1
|
627 |
+
# Example: enable chatgpt
|
628 |
+
export CUDA_VISIBLE_DEVICES=0
|
629 |
+
python grounded_sam_whisper_inpainting_demo.py \
|
630 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
631 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
632 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
633 |
+
--input_image assets/inpaint_demo.jpg \
|
634 |
+
--output_dir "outputs" \
|
635 |
+
--box_threshold 0.3 \
|
636 |
+
--text_threshold 0.25 \
|
637 |
+
--prompt_speech_file assets/acoustics/prompt_speech_file.mp3 \
|
638 |
+
--enable_chatgpt \
|
639 |
+
--openai_key $OPENAI_API_KEY\
|
640 |
+
--device "cuda"
|
641 |
+
```
|
642 |
+
|
643 |
+
```bash
|
644 |
+
# Example: without chatgpt
|
645 |
+
export CUDA_VISIBLE_DEVICES=0
|
646 |
+
python grounded_sam_whisper_inpainting_demo.py \
|
647 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
648 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
649 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
650 |
+
--input_image assets/inpaint_demo.jpg \
|
651 |
+
--output_dir "outputs" \
|
652 |
+
--box_threshold 0.3 \
|
653 |
+
--text_threshold 0.25 \
|
654 |
+
--det_speech_file "assets/acoustics/det_voice.mp3" \
|
655 |
+
--inpaint_speech_file "assets/acoustics/inpaint_voice.mp3" \
|
656 |
+
--device "cuda"
|
657 |
+
```
|
658 |
+
|
659 |
+

|
660 |
+
|
661 |
+
### :speech_balloon: Grounded-SAM ChatBot Demo
|
662 |
+
|
663 |
+
https://user-images.githubusercontent.com/24236723/231955561-2ae4ec1a-c75f-4cc5-9b7b-517aa1432123.mp4
|
664 |
+
|
665 |
+
Following [Visual ChatGPT](https://github.com/microsoft/visual-chatgpt), we add a ChatBot for our project. Currently, it supports:
|
666 |
+
1. "Describe the image."
|
667 |
+
2. "Detect the dog (and the cat) in the image."
|
668 |
+
3. "Segment anything in the image."
|
669 |
+
4. "Segment the dog (and the cat) in the image."
|
670 |
+
5. "Help me label the image."
|
671 |
+
6. "Replace the dog with a cat in the image."
|
672 |
+
|
673 |
+
To use the ChatBot:
|
674 |
+
- Install whisper if you want to use audio as input.
|
675 |
+
- Set the default model setting in the tool `Grounded_dino_sam_inpainting`.
|
676 |
+
- Run Demo
|
677 |
+
```bash
|
678 |
+
export OPENAI_API_KEY=your_openai_key
|
679 |
+
export OPENAI_API_BASE=https://closeai.deno.dev/v1
|
680 |
+
export CUDA_VISIBLE_DEVICES=0
|
681 |
+
python chatbot.py
|
682 |
+
```
|
683 |
+
|
684 |
+
### :man_dancing: Run Grounded-Segment-Anything + OSX Demo
|
685 |
+
|
686 |
+
<p align="middle">
|
687 |
+
<img src="assets/osx/grouned_sam_osx_demo.gif">
|
688 |
+
<br>
|
689 |
+
</p>
|
690 |
+
|
691 |
+
|
692 |
+
- Download the checkpoint `osx_l_wo_decoder.pth.tar` from [here](https://drive.google.com/drive/folders/1x7MZbB6eAlrq5PKC9MaeIm4GqkBpokow?usp=share_link) for OSX:
|
693 |
+
- Download the human model files and place it into `grounded-sam-osx/utils/human_model_files` following the instruction of [OSX](https://github.com/IDEA-Research/OSX).
|
694 |
+
|
695 |
+
- Run Demo
|
696 |
+
|
697 |
+
```shell
|
698 |
+
export CUDA_VISIBLE_DEVICES=0
|
699 |
+
python grounded_sam_osx_demo.py \
|
700 |
+
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
701 |
+
--grounded_checkpoint groundingdino_swint_ogc.pth \
|
702 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
703 |
+
--osx_checkpoint osx_l_wo_decoder.pth.tar \
|
704 |
+
--input_image assets/osx/grounded_sam_osx_demo.png \
|
705 |
+
--output_dir "outputs" \
|
706 |
+
--box_threshold 0.3 \
|
707 |
+
--text_threshold 0.25 \
|
708 |
+
--text_prompt "humans, chairs" \
|
709 |
+
--device "cuda"
|
710 |
+
```
|
711 |
+
|
712 |
+
- The model prediction visualization will be saved in `output_dir` as follows:
|
713 |
+
|
714 |
+
<img src="assets/osx/grounded_sam_osx_output.jpg" style="zoom: 49%;" />
|
715 |
+
|
716 |
+
- We also support promptable 3D whole-body mesh recovery. For example, you can track someone with a text prompt and estimate his 3D pose and shape :
|
717 |
+
|
718 |
+
|  |
|
719 |
+
| :---------------------------------------------------: |
|
720 |
+
| *A person with pink clothes* |
|
721 |
+
|
722 |
+
|  |
|
723 |
+
| :---------------------------------------------------: |
|
724 |
+
| *A man with a sunglasses* |
|
725 |
+
|
726 |
+
|
727 |
+
## :man_dancing: Run Grounded-Segment-Anything + VISAM Demo
|
728 |
+
|
729 |
+
- Download the checkpoint `motrv2_dancetrack.pth` from [here](https://drive.google.com/file/d/1EA4lndu2yQcVgBKR09KfMe5efbf631Th/view?usp=share_link) for MOTRv2:
|
730 |
+
- See the more thing if you have other questions for the installation.
|
731 |
+
|
732 |
+
- Run Demo
|
733 |
+
|
734 |
+
```shell
|
735 |
+
export CUDA_VISIBLE_DEVICES=0
|
736 |
+
python grounded_sam_visam.py \
|
737 |
+
--meta_arch motr \
|
738 |
+
--dataset_file e2e_dance \
|
739 |
+
--with_box_refine \
|
740 |
+
--query_interaction_layer QIMv2 \
|
741 |
+
--num_queries 10 \
|
742 |
+
--det_db det_db_motrv2.json \
|
743 |
+
--use_checkpoint \
|
744 |
+
--mot_path your_data_path \
|
745 |
+
--resume motrv2_dancetrack.pth \
|
746 |
+
--sam_checkpoint sam_vit_h_4b8939.pth \
|
747 |
+
--video_path DanceTrack/test/dancetrack0003
|
748 |
+
```
|
749 |
+
||
|
750 |
+
|
751 |
+
|
752 |
+
### :dancers: Interactive Editing
|
753 |
+
- Release the interactive fashion-edit playground in [here](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/humanFace). Run in the notebook, just click for annotating points for further segmentation. Enjoy it!
|
754 |
+
|
755 |
+
|
756 |
+
- Release human-face-edit branch [here](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/humanFace). We'll keep updating this branch with more interesting features. Here are some examples:
|
757 |
+
|
758 |
+

|
759 |
+
|
760 |
+
## :camera: 3D-Box via Segment Anything
|
761 |
+
We extend the scope to 3D world by combining Segment Anything and [VoxelNeXt](https://github.com/dvlab-research/VoxelNeXt). When we provide a prompt (e.g., a point / box), the result is not only 2D segmentation mask, but also 3D boxes. Please check [voxelnext_3d_box](./voxelnext_3d_box/) for more details.
|
762 |
+

|
763 |
+

|
764 |
+
|
765 |
+
|
766 |
+
|
767 |
+
|
768 |
+
## :cupid: Acknowledgements
|
769 |
+
|
770 |
+
- [Segment Anything](https://github.com/facebookresearch/segment-anything)
|
771 |
+
- [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)
|
772 |
+
|
773 |
+
|
774 |
+
## Contributors
|
775 |
+
|
776 |
+
Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better.
|
777 |
+
|
778 |
+
<a href="https://github.com/IDEA-Research/Grounded-Segment-Anything/graphs/contributors">
|
779 |
+
<img src="https://contrib.rocks/image?repo=IDEA-Research/Grounded-Segment-Anything" />
|
780 |
+
</a>
|
781 |
+
|
782 |
+
|
783 |
+
## Citation
|
784 |
+
If you find this project helpful for your research, please consider citing the following BibTeX entry.
|
785 |
+
```BibTex
|
786 |
+
@article{kirillov2023segany,
|
787 |
+
title={Segment Anything},
|
788 |
+
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
|
789 |
+
journal={arXiv:2304.02643},
|
790 |
+
year={2023}
|
791 |
+
}
|
792 |
+
|
793 |
+
@article{liu2023grounding,
|
794 |
+
title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
|
795 |
+
author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
|
796 |
+
journal={arXiv preprint arXiv:2303.05499},
|
797 |
+
year={2023}
|
798 |
+
}
|
799 |
+
|
800 |
+
@misc{ren2024grounded,
|
801 |
+
title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
|
802 |
+
author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},
|
803 |
+
year={2024},
|
804 |
+
eprint={2401.14159},
|
805 |
+
archivePrefix={arXiv},
|
806 |
+
primaryClass={cs.CV}
|
807 |
+
}
|
808 |
+
```
|