freddyaboulton HF staff commited on
Commit
534fffb
·
verified ·
1 Parent(s): 023470c

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +28 -10
  2. index.html +45 -1
app.py CHANGED
@@ -3,8 +3,9 @@ from pathlib import Path
3
 
4
  import cv2
5
  import gradio as gr
 
6
  from fastapi.responses import HTMLResponse
7
- from fastrtc import Stream, get_twilio_turn_credentials
8
  from gradio.utils import get_space
9
  from huggingface_hub import hf_hub_download
10
  from pydantic import BaseModel, Field
@@ -12,7 +13,7 @@ from pydantic import BaseModel, Field
12
  try:
13
  from demo.object_detection.inference import YOLOv10
14
  except (ImportError, ModuleNotFoundError):
15
- from inference import YOLOv10
16
 
17
 
18
  cur_dir = Path(__file__).parent
@@ -25,10 +26,16 @@ model = YOLOv10(model_file)
25
 
26
 
27
  def detection(image, conf_threshold=0.3):
28
- image = cv2.resize(image, (model.input_width, model.input_height))
29
- print("conf_threshold", conf_threshold)
30
- new_image = model.detect_objects(image, conf_threshold)
31
- return cv2.resize(new_image, (500, 500))
 
 
 
 
 
 
32
 
33
 
34
  stream = Stream(
@@ -40,8 +47,12 @@ stream = Stream(
40
  concurrency_limit=20 if get_space() else None,
41
  )
42
 
 
 
 
43
 
44
- @stream.get("/")
 
45
  async def _():
46
  rtc_config = get_twilio_turn_credentials() if get_space() else None
47
  html_content = open(cur_dir / "index.html").read()
@@ -54,12 +65,19 @@ class InputData(BaseModel):
54
  conf_threshold: float = Field(ge=0, le=1)
55
 
56
 
57
- @stream.post("/input_hook")
58
  async def _(data: InputData):
59
  stream.set_input(data.webrtc_id, data.conf_threshold)
60
 
61
 
62
  if __name__ == "__main__":
63
- import uvicorn
 
 
 
 
 
 
 
64
 
65
- uvicorn.run(stream, host="0.0.0.0", port=7860)
 
3
 
4
  import cv2
5
  import gradio as gr
6
+ from fastapi import FastAPI
7
  from fastapi.responses import HTMLResponse
8
+ from fastrtc import Stream, WebRTCError, get_twilio_turn_credentials
9
  from gradio.utils import get_space
10
  from huggingface_hub import hf_hub_download
11
  from pydantic import BaseModel, Field
 
13
  try:
14
  from demo.object_detection.inference import YOLOv10
15
  except (ImportError, ModuleNotFoundError):
16
+ from .inference import YOLOv10
17
 
18
 
19
  cur_dir = Path(__file__).parent
 
26
 
27
 
28
  def detection(image, conf_threshold=0.3):
29
+ try:
30
+ image = cv2.resize(image, (model.input_width, model.input_height))
31
+ print("conf_threshold", conf_threshold)
32
+ new_image = model.detect_objects(image, conf_threshold)
33
+ return cv2.resize(new_image, (500, 500))
34
+ except Exception as e:
35
+ import traceback
36
+
37
+ traceback.print_exc()
38
+ raise WebRTCError(str(e))
39
 
40
 
41
  stream = Stream(
 
47
  concurrency_limit=20 if get_space() else None,
48
  )
49
 
50
+ app = FastAPI()
51
+
52
+ stream.mount(app)
53
 
54
+
55
+ @app.get("/")
56
  async def _():
57
  rtc_config = get_twilio_turn_credentials() if get_space() else None
58
  html_content = open(cur_dir / "index.html").read()
 
65
  conf_threshold: float = Field(ge=0, le=1)
66
 
67
 
68
+ @app.post("/input_hook")
69
  async def _(data: InputData):
70
  stream.set_input(data.webrtc_id, data.conf_threshold)
71
 
72
 
73
  if __name__ == "__main__":
74
+ import os
75
+
76
+ if (mode := os.getenv("MODE")) == "UI":
77
+ stream.ui.launch(server_port=7860, server_name="0.0.0.0")
78
+ elif mode == "PHONE":
79
+ stream.fastphone(host="0.0.0.0", port=7860)
80
+ else:
81
+ import uvicorn
82
 
83
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html CHANGED
@@ -112,10 +112,28 @@
112
  border-radius: 50%;
113
  cursor: pointer;
114
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  </style>
116
  </head>
117
 
118
  <body>
 
 
119
  <div class="container">
120
  <h1>Real-time Object Detection</h1>
121
  <p>Using YOLOv10 to detect objects in your webcam feed</p>
@@ -160,6 +178,17 @@
160
  });
