Spaces:
Running
on
T4
Running
on
T4
Some questions about Paligemma’s segmentation capabilities
#1
by
cslys1999
- opened
I downloaded them from the two links :
https://huggingface.co/collections/google/paligemma-release-6643a9ffbf57de2ae0448dda
https://huggingface.co/collections/google/paligemma-ft-models-6643b03efb769dad650d2dda.
And I got
google/paligemma-3b-ft-refcoco-seg-224,
google/paligemma-3b-mix-224,
google/paligemma-3b-ft-refcoco-seg-224
three checkpoints.
I want to use Paligemma to perform my instance segmentation task. However, I found that the ability of these three checkpoints on instance segmentation is not as good as the model in huggingface demo. Why is this?
My inference code looks like this:
from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch
from loguru import logger
"""
ref: https://colab.research.google.com/drive/1gOhRCFyt9yIoasJkd4VoaHcIqJPdJnlg?usp=sharing#scrollTo=u9kau5IOjNt9
"""
model = PaliGemmaForConditionalGeneration.from_pretrained("./paligemma_refcoco_ft/refcoco_ft", torch_dtype=torch.float32)
processor = AutoProcessor.from_pretrained("./paligemma_refcoco_ft/refcoco_ft")
model = model.to("cuda:0")
# url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
# image = Image.open(requests.get(url, stream=True).raw)
while True:
try:
image_path = input("the image file path: ")
prompt = input("please enter the prompt:")
# from IPython import embed;embed()
image = Image.open(image_path).convert("RGB")
inputs = processor(text=prompt, images=image, return_tensors="pt", padding=True)
inputs = inputs.to(dtype=model.dtype)
inputs = {key: value.to("cuda:0") for key, value in inputs.items()}
# Generate
from IPython import embed;embed()
with torch.inference_mode():
generate_ids = model.generate(**inputs, max_new_tokens=100)
result = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)
except Exception as e:
logger.error(e)
Looking forward to getting your answers very much!