chansung commited on
Commit
2143543
·
1 Parent(s): 758a74d

update custom handler

Browse files
__pycache__/handler.cpython-38.pyc CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
 
handler.py CHANGED
@@ -1,12 +1,14 @@
1
  from typing import Dict, List, Any
2
- import base64
3
 
 
 
4
  import math
5
  import numpy as np
6
  import tensorflow as tf
7
  from tensorflow import keras
8
- from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD
9
- from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
 
10
 
11
  class GroupNormalization(tf.keras.layers.Layer):
12
  """GroupNormalization layer.
@@ -184,7 +186,7 @@ class ImageEncoder(keras.Sequential):
184
  self.load_weights(image_encoder_weights_fpath)
185
 
186
  class EndpointHandler():
187
- def __init__(self, path=""):
188
  self.seed = None
189
 
190
  img_height = 512
@@ -193,15 +195,33 @@ class EndpointHandler():
193
  self.img_width = round(img_width / 128) * 128
194
 
195
  self.MAX_PROMPT_LENGTH = 77
196
- self.diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
197
- diffusion_model_weights_fpath = keras.utils.get_file(
198
- origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
199
- file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
200
- )
201
- self.diffusion_model.load_weights(diffusion_model_weights_fpath)
202
 
203
  self.image_encoder = ImageEncoder()
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def _get_initial_diffusion_noise(self, batch_size, seed):
206
  if seed is not None:
207
  return tf.random.stateless_normal(
@@ -266,11 +286,17 @@ class EndpointHandler():
266
 
267
  context = base64.b64decode(inputs[0])
268
  context = np.frombuffer(context, dtype="float32")
269
- context = np.reshape(context, (batch_size, 77, 768))
 
 
 
270
 
271
  unconditional_context = base64.b64decode(inputs[1])
272
  unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
273
- unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
 
 
 
274
 
275
  num_steps = data.pop("num_steps", 25)
276
  unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
 
1
  from typing import Dict, List, Any
 
2
 
3
+ import sys
4
+ import base64
5
  import math
6
  import numpy as np
7
  import tensorflow as tf
8
  from tensorflow import keras
9
+ from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
10
+ from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
11
+ from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModelV2
12
 
13
  class GroupNormalization(tf.keras.layers.Layer):
14
  """GroupNormalization layer.
 
186
  self.load_weights(image_encoder_weights_fpath)
187
 
188
  class EndpointHandler():
189
+ def __init__(self, path="", version="2"):
190
  self.seed = None
191
 
192
  img_height = 512
 
195
  self.img_width = round(img_width / 128) * 128
196
 
197
  self.MAX_PROMPT_LENGTH = 77
198
+ self.version = version
199
+ self.diffusion_model = self._instantiate_diffusion_model(version)
200
+ if isinstance(self.diffusion_model, str):
201
+ sys.exit(self.diffusion_model)
 
 
202
 
203
  self.image_encoder = ImageEncoder()
204
 
205
+ def _instantiate_diffusion_model(self, version: str):
206
+ if version == "1.4":
207
+ diffusion_model_weights_fpath = keras.utils.get_file(
208
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
209
+ file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
210
+ )
211
+ diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
212
+ diffusion_model.load_weights(diffusion_model_weights_fpath)
213
+ return diffusion_model
214
+ elif version == "2":
215
+ diffusion_model_weights_fpath = keras.utils.get_file(
216
+ origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5",
217
+ file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d",
218
+ )
219
+ diffusion_model = DiffusionModelV2(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
220
+ diffusion_model.load_weights(diffusion_model_weights_fpath)
221
+ return diffusion_model
222
+ else:
223
+ return f"v{version} is not supported"
224
+
225
  def _get_initial_diffusion_noise(self, batch_size, seed):
226
  if seed is not None:
227
  return tf.random.stateless_normal(
 
286
 
287
  context = base64.b64decode(inputs[0])
288
  context = np.frombuffer(context, dtype="float32")
289
+ if self.version == "1.4":
290
+ context = np.reshape(context, (batch_size, 77, 768))
291
+ else:
292
+ context = np.reshape(context, (batch_size, 77, 1024))
293
 
294
  unconditional_context = base64.b64decode(inputs[1])
295
  unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
296
+ if self.version == "1.4":
297
+ unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
298
+ else:
299
+ unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 1024))
300
 
301
  num_steps = data.pop("num_steps", 25)
302
  unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- keras-cv
2
- tensorflow
3
  tensorflow_datasets
 
1
+ keras-cv==0.4
2
+ tensorflow==2.11
3
  tensorflow_datasets