File size: 6,688 Bytes
9bb0e80
 
 
 
 
 
 
9c6c16a
9bb0e80
 
43e677e
ed46d76
c01db76
9bb0e80
 
 
8221cd6
 
ed46d76
 
 
 
 
 
 
43e677e
 
 
 
 
 
ed46d76
 
 
 
 
 
 
 
43e677e
 
 
 
 
9bb0e80
 
 
 
8221cd6
ed46d76
 
 
 
 
 
 
 
c01db76
 
 
 
 
 
 
 
 
 
 
231fef0
 
ed46d76
 
c01db76
ed46d76
 
 
8221cd6
 
 
9bb0e80
 
8221cd6
9bb0e80
 
 
43e677e
 
9bb0e80
 
 
8221cd6
 
 
 
 
 
9bb0e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8221cd6
231fef0
 
 
8221cd6
43e677e
 
 
8221cd6
231fef0
 
 
20150ec
 
231fef0
 
 
 
 
 
 
 
1ae0e63
 
 
20150ec
 
 
 
 
 
 
 
231fef0
ed46d76
9bb0e80
ed46d76
43e677e
 
 
 
 
9bb0e80
43e677e
 
 
 
 
 
 
9bb0e80
 
 
9c6c16a
9bb0e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6c16a
 
9bb0e80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from __future__ import annotations

import os
import time
import sys
import PIL.Image
import numpy as np
import gradio as gr
import spaces
import cuid
import importlib.util
import glob
import shutil

from huggingface_hub import snapshot_download

print("Starting application...")

def find_file(pattern, search_path="."):
    """Find a file matching the pattern in the search path."""
    matches = glob.glob(os.path.join(search_path, pattern))
    if matches:
        return matches[0]
    return None

def import_from_path(module_name, file_path):
    """Import a module from a file path."""
    if not os.path.exists(file_path):
        raise ImportError(f"File not found: {file_path}")
    
    print(f"Importing {module_name} from {file_path}")
    with open(file_path, 'r') as f:
        print(f"File contents (first 5 lines):")
        for i, line in enumerate(f):
            if i < 5:
                print(line.strip())
            else:
                break
    
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

# Set up paths
ProjectDir = os.path.dirname(os.path.abspath(__file__))
CheckpointsDir = os.path.join(ProjectDir, "checkpoints")

print(f"Project directory: {ProjectDir}")
print(f"Directory contents:")
print(os.listdir(ProjectDir))

# Find MuseV directory with case-insensitive search
musev_candidates = [d for d in os.listdir(ProjectDir) if d.lower() == "musev"]
if musev_candidates:
    MuseVDir = os.path.join(ProjectDir, musev_candidates[0])
    print(f"Found MuseV directory: {MuseVDir}")
    
    # If it's lowercase, try to rename it
    if musev_candidates[0] == "musev":
        try:
            temp_dir = os.path.join(ProjectDir, "MuseV_temp")
            shutil.move(MuseVDir, temp_dir)
            shutil.move(temp_dir, os.path.join(ProjectDir, "MuseV"))
            MuseVDir = os.path.join(ProjectDir, "MuseV")
            print("Successfully renamed musev to MuseV")
        except Exception as e:
            print(f"Warning: Could not rename directory: {str(e)}")
            # If we can't rename, use the existing directory
            MuseVDir = os.path.join(ProjectDir, "musev")
else:
    print("Warning: Could not find MuseV directory")
    sys.exit(1)

GradioScriptsDir = os.path.join(MuseVDir, "scripts", "gradio")

print(f"MuseV directory: {MuseVDir}")
print(f"Gradio scripts directory: {GradioScriptsDir}")

# Add the MuseV paths to sys.path
paths_to_add = [
    ProjectDir,  # Add current directory first
    MuseVDir,
    os.path.join(MuseVDir, "MMCM"),
    os.path.join(MuseVDir, "diffusers", "src"),
    os.path.join(MuseVDir, "controlnet_aux", "src"),
    GradioScriptsDir
]

