Spaces:
Runtime error
Runtime error
import gradio as gr | |
import openai | |
import sys | |
import os | |
import json | |
import threading | |
import time | |
import requests | |
import argparse | |
import markdown2 | |
import uuid | |
import traceback | |
from pathlib import Path | |
from dotenv import load_dotenv | |
from IPython.display import Image | |
from moviepy.editor import VideoFileClip, concatenate_videoclips, ImageClip | |
from moviepy.video.fx.all import fadein, fadeout | |
from PIL import Image as PIL_Image | |
from pydub import AudioSegment | |
from moviepy.editor import VideoFileClip, AudioFileClip | |
from jinja2 import Template | |
ENV = os.getenv("ENV") | |
# MODEL = "gpt-3.5-turbo" | |
MODEL = "gpt-4" | |
load_dotenv() | |
openai.api_key = os.getenv('OPENAI_API_KEY') | |
REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',') | |
NUMBER_OF_SCENES = os.getenv("NUMBER_OF_SCENES") | |
import replicate | |
from replicate.client import Client | |
class Replicate: | |
def __init__(self, id, client: Client, args, index=None): | |
self.id = id | |
self.client = client | |
self.args = args | |
self.index = index | |
self.prompt = "" | |
self.file_path_format = "" | |
self.REPLICATE_MODEL_PATH = "" | |
self.REPLICATE_MODEL_VERSION = "" | |
self.input={} | |
self.response = None | |
self.prediction_id = None | |
self.lock = threading.Lock() | |
def run_replicate(self, retries=0): | |
try: | |
# self.client.api_token = self.client.api_token_controller.get_next_token() | |
start_time = time.time() | |
# os.environ["REPLICATE_API_TOKEN"] = self.client.api_token | |
#tokenの最初の10文字だけ出力 | |
print(f"Thread {self.index} token: {self.client.api_token[:10]}") | |
model = self.client.models.get(self.REPLICATE_MODEL_PATH) | |
version = model.versions.get(self.REPLICATE_MODEL_VERSION) | |
self.prediction = self.client.predictions.create( | |
version=version, | |
input=self.input | |
) | |
self.prediction_id = self.prediction.id | |
# print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction: {self.prediction}") | |
print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") | |
self.prediction.reload() | |
print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") | |
self.prediction.wait() | |
print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") | |
if self.prediction.status == "succeeded": | |
self.response = self.prediction.output | |
self.response = self.response | |
print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}") | |
else: | |
self.response = None | |
self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) | |
end_time = time.time() | |
duration = end_time - start_time | |
self.print_thread_info(start_time, end_time, duration) | |
return self.response | |
except Exception as e: | |
print(f"Error in thread {self.index}: {e}") | |
print(traceback.format_exc()) | |
def download_and_save(self, url, file_path): | |
with self.lock: # ロックを取得 | |
response = requests.get(url) | |
with open(file_path, "wb") as f: | |
f.write(response.content) | |
def print_thread_info(self, start_time, end_time, duration): | |
print(f"Thread {self.index} response: {self.response}") | |
print(f"Thread {self.index} start time: {start_time}") | |
print(f"Thread {self.index} end time: {end_time}") | |
print(f"Thread {self.index} duration: {duration}") | |
class LucatacoAnimateDiff(Replicate): | |
def __init__(self, id, client: Client, args, scene, index=None): | |
super().__init__(id, client, args, index) | |
self.REPLICATE_MODEL_PATH = "lucataco/animate-diff" | |
self.REPLICATE_MODEL_VERSION = "beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f" | |
self.scene = scene | |
self.prompt = "masterpiece, awards, best quality, dramatic-lighting, " | |
self.prompt = self.prompt + scene.get("visual_prompt_in_en") | |
self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles") | |
self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, deformed iris, deformed pupils, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" | |
self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4" | |
self.file_path = None | |
self.input={ | |
"motion_module": "mm_sd_v14", | |
"prompt": self.prompt, | |
"n_prompt": self.nagative_prompt, | |
"seed": 0, # random | |
} | |
def run_replicate(self, retries=0): | |
self.response = super().run_replicate() | |
self.download_and_save(url=self.response, file_path=self.file_path) | |
self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) | |
return self.response | |
class ZsxkibAnimateDiff(Replicate): | |
def __init__(self, id, client: Client, args, scene, index=None): | |
super().__init__(id, client, args, index) | |
self.REPLICATE_MODEL_PATH = "zsxkib/animate-diff" | |
self.REPLICATE_MODEL_VERSION = "269a616c8b0c2bbc12fc15fd51bb202b11e94ff0f7786c026aa905305c4ed9fb" | |
self.scene = scene | |
self.prompt = "masterpiece, awards, best quality, dramatic-lighting, " | |
self.prompt = self.prompt + scene.get("visual_prompt_in_en") | |
self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles") | |
self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, deformed iris, deformed pupils, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" | |
self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4" | |
self.file_path = None | |
self.input={ | |
"prompt": self.prompt, | |
"negative_prompt": self.nagative_prompt, | |
"base_model": "toonyou_beta3", #Allowed values:realisticVisionV20_v20, lyriel_v16, majicmixRealistic_v5Preview, rcnzCartoon3d_v10, toonyou_beta3 | |
} | |
def run_replicate(self, retries=0): | |
self.response = super().run_replicate() | |
self.video = self.response[0] | |
self.download_and_save(url=self.video, file_path=self.file_path) | |
self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) | |
return self.response | |
class Interpolator(Replicate): | |
def __init__(self, id, client: Client, args, video, index=None): | |
super().__init__(id, client, args, index) | |
self.REPLICATE_MODEL_PATH = "zsxkib/st-mfnet" | |
self.REPLICATE_MODEL_VERSION = "faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb" | |
self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4" | |
self.file_path = None | |
self.input={ | |
"mp4": video, | |
"framerate_multiplier": 4, | |
"keep_original_duration": False, | |
"custom_fps": 24, | |
} | |
def run_replicate(self, retries=0): | |
self.response = super().run_replicate() | |
self.download_and_save(url=list(self.response)[-1], file_path=self.file_path) | |
self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) | |
return self.response | |
class Video(): | |
def __init__(self, id, client: Client, args, scene, index=None): | |
self.client = client | |
self.index = index | |
# self.animatediff = LucatacoAnimateDiff(id, client, args, scene, index) | |
self.animatediff = ZsxkibAnimateDiff(id, client, args, scene, index) | |
self.prompt = self.animatediff.prompt | |
self.interpolator = None | |
def run_replicate(self, retries=0): | |
self.animatediff.run_replicate(retries) | |
self.interpolator = Interpolator(self.animatediff.id, self.animatediff.client, self.animatediff.args, self.animatediff.video, self.animatediff.index) | |
self.response = self.interpolator.run_replicate(retries) | |
self.file_path = self.interpolator.file_path | |
return self.response | |
class Music(Replicate): | |
def __init__(self, id, client: Client, args, duration): | |
super().__init__(id, client, args) | |
self.REPLICATE_MODEL_PATH = "facebookresearch/musicgen" | |
self.REPLICATE_MODEL_VERSION = "f8578df960c345df7bc1f85dd152c5ae0b57ce45a6fc09511c467a62ad820ba3", | |
self.prompt = "innovative, exceptional, captivating, " \ | |
+ args.get("bgm_prompt_in_en") | |
self.file_path_format = "assets/{id}/{class_name}_{index}_request_{prediction_id}.mp3" | |
self.file_path = None | |
self.duration = duration | |
self.input = { | |
"model_version": "large", | |
"prompt": self.prompt, | |
"duration": self.duration, | |
"output_format": "mp3", | |
"seed": -1, # random | |
} | |
def run_replicate(self, retries=0): | |
start_time = time.time() | |
print(f"Thread {self.index} token: {self.client.api_token[:10]}") | |
os.environ['REPLICATE_API_TOKEN'] = self.client.api_token | |
output = replicate.run( | |
"facebookresearch/musicgen:7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906", | |
input={ | |
"model_version": "large", | |
# "prompt": "The sound of samurai's footsteps marching across the field, the echo of the mountain, the fierce battle sound, and finally the triumphant fanfare as they claim victory." | |
"prompt": self.prompt, | |
"duration": self.duration, | |
"output_format": "mp3", | |
"seed": -1, # random | |
} | |
) | |
print(output) | |
self.response = output | |
self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) | |
end_time = time.time() | |
duration = end_time - start_time | |
self.download_and_save(url=self.response, file_path=self.file_path) | |
self.print_thread_info(start_time, end_time, duration) | |
return self.response | |
class ThreadController: | |
def __init__(self, args): | |
self.id = uuid.uuid4() | |
self.args = args | |
scenes = args.get("scenes") | |
self.music = None | |
self.videos = [] | |
self.threads = [] | |
self.lock = threading.Lock() | |
self.replicate_client_list = {} | |
# 2.1秒 * シーン数 * APIトークン数 ただし30秒を超える場合は30秒にする | |
self.duration = int(2.1 * len(scenes) * len(REPLICATE_API_TOKEN_LIST)) if int(2.1 * len(scenes) * len(REPLICATE_API_TOKEN_LIST)) < 30 else 30 | |
os.makedirs(f"assets/{self.id}", exist_ok=True) | |
for token_index, token in enumerate(REPLICATE_API_TOKEN_LIST): | |
client = Client() | |
client.api_token = token | |
client.api_token_index = 0 | |
self.replicate_client_list[token] = client | |
if token_index == 0: | |
self.music = Music(self.id, client, args, self.duration) | |
self.music.duration = self.duration | |
for index, scene in enumerate(scenes): | |
token = REPLICATE_API_TOKEN_LIST[token_index] | |
video = Video(self.id, client, args, scene, index) | |
self.videos.append(video) | |
client.api_token_index = (token_index + 1) % len(REPLICATE_API_TOKEN_LIST) | |
def run_threads(self): | |
thread = threading.Thread(target=self.music.run_replicate) | |
self.threads.append(thread) | |
thread.start() | |
token = self.music.client.api_token | |
for video in self.videos: | |
if token is not None and video.client.api_token != token: | |
# tokenが異なる場合、4秒待ってから次を実行 | |
print(f"Thread {video.index} token changed. Waiting 4 seconds.") | |
time.sleep(4) | |
thread = threading.Thread(target=video.run_replicate) | |
self.threads.append(thread) | |
thread.start() | |
token = video.client.api_token | |
# time.sleep(5) | |
for thread in self.threads: | |
thread.join() | |
def merge_videos(self): | |
clips = [] | |
for video in sorted(self.videos, key=lambda x: x.index): | |
video_path = Path(video.file_path) | |
if video_path.exists(): | |
clips.append(VideoFileClip(video.file_path)) | |
else: | |
print(f"Error: Video file {video.file_path} could not be found! Skipping this file.") | |
# 他のログ出力方法も使用可能、例: loggingモジュール | |
output_path = f"assets/{self.id}/concatenated_video_{self.id}.mp4" | |
final_clip = concatenate_videoclips(clips) | |
final_clip.write_videofile(output_path, codec='libx264', fps=24) | |
# Load the video file using MoviePy | |
video_clip = VideoFileClip(output_path) | |
video_duration = video_clip.duration | |
# Re-loading the audio file using pydub | |
audio_segment = AudioSegment.from_mp3(self.music.file_path) | |
# Calculating the number of loops needed to match the video duration | |
num_loops = int(video_duration * 1000) // len(audio_segment) + 1 | |
# Creating an audio segment that has the same duration as the video by looping the original audio | |
final_audio_segment = audio_segment * num_loops | |
# Trimming the final audio segment to match the video duration exactly | |
final_audio_segment = final_audio_segment[:int(video_duration * 1000)] | |
import tempfile | |
# Saving the final audio as a temporary WAV file | |
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as f: | |
temp_audio_path = f.name | |
final_audio_segment.export(temp_audio_path, format="mp3") | |
# Loading the temporary audio file as a MoviePy AudioFileClip | |
final_audio_clip = AudioFileClip(temp_audio_path) | |
# Setting the audio to the video | |
final_video_clip = video_clip.set_audio(final_audio_clip) | |
# Saving the final video with audio to a temporary file | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: | |
output_path_with_audio_fixed = f.name | |
final_video_clip.write_videofile(output_path_with_audio_fixed, codec="libx264", audio_codec="aac") | |
os.makedirs(f"videos/{self.id}/", exist_ok=True) | |
output_path = f"videos/{self.id}/final_concatenated_video_{self.id}.mp4" | |
# final_clip.write_videofile(output_path, codec='libx264', fps=24) | |
import shutil | |
shutil.move(output_path_with_audio_fixed, output_path) | |
# Delete the temporary files | |
os.remove(temp_audio_path) | |
shutil.rmtree(f"assets/{self.id}/") | |
return output_path | |
def print_prompts(self): | |
for video in self.videos: | |
print(f"Thread {video.index} prompt: {video.prompt}") | |
def main(args): | |
thread_controller = ThreadController(args) | |
thread_controller.run_threads() | |
merged_video_path = thread_controller.merge_videos() | |
thread_controller.print_prompts() | |
return merged_video_path | |
def load_prompts(file_path): | |
with open(file_path, "r") as f: | |
prompts = f.read().splitlines() | |
return prompts | |
def get_filetext(filename): | |
with open(filename, "r") as file: | |
filetext = file.read() | |
return filetext | |
def get_functions_from_schema(filename): | |
schema = get_filetext(filename) | |
schema_json = json.loads(schema) | |
functions = schema_json.get("functions") | |
return functions | |
functions = get_functions_from_schema('schema.json') | |
class OpenAI: | |
def chat_completion_with_function(cls, prompt, messages, functions): | |
print("prompt:"+prompt) | |
# 文章生成にかかる時間を計測する | |
start = time.time() | |
# ChatCompletion APIを呼び出す | |
response = openai.ChatCompletion.create( | |
model=MODEL, | |
messages=messages, | |
functions=functions, | |
function_call={"name": "generate_video"} | |
) | |
print("gpt generation time: "+str(time.time() - start)) | |
# ChatCompletion APIから返された結果を取得する | |
message = response.choices[0].message | |
print("chat completion message: " + json.dumps(message, indent=2)) | |
return response | |
class NajiminoAI: | |
def __init__(self, user_message): | |
self.user_message = user_message | |
def generate_markdown(self, args, generation_time): | |
template_string = get_filetext(filename = "template.md") | |
template = Template(template_string) | |
result = template.render(args=args, generation_time=generation_time) | |
print(result) | |
return result | |
def generate(cls, user_message): | |
najiminoai = NajiminoAI(user_message) | |
return najiminoai.create_video() | |
def create_video(self): | |
main_start_time = time.time() | |
user_message = self.user_message + f" {NUMBER_OF_SCENES}シーン" | |
messages = [ | |
{"role": "user", "content": user_message} | |
] | |
functions = get_functions_from_schema('schema.json') | |
response = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) | |
message = response.choices[0].message | |
total_tokens = response.usage.total_tokens | |
video_path = None | |
html = None | |
if message.get("function_call") is None: | |
print("message: " + json.dumps(message, indent=2)) | |
return [video_path, html] | |
function_name = message["function_call"]["name"] | |
try: | |
args = json.loads(message["function_call"]["arguments"]) | |
except json.JSONDecodeError as e: | |
print(f"JSON decode error at position {e.pos}: {e.msg}") | |
print("message: " + json.dumps(message, indent=2)) | |
raise e | |
print("args: " + json.dumps(args, indent=2)) | |
video_path = main(args) | |
main_end_time = time.time() | |
main_duration = main_end_time - main_start_time | |
print("Thread Main start time:", main_start_time) | |
print("Thread Main end time:", main_end_time) | |
print("Thread Main duration:", main_duration) | |
print("All threads finished.") | |
function_response = self.generate_markdown(args, main_duration) | |
html = ( | |
"<div style='max-width:100%; overflow:auto'>" | |
+ "<p>" | |
+ markdown2.markdown(function_response,extras=["tables"]) | |
+ "</div>" | |
) | |
return [video_path, html] | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Generate videos from text prompts") | |
parser.add_argument("--prompts_file", type=str, help="File containing prompts (one per line)") | |
args = parser.parse_args() | |
if args.prompts_file: | |
prompts = load_prompts(args.prompts_file) | |
# main(prompts) | |
NajiminoAI.generate("子どもたちが笑ったり怒ったり泣いたり楽しんだりする") | |
else: | |
description = """ | |
入力されたテキストプロンプトに基づいてビデオを生成します | |
Generate a video based on the text prompt you enter. | |
""" | |
iface = gr.Interface( | |
fn=NajiminoAI.generate, | |
# inputs=gr.Textbox(label=inputs_label), | |
outputs=[ | |
gr.Video(), | |
"html" | |
], | |
# title=title, | |
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), | |
title="najimino Video Generator (β)", | |
description=description, | |
examples=[ | |
["侍たちは野を超え山を超え、敵軍大将を討ち取り、天下の大将軍となった!"], | |
["子どもたちが笑ったり怒ったり泣いたり楽しんだりする"], | |
["日は昇り、大地を照らし、日は沈む。闇夜を照らし、陽はまた昇る。 "], | |
], | |
) | |
iface.launch() | |