import logging
import re
import cairosvg
import torch
from transformers import AutoModelForCausalLM
from lxml import etree
import kagglehub
from gen_image import ImageGenerator
from starvector.data.util import process_and_rasterize_svg
svg_constraints = kagglehub.package_import('metric/svg-constraints')
class DLModel:
def __init__(self, model_id="starvector/starvector-8b-im2svg", device="cuda"):
"""
Initialize the SVG generation pipeline using StarVector.
Args:
model_id (str): The model identifier for the StarVector model.
device (str): The device to run the model on, either "cuda" or "cpu".
"""
self.image_generator = ImageGenerator(model_id="stabilityai/stable-diffusion-2-1-base", device=device)
self.default_svg = """"""
self.constraints = svg_constraints.SVGConstraints()
self.timeout_seconds = 90
# Load StarVector model
self.device = device
self.starvector = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
trust_remote_code=True
)
self.processor = self.starvector.model.processor
self.starvector.to(device)
self.starvector.eval()
def predict(self, description):
"""
Generate an SVG from a text description.
Args:
description (str): The text description to generate an image from.
Returns:
str: The generated SVG content.
"""
try:
# Step 1: Generate image using diffusion model
images = self.image_generator.generate(description)
image = images[0]
# Save the generated image
image_path = "diff_image.png"
image.save(image_path)
logging.info(f"Intermediate image saved to {image_path}")
# Step 2: Convert image to SVG using StarVector
processed_image = self.processor(image, return_tensors="pt")['pixel_values'].to(self.device)
if not processed_image.shape[0] == 1:
processed_image = processed_image.squeeze(0)
batch = {"image": processed_image}
with torch.no_grad():
raw_svg = self.starvector.generate_im2svg(batch, max_length=4000)[0]
raw_svg, _ = process_and_rasterize_svg(raw_svg)
if 'viewBox' not in raw_svg:
raw_svg = raw_svg.replace('