Francesco Capuano commited on
Commit
546bd37
·
1 Parent(s): b976944

fix: add a running attribute instead of sleeping for a day :')

Browse files
lerobot/scripts/server/policy_server.py CHANGED
@@ -62,6 +62,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
62
 
63
  self.actions_per_chunk = 20
64
  self.actions_overlap = 10
 
65
 
66
  def _setup_server(self) -> None:
67
  """Flushes server state when new client connects."""
@@ -325,27 +326,31 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
325
 
326
  return action
327
 
 
 
 
 
328
 
329
- def serve():
330
- import gradio as gr
331
-
332
- def greet(name):
333
- return "Hello " + name + "!"
334
 
335
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
336
- demo.launch()
 
337
 
 
338
  server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
339
- async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server)
340
  server.add_insecure_port("[::]:50051")
341
  server.start()
342
  logger.info("PolicyServer started on port 50051")
343
-
344
-
345
  try:
346
- while True:
347
- time.sleep(86400) # Sleep for a day, or until interrupted
 
348
  except KeyboardInterrupt:
 
 
 
349
  server.stop(0)
350
  logger.info("Server stopped")
351
 
 
62
 
63
  self.actions_per_chunk = 20
64
  self.actions_overlap = 10
65
+ self.running = True # Add a running flag to control server lifetime
66
 
67
  def _setup_server(self) -> None:
68
  """Flushes server state when new client connects."""
 
326
 
327
  return action
328
 
329
+ def stop(self):
330
+ """Stop the server"""
331
+ self.running = False
332
+ logger.info("Server stopping...")
333
 
 
 
 
 
 
334
 
335
+ def serve():
336
+ # Create the server instance first
337
+ policy_server = PolicyServer()
338
 
339
+ # Setup and start gRPC server
340
  server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
341
+ async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
342
  server.add_insecure_port("[::]:50051")
343
  server.start()
344
  logger.info("PolicyServer started on port 50051")
345
+
 
346
  try:
347
+ # Use the running attribute to control server lifetime
348
+ while policy_server.running:
349
+ time.sleep(1) # Check every second instead of sleeping indefinitely
350
  except KeyboardInterrupt:
351
+ policy_server.stop()
352
+ logger.info("Keyboard interrupt received")
353
+ finally:
354
  server.stop(0)
355
  logger.info("Server stopped")
356