import concurrent
import io
import logging
import re
import cairosvg
import kagglehub
import torch
from lxml import etree
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
svg_constraints = kagglehub.package_import('metric/svg-constraints')
class NaiveModel:
def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"):
self.device = device
self.max_seq_length = max_seq_length
self.load_in_4bit = True
# Load the Unsloth Phi-4 model
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=self.max_seq_length,
load_in_4bit=self.load_in_4bit
)
# Set up chat template
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template="phi-4",
)
# Prepare model for inference
FastLanguageModel.for_inference(self.model)
self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints.
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis.
"A red circle with a blue square inside"
```svg
```
"{}"
"""
self.default_svg = """"""
self.constraints = svg_constraints.SVGConstraints()
self.timeout_seconds = 90
def predict(self, description: str, max_new_tokens=512) -> str:
def generate_svg():
try:
# Format the prompt
prompt = self.prompt_template.format(description)
# Create messages in the format expected by the chat template
messages = [
{"role": "user", "content": prompt},
]
# Tokenize the messages
inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(self.device)
# Generate the output
outputs = self.model.generate(
input_ids=inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
temperature=1.0,
min_p=0.1,
do_sample=True,
)
# Decode the output
output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated text (skip the prompt)
generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else ""
logging.debug('Output decoded from model: %s', output_decoded)
matches = re.findall(r"", output_decoded, re.DOTALL | re.IGNORECASE)
if matches:
svg = matches[-1]
else:
return self.default_svg
logging.debug('Unprocessed SVG: %s', svg)
svg = self.enforce_constraints(svg)
logging.debug('Processed SVG: %s', svg)
# Ensure the generated code can be converted by cairosvg
cairosvg.svg2png(bytestring=svg.encode('utf-8'))
return svg
except Exception as e:
logging.error('Exception during SVG generation: %s', e)
return self.default_svg
# Execute SVG generation in a new thread to enforce time constraints
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(generate_svg)
try:
return future.result(timeout=self.timeout_seconds)
except concurrent.futures.TimeoutError:
logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
return self.default_svg
except Exception as e:
logging.error(f"An unexpected error occurred: {e}")
return self.default_svg
def enforce_constraints(self, svg_string: str) -> str:
"""Enforces constraints on an SVG string, removing disallowed elements
and attributes.
Parameters
----------
svg_string : str
The SVG string to process.
Returns
-------
str
The processed SVG string, or the default SVG if constraints
cannot be satisfied.
"""
logging.info('Sanitizing SVG...')
try:
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
root = etree.fromstring(svg_string, parser=parser)
except etree.ParseError as e:
logging.error('SVG Parse Error: %s. Returning default SVG.', e)
logging.error('SVG string: %s', svg_string)
return self.default_svg
elements_to_remove = []
for element in root.iter():
tag_name = etree.QName(element.tag).localname
# Remove disallowed elements
if tag_name not in self.constraints.allowed_elements:
elements_to_remove.append(element)
continue # Skip attribute checks for removed elements
# Remove disallowed attributes
attrs_to_remove = []
for attr in element.attrib:
attr_name = etree.QName(attr).localname
if (
attr_name
not in self.constraints.allowed_elements[tag_name]
and attr_name
not in self.constraints.allowed_elements['common']
):
attrs_to_remove.append(attr)
for attr in attrs_to_remove:
logging.debug(
'Attribute "%s" for element "%s" not allowed. Removing.',
attr,
tag_name,
)
del element.attrib[attr]
# Check and remove invalid href attributes
for attr, value in element.attrib.items():
if etree.QName(attr).localname == 'href' and not value.startswith('#'):
logging.debug(
'Removing invalid href attribute in element "%s".', tag_name
)
del element.attrib[attr]
# Validate path elements to help ensure SVG conversion
if tag_name == 'path':
d_attribute = element.get('d')
if not d_attribute:
logging.warning('Path element is missing "d" attribute. Removing path.')
elements_to_remove.append(element)
continue # Skip further checks for this removed element
# Use regex to validate 'd' attribute format
path_regex = re.compile(
r'^' # Start of string
r'(?:' # Non-capturing group for each command + numbers block
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
r'\s*' # Optional whitespace after command
r'(?:' # Non-capturing group for optional numbers
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
r')?' # Numbers are optional (e.g. for Z command)
r'\s*' # Optional whitespace after numbers/command block
r')+' # One or more command blocks
r'\s*' # Optional trailing whitespace
r'$' # End of string
)
if not path_regex.match(d_attribute):
logging.warning(
'Path element has malformed "d" attribute format. Removing path.'
)
elements_to_remove.append(element)
continue
logging.debug('Path element "d" attribute validated (regex check).')
# Remove elements marked for removal
for element in elements_to_remove:
if element.getparent() is not None:
element.getparent().remove(element)
logging.debug('Removed element: %s', element.tag)
try:
cleaned_svg_string = etree.tostring(root, encoding='unicode')
return cleaned_svg_string
except ValueError as e:
logging.error(
'SVG could not be sanitized to meet constraints: %s', e
)
return self.default_svg
if __name__ == "__main__":
model = NaiveModel()
svg = model.predict("a purple forest at dusk")
# Convert SVG to PNG
try:
# Create a PNG in memory
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
# Save the PNG to a file
with open("output.png", "wb") as f:
f.write(png_data)
print("SVG saved as output.png")
except Exception as e:
print(f"Error converting SVG to PNG: {e}")