Spaces:
Running
on
Zero
Running
on
Zero
Native safety checker
#918
by
multimodalart
HF Staff
- opened
app.py
CHANGED
|
@@ -23,7 +23,7 @@ import user_history
|
|
| 23 |
from illusion_style import css
|
| 24 |
import os
|
| 25 |
from transformers import CLIPImageProcessor
|
| 26 |
-
from safety_checker import StableDiffusionSafetyChecker
|
| 27 |
|
| 28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 29 |
|
|
@@ -49,16 +49,16 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
| 49 |
).to("cuda")
|
| 50 |
|
| 51 |
# Function to check NSFW images
|
| 52 |
-
def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
| 53 |
-
if SAFETY_CHECKER_ENABLED:
|
| 54 |
-
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
| 55 |
-
has_nsfw_concepts = safety_checker(
|
| 56 |
-
images=[images],
|
| 57 |
-
clip_input=safety_checker_input.pixel_values.to("cuda")
|
| 58 |
-
)
|
| 59 |
-
return images, has_nsfw_concepts
|
| 60 |
-
else:
|
| 61 |
-
return images, [False] * len(images)
|
| 62 |
|
| 63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 64 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
|
@@ -284,4 +284,4 @@ with gr.Blocks(css=css) as app_with_history:
|
|
| 284 |
app_with_history.queue(max_size=20,api_open=False )
|
| 285 |
|
| 286 |
if __name__ == "__main__":
|
| 287 |
-
app_with_history.launch(max_threads=400)
|
|
|
|
| 23 |
from illusion_style import css
|
| 24 |
import os
|
| 25 |
from transformers import CLIPImageProcessor
|
| 26 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 27 |
|
| 28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 29 |
|
|
|
|
| 49 |
).to("cuda")
|
| 50 |
|
| 51 |
# Function to check NSFW images
|
| 52 |
+
#def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
| 53 |
+
# if SAFETY_CHECKER_ENABLED:
|
| 54 |
+
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
| 55 |
+
# has_nsfw_concepts = safety_checker(
|
| 56 |
+
# images=[images],
|
| 57 |
+
# clip_input=safety_checker_input.pixel_values.to("cuda")
|
| 58 |
+
# )
|
| 59 |
+
# return images, has_nsfw_concepts
|
| 60 |
+
# else:
|
| 61 |
+
# return images, [False] * len(images)
|
| 62 |
|
| 63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 64 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
|
|
|
| 284 |
app_with_history.queue(max_size=20,api_open=False )
|
| 285 |
|
| 286 |
if __name__ == "__main__":
|
| 287 |
+
app_with_history.launch(max_threads=400)
|