sanghan commited on
Commit
51dd778
·
1 Parent(s): 013bae8

load model in function call

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -40,6 +40,11 @@ def inference(video):
40
  if get_video_dimensions(video) > (1920, 1920):
41
  raise gr.Error("Video resolution must not be higher than 1920x1080")
42
 
 
 
 
 
 
43
  convert_video(
44
  model, # The loaded model, can be on any device (cpu or cuda).
45
  input_source=video, # A video file or an image sequence directory.
@@ -57,20 +62,15 @@ def inference(video):
57
 
58
 
59
  if __name__ == "__main__":
60
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
61
-
62
  if torch.cuda.is_available():
63
  free_memory = get_free_memory_gb()
64
  concurrency_count = int(free_memory // 7)
65
- model = model.cuda()
66
  print(f"Using GPU with concurrency: {concurrency_count}")
67
  print(f"Available video memory: {free_memory} GB")
68
  else:
69
  print("Using CPU")
70
  concurrency_count = 1
71
 
72
- convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
73
-
74
  with gr.Blocks(title="Robust Video Matting") as block:
75
  gr.Markdown("# Robust Video Matting")
76
  gr.Markdown(
 
40
  if get_video_dimensions(video) > (1920, 1920):
41
  raise gr.Error("Video resolution must not be higher than 1920x1080")
42
 
43
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
44
+ if torch.cuda.is_available():
45
+ model = model.cuda()
46
+
47
+ convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
48
  convert_video(
49
  model, # The loaded model, can be on any device (cpu or cuda).
50
  input_source=video, # A video file or an image sequence directory.
 
62
 
63
 
64
  if __name__ == "__main__":
 
 
65
  if torch.cuda.is_available():
66
  free_memory = get_free_memory_gb()
67
  concurrency_count = int(free_memory // 7)
 
68
  print(f"Using GPU with concurrency: {concurrency_count}")
69
  print(f"Available video memory: {free_memory} GB")
70
  else:
71
  print("Using CPU")
72
  concurrency_count = 1
73
 
 
 
74
  with gr.Blocks(title="Robust Video Matting") as block:
75
  gr.Markdown("# Robust Video Matting")
76
  gr.Markdown(