update custom handler
Browse files- __pycache__/handler.cpython-38.pyc +0 -0
- handler.py +38 -12
- requirements.txt +2 -2
__pycache__/handler.cpython-38.pyc
CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
|
|
handler.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
from typing import Dict, List, Any
|
2 |
-
import base64
|
3 |
|
|
|
|
|
4 |
import math
|
5 |
import numpy as np
|
6 |
import tensorflow as tf
|
7 |
from tensorflow import keras
|
8 |
-
from keras_cv.models.
|
9 |
-
from keras_cv.models.
|
|
|
10 |
|
11 |
class GroupNormalization(tf.keras.layers.Layer):
|
12 |
"""GroupNormalization layer.
|
@@ -184,7 +186,7 @@ class ImageEncoder(keras.Sequential):
|
|
184 |
self.load_weights(image_encoder_weights_fpath)
|
185 |
|
186 |
class EndpointHandler():
|
187 |
-
def __init__(self, path=""):
|
188 |
self.seed = None
|
189 |
|
190 |
img_height = 512
|
@@ -193,15 +195,33 @@ class EndpointHandler():
|
|
193 |
self.img_width = round(img_width / 128) * 128
|
194 |
|
195 |
self.MAX_PROMPT_LENGTH = 77
|
196 |
-
self.
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
)
|
201 |
-
self.diffusion_model.load_weights(diffusion_model_weights_fpath)
|
202 |
|
203 |
self.image_encoder = ImageEncoder()
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
def _get_initial_diffusion_noise(self, batch_size, seed):
|
206 |
if seed is not None:
|
207 |
return tf.random.stateless_normal(
|
@@ -266,11 +286,17 @@ class EndpointHandler():
|
|
266 |
|
267 |
context = base64.b64decode(inputs[0])
|
268 |
context = np.frombuffer(context, dtype="float32")
|
269 |
-
|
|
|
|
|
|
|
270 |
|
271 |
unconditional_context = base64.b64decode(inputs[1])
|
272 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
273 |
-
|
|
|
|
|
|
|
274 |
|
275 |
num_steps = data.pop("num_steps", 25)
|
276 |
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
|
|
|
1 |
from typing import Dict, List, Any
|
|
|
2 |
|
3 |
+
import sys
|
4 |
+
import base64
|
5 |
import math
|
6 |
import numpy as np
|
7 |
import tensorflow as tf
|
8 |
from tensorflow import keras
|
9 |
+
from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
|
10 |
+
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
|
11 |
+
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModelV2
|
12 |
|
13 |
class GroupNormalization(tf.keras.layers.Layer):
|
14 |
"""GroupNormalization layer.
|
|
|
186 |
self.load_weights(image_encoder_weights_fpath)
|
187 |
|
188 |
class EndpointHandler():
|
189 |
+
def __init__(self, path="", version="2"):
|
190 |
self.seed = None
|
191 |
|
192 |
img_height = 512
|
|
|
195 |
self.img_width = round(img_width / 128) * 128
|
196 |
|
197 |
self.MAX_PROMPT_LENGTH = 77
|
198 |
+
self.version = version
|
199 |
+
self.diffusion_model = self._instantiate_diffusion_model(version)
|
200 |
+
if isinstance(self.diffusion_model, str):
|
201 |
+
sys.exit(self.diffusion_model)
|
|
|
|
|
202 |
|
203 |
self.image_encoder = ImageEncoder()
|
204 |
|
205 |
+
def _instantiate_diffusion_model(self, version: str):
|
206 |
+
if version == "1.4":
|
207 |
+
diffusion_model_weights_fpath = keras.utils.get_file(
|
208 |
+
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
|
209 |
+
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
|
210 |
+
)
|
211 |
+
diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
|
212 |
+
diffusion_model.load_weights(diffusion_model_weights_fpath)
|
213 |
+
return diffusion_model
|
214 |
+
elif version == "2":
|
215 |
+
diffusion_model_weights_fpath = keras.utils.get_file(
|
216 |
+
origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5",
|
217 |
+
file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d",
|
218 |
+
)
|
219 |
+
diffusion_model = DiffusionModelV2(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
|
220 |
+
diffusion_model.load_weights(diffusion_model_weights_fpath)
|
221 |
+
return diffusion_model
|
222 |
+
else:
|
223 |
+
return f"v{version} is not supported"
|
224 |
+
|
225 |
def _get_initial_diffusion_noise(self, batch_size, seed):
|
226 |
if seed is not None:
|
227 |
return tf.random.stateless_normal(
|
|
|
286 |
|
287 |
context = base64.b64decode(inputs[0])
|
288 |
context = np.frombuffer(context, dtype="float32")
|
289 |
+
if self.version == "1.4":
|
290 |
+
context = np.reshape(context, (batch_size, 77, 768))
|
291 |
+
else:
|
292 |
+
context = np.reshape(context, (batch_size, 77, 1024))
|
293 |
|
294 |
unconditional_context = base64.b64decode(inputs[1])
|
295 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
296 |
+
if self.version == "1.4":
|
297 |
+
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
|
298 |
+
else:
|
299 |
+
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 1024))
|
300 |
|
301 |
num_steps = data.pop("num_steps", 25)
|
302 |
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
keras-cv
|
2 |
-
tensorflow
|
3 |
tensorflow_datasets
|
|
|
1 |
+
keras-cv==0.4
|
2 |
+
tensorflow==2.11
|
3 |
tensorflow_datasets
|