for path in paths_to_add:
    if os.path.exists(path):
        if path not in sys.path:
            sys.path.insert(0, path)
            print(f"Added {path} to PYTHONPATH")
    else:
        print(f"Warning: Path does not exist: {path}")

def download_model():
    if not os.path.exists(CheckpointsDir):
        print("Checkpoint Not Downloaded, start downloading...")
        tic = time.time()
        snapshot_download(
            repo_id="TMElyralab/MuseV",
            local_dir=CheckpointsDir,
            max_workers=8,
        )
        toc = time.time()
        print(f"download cost {toc-tic} seconds")
    else:
        print("Already download the model.")

print("Starting model download...")
download_model()
print("Model download complete.")

print("Setting up paths...")
for path in sys.path:
    print(f"Path: {path}")

print("Attempting to import gradio modules...")

# First try to find modules in current directory
video2video_path = os.path.join(ProjectDir, "gradio_video2video.py")
text2video_path = os.path.join(ProjectDir, "gradio_text2video.py")

print(f"Looking for modules at:")
print(f"video2video: {video2video_path}")
print(f"text2video: {text2video_path}")

# If not found in current directory, look in MuseV directory
if not os.path.exists(video2video_path) or not os.path.exists(text2video_path):
    print("Modules not found in current directory, checking MuseV directory...")
    musev_video2video = os.path.join(MuseVDir, "scripts", "gradio", "gradio_video2video.py")
    musev_text2video = os.path.join(MuseVDir, "scripts", "gradio", "gradio_text2video.py")
    
    if os.path.exists(musev_video2video) and os.path.exists(musev_text2video):
        print("Found modules in MuseV directory, copying to current directory...")
        shutil.copy2(musev_video2video, video2video_path)
        shutil.copy2(musev_text2video, text2video_path)
        print("Successfully copied modules")
    else:
        print("Error: Could not find modules in MuseV directory")
        print(f"MuseV directory contents:")
        if os.path.exists(MuseVDir):
            print(os.listdir(MuseVDir))
        print(f"\nScripts directory contents:")
        scripts_dir = os.path.join(MuseVDir, "scripts")
        if os.path.exists(scripts_dir):
            print(os.listdir(scripts_dir))
            gradio_dir = os.path.join(scripts_dir, "gradio")
            if os.path.exists(gradio_dir):
                print("\nGradio scripts directory contents:")
                print(os.listdir(gradio_dir))
        sys.exit(1)

try:
    print("Attempting to import modules...")
    video2video = import_from_path("gradio_video2video", video2video_path)
    text2video = import_from_path("gradio_text2video", text2video_path)
    online_v2v_inference = video2video.online_v2v_inference
    online_t2v_inference = text2video.online_t2v_inference
    print("Successfully imported modules")
except Exception as e:
    print(f"Error importing modules: {str(e)}")
    print("\nDirectory contents:")
    print(os.listdir(ProjectDir))
    if os.path.exists(GradioScriptsDir):
        print("\nGradio scripts directory contents:")
        print(os.listdir(GradioScriptsDir))
    sys.exit(1)

ignore_video2video = False
max_image_edge = 1280

print("Setting up Gradio interface...")
demo = gr.Interface(
    fn=online_t2v_inference,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(label="Reference Image"),
        gr.Number(label="Seed", value=-1),
        gr.Number(label="FPS", value=6),
        gr.Number(label="Width", value=-1),
        gr.Number(label="Height", value=-1),
        gr.Number(label="Video Length", value=12),
        gr.Number(label="Image Edge Ratio", value=1.0),
    ],
    outputs=gr.Video(),
    title="MuseV Demo",
    description="Generate videos from text and reference images"
)

print("Launching Gradio interface...")
demo.queue().launch(server_name="0.0.0.0", server_port=7860)