Spaces:
Running
Running
import { | |
SamModel, | |
AutoProcessor, | |
RawImage, | |
Tensor, | |
} from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]"; | |
// Reference the elements we will use | |
const statusLabel = document.getElementById("status"); | |
const fileUpload = document.getElementById("upload"); | |
const imageContainer = document.getElementById("container"); | |
const example = document.getElementById("example"); | |
const uploadButton = document.getElementById("upload-button"); | |
const resetButton = document.getElementById("reset-image"); | |
const clearButton = document.getElementById("clear-points"); | |
const cutButton = document.getElementById("cut-mask"); | |
const maskCanvas = document.getElementById("mask-output"); | |
const maskContext = maskCanvas.getContext("2d"); | |
const EXAMPLE_URL = | |
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"; | |
// State variables | |
let isEncoding = false; | |
let imageInput = null; | |
let imageProcessed = null; | |
let imageEmbeddings = null; | |
let model = null; | |
let processor = null; | |
async function encode(url) { | |
if (isEncoding) return; | |
isEncoding = true; | |
statusLabel.textContent = "Extracting image embedding..."; | |
imageInput = await RawImage.fromURL(url); | |
// Update UI | |
imageContainer.style.backgroundImage = `url(${url})`; | |
uploadButton.style.display = "none"; | |
cutButton.disabled = true; | |
// Recompute image embeddings | |
imageProcessed = await processor(imageInput); | |
imageEmbeddings = await model.get_image_embeddings(imageProcessed); | |
statusLabel.textContent = "Embedding extracted!"; | |
isEncoding = false; | |
// Otomatik segmentasyon için hemen çalıştır | |
await autoSegment(); | |
} | |
async function autoSegment() { | |
if (!imageEmbeddings) { | |
statusLabel.textContent = "No image embeddings available!"; | |
return; | |
} | |
statusLabel.textContent = "Generating automatic segments..."; | |
// Grid tabanlı noktalar oluştur (otomatik segmentasyon için) | |
const gridSize = 50; // Grid boyutu (piksel cinsinden) | |
const reshaped = imageProcessed.reshaped_input_sizes[0]; | |
let points = []; | |
for (let y = gridSize / 2; y < imageInput.height; y += gridSize) { | |
for (let x = gridSize / 2; x < imageInput.width; x += gridSize) { | |
points.push([ | |
(x / imageInput.width) * reshaped[1], | |
(y / imageInput.height) * reshaped[0], | |
]); | |
} | |
} | |
// Maskeleri saklamak için dizi | |
let masks = []; | |
let scores = []; | |
// Her grid noktası için segmentasyon yap | |
for (let i = 0; i < points.length; i++) { | |
const point = points[i]; | |
const input_points = new Tensor("float32", point, [1, 1, 1, 2]); | |
const input_labels = new Tensor("int64", [1n], [1, 1, 1]); | |
const { pred_masks, iou_scores } = await model({ | |
...imageEmbeddings, | |
input_points, | |
input_labels, | |
}); | |
const processedMasks = await processor.post_process_masks( | |
pred_masks, | |
imageProcessed.original_sizes, | |
imageProcessed.reshaped_input_sizes, | |
); | |
masks.push(processedMasks[0][0]); // İlk maskeyi al | |
scores.push(iou_scores.data); | |
} | |
// Maskeleri filtrele (çok küçük veya düşük skorlu maskeleri atla) | |
const filteredMasks = []; | |
const filteredScores = []; | |
for (let i = 0; i < masks.length; i++) { | |
const mask = masks[i]; | |
let pixelCount = 0; | |
for (let j = 0; j < mask.data.length; j++) { | |
if (mask.data[j] === 1) pixelCount++; | |
} | |
if (pixelCount > (imageInput.width * imageInput.height) / 100) { | |
// %1'den büyük maskeler | |
filteredMasks.push(mask); | |
filteredScores.push(scores[i]); | |
} | |
} | |
// Maskeleri ve etiketleri çiz | |
updateMaskOverlay(filteredMasks, filteredScores); | |
statusLabel.textContent = `Found ${filteredMasks.length} objects`; | |
} | |
function updateMaskOverlay(masks, scores) { | |
// Canvas boyutlarını güncelle | |
if ( | |
maskCanvas.width !== imageInput.width || | |
maskCanvas.height !== imageInput.height | |
) { | |
maskCanvas.width = imageInput.width; | |
maskCanvas.height = imageInput.height; | |
} | |
// Önce canvas'i temizle | |
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); | |
// Her maskeyi çiz | |
for (let m = 0; m < masks.length; m++) { | |
const mask = masks[m]; | |
const imageData = maskContext.createImageData( | |
maskCanvas.width, | |
maskCanvas.height, | |
); | |
const pixelData = imageData.data; | |
// En iyi maskeyi seç | |
const numMasks = scores[m].length; | |
let bestIndex = 0; | |
for (let i = 1; i < numMasks; ++i) { | |
if (scores[m][i] > scores[m][bestIndex]) { | |
bestIndex = i; | |
} | |
} | |
// Maskeyi renklendir | |
const r = Math.random() * 255; | |
const g = Math.random() * 255; | |
const b = Math.random() * 255; | |
for (let i = 0; i < pixelData.length; ++i) { | |
if (mask.data[numMasks * i + bestIndex] === 1) { | |
const offset = 4 * i; | |
pixelData[offset] = r; | |
pixelData[offset + 1] = g; | |
pixelData[offset + 2] = b; | |
pixelData[offset + 3] = 128; // Yarı saydam | |
} | |
} | |
maskContext.putImageData(imageData, 0, 0); | |
// Etiketi ekle | |
let xIndices = []; | |
let yIndices = []; | |
for (let i = 0; i < mask.data.length; i++) { | |
if (mask.data[numMasks * i + bestIndex] === 1) { | |
const x = i % maskCanvas.width; | |
const y = Math.floor(i / maskCanvas.width); | |
xIndices.push(x); | |
yIndices.push(y); | |
} | |
} | |
if (xIndices.length > 0 && yIndices.length > 0) { | |
const centerX = Math.floor( | |
xIndices.reduce((a, b) => a + b, 0) / xIndices.length, | |
); | |
const centerY = Math.floor( | |
yIndices.reduce((a, b) => a + b, 0) / yIndices.length, | |
); | |
maskContext.fillStyle = "white"; | |
maskContext.font = "16px Arial"; | |
maskContext.strokeStyle = "black"; | |
maskContext.lineWidth = 2; | |
const label = `Object ${m + 1}`; | |
maskContext.strokeText(label, centerX, centerY); | |
maskContext.fillText(label, centerX, centerY); | |
} | |
} | |
// Kesme butonunu etkinleştir | |
cutButton.disabled = false; | |
} | |
// Mevcut event listener'ları koru, ama tıklama olaylarını kaldır | |
fileUpload.addEventListener("change", function (e) { | |
const file = e.target.files[0]; | |
if (!file) return; | |
const reader = new FileReader(); | |
reader.onload = (e2) => encode(e2.target.result); | |
reader.readAsDataURL(file); | |
}); | |
example.addEventListener("click", (e) => { | |
e.preventDefault(); | |
encode(EXAMPLE_URL); | |
}); | |
resetButton.addEventListener("click", () => { | |
imageInput = null; | |
imageProcessed = null; | |
imageEmbeddings = null; | |
isEncoding = false; | |
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); | |
cutButton.disabled = true; | |
imageContainer.style.backgroundImage = "none"; | |
uploadButton.style.display = "flex"; | |
statusLabel.textContent = "Ready"; | |
}); | |
cutButton.addEventListener("click", async () => { | |
const [w, h] = [maskCanvas.width, maskCanvas.height]; | |
const maskImageData = maskContext.getImageData(0, 0, w, h); | |
const cutCanvas = new OffscreenCanvas(w, h); | |
const cutContext = cutCanvas.getContext("2d"); | |
const maskPixelData = maskImageData.data; | |
const imagePixelData = imageInput.data; | |
for (let i = 0; i < w * h; ++i) { | |
const sourceOffset = 3 * i; | |
const targetOffset = 4 * i; | |
if (maskPixelData[targetOffset + 3] > 0) { | |
for (let j = 0; j < 3; ++j) { | |
maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j]; | |
} | |
} | |
} | |
cutContext.putImageData(maskImageData, 0, 0); | |
const link = document.createElement("a"); | |
link.download = "image.png"; | |
link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); | |
link.click(); | |
link.remove(); | |
}); | |
// Modeli yükle | |
const model_id = "Xenova/slimsam-77-uniform"; | |
statusLabel.textContent = "Loading model..."; | |
model = await SamModel.from_pretrained(model_id, { | |
dtype: "fp16", | |
device: "webgpu", | |
}); | |
processor = await AutoProcessor.from_pretrained(model_id); | |
statusLabel.textContent = "Ready"; | |
// UI'yi etkinleştir | |
fileUpload.disabled = false; | |
uploadButton.style.opacity = 1; | |
example.style.pointerEvents = "auto"; |