sezer91's picture
d
ee7b9e5
import {
SamModel,
AutoProcessor,
RawImage,
Tensor,
} from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";
// Reference the elements we will use
const statusLabel = document.getElementById("status");
const fileUpload = document.getElementById("upload");
const imageContainer = document.getElementById("container");
const example = document.getElementById("example");
const uploadButton = document.getElementById("upload-button");
const resetButton = document.getElementById("reset-image");
const clearButton = document.getElementById("clear-points");
const cutButton = document.getElementById("cut-mask");
const maskCanvas = document.getElementById("mask-output");
const maskContext = maskCanvas.getContext("2d");
const EXAMPLE_URL =
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg";
// State variables
let isEncoding = false;
let imageInput = null;
let imageProcessed = null;
let imageEmbeddings = null;
let model = null;
let processor = null;
async function encode(url) {
if (isEncoding) return;
isEncoding = true;
statusLabel.textContent = "Extracting image embedding...";
imageInput = await RawImage.fromURL(url);
// Update UI
imageContainer.style.backgroundImage = `url(${url})`;
uploadButton.style.display = "none";
cutButton.disabled = true;
// Recompute image embeddings
imageProcessed = await processor(imageInput);
imageEmbeddings = await model.get_image_embeddings(imageProcessed);
statusLabel.textContent = "Embedding extracted!";
isEncoding = false;
// Otomatik segmentasyon için hemen çalıştır
await autoSegment();
}
async function autoSegment() {
if (!imageEmbeddings) {
statusLabel.textContent = "No image embeddings available!";
return;
}
statusLabel.textContent = "Generating automatic segments...";
// Grid tabanlı noktalar oluştur (otomatik segmentasyon için)
const gridSize = 50; // Grid boyutu (piksel cinsinden)
const reshaped = imageProcessed.reshaped_input_sizes[0];
let points = [];
for (let y = gridSize / 2; y < imageInput.height; y += gridSize) {
for (let x = gridSize / 2; x < imageInput.width; x += gridSize) {
points.push([
(x / imageInput.width) * reshaped[1],
(y / imageInput.height) * reshaped[0],
]);
}
}
// Maskeleri saklamak için dizi
let masks = [];
let scores = [];
// Her grid noktası için segmentasyon yap
for (let i = 0; i < points.length; i++) {
const point = points[i];
const input_points = new Tensor("float32", point, [1, 1, 1, 2]);
const input_labels = new Tensor("int64", [1n], [1, 1, 1]);
const { pred_masks, iou_scores } = await model({
...imageEmbeddings,
input_points,
input_labels,
});
const processedMasks = await processor.post_process_masks(
pred_masks,
imageProcessed.original_sizes,
imageProcessed.reshaped_input_sizes,
);
masks.push(processedMasks[0][0]); // İlk maskeyi al
scores.push(iou_scores.data);
}
// Maskeleri filtrele (çok küçük veya düşük skorlu maskeleri atla)
const filteredMasks = [];
const filteredScores = [];
for (let i = 0; i < masks.length; i++) {
const mask = masks[i];
let pixelCount = 0;
for (let j = 0; j < mask.data.length; j++) {
if (mask.data[j] === 1) pixelCount++;
}
if (pixelCount > (imageInput.width * imageInput.height) / 100) {
// %1'den büyük maskeler
filteredMasks.push(mask);
filteredScores.push(scores[i]);
}
}
// Maskeleri ve etiketleri çiz
updateMaskOverlay(filteredMasks, filteredScores);
statusLabel.textContent = `Found ${filteredMasks.length} objects`;
}
function updateMaskOverlay(masks, scores) {
// Canvas boyutlarını güncelle
if (
maskCanvas.width !== imageInput.width ||
maskCanvas.height !== imageInput.height
) {
maskCanvas.width = imageInput.width;
maskCanvas.height = imageInput.height;
}
// Önce canvas'i temizle
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
// Her maskeyi çiz
for (let m = 0; m < masks.length; m++) {
const mask = masks[m];
const imageData = maskContext.createImageData(
maskCanvas.width,
maskCanvas.height,
);
const pixelData = imageData.data;
// En iyi maskeyi seç
const numMasks = scores[m].length;
let bestIndex = 0;
for (let i = 1; i < numMasks; ++i) {
if (scores[m][i] > scores[m][bestIndex]) {
bestIndex = i;
}
}
// Maskeyi renklendir
const r = Math.random() * 255;
const g = Math.random() * 255;
const b = Math.random() * 255;
for (let i = 0; i < pixelData.length; ++i) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const offset = 4 * i;
pixelData[offset] = r;
pixelData[offset + 1] = g;
pixelData[offset + 2] = b;
pixelData[offset + 3] = 128; // Yarı saydam
}
}
maskContext.putImageData(imageData, 0, 0);
// Etiketi ekle
let xIndices = [];
let yIndices = [];
for (let i = 0; i < mask.data.length; i++) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const x = i % maskCanvas.width;
const y = Math.floor(i / maskCanvas.width);
xIndices.push(x);
yIndices.push(y);
}
}
if (xIndices.length > 0 && yIndices.length > 0) {
const centerX = Math.floor(
xIndices.reduce((a, b) => a + b, 0) / xIndices.length,
);
const centerY = Math.floor(
yIndices.reduce((a, b) => a + b, 0) / yIndices.length,
);
maskContext.fillStyle = "white";
maskContext.font = "16px Arial";
maskContext.strokeStyle = "black";
maskContext.lineWidth = 2;
const label = `Object ${m + 1}`;
maskContext.strokeText(label, centerX, centerY);
maskContext.fillText(label, centerX, centerY);
}
}
// Kesme butonunu etkinleştir
cutButton.disabled = false;
}
// Mevcut event listener'ları koru, ama tıklama olaylarını kaldır
fileUpload.addEventListener("change", function (e) {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = (e2) => encode(e2.target.result);
reader.readAsDataURL(file);
});
example.addEventListener("click", (e) => {
e.preventDefault();
encode(EXAMPLE_URL);
});
resetButton.addEventListener("click", () => {
imageInput = null;
imageProcessed = null;
imageEmbeddings = null;
isEncoding = false;
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
cutButton.disabled = true;
imageContainer.style.backgroundImage = "none";
uploadButton.style.display = "flex";
statusLabel.textContent = "Ready";
});
cutButton.addEventListener("click", async () => {
const [w, h] = [maskCanvas.width, maskCanvas.height];
const maskImageData = maskContext.getImageData(0, 0, w, h);
const cutCanvas = new OffscreenCanvas(w, h);
const cutContext = cutCanvas.getContext("2d");
const maskPixelData = maskImageData.data;
const imagePixelData = imageInput.data;
for (let i = 0; i < w * h; ++i) {
const sourceOffset = 3 * i;
const targetOffset = 4 * i;
if (maskPixelData[targetOffset + 3] > 0) {
for (let j = 0; j < 3; ++j) {
maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
}
}
}
cutContext.putImageData(maskImageData, 0, 0);
const link = document.createElement("a");
link.download = "image.png";
link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
link.click();
link.remove();
});
// Modeli yükle
const model_id = "Xenova/slimsam-77-uniform";
statusLabel.textContent = "Loading model...";
model = await SamModel.from_pretrained(model_id, {
dtype: "fp16",
device: "webgpu",
});
processor = await AutoProcessor.from_pretrained(model_id);
statusLabel.textContent = "Ready";
// UI'yi etkinleştir
fileUpload.disabled = false;
uploadButton.style.opacity = 1;
example.style.pointerEvents = "auto";