import concurrent import io import logging import re import cairosvg import kagglehub import torch from lxml import etree from unsloth import FastLanguageModel from unsloth.chat_templates import get_chat_template svg_constraints = kagglehub.package_import('metric/svg-constraints') class NaiveModel: def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"): self.device = device self.max_seq_length = max_seq_length self.load_in_4bit = True # Load the Unsloth Phi-4 model self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=self.max_seq_length, load_in_4bit=self.load_in_4bit ) # Set up chat template self.tokenizer = get_chat_template( self.tokenizer, chat_template="phi-4", ) # Prepare model for inference FastLanguageModel.for_inference(self.model) self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints. * **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs` * **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity` Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis. "A red circle with a blue square inside" ```svg ``` "{}" """ self.default_svg = """""" self.constraints = svg_constraints.SVGConstraints() self.timeout_seconds = 90 def predict(self, description: str, max_new_tokens=512) -> str: def generate_svg(): try: # Format the prompt prompt = self.prompt_template.format(description) # Create messages in the format expected by the chat template messages = [ {"role": "user", "content": prompt}, ] # Tokenize the messages inputs = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(self.device) # Generate the output outputs = self.model.generate( input_ids=inputs, max_new_tokens=max_new_tokens, use_cache=True, temperature=1.0, min_p=0.1, do_sample=True, ) # Decode the output output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the generated text (skip the prompt) generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else "" logging.debug('Output decoded from model: %s', output_decoded) matches = re.findall(r"", output_decoded, re.DOTALL | re.IGNORECASE) if matches: svg = matches[-1] else: return self.default_svg logging.debug('Unprocessed SVG: %s', svg) svg = self.enforce_constraints(svg) logging.debug('Processed SVG: %s', svg) # Ensure the generated code can be converted by cairosvg cairosvg.svg2png(bytestring=svg.encode('utf-8')) return svg except Exception as e: logging.error('Exception during SVG generation: %s', e) return self.default_svg # Execute SVG generation in a new thread to enforce time constraints with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(generate_svg) try: return future.result(timeout=self.timeout_seconds) except concurrent.futures.TimeoutError: logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds) return self.default_svg except Exception as e: logging.error(f"An unexpected error occurred: {e}") return self.default_svg def enforce_constraints(self, svg_string: str) -> str: """Enforces constraints on an SVG string, removing disallowed elements and attributes. Parameters ---------- svg_string : str The SVG string to process. Returns ------- str The processed SVG string, or the default SVG if constraints cannot be satisfied. """ logging.info('Sanitizing SVG...') try: parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) root = etree.fromstring(svg_string, parser=parser) except etree.ParseError as e: logging.error('SVG Parse Error: %s. Returning default SVG.', e) logging.error('SVG string: %s', svg_string) return self.default_svg elements_to_remove = [] for element in root.iter(): tag_name = etree.QName(element.tag).localname # Remove disallowed elements if tag_name not in self.constraints.allowed_elements: elements_to_remove.append(element) continue # Skip attribute checks for removed elements # Remove disallowed attributes attrs_to_remove = [] for attr in element.attrib: attr_name = etree.QName(attr).localname if ( attr_name not in self.constraints.allowed_elements[tag_name] and attr_name not in self.constraints.allowed_elements['common'] ): attrs_to_remove.append(attr) for attr in attrs_to_remove: logging.debug( 'Attribute "%s" for element "%s" not allowed. Removing.', attr, tag_name, ) del element.attrib[attr] # Check and remove invalid href attributes for attr, value in element.attrib.items(): if etree.QName(attr).localname == 'href' and not value.startswith('#'): logging.debug( 'Removing invalid href attribute in element "%s".', tag_name ) del element.attrib[attr] # Validate path elements to help ensure SVG conversion if tag_name == 'path': d_attribute = element.get('d') if not d_attribute: logging.warning('Path element is missing "d" attribute. Removing path.') elements_to_remove.append(element) continue # Skip further checks for this removed element # Use regex to validate 'd' attribute format path_regex = re.compile( r'^' # Start of string r'(?:' # Non-capturing group for each command + numbers block r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) r'\s*' # Optional whitespace after command r'(?:' # Non-capturing group for optional numbers r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) r')?' # Numbers are optional (e.g. for Z command) r'\s*' # Optional whitespace after numbers/command block r')+' # One or more command blocks r'\s*' # Optional trailing whitespace r'$' # End of string ) if not path_regex.match(d_attribute): logging.warning( 'Path element has malformed "d" attribute format. Removing path.' ) elements_to_remove.append(element) continue logging.debug('Path element "d" attribute validated (regex check).') # Remove elements marked for removal for element in elements_to_remove: if element.getparent() is not None: element.getparent().remove(element) logging.debug('Removed element: %s', element.tag) try: cleaned_svg_string = etree.tostring(root, encoding='unicode') return cleaned_svg_string except ValueError as e: logging.error( 'SVG could not be sanitized to meet constraints: %s', e ) return self.default_svg if __name__ == "__main__": model = NaiveModel() svg = model.predict("a purple forest at dusk") # Convert SVG to PNG try: # Create a PNG in memory png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) # Save the PNG to a file with open("output.png", "wb") as f: f.write(png_data) print("SVG saved as output.png") except Exception as e: print(f"Error converting SVG to PNG: {e}")