161
  }
162
 
 
 
 
 
 
 
 
 
 
 
 
163
  async function setupWebRTC() {
164
  const config = __RTC_CONFIGURATION__;
165
  peerConnection = new RTCPeerConnection(config);
@@ -182,7 +211,9 @@
182
  const dataChannel = peerConnection.createDataChannel('text');
183
  dataChannel.onmessage = (event) => {
184
  const eventJson = JSON.parse(event.data);
185
- if (eventJson.type === "send_input") {
 
 
186
  updateConfThreshold(confThreshold.value);
187
  }
188
  };
@@ -217,6 +248,16 @@
217
  });
218
 
219
  const serverResponse = await response.json();
 
 
 
 
 
 
 
 
 
 
220
  await peerConnection.setRemoteDescription(serverResponse);
221
 
222
  // Send initial confidence threshold
@@ -224,6 +265,9 @@
224
 
225
  } catch (err) {
226
  console.error('Error setting up WebRTC:', err);
 
 
 
227
  }
228
  }
229
 
 
112
  border-radius: 50%;
113
  cursor: pointer;
114
  }
115
+
116
+ /* Add styles for toast notifications */
117
+ .toast {
118
+ position: fixed;
119
+ top: 20px;
120
+ left: 50%;
121
+ transform: translateX(-50%);
122
+ background-color: #f44336;
123
+ color: white;
124
+ padding: 16px 24px;
125
+ border-radius: 4px;
126
+ font-size: 14px;
127
+ z-index: 1000;
128
+ display: none;
129
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
130
+ }
131
  </style>
132
  </head>
133
 
134
  <body>
135
+ <!-- Add toast element after body opening tag -->
136
+ <div id="error-toast" class="toast"></div>
137
  <div class="container">
138
  <h1>Real-time Object Detection</h1>
139
  <p>Using YOLOv10 to detect objects in your webcam feed</p>
 
178
  });
179
  }
180
 
181
+ function showError(message) {
182
+ const toast = document.getElementById('error-toast');
183
+ toast.textContent = message;
184
+ toast.style.display = 'block';
185
+
186
+ // Hide toast after 5 seconds
187
+ setTimeout(() => {
188
+ toast.style.display = 'none';
189
+ }, 5000);
190
+ }
191
+
192
  async function setupWebRTC() {
193
  const config = __RTC_CONFIGURATION__;
194
  peerConnection = new RTCPeerConnection(config);
 
211
  const dataChannel = peerConnection.createDataChannel('text');
212
  dataChannel.onmessage = (event) => {
213
  const eventJson = JSON.parse(event.data);
214
+ if (eventJson.type === "error") {
215
+ showError(eventJson.message);
216
+ } else if (eventJson.type === "send_input") {
217
  updateConfThreshold(confThreshold.value);
218
  }
219
  };
 
248
  });
249
 
250
  const serverResponse = await response.json();
251
+
252
+ if (serverResponse.status === 'failed') {
253
+ showError(serverResponse.meta.error === 'concurrency_limit_reached'
254
+ ? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
255
+ : serverResponse.meta.error);
256
+ stop();
257
+ startButton.textContent = 'Start';
258
+ return;
259
+ }
260
+
261
  await peerConnection.setRemoteDescription(serverResponse);
262
 
263
  // Send initial confidence threshold
 
265
 
266
  } catch (err) {
267
  console.error('Error setting up WebRTC:', err);
268
+ showError('Failed to establish connection. Please try again.');
269
+ stop();
270
+ startButton.textContent = 'Start';
271
  }
272
  }
273