sezer91 commited on
Commit
ee7b9e5
·
1 Parent(s): 274e615
Files changed (1) hide show
  1. index.js +169 -201
index.js CHANGED
@@ -14,8 +14,6 @@ const uploadButton = document.getElementById("upload-button");
14
  const resetButton = document.getElementById("reset-image");
15
  const clearButton = document.getElementById("clear-points");
16
  const cutButton = document.getElementById("cut-mask");
17
- const starIcon = document.getElementById("star-icon");
18
- const crossIcon = document.getElementById("cross-icon");
19
  const maskCanvas = document.getElementById("mask-output");
20
  const maskContext = maskCanvas.getContext("2d");
21
 
@@ -24,161 +22,189 @@ const EXAMPLE_URL =
24
 
25
  // State variables
26
  let isEncoding = false;
27
- let isDecoding = false;
28
- let decodePending = false;
29
- let lastPoints = null;
30
- let isMultiMaskMode = false;
31
  let imageInput = null;
32
  let imageProcessed = null;
33
  let imageEmbeddings = null;
 
 
34
 
35
- async function decode() {
36
- // Only proceed if we are not already decoding
37
- if (isDecoding) {
38
- decodePending = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return;
40
  }
41
- isDecoding = true;
42
 
43
- // Prepare inputs for decoding
 
 
 
44
  const reshaped = imageProcessed.reshaped_input_sizes[0];
45
- const points = lastPoints
46
- .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
47
- .flat(Infinity);
48
- const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);
49
-
50
- const num_points = lastPoints.length;
51
- const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
52
- const input_labels = new Tensor("int64", labels, [1, 1, num_points]);
53
-
54
- // Generate the mask
55
- const { pred_masks, iou_scores } = await model({
56
- ...imageEmbeddings,
57
- input_points,
58
- input_labels,
59
- });
60
-
61
- // Post-process the mask
62
- const masks = await processor.post_process_masks(
63
- pred_masks,
64
- imageProcessed.original_sizes,
65
- imageProcessed.reshaped_input_sizes,
66
- );
67
-
68
- isDecoding = false;
69
-
70
- updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);
71
-
72
- // Check if another decode is pending
73
- if (decodePending) {
74
- decodePending = false;
75
- decode();
76
  }
77
- }
78
 
79
- function updateMaskOverlay(mask, scores) {
80
- // Update canvas dimensions (if different)
81
- if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
82
- maskCanvas.width = mask.width;
83
- maskCanvas.height = mask.height;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  }
85
 
86
- // Allocate buffer for pixel data
87
- const imageData = maskContext.createImageData(
88
- maskCanvas.width,
89
- maskCanvas.height,
90
- );
91
-
92
- // Select best mask
93
- const numMasks = scores.length; // 3
94
- let bestIndex = 0;
95
- for (let i = 1; i < numMasks; ++i) {
96
- if (scores[i] > scores[bestIndex]) {
97
- bestIndex = i;
98
  }
99
- }
100
- statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
101
-
102
- // Fill mask with colour
103
- const pixelData = imageData.data;
104
- for (let i = 0; i < pixelData.length; ++i) {
105
- if (mask.data[numMasks * i + bestIndex] === 1) {
106
- const offset = 4 * i;
107
- pixelData[offset] = 0; // red
108
- pixelData[offset + 1] = 114; // green
109
- pixelData[offset + 2] = 189; // blue
110
- pixelData[offset + 3] = 255; // alpha
111
  }
112
  }
113
 
114
- // Draw image data to context
115
- maskContext.putImageData(imageData, 0, 0);
116
- }
117
-
118
- function clearPointsAndMask() {
119
- // Reset state
120
- isMultiMaskMode = false;
121
- lastPoints = null;
122
-
123
- // Remove points from previous mask (if any)
124
- document.querySelectorAll(".icon").forEach((e) => e.remove());
125
 
126
- // Disable cut button
127
- cutButton.disabled = true;
128
-
129
- // Reset mask canvas
130
- maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
131
  }
