shigeru saito commited on
Commit
59876db
·
1 Parent(s): f19a595

REPLICATE複数トークン対応、トークン数を取得できるように修正、シーン数を環境変数化、

Browse files
Files changed (3) hide show
  1. .env.example +1 -0
  2. app.py +76 -43
  3. schema.json +2 -2
.env.example CHANGED
@@ -1,3 +1,4 @@
1
  REPLICATE_API_TOKEN_LIST=key1,key2
2
  OPENAI_API_KEY=
3
  ENV=PRODUCTION
 
 
1
  REPLICATE_API_TOKEN_LIST=key1,key2
2
  OPENAI_API_KEY=
3
  ENV=PRODUCTION
4
+ NUMBER_OF_SCENES=4
app.py CHANGED
@@ -8,6 +8,8 @@ import time
8
  import requests
9
  import argparse
10
  import markdown2
 
 
11
 
12
  from dotenv import load_dotenv
13
  from IPython.display import Image
@@ -27,6 +29,7 @@ openai.api_key = os.getenv('OPENAI_API_KEY')
27
  # REPLICATE_API_TOKEN_LISTをロードし、カンマで分割してリストに変換
28
  REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',')
29
  REPLICATE_API_TOKEN_INDEX = 0 # トークンのインデックスを初期化
 
30
 
31
  if ENV == "PRODUCTION":
32
  import replicate
@@ -34,7 +37,8 @@ else:
34
  from stub import replicate
35
 
36
  class Video:
37
- def __init__(self, scene, index):
 
38
  self.scene = scene
39
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
40
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
@@ -42,33 +46,47 @@ class Video:
42
  self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, "
43
  self.index = index
44
  self.output_url = None
45
- self.file_path = f"assets/thread_{index}_video.mp4"
46
-
47
- def run_replicate(self):
48
- global REPLICATE_API_TOKEN_INDEX
49
- start_time = time.time()
50
-
51
- # 現在のトークンを取得し、次のトークンにインデックスを更新
52
- token = REPLICATE_API_TOKEN_LIST[REPLICATE_API_TOKEN_INDEX]
53
- REPLICATE_API_TOKEN_INDEX = (REPLICATE_API_TOKEN_INDEX + 1) % len(REPLICATE_API_TOKEN_LIST)
54
- os.environ['REPLICATE_API_TOKEN'] = token
55
-
56
- start_time = time.time()
57
-
58
- self.output_url = replicate.run(
59
- "lucataco/animate-diff:1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663",
60
- input={
61
- "motion_module": "mm_sd_v14",
62
- "prompt": self.prompt,
63
- "n_prompt": self.nagative_prompt,
64
- }
65
- )
66
-
67
- end_time = time.time()
68
- duration = end_time - start_time
69
 
