import gradio as gr import openai import sys import os import json import threading import time import requests import argparse import markdown2 import uuid import traceback from pathlib import Path from dotenv import load_dotenv from IPython.display import Image from moviepy.editor import VideoFileClip, concatenate_videoclips, ImageClip from moviepy.video.fx.all import fadein, fadeout from PIL import Image as PIL_Image from pydub import AudioSegment from moviepy.editor import VideoFileClip, AudioFileClip from jinja2 import Template ENV = os.getenv("ENV") # MODEL = "gpt-3.5-turbo" MODEL = "gpt-4" load_dotenv() openai.api_key = os.getenv('OPENAI_API_KEY') REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',') NUMBER_OF_SCENES = os.getenv("NUMBER_OF_SCENES") import replicate from replicate.client import Client class Replicate: def __init__(self, id, client: Client, args, index=None): self.id = id self.client = client self.args = args self.index = index self.prompt = "" self.file_path_format = "" self.REPLICATE_MODEL_PATH = "" self.REPLICATE_MODEL_VERSION = "" self.input={} self.output_url = None self.response = None self.prediction_id = None def run_replicate(self, retries=0): try: # self.client.api_token = self.client.api_token_controller.get_next_token() start_time = time.time() # os.environ["REPLICATE_API_TOKEN"] = self.client.api_token #tokenの最初の10文字だけ出力 print(f"Thread {self.index} token: {self.client.api_token[:10]}") model = self.client.models.get(self.REPLICATE_MODEL_PATH) version = model.versions.get(self.REPLICATE_MODEL_VERSION) self.prediction = self.client.predictions.create( version=version, input=self.input ) self.prediction_id = self.prediction.id # print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction: {self.prediction}") print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") self.prediction.reload() print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") self.prediction.wait() print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") if self.prediction.status == "succeeded": self.output_url = self.prediction.output print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}") else: self.output_url = None self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) end_time = time.time() duration = end_time - start_time self.download_and_save(url=self.output_url, file_path=self.file_path) self.print_thread_info(start_time, end_time, duration) except replicate.exceptions.ReplicateError as e: print(f"Error fetching model or version: {e}") print(f"Model Path: {self.REPLICATE_MODEL_PATH}") print(f"Model Version: {self.REPLICATE_MODEL_VERSION}") if self.prediction_id and str(e) == "The requested resource could not be found.": predictions = self.client.predictions.list() self.prediction = next((p for p in predictions if p.id == self.prediction_id), None) if self.prediction: print(f"Found prediction with ID {self.prediction_id}: {self.prediction}") else: print(f"No prediction found with ID {self.prediction_id}") self.prediction.wait() print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.status: {self.prediction.status}") if self.prediction.status == "succeeded": self.output_url = self.prediction.output print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}") else: self.output_url = None print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: Error") print(f"Thread {self.index} token: {self.client.api_token[:10]} prediction.output: {self.prediction.output}") self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) end_time = time.time() duration = end_time - start_time self.download_and_save(url=self.output_url, file_path=self.file_path) self.print_thread_info(start_time, end_time, duration) else: print(f"Error in thread {self.index}: {e}") print(traceback.format_exc()) print("予期しないエラーが発生しました。スレッドを終了します。") # 予期しないエラーが発生した場合の追加処理 raise e except Exception as e: print(f"Error in thread {self.index}: {e}") print(traceback.format_exc()) def download_and_save(self, url, file_path): response = requests.get(url) with open(file_path, "wb") as f: f.write(response.content) def print_thread_info(self, start_time, end_time, duration): print(f"Thread {self.index} output_url: {self.output_url}") print(f"Thread {self.index} start time: {start_time}") print(f"Thread {self.index} end time: {end_time}") print(f"Thread {self.index} duration: {duration}") class Video(Replicate): def __init__(self, id, client: Client, args, scene, index=None): super().__init__(id, client, args, index) self.REPLICATE_MODEL_PATH = "lucataco/animate-diff" self.REPLICATE_MODEL_VERSION = "1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663" self.scene = scene self.prompt = "masterpiece, awards, best quality, dramatic-lighting, " self.prompt = self.prompt + scene.get("visual_prompt_in_en") self.prompt = self.prompt + ", cinematic-angles-" + scene.get("cinematic_angles") self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, " self.file_path_format = "assets/{id}/{class_name}_thread_{index}_request_{prediction_id}.mp4" self.file_path = None self.input={ "motion_module": "mm_sd_v14", "prompt": self.prompt, "n_prompt": self.nagative_prompt, "seed": 0, # random } def run_replicate(self, retries=0): self.response = super().run_replicate() self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) return self.response class Music(Replicate): def __init__(self, id, client: Client, args): super().__init__(id, client, args) self.REPLICATE_MODEL_PATH = "facebookresearch/musicgen" self.REPLICATE_MODEL_VERSION = "f8578df960c345df7bc1f85dd152c5ae0b57ce45a6fc09511c467a62ad820ba3", self.prompt = "innovative, exceptional, captivating, " \ + args.get("bgm_prompt_in_en") self.file_path_format = "assets/{id}/{class_name}_{index}_request_{prediction_id}.mp3" self.file_path = None self.duration = args.get("") self.input = { "model_version": "large", "prompt": self.prompt, "duration": self.duration, "output_format": "mp3", "seed": -1, # random } def run_replicate(self, retries=0): start_time = time.time() print(f"Thread {self.index} token: {self.client.api_token[:10]}") os.environ['REPLICATE_API_TOKEN'] = self.client.api_token output = replicate.run( "facebookresearch/musicgen:7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906", input={ "model_version": "large", # "prompt": "The sound of samurai's footsteps marching across the field, the echo of the mountain, the fierce battle sound, and finally the triumphant fanfare as they claim victory." "prompt": self.prompt, "duration": self.duration, "output_format": "mp3", "seed": -1, # random } ) print(output) self.output_url = output self.response = output self.file_path = self.file_path_format.format(id=self.id, class_name=self.__class__.__name__, index=self.index, prediction_id=self.prediction_id) end_time = time.time() duration = end_time - start_time self.download_and_save(url=self.output_url, file_path=self.file_path) self.print_thread_info(start_time, end_time, duration) return self.response class ThreadController: def __init__(self, args): self.id = uuid.uuid4() self.args = args scenes = args.get("scenes") self.music = None self.videos = [] self.threads = [] self.lock = threading.Lock() self.replicate_client_list = {} self.duration = int(2.1 * len(scenes) * len(REPLICATE_API_TOKEN_LIST)) # 2.1秒 * シーン数 * APIトークン数 os.makedirs(f"assets/{self.id}", exist_ok=True) for token_index, token in enumerate(REPLICATE_API_TOKEN_LIST): client = Client() client.api_token = token self.replicate_client_list[token] = client if token_index == 0: self.music = Music(self.id, client, args) self.music.duration = self.duration for index, scene in enumerate(scenes): token = REPLICATE_API_TOKEN_LIST[token_index] video = Video(self.id, client, args, scene, index) self.videos.append(video) # client.api_token_index = (token_index + 1) % len(REPLICATE_API_TOKEN_LIST) def run_threads(self): thread = threading.Thread(target=self.music.run_replicate) self.threads.append(thread) thread.start() token = self.music.client.api_token for video in self.videos: if token is not None and video.client.api_token != token: # tokenが異なる場合、4秒待ってから次を実行 print(f"Thread {video.index} token changed. Waiting 4 seconds.") time.sleep(4) thread = threading.Thread(target=video.run_replicate) self.threads.append(thread) thread.start() token = video.client.api_token # time.sleep(5) for thread in self.threads: thread.join() def merge_videos(self): clips = [] for video in self.videos: video_path = Path(video.file_path) if video_path.exists(): clips.append(VideoFileClip(video.file_path)) else: print(f"Error: Video file {video.file_path} could not be found! Skipping this file.") # 他のログ出力方法も使用可能、例: loggingモジュール output_path = f"assets/{self.id}/concatenated_video_{self.id}.mp4" final_clip = concatenate_videoclips(clips) final_clip.write_videofile(output_path, codec='libx264', fps=24) # Load the video file using MoviePy video_clip = VideoFileClip(output_path) video_duration = video_clip.duration # Re-loading the audio file using pydub audio_segment = AudioSegment.from_mp3(self.music.file_path) # Calculating the number of loops needed to match the video duration num_loops = int(video_duration * 1000) // len(audio_segment) + 1 # Creating an audio segment that has the same duration as the video by looping the original audio final_audio_segment = audio_segment * num_loops # Trimming the final audio segment to match the video duration exactly final_audio_segment = final_audio_segment[:int(video_duration * 1000)] temp_audio_path = "/tmp/temp_audio.mp3" # Saving the final audio as a temporary WAV file final_audio_segment.export(temp_audio_path, format="mp3") # Loading the temporary audio file as a MoviePy AudioFileClip final_audio_clip = AudioFileClip(temp_audio_path) # Setting the audio to the video final_video_clip = video_clip.set_audio(final_audio_clip) # Path to save the final video with audio (different name to avoid confusion) output_path_with_audio_fixed = "/tmp/final_video_with_audio_fixed.mp4" # Saving the final video with audio final_video_clip.write_videofile(output_path_with_audio_fixed, codec="libx264", audio_codec="aac") # Path to the final video with audio (fixed version) output_path_with_audio_fixed os.makedirs(f"videos/{self.id}/", exist_ok=True) output_path = f"videos/{self.id}/final_concatenated_video_{self.id}.mp4" # final_clip.write_videofile(output_path, codec='libx264', fps=24) import shutil shutil.move(output_path_with_audio_fixed, output_path) return output_path def print_prompts(self): for video in self.videos: print(f"Thread {video.index} prompt: {video.prompt}") def main(args): thread_controller = ThreadController(args) thread_controller.run_threads() merged_video_path = thread_controller.merge_videos() thread_controller.print_prompts() return merged_video_path def load_prompts(file_path): with open(file_path, "r") as f: prompts = f.read().splitlines() return prompts def get_filetext(filename): with open(filename, "r") as file: filetext = file.read() return filetext def get_functions_from_schema(filename): schema = get_filetext(filename) schema_json = json.loads(schema) functions = schema_json.get("functions") return functions functions = get_functions_from_schema('schema.json') class OpenAI: @classmethod def chat_completion_with_function(cls, prompt, messages, functions): print("prompt:"+prompt) # 文章生成にかかる時間を計測する start = time.time() # ChatCompletion APIを呼び出す response = openai.ChatCompletion.create( model=MODEL, messages=messages, functions=functions, function_call={"name": "generate_video"} ) print("gpt generation time: "+str(time.time() - start)) # ChatCompletion APIから返された結果を取得する message = response.choices[0].message print("chat completion message: " + json.dumps(message, indent=2)) return response class NajiminoAI: def __init__(self, user_message): self.user_message = user_message def generate_markdown(self, args, generation_time): template_string = get_filetext(filename = "template.md") template = Template(template_string) result = template.render(args=args, generation_time=generation_time) print(result) return result @classmethod def generate(cls, user_message): najiminoai = NajiminoAI(user_message) return najiminoai.create_video() def create_video(self): main_start_time = time.time() user_message = self.user_message + f" {NUMBER_OF_SCENES}シーン" messages = [ {"role": "user", "content": user_message} ] functions = get_functions_from_schema('schema.json') response = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) message = response.choices[0].message total_tokens = response.usage.total_tokens video_path = None html = None if message.get("function_call") is None: print("message: " + json.dumps(message, indent=2)) return [video_path, html] function_name = message["function_call"]["name"] try: args = json.loads(message["function_call"]["arguments"]) except json.JSONDecodeError as e: print(f"JSON decode error at position {e.pos}: {e.msg}") print("message: " + json.dumps(message, indent=2)) raise e print("args: " + json.dumps(args, indent=2)) video_path = main(args) main_end_time = time.time() main_duration = main_end_time - main_start_time print("Thread Main start time:", main_start_time) print("Thread Main end time:", main_end_time) print("Thread Main duration:", main_duration) print("All threads finished.") function_response = self.generate_markdown(args, main_duration) html = ( "
" + markdown2.markdown(function_response,extras=["tables"]) + "