132
- clearButton.addEventListener("click", clearPointsAndMask);
133
-
134
- resetButton.addEventListener("click", () => {
135
- // Reset the state
136
- imageInput = null;
137
- imageProcessed = null;
138
- imageEmbeddings = null;
139
- isEncoding = false;
140
- isDecoding = false;
141
 
142
- // Clear points and mask (if present)
143
- clearPointsAndMask();
144
-
145
- // Update UI
146
- cutButton.disabled = true;
147
- imageContainer.style.backgroundImage = "none";
148
- uploadButton.style.display = "flex";
149
- statusLabel.textContent = "Ready";
150
- });
151
-
152
- async function encode(url) {
153
- if (isEncoding) return;
154
- isEncoding = true;
155
- statusLabel.textContent = "Extracting image embedding...";
156
 
157
- imageInput = await RawImage.fromURL(url);
 
158
 
159
- // Update UI
160
- imageContainer.style.backgroundImage = `url(${url})`;
161
- uploadButton.style.display = "none";
162
- cutButton.disabled = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- // Recompute image embeddings
165
- imageProcessed = await processor(imageInput);
166
- imageEmbeddings = await model.get_image_embeddings(imageProcessed);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- statusLabel.textContent = "Embedding extracted!";
169
- isEncoding = false;
170
  }
171
 
172
- // Handle file selection
173
  fileUpload.addEventListener("change", function (e) {
174
  const file = e.target.files[0];
175
  if (!file) return;
176
 
177
  const reader = new FileReader();
178
-
179
- // Set up a callback when the file is loaded
180
  reader.onload = (e2) => encode(e2.target.result);
181
-
182
  reader.readAsDataURL(file);
183
  });
184
 
@@ -187,90 +213,32 @@ example.addEventListener("click", (e) => {
187
  encode(EXAMPLE_URL);
188
  });
189
 
190
- // Attach hover event to image container
191
- imageContainer.addEventListener("mousedown", (e) => {
192
- if (e.button !== 0 && e.button !== 2) {
193
- return; // Ignore other buttons
194
- }
195
- if (!imageEmbeddings) {
196
- return; // Ignore if not encoded yet
197
- }
198
- if (!isMultiMaskMode) {
199
- lastPoints = [];
200
- isMultiMaskMode = true;
201
- cutButton.disabled = false;
202
- }
203
-
204
- const point = getPoint(e);
205
- lastPoints.push(point);
206
-
207
- // add icon
208
- const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode();
209
- icon.style.left = `${point.position[0] * 100}%`;
210
- icon.style.top = `${point.position[1] * 100}%`;
211
- imageContainer.appendChild(icon);
212
-
213
- // Run decode
214
- decode();
215
- });
216
-
217
- // Clamp a value inside a range [min, max]
218
- function clamp(x, min = 0, max = 1) {
219
- return Math.max(Math.min(x, max), min);
220
- }
221
-
222
- function getPoint(e) {
223
- // Get bounding box
224
- const bb = imageContainer.getBoundingClientRect();
225
-
226
- // Get the mouse coordinates relative to the container
227
- const mouseX = clamp((e.clientX - bb.left) / bb.width);
228
- const mouseY = clamp((e.clientY - bb.top) / bb.height);
229
-
230
- return {
231
- position: [mouseX, mouseY],
232
- label:
233
- e.button === 2 // right click
234
- ? 0 // negative prompt
235
- : 1, // positive prompt
236
- };
237
- }
238
-
239
- // Do not show context menu on right click
240
- imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());
241
-
242
- // Attach hover event to image container
243
- imageContainer.addEventListener("mousemove", (e) => {
244
- if (!imageEmbeddings || isMultiMaskMode) {
245
- // Ignore mousemove events if the image is not encoded yet,
246
- // or we are in multi-mask mode
247
- return;
248
- }
249
- lastPoints = [getPoint(e)];
250
 
251
- decode();
 
 
 
 
252
  });
253
 
