Jinglong Xiong
run evaluation
da1ac10
import os
import tempfile
import logging
import re
import subprocess
import cairosvg
from lxml import etree
import kagglehub
from gen_image import ImageGenerator
import vtracer
svg_constraints = kagglehub.package_import('metric/svg-constraints')
class MLModel:
def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"):
"""
Initialize the SVG generation pipeline.
Args:
model_id (str): The model identifier for the stable diffusion model.
device (str): The device to run the model on, either "cuda" or "cpu".
"""
self.image_generator = ImageGenerator(model_id=model_id, device=device)
self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
self.constraints = svg_constraints.SVGConstraints()
self.timeout_seconds = 90
def predict(self, description, simplify=True, color_precision=6,
filter_speckle=4, path_precision=8):
"""
Generate an SVG from a text description.
Args:
description (str): The text description to generate an image from.
simplify (bool): Whether to simplify the SVG paths.
color_precision (int): Color quantization precision.
filter_speckle (int): Filter speckle size.
path_precision (int): Path fitting precision.
Returns:
str: The generated SVG content.
"""
try:
# Step 1: Generate image using diffusion model
images = self.image_generator.generate(description)
image = images[0]
# Step 2: Save image to a temporary file
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
temp_img_path = temp_img.name
image.save(temp_img_path)
# Step 3: Convert image to SVG using vtracer
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
temp_svg_path = temp_svg.name
# Process the image with vtracer using parameters directly
vtracer.convert_image_to_svg_py(
temp_img_path,
temp_svg_path,
colormode='color',
hierarchical='stacked' if simplify else 'cutout',
mode='spline',
filter_speckle=filter_speckle,
color_precision=color_precision,
path_precision=path_precision,
corner_threshold=60,
length_threshold=4.0,
max_iterations=10,
splice_threshold=45
)
# Step 4: Read the generated SVG
with open(temp_svg_path, 'r') as f:
svg_content = f.read()
# Clean up temporary files
os.unlink(temp_img_path)
os.unlink(temp_svg_path)
# Step 5: Enforce constraints
svg_content = self.enforce_constraints(svg_content)
return svg_content
except Exception as e:
logging.error(f"Error generating SVG: {e}")
return self.default_svg
def enforce_constraints(self, svg_string: str) -> str:
"""Enforces constraints on an SVG string, removing disallowed elements
and attributes.
Parameters
----------
svg_string : str
The SVG string to process.
Returns
-------
str
The processed SVG string, or the default SVG if constraints
cannot be satisfied.
"""
logging.info('Sanitizing SVG...')
try:
# Remove XML declaration if it exists
svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip()
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
root = etree.fromstring(svg_string, parser=parser)
except etree.ParseError as e:
logging.error('SVG Parse Error: %s. Returning default SVG.', e)
logging.error('SVG string: %s', svg_string)
return self.default_svg
elements_to_remove = []
for element in root.iter():
tag_name = etree.QName(element.tag).localname
# Remove disallowed elements
if tag_name not in self.constraints.allowed_elements:
elements_to_remove.append(element)
continue # Skip attribute checks for removed elements
# Remove disallowed attributes
attrs_to_remove = []
for attr in element.attrib:
attr_name = etree.QName(attr).localname
if (
attr_name
not in self.constraints.allowed_elements[tag_name]
and attr_name
not in self.constraints.allowed_elements['common']
):
attrs_to_remove.append(attr)
for attr in attrs_to_remove:
logging.debug(
'Attribute "%s" for element "%s" not allowed. Removing.',
attr,
tag_name,
)
del element.attrib[attr]
# Check and remove invalid href attributes
for attr, value in element.attrib.items():
if etree.QName(attr).localname == 'href' and not value.startswith('#'):
logging.debug(
'Removing invalid href attribute in element "%s".', tag_name
)
del element.attrib[attr]
# Validate path elements to help ensure SVG conversion
if tag_name == 'path':
d_attribute = element.get('d')
if not d_attribute:
logging.warning('Path element is missing "d" attribute. Removing path.')
elements_to_remove.append(element)
continue # Skip further checks for this removed element
# Use regex to validate 'd' attribute format
path_regex = re.compile(
r'^' # Start of string
r'(?:' # Non-capturing group for each command + numbers block
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
r'\s*' # Optional whitespace after command
r'(?:' # Non-capturing group for optional numbers
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
r')?' # Numbers are optional (e.g. for Z command)
r'\s*' # Optional whitespace after numbers/command block
r')+' # One or more command blocks
r'\s*' # Optional trailing whitespace
r'$' # End of string
)
if not path_regex.match(d_attribute):
logging.warning(
'Path element has malformed "d" attribute format. Removing path.'
)
elements_to_remove.append(element)
continue
logging.debug('Path element "d" attribute validated (regex check).')
# Remove elements marked for removal
for element in elements_to_remove:
if element.getparent() is not None:
element.getparent().remove(element)
logging.debug('Removed element: %s', element.tag)
try:
cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False)
return cleaned_svg_string
except ValueError as e:
logging.error(
'SVG could not be sanitized to meet constraints: %s', e
)
return self.default_svg
def optimize_svg(self, svg_content):
"""
Optimize the SVG content using SVGO.
Args:
svg_content (str): The SVG content to optimize.
Returns:
str: The optimized SVG content.
"""
try:
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
temp_svg_path = temp_svg.name
temp_svg.write(svg_content.encode('utf-8'))
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_out:
temp_out_path = temp_out.name
subprocess.run(["svgo", temp_svg_path, "-o", temp_out_path], check=True)
with open(temp_out_path, 'r') as f:
optimized_svg = f.read()
os.unlink(temp_svg_path)
os.unlink(temp_out_path)
return optimized_svg
except (FileNotFoundError, subprocess.CalledProcessError):
print("Warning: SVGO not found or failed. Returning unoptimized SVG.")
return svg_content
# Example usage
if __name__ == "__main__":
model = MLModel()
svg = model.predict("a purple forest at dusk")
# Convert SVG to PNG
try:
# Create a PNG in memory
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
# Save the PNG to a file
with open("output.png", "wb") as f:
f.write(png_data)
print("SVG saved as output.png")
except Exception as e:
print(f"Error converting SVG to PNG: {e}")