Spaces:
Runtime error
Runtime error
File size: 4,166 Bytes
f8d8ee5 fc869bc f8d8ee5 42505ce f8d8ee5 d622dcf dad1da5 55e9f56 f8d8ee5 089df3a f8d8ee5 089df3a f8d8ee5 f28db37 f8d8ee5 ff8ae44 f8d8ee5 ff8ae44 f8d8ee5 55e9f56 ff8ae44 f8d8ee5 9469756 f8d8ee5 bfb1ec4 f8d8ee5 4645982 f8d8ee5 55e9f56 c040c30 b128f7c f3929ed b128f7c 55e9f56 f8d8ee5 f28db37 f8d8ee5 7a1f182 e3469be f8d8ee5 7a1f182 f8d8ee5 e3469be 55e9f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
import random
import numpy as np
import pythreejs as p3js
from skimage.measure import find_contours
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
from enum import Enum
import json
from flask import Flask
from flask import request
import base64
import io
import os
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/code/.cache'
os.environ['TRANSFORMERS_CACHE '] = '/code/.cache'
device = torch.device("cpu")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade")
class LABEL_TYPE(Enum):
WINDOW = 8
WALL = 0
FLOOR = 3
def query_image(img):
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
results = torch.argmax(results, dim=0).numpy()
return results
def find_boundary(label_value, mask):
contours = find_contours(mask == label_value, 0.5, fully_connected="high")
return contours
def send_response(mask):
vertices=[]
test=[]
for contour in mask:
test.extend(contour.ravel())
ar=contour.astype(str)
vertices.extend(ar.ravel())
return vertices
def extract_window_edges(window_contours):
windows=[]
for contour in window_contours:
min_x = str(np.min(contour[:, 0]))
max_x = str(np.max(contour[:, 0]))
min_y = str(np.min(contour[:, 1]))
max_y = str(np.max(contour[:, 1]))
windows.append(
[
[min_x,min_y],
[min_x,max_y],
[max_x,min_y],
[max_x,max_y],
]
)
return windows
app = Flask(__name__)
def process(image, items):
preds = detector("https://visualization.graberblinds.com/assets/sample_sessions/02e1d080-c4bf-4cdc-b1bc-f39f9b2a2230_thumb.jpg", candidate_labels=items)
return preds
@app.route("/",methods=["OPTIONS"])
def cors():
response=app.make_response("ok")
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "*")
response.headers.add("Access-Control-Allow-Methods", "*")
return response
@app.route("/",methods=["POST"])
def detect():
body=request.get_json()
base64_str=body['img']
# Assuming base64_str is the string value without 'data:image/jpeg;base64,'
img = Image.open(io.BytesIO(base64.decodebytes(bytes(base64_str, "utf-8"))))
numpydata = np.asarray(img)
mask = query_image(numpydata)
window_mask = find_boundary(LABEL_TYPE.WINDOW.value, mask)
floor_mask = find_boundary(LABEL_TYPE.FLOOR.value, mask)
window_contours = [np.fliplr(ctr).astype(np.int32) for ctr in window_mask]
floor_contours = [np.fliplr(ctr).astype(np.int32) for ctr in floor_mask]
windows_vertices=extract_window_edges(window_contours)
w=send_response(window_contours)
f=send_response(floor_contours)
print(w)
print(f)
print(type(w))
print(type(f))
response=app.make_response(json.dumps({
"window_contours":w,
"floor_contours" :f,
"windows_vertices":windows_vertices
}))
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "*")
response.headers.add("Access-Control-Allow-Methods", "*")
response.content_type="application/json"
return response
app.run(debug=True,port="7860") |