254
- // Handle cut button click
255
  cutButton.addEventListener("click", async () => {
256
  const [w, h] = [maskCanvas.width, maskCanvas.height];
257
-
258
- // Get the mask pixel data (and use this as a buffer)
259
  const maskImageData = maskContext.getImageData(0, 0, w, h);
260
 
261
- // Create a new canvas to hold the cut-out
262
  const cutCanvas = new OffscreenCanvas(w, h);
263
  const cutContext = cutCanvas.getContext("2d");
264
 
265
- // Copy the image pixel data to the cut canvas
266
  const maskPixelData = maskImageData.data;
267
  const imagePixelData = imageInput.data;
268
  for (let i = 0; i < w * h; ++i) {
269
- const sourceOffset = 3 * i; // RGB
270
- const targetOffset = 4 * i; // RGBA
271
-
272
  if (maskPixelData[targetOffset + 3] > 0) {
273
- // Only copy opaque pixels
274
  for (let j = 0; j < 3; ++j) {
275
  maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
276
  }
@@ -278,7 +246,6 @@ cutButton.addEventListener("click", async () => {
278
  }
279
  cutContext.putImageData(maskImageData, 0, 0);
280
 
281
- // Download image
282
  const link = document.createElement("a");
283
  link.download = "image.png";
284
  link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
@@ -286,16 +253,17 @@ cutButton.addEventListener("click", async () => {
286
  link.remove();
287
  });
288
 
 
289
  const model_id = "Xenova/slimsam-77-uniform";
290
  statusLabel.textContent = "Loading model...";
291
- const model = await SamModel.from_pretrained(model_id, {
292
- dtype: "fp16", // or "fp32"
293
  device: "webgpu",
294
  });
295
- const processor = await AutoProcessor.from_pretrained(model_id);
296
  statusLabel.textContent = "Ready";
297
 
298
- // Enable the user interface
299
  fileUpload.disabled = false;
300
  uploadButton.style.opacity = 1;
301
  example.style.pointerEvents = "auto";
 
14
  const resetButton = document.getElementById("reset-image");
15
  const clearButton = document.getElementById("clear-points");
16
  const cutButton = document.getElementById("cut-mask");
 
 
17
  const maskCanvas = document.getElementById("mask-output");
18
  const maskContext = maskCanvas.getContext("2d");
19
 
 
22
 
23
  // State variables
24
  let isEncoding = false;
 
 
 
 
25
  let imageInput = null;
26
  let imageProcessed = null;
27
  let imageEmbeddings = null;
28
+ let model = null;
29
+ let processor = null;
30
 
31
+ async function encode(url) {
32
+ if (isEncoding) return;
33
+ isEncoding = true;
34
+ statusLabel.textContent = "Extracting image embedding...";
35
+
36
+ imageInput = await RawImage.fromURL(url);
37
+
38
+ // Update UI
39
+ imageContainer.style.backgroundImage = `url(${url})`;
40
+ uploadButton.style.display = "none";
41
+ cutButton.disabled = true;
42
+
43
+ // Recompute image embeddings
44
+ imageProcessed = await processor(imageInput);
45
+ imageEmbeddings = await model.get_image_embeddings(imageProcessed);
46
+
47
+ statusLabel.textContent = "Embedding extracted!";
48
+ isEncoding = false;
49
+
50
+ // Otomatik segmentasyon için hemen çalıştır
51
+ await autoSegment();
52
+ }
53
+
54
+ async function autoSegment() {
55
+ if (!imageEmbeddings) {
56
+ statusLabel.textContent = "No image embeddings available!";
57
  return;
58
  }
 
59
 
60
+ statusLabel.textContent = "Generating automatic segments...";
61
+
62
+ // Grid tabanlı noktalar oluştur (otomatik segmentasyon için)
63
+ const gridSize = 50; // Grid boyutu (piksel cinsinden)
64
  const reshaped = imageProcessed.reshaped_input_sizes[0];
65
+ let points = [];
66
+ for (let y = gridSize / 2; y < imageInput.height; y += gridSize) {
67
+ for (let x = gridSize / 2; x < imageInput.width; x += gridSize) {
68
+ points.push([
69
+ (x / imageInput.width) * reshaped[1],
70
+ (y / imageInput.height) * reshaped[0],
71
+ ]);
72
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  }
 
74
 
75
+ // Maskeleri saklamak için dizi
76
+ let masks = [];
77
+ let scores = [];
78
+
79
+ // Her grid noktası için segmentasyon yap
80
+ for (let i = 0; i < points.length; i++) {
81
+ const point = points[i];
82
+ const input_points = new Tensor("float32", point, [1, 1, 1, 2]);
83
+ const input_labels = new Tensor("int64", [1n], [1, 1, 1]);
84
+
85
+ const { pred_masks, iou_scores } = await model({
86
+ ...imageEmbeddings,
87
+ input_points,
88
+ input_labels,
89
+ });
90
+
91
+ const processedMasks = await processor.post_process_masks(
92
+ pred_masks,
93
+ imageProcessed.original_sizes,
94
+ imageProcessed.reshaped_input_sizes,
95
+ );
96
+
97
+ masks.push(processedMasks[0][0]); // İlk maskeyi al
98
+ scores.push(iou_scores.data);
99
  }
100
 
101
+ // Maskeleri filtrele (çok küçük veya düşük skorlu maskeleri atla)
102
+ const filteredMasks = [];
103
+ const filteredScores = [];
104
+ for (let i = 0; i < masks.length; i++) {
105
+ const mask = masks[i];
106
+ let pixelCount = 0;
107
+ for (let j = 0; j < mask.data.length; j++) {
108
+ if (mask.data[j] === 1) pixelCount++;
 
 
 
 
109
  }
110
+ if (pixelCount > (imageInput.width * imageInput.height) / 100) {
111
+ // %1'den büyük maskeler
112
+ filteredMasks.push(mask);
113
+ filteredScores.push(scores[i]);
 
 
 
 
 
 
 
 
114
  }
115
  }
116
 
117
+ // Maskeleri ve etiketleri çiz
118
+ updateMaskOverlay(filteredMasks, filteredScores);
 
 
 
 
 
 
 
 
 
119
 
120
+ statusLabel.textContent = `Found ${filteredMasks.length} objects`;
 
 
 
 
121
  }
 
 
 
 
 
 
 
 
 
122
 
123
+ function updateMaskOverlay(masks, scores) {
124
+ // Canvas boyutlarını güncelle
125
+ if (
126
+ maskCanvas.width !== imageInput.width ||
127
+ maskCanvas.height !== imageInput.height
128
+ ) {
129
+ maskCanvas.width = imageInput.width;
130
+ maskCanvas.height = imageInput.height;
131
+ }
 
 
 
 
 
132
 
133
+ // Önce canvas'i temizle
134
+ maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
135
 
136
+ // Her maskeyi çiz
137
+ for (let m = 0; m < masks.length; m++) {
138
+ const mask = masks[m];
139
+ const imageData = maskContext.createImageData(
140
+ maskCanvas.width,
141
+ maskCanvas.height,
142
+ );
143
+ const pixelData = imageData.data;
144
+
145
+ // En iyi maskeyi seç
146
+ const numMasks = scores[m].length;
147
+ let bestIndex = 0;
148
+ for (let i = 1; i < numMasks; ++i) {
149
+ if (scores[m][i] > scores[m][bestIndex]) {
150
+ bestIndex = i;
151
+ }
152
+ }
153
 
154
+ // Maskeyi renklendir
155
+ const r = Math.random() * 255;
156
+ const g = Math.random() * 255;
157
+ const b = Math.random() * 255;
158
+ for (let i = 0; i < pixelData.length; ++i) {
159
+ if (mask.data[numMasks * i + bestIndex] === 1) {
160
+ const offset = 4 * i;
161
+ pixelData[offset] = r;
162
+ pixelData[offset + 1] = g;
163
+ pixelData[offset + 2] = b;
164
+ pixelData[offset + 3] = 128; // Yarı saydam
165
+ }
166
+ }
167
+ maskContext.putImageData(imageData, 0, 0);
168
+
169
+ // Etiketi ekle
170
+ let xIndices = [];
171
+ let yIndices = [];
172
+ for (let i = 0; i < mask.data.length; i++) {
173
+ if (mask.data[numMasks * i + bestIndex] === 1) {
174
+ const x = i % maskCanvas.width;
175
+ const y = Math.floor(i / maskCanvas.width);
176
+ xIndices.push(x);
177
+ yIndices.push(y);
178
+ }
179
+ }
180
+ if (xIndices.length > 0 && yIndices.length > 0) {
181
+ const centerX = Math.floor(
182
+ xIndices.reduce((a, b) => a + b, 0) / xIndices.length,
183
+ );
184
+ const centerY = Math.floor(
185
+ yIndices.reduce((a, b) => a + b, 0) / yIndices.length,
186
+ );
187
+ maskContext.fillStyle = "white";
188
+ maskContext.font = "16px Arial";
189
+ maskContext.strokeStyle = "black";
190
+ maskContext.lineWidth = 2;
191
+ const label = `Object ${m + 1}`;
192
+ maskContext.strokeText(label, centerX, centerY);
193
+ maskContext.fillText(label, centerX, centerY);
194
+ }
195
+ }
196
 
197
+ // Kesme butonunu etkinleştir
198
+ cutButton.disabled = false;
199
  }
200
 
201
+ // Mevcut event listener'ları koru, ama tıklama olaylarını kaldır
202
  fileUpload.addEventListener("change", function (e) {
203
  const file = e.target.files[0];
204
  if (!file) return;
205
 
206
  const reader = new FileReader();
 
 
207
  reader.onload = (e2) => encode(e2.target.result);
 
208
  reader.readAsDataURL(file);
209
  });
210
 
 
213
  encode(EXAMPLE_URL);
214
  });
215
 
216
+ resetButton.addEventListener("click", () => {
217
+ imageInput = null;
218
+ imageProcessed = null;
219
+ imageEmbeddings = null;
220
+ isEncoding = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
223
+ cutButton.disabled = true;
224
+ imageContainer.style.backgroundImage = "none";
225
+ uploadButton.style.display = "flex";
226
+ statusLabel.textContent = "Ready";
227
  });
228
 
 
229
  cutButton.addEventListener("click", async () => {
230
  const [w, h] = [maskCanvas.width, maskCanvas.height];
 
 
231
  const maskImageData = maskContext.getImageData(0, 0, w, h);
232
 
 
233
  const cutCanvas = new OffscreenCanvas(w, h);
234
  const cutContext = cutCanvas.getContext("2d");
235
 
 
236
  const maskPixelData = maskImageData.data;
237
  const imagePixelData = imageInput.data;
238
  for (let i = 0; i < w * h; ++i) {
239
+ const sourceOffset = 3 * i;
240
+ const targetOffset = 4 * i;
 
241
  if (maskPixelData[targetOffset + 3] > 0) {
 
242
  for (let j = 0; j < 3; ++j) {
243
  maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
244
  }
 
246
  }
247
  cutContext.putImageData(maskImageData, 0, 0);
248
 
 
249
  const link = document.createElement("a");
250
  link.download = "image.png";
251
  link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
 
253
  link.remove();
254
  });
255
 
256
+ // Modeli yükle
257
  const model_id = "Xenova/slimsam-77-uniform";
258
  statusLabel.textContent = "Loading model...";
259
+ model = await SamModel.from_pretrained(model_id, {
260
+ dtype: "fp16",
261
  device: "webgpu",
262
  });
263
+ processor = await AutoProcessor.from_pretrained(model_id);
264
  statusLabel.textContent = "Ready";
265
 
266
+ // UI'yi etkinleştir
267
  fileUpload.disabled = false;
268
  uploadButton.style.opacity = 1;
269
  example.style.pointerEvents = "auto";