70
- self.download_and_save(url=self.output_url, file_path=self.file_path)
71
- self.print_thread_info(start_time, end_time, duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def download_and_save(self, url, file_path):
74
  response = requests.get(url)
@@ -84,18 +102,17 @@ class Video:
84
  class ThreadController:
85
  def __init__(self, args):
86
  self.args = args
87
- self.num_threads = len(args)
88
  scenes = args.get("scenes")
89
- # prompts = []
90
- # if scenes:
91
- # for scene_data in scenes:
92
- # prompt = scene_data.get("visual_prompt_in_en")
93
- # prompt = prompt + ", " + scene_data.get("cinematic_angles")
94
- # prompt = prompt + ", " + scene_data.get("visual_prompt_in_en")
95
- # prompts.append(prompt)
96
-
97
- self.videos = [Video(scene, index) for index, scene in enumerate(scenes)]
98
  self.threads = []
 
 
 
 
 
 
 
 
99
 
100
  def run_threads(self):
101
  os.makedirs("assets", exist_ok=True)
@@ -104,6 +121,8 @@ class ThreadController:
104
  thread = threading.Thread(target=video.run_replicate)
105
  self.threads.append(thread)
106
  thread.start()
 
 
107
 
108
  for thread in self.threads:
109
  thread.join()
@@ -111,12 +130,17 @@ class ThreadController:
111
  def merge_videos(self):
112
  clips = []
113
  for video in self.videos:
114
- clips.append(VideoFileClip(video.file_path))
 
 
 
 
 
115
 
116
  final_clip = concatenate_videoclips(clips)
117
 
118
  os.makedirs("videos", exist_ok=True)
119
- output_path = "videos/final_concatenated_video.mp4"
120
 
121
  final_clip.write_videofile(output_path, codec='libx264', fps=24)
122
 
@@ -126,6 +150,12 @@ class ThreadController:
126
  for video in self.videos:
127
  print(f"Thread {video.index} prompt: {video.prompt}")
128
 
 
 
 
 
 
 
129
  def main(args):
130
  thread_controller = ThreadController(args)
131
  thread_controller.run_threads()
@@ -173,8 +203,8 @@ class OpenAI:
173
  # ChatCompletion APIから返された結果を取得する
174
  message = response.choices[0].message
175
  print("chat completion message: " + json.dumps(message, indent=2))
176
-
177
- return message
178
 
179
  class NajiminoAI:
180
 
@@ -230,7 +260,7 @@ class NajiminoAI:
230
  def create_video(self):
231
  main_start_time = time.time()
232
 
233
- user_message = self.user_message + " 4シーン"
234
 
235
  messages = [
236
  {"role": "user", "content": user_message}
@@ -238,7 +268,10 @@ class NajiminoAI:
238
 
239
  functions = get_functions_from_schema('schema.json')
240
 
241
- message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions)
 
 
 
242
 
243
  video_path = None
244
  html = None
 
8
  import requests
9
  import argparse
10
  import markdown2
11
+ import uuid
12
+ from pathlib import Path
13
 
14
  from dotenv import load_dotenv
15
  from IPython.display import Image
 
29
  # REPLICATE_API_TOKEN_LISTをロードし、カンマで分割してリストに変換
30
  REPLICATE_API_TOKEN_LIST = os.getenv("REPLICATE_API_TOKEN_LIST").split(',')
31
  REPLICATE_API_TOKEN_INDEX = 0 # トークンのインデックスを初期化
32
+ NUMBER_OF_SCENES = os.getenv("NUMBER_OF_SCENES")
33
 
34
  if ENV == "PRODUCTION":
35
  import replicate
 
37
  from stub import replicate
38
 
39
  class Video:
40
+ def __init__(self, scene, index, token_controller):
41
+ self.token_controller = token_controller
42
  self.scene = scene
43
  self.prompt = "masterpiece, awards, best quality, dramatic-lighting, "
44
  self.prompt = self.prompt + scene.get("visual_prompt_in_en")
 
46
  self.nagative_prompt = "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, nsfw, "
47
  self.index = index
48
  self.output_url = None
49
+ self.video_id = uuid.uuid4()
50
+ self.file_path = f"assets/thread_{index}_request_{self.video_id}_video.mp4"
51
+
52
+ MAX_RETRIES = 2
53
+ def run_replicate(self, retries=0):
54
+ try:
55
+ self.token = self.token_controller.get_next_token()
56
+ start_time = time.time()
57
+
58
+ os.environ["REPLICATE_API_TOKEN"] = self.token
59
+ #tokenの最初の10文字だけ出力
60
+ print(f"Thread {self.index} token: {self.token[:10]}")
61
+
62
+ self.output_url = replicate.run(
63
+ "lucataco/animate-diff:1531004ee4c98894ab11f8a4ce6206099e732c1da15121987a8eef54828f0663",
64
+ input={
65
+ "motion_module": "mm_sd_v14",
66
+ "prompt": self.prompt,
67
+ "n_prompt": self.nagative_prompt,
68
+ "seed": 0,
69
+ }
70
+ )
 
 
71
 
72
+ end_time = time.time()
73
+ duration = end_time - start_time
74
+
75
+ self.download_and_save(url=self.output_url, file_path=self.file_path)
76
+ self.print_thread_info(start_time, end_time, duration)
77
+ except replicate.exceptions.ReplicateError as e:
78
+ if str(e) == "The requested resource could not be found." and retries < self.MAX_RETRIES:
79
+ print("リソースが見つからないエラーが発生しました。2秒後に再試行します。")
80
+ time.sleep(2)
81
+ self.run_replicate(retries + 1) # 再帰的に関数を呼び出して再試行
82
+ elif retries >= self.MAX_RETRIES:
83
+ print("最大再試行回数に達しました。スレッドを終了します。")
84
+ # 最大再試行回数に達した場合の追加処理
85
+ else:
86
+ print("予期しないエラーが発生しました。スレッドを終了します。")
87
+ # 予期しないエラーが発生した場合の追加処理
88
+ except Exception as e:
89
+ print(f"Error in thread {self.index}: {e}")
90
 
91
  def download_and_save(self, url, file_path):
92
  response = requests.get(url)
 
102
  class ThreadController:
103
  def __init__(self, args):
104
  self.args = args
 
105
  scenes = args.get("scenes")
106
+ self.videos = []
 
 
 
 
 
 
 
 
107
  self.threads = []
108
+ self.token_index = 0
109
+ self.lock = threading.Lock()
110
+ for index, scene in enumerate(scenes):
111
+ for _ in REPLICATE_API_TOKEN_LIST:
112
+ # token = REPLICATE_API_TOKEN_LIST[self.token_index]
113
+ video = Video(scene, index, self)
114
+ self.videos.append(video)
115
+ self.token_index = (self.token_index + 1) % len(REPLICATE_API_TOKEN_LIST)
116
 
117
  def run_threads(self):
118
  os.makedirs("assets", exist_ok=True)
 
121
  thread = threading.Thread(target=video.run_replicate)
122
  self.threads.append(thread)
123
  thread.start()
124
+ # 1秒待ってから実行
125
+ # time.sleep(1)
126
 
127
  for thread in self.threads:
128
  thread.join()
 
130
  def merge_videos(self):
131
  clips = []
132
  for video in self.videos:
133
+ video_path = Path(video.file_path)
134
+ if video_path.exists():
135
+ clips.append(VideoFileClip(video.file_path))
136
+ else:
137
+ print(f"Error: Video file {video.file_path} could not be found! Skipping this file.")
138
+ # 他のログ出力方法も使用可能、例: loggingモジュール
139
 
140
  final_clip = concatenate_videoclips(clips)
141
 
142
  os.makedirs("videos", exist_ok=True)
143
+ output_path = f"videos/final_concatenated_video_{uuid.uuid4()}.mp4"
144
 
145
  final_clip.write_videofile(output_path, codec='libx264', fps=24)
146
 
 
150
  for video in self.videos:
151
  print(f"Thread {video.index} prompt: {video.prompt}")
152
 
153
+ def get_next_token(self):
154
+ with self.lock:
155
+ token = REPLICATE_API_TOKEN_LIST[self.token_index]
156
+ self.token_index = (self.token_index + 1) % len(REPLICATE_API_TOKEN_LIST)
157
+ return token
158
+
159
  def main(args):
160
  thread_controller = ThreadController(args)
161
  thread_controller.run_threads()
 
203
  # ChatCompletion APIから返された結果を取得する
204
  message = response.choices[0].message
205
  print("chat completion message: " + json.dumps(message, indent=2))
206
+
207
+ return response
208
 
209
  class NajiminoAI:
210
 
 
260
  def create_video(self):
261
  main_start_time = time.time()
262
 
263
+ user_message = self.user_message + f" {NUMBER_OF_SCENES}シーン"
264
 
265
  messages = [
266
  {"role": "user", "content": user_message}
 
268
 
269
  functions = get_functions_from_schema('schema.json')
270
 
271
+ response = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions)
272
+
273
+ message = response.choices[0].message
274
+ total_tokens = response.usage.total_tokens
275
 
276
  video_path = None
277
  html = None
schema.json CHANGED
@@ -16,11 +16,11 @@
16
  "properties": {
17
  "title": {
18
  "type": "string",
19
- "description": "動画のタイトル"
20
  },
21
  "story": {
22
  "type": "string",
23
- "description": "動画のストーリーを詳しく時系列に説明する"
24
  },
25
  "visual_style": {
26
  "type": "string",
 
16
  "properties": {
17
  "title": {
18
  "type": "string",
19
+ "description": "映画のタイトル"
20
  },
21
  "story": {
22
  "type": "string",
23
+ "description": "映画のあらすじを、起承転結を交えて時系列に詳しく説明する"
24
  },
25
  "visual_style": {
26
  "type": "string",