Spaces:
Paused
Paused
Jinglong Xiong
commited on
Commit
Β·
6642f4e
1
Parent(s):
bb5a422
add models
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitignore +2 -0
- dl.py +203 -0
- gen_image.py +2 -2
- ml.py +246 -0
- naive.py +246 -0
- requirements.txt +6 -0
- starter.ipynb +0 -333
- starvector/__init__.py +0 -0
- starvector/adapter.py +53 -0
- starvector/clip_model.py +191 -0
- starvector/data/augmentation.py +250 -0
- starvector/data/base.py +71 -0
- starvector/data/dataset.py +42 -0
- starvector/data/emojisvg.py +27 -0
- starvector/data/figrsvg.py +27 -0
- starvector/data/fontsvg.py +28 -0
- starvector/data/iconsvg.py +38 -0
- starvector/data/stacksvg.py +59 -0
- starvector/data/util.py +389 -0
- starvector/image_encoder.py +119 -0
- starvector/metrics/base_metric.py +51 -0
- starvector/metrics/compute_LPIPS.py +56 -0
- starvector/metrics/compute_SSIM.py +35 -0
- starvector/metrics/compute_clip_score.py +55 -0
- starvector/metrics/compute_dino_score.py +55 -0
- starvector/metrics/compute_fid.py +145 -0
- starvector/metrics/compute_l2.py +37 -0
- starvector/metrics/count_token_length.py +54 -0
- starvector/metrics/inception.py +341 -0
- starvector/metrics/metrics.py +127 -0
- starvector/metrics/util.py +20 -0
- starvector/model/adapters/adapter.py +53 -0
- starvector/model/builder.py +49 -0
- starvector/model/gpt_bigcode/__init__.py +65 -0
- starvector/model/gpt_bigcode/configuration_gpt_bigcode.py +143 -0
- starvector/model/gpt_bigcode/modeling_gpt_bigcode.py +1502 -0
- starvector/model/image_encoder/clip_model.py +191 -0
- starvector/model/image_encoder/image_encoder.py +120 -0
- starvector/model/llm/starcoder.py +51 -0
- starvector/model/llm/starcoder2.py +61 -0
- starvector/model/models/starvector_base.py +339 -0
- starvector/model/models/starvector_v1.py +22 -0
- starvector/model/models/starvector_v2.py +63 -0
- starvector/model/starvector_arch.py +194 -0
- starvector/serve/__init__.py +0 -0
- starvector/serve/constants.py +16 -0
- starvector/serve/controller.py +293 -0
- starvector/serve/conversation.py +211 -0
- starvector/serve/gradio_demo_with_updated_gradio.py +432 -0
- starvector/serve/gradio_web_server.py +562 -0
.gitignore
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
star-vector/
|
2 |
SVGDreamer/
|
3 |
*.parquet
|
|
|
1 |
+
unsloth_compiled_cache/
|
2 |
+
*.ipynb
|
3 |
star-vector/
|
4 |
SVGDreamer/
|
5 |
*.parquet
|
dl.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import cairosvg
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM
|
6 |
+
from lxml import etree
|
7 |
+
import kagglehub
|
8 |
+
from gen_image import ImageGenerator
|
9 |
+
from starvector.data.util import process_and_rasterize_svg
|
10 |
+
|
11 |
+
svg_constraints = kagglehub.package_import('metric/svg-constraints')
|
12 |
+
|
13 |
+
class DLModel:
|
14 |
+
def __init__(self, model_id="starvector/starvector-8b-im2svg", device="cuda"):
|
15 |
+
"""
|
16 |
+
Initialize the SVG generation pipeline using StarVector.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_id (str): The model identifier for the StarVector model.
|
20 |
+
device (str): The device to run the model on, either "cuda" or "cpu".
|
21 |
+
"""
|
22 |
+
self.image_generator = ImageGenerator(model_id="stabilityai/stable-diffusion-2-1-base", device=device)
|
23 |
+
self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
|
24 |
+
self.constraints = svg_constraints.SVGConstraints()
|
25 |
+
self.timeout_seconds = 90
|
26 |
+
|
27 |
+
# Load StarVector model
|
28 |
+
self.device = device
|
29 |
+
self.starvector = AutoModelForCausalLM.from_pretrained(
|
30 |
+
model_id,
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
trust_remote_code=True
|
33 |
+
)
|
34 |
+
self.processor = self.starvector.model.processor
|
35 |
+
self.starvector.to(device)
|
36 |
+
self.starvector.eval()
|
37 |
+
|
38 |
+
def predict(self, description):
|
39 |
+
"""
|
40 |
+
Generate an SVG from a text description.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
description (str): The text description to generate an image from.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
str: The generated SVG content.
|
47 |
+
"""
|
48 |
+
try:
|
49 |
+
# Step 1: Generate image using diffusion model
|
50 |
+
images = self.image_generator.generate(description)
|
51 |
+
image = images[0]
|
52 |
+
|
53 |
+
# Save the generated image
|
54 |
+
image_path = "diff_image.png"
|
55 |
+
image.save(image_path)
|
56 |
+
logging.info(f"Intermediate image saved to {image_path}")
|
57 |
+
|
58 |
+
# Step 2: Convert image to SVG using StarVector
|
59 |
+
processed_image = self.processor(image, return_tensors="pt")['pixel_values'].to(self.device)
|
60 |
+
if not processed_image.shape[0] == 1:
|
61 |
+
processed_image = processed_image.squeeze(0)
|
62 |
+
|
63 |
+
batch = {"image": processed_image}
|
64 |
+
with torch.no_grad():
|
65 |
+
raw_svg = self.starvector.generate_im2svg(batch, max_length=4000)[0]
|
66 |
+
raw_svg, _ = process_and_rasterize_svg(raw_svg)
|
67 |
+
|
68 |
+
if 'viewBox' not in raw_svg:
|
69 |
+
raw_svg = raw_svg.replace('<svg', f'<svg viewBox="0 0 384 384"')
|
70 |
+
|
71 |
+
# Step 3: Enforce constraints
|
72 |
+
svg_content = self.enforce_constraints(raw_svg)
|
73 |
+
|
74 |
+
return svg_content
|
75 |
+
except Exception as e:
|
76 |
+
logging.error(f"Error generating SVG: {e}")
|
77 |
+
return self.default_svg
|
78 |
+
|
79 |
+
def enforce_constraints(self, svg_string: str) -> str:
|
80 |
+
"""Enforces constraints on an SVG string, removing disallowed elements
|
81 |
+
and attributes.
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
svg_string : str
|
86 |
+
The SVG string to process.
|
87 |
+
|
88 |
+
Returns
|
89 |
+
-------
|
90 |
+
str
|
91 |
+
The processed SVG string, or the default SVG if constraints
|
92 |
+
cannot be satisfied.
|
93 |
+
"""
|
94 |
+
logging.info('Sanitizing SVG...')
|
95 |
+
|
96 |
+
try:
|
97 |
+
# Remove XML declaration if it exists
|
98 |
+
svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip()
|
99 |
+
|
100 |
+
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
|
101 |
+
root = etree.fromstring(svg_string, parser=parser)
|
102 |
+
except etree.ParseError as e:
|
103 |
+
logging.error('SVG Parse Error: %s. Returning default SVG.', e)
|
104 |
+
logging.error('SVG string: %s', svg_string)
|
105 |
+
return self.default_svg
|
106 |
+
|
107 |
+
elements_to_remove = []
|
108 |
+
for element in root.iter():
|
109 |
+
tag_name = etree.QName(element.tag).localname
|
110 |
+
|
111 |
+
# Remove disallowed elements
|
112 |
+
if tag_name not in self.constraints.allowed_elements:
|
113 |
+
elements_to_remove.append(element)
|
114 |
+
continue # Skip attribute checks for removed elements
|
115 |
+
|
116 |
+
# Remove disallowed attributes
|
117 |
+
attrs_to_remove = []
|
118 |
+
for attr in element.attrib:
|
119 |
+
attr_name = etree.QName(attr).localname
|
120 |
+
if (
|
121 |
+
attr_name
|
122 |
+
not in self.constraints.allowed_elements[tag_name]
|
123 |
+
and attr_name
|
124 |
+
not in self.constraints.allowed_elements['common']
|
125 |
+
):
|
126 |
+
attrs_to_remove.append(attr)
|
127 |
+
|
128 |
+
for attr in attrs_to_remove:
|
129 |
+
logging.debug(
|
130 |
+
'Attribute "%s" for element "%s" not allowed. Removing.',
|
131 |
+
attr,
|
132 |
+
tag_name,
|
133 |
+
)
|
134 |
+
del element.attrib[attr]
|
135 |
+
|
136 |
+
# Check and remove invalid href attributes
|
137 |
+
for attr, value in element.attrib.items():
|
138 |
+
if etree.QName(attr).localname == 'href' and not value.startswith('#'):
|
139 |
+
logging.debug(
|
140 |
+
'Removing invalid href attribute in element "%s".', tag_name
|
141 |
+
)
|
142 |
+
del element.attrib[attr]
|
143 |
+
|
144 |
+
# Validate path elements to help ensure SVG conversion
|
145 |
+
if tag_name == 'path':
|
146 |
+
d_attribute = element.get('d')
|
147 |
+
if not d_attribute:
|
148 |
+
logging.warning('Path element is missing "d" attribute. Removing path.')
|
149 |
+
elements_to_remove.append(element)
|
150 |
+
continue # Skip further checks for this removed element
|
151 |
+
# Use regex to validate 'd' attribute format
|
152 |
+
path_regex = re.compile(
|
153 |
+
r'^' # Start of string
|
154 |
+
r'(?:' # Non-capturing group for each command + numbers block
|
155 |
+
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
|
156 |
+
r'\s*' # Optional whitespace after command
|
157 |
+
r'(?:' # Non-capturing group for optional numbers
|
158 |
+
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
|
159 |
+
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
|
160 |
+
r')?' # Numbers are optional (e.g. for Z command)
|
161 |
+
r'\s*' # Optional whitespace after numbers/command block
|
162 |
+
r')+' # One or more command blocks
|
163 |
+
r'\s*' # Optional trailing whitespace
|
164 |
+
r'$' # End of string
|
165 |
+
)
|
166 |
+
if not path_regex.match(d_attribute):
|
167 |
+
logging.warning(
|
168 |
+
'Path element has malformed "d" attribute format. Removing path.'
|
169 |
+
)
|
170 |
+
elements_to_remove.append(element)
|
171 |
+
continue
|
172 |
+
logging.debug('Path element "d" attribute validated (regex check).')
|
173 |
+
|
174 |
+
# Remove elements marked for removal
|
175 |
+
for element in elements_to_remove:
|
176 |
+
if element.getparent() is not None:
|
177 |
+
element.getparent().remove(element)
|
178 |
+
logging.debug('Removed element: %s', element.tag)
|
179 |
+
|
180 |
+
try:
|
181 |
+
cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False)
|
182 |
+
return cleaned_svg_string
|
183 |
+
except ValueError as e:
|
184 |
+
logging.error(
|
185 |
+
'SVG could not be sanitized to meet constraints: %s', e
|
186 |
+
)
|
187 |
+
return self.default_svg
|
188 |
+
|
189 |
+
# Example usage
|
190 |
+
if __name__ == "__main__":
|
191 |
+
model = DLModel()
|
192 |
+
svg = model.predict("a purple forest at dusk")
|
193 |
+
# Convert SVG to PNG
|
194 |
+
try:
|
195 |
+
# Create a PNG in memory
|
196 |
+
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
|
197 |
+
|
198 |
+
# Save the PNG to a file
|
199 |
+
with open("output.png", "wb") as f:
|
200 |
+
f.write(png_data)
|
201 |
+
print("SVG saved as output.png")
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Error converting SVG to PNG: {e}")
|
gen_image.py
CHANGED
@@ -35,7 +35,7 @@ class ImageGenerator:
|
|
35 |
num_images (int, optional): Number of images to generate.
|
36 |
|
37 |
Returns:
|
38 |
-
PIL.Image.Image: The generated
|
39 |
"""
|
40 |
prompt = f"{prompt}, {self.positive_prompt}"
|
41 |
if negative_prompt is None:
|
@@ -51,7 +51,7 @@ class ImageGenerator:
|
|
51 |
for i, image in enumerate(images):
|
52 |
image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}")
|
53 |
|
54 |
-
return
|
55 |
|
56 |
# Example usage
|
57 |
if __name__ == "__main__":
|
|
|
35 |
num_images (int, optional): Number of images to generate.
|
36 |
|
37 |
Returns:
|
38 |
+
list[PIL.Image.Image]: The generated images.
|
39 |
"""
|
40 |
prompt = f"{prompt}, {self.positive_prompt}"
|
41 |
if negative_prompt is None:
|
|
|
51 |
for i, image in enumerate(images):
|
52 |
image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}")
|
53 |
|
54 |
+
return images
|
55 |
|
56 |
# Example usage
|
57 |
if __name__ == "__main__":
|
ml.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
import cairosvg
|
7 |
+
from lxml import etree
|
8 |
+
import kagglehub
|
9 |
+
from gen_image import ImageGenerator
|
10 |
+
import vtracer
|
11 |
+
|
12 |
+
svg_constraints = kagglehub.package_import('metric/svg-constraints')
|
13 |
+
|
14 |
+
class MLModel:
|
15 |
+
def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"):
|
16 |
+
"""
|
17 |
+
Initialize the SVG generation pipeline.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_id (str): The model identifier for the stable diffusion model.
|
21 |
+
device (str): The device to run the model on, either "cuda" or "cpu".
|
22 |
+
"""
|
23 |
+
self.image_generator = ImageGenerator(model_id=model_id, device=device)
|
24 |
+
self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
|
25 |
+
self.constraints = svg_constraints.SVGConstraints()
|
26 |
+
self.timeout_seconds = 90
|
27 |
+
|
28 |
+
def predict(self, description, simplify=True, color_precision=6,
|
29 |
+
gradient_step=10, filter_speckle=4, path_precision=8):
|
30 |
+
"""
|
31 |
+
Generate an SVG from a text description.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
description (str): The text description to generate an image from.
|
35 |
+
simplify (bool): Whether to simplify the SVG paths.
|
36 |
+
color_precision (int): Color quantization precision.
|
37 |
+
gradient_step (int): Gradient step for color quantization (not used by vtracer).
|
38 |
+
filter_speckle (int): Filter speckle size.
|
39 |
+
path_precision (int): Path fitting precision.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
str: The generated SVG content.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
# Step 1: Generate image using diffusion model
|
46 |
+
images = self.image_generator.generate(description)
|
47 |
+
image = images[0]
|
48 |
+
|
49 |
+
# Step 2: Save image to a temporary file
|
50 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
|
51 |
+
temp_img_path = temp_img.name
|
52 |
+
image.save(temp_img_path)
|
53 |
+
|
54 |
+
# Step 3: Convert image to SVG using vtracer
|
55 |
+
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
|
56 |
+
temp_svg_path = temp_svg.name
|
57 |
+
|
58 |
+
# Process the image with vtracer using parameters directly
|
59 |
+
vtracer.convert_image_to_svg_py(
|
60 |
+
temp_img_path,
|
61 |
+
temp_svg_path,
|
62 |
+
colormode='color',
|
63 |
+
hierarchical='stacked' if simplify else 'cutout',
|
64 |
+
mode='spline',
|
65 |
+
filter_speckle=filter_speckle,
|
66 |
+
color_precision=color_precision,
|
67 |
+
path_precision=path_precision,
|
68 |
+
corner_threshold=60,
|
69 |
+
length_threshold=4.0,
|
70 |
+
max_iterations=10,
|
71 |
+
splice_threshold=45
|
72 |
+
)
|
73 |
+
|
74 |
+
# Step 4: Read the generated SVG
|
75 |
+
with open(temp_svg_path, 'r') as f:
|
76 |
+
svg_content = f.read()
|
77 |
+
|
78 |
+
# Clean up temporary files
|
79 |
+
os.unlink(temp_img_path)
|
80 |
+
os.unlink(temp_svg_path)
|
81 |
+
|
82 |
+
# Step 5: Enforce constraints
|
83 |
+
svg_content = self.enforce_constraints(svg_content)
|
84 |
+
|
85 |
+
return svg_content
|
86 |
+
except Exception as e:
|
87 |
+
logging.error(f"Error generating SVG: {e}")
|
88 |
+
return self.default_svg
|
89 |
+
|
90 |
+
def enforce_constraints(self, svg_string: str) -> str:
|
91 |
+
"""Enforces constraints on an SVG string, removing disallowed elements
|
92 |
+
and attributes.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
svg_string : str
|
97 |
+
The SVG string to process.
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
str
|
102 |
+
The processed SVG string, or the default SVG if constraints
|
103 |
+
cannot be satisfied.
|
104 |
+
"""
|
105 |
+
logging.info('Sanitizing SVG...')
|
106 |
+
|
107 |
+
try:
|
108 |
+
# Remove XML declaration if it exists
|
109 |
+
svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip()
|
110 |
+
|
111 |
+
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
|
112 |
+
root = etree.fromstring(svg_string, parser=parser)
|
113 |
+
except etree.ParseError as e:
|
114 |
+
logging.error('SVG Parse Error: %s. Returning default SVG.', e)
|
115 |
+
logging.error('SVG string: %s', svg_string)
|
116 |
+
return self.default_svg
|
117 |
+
|
118 |
+
elements_to_remove = []
|
119 |
+
for element in root.iter():
|
120 |
+
tag_name = etree.QName(element.tag).localname
|
121 |
+
|
122 |
+
# Remove disallowed elements
|
123 |
+
if tag_name not in self.constraints.allowed_elements:
|
124 |
+
elements_to_remove.append(element)
|
125 |
+
continue # Skip attribute checks for removed elements
|
126 |
+
|
127 |
+
# Remove disallowed attributes
|
128 |
+
attrs_to_remove = []
|
129 |
+
for attr in element.attrib:
|
130 |
+
attr_name = etree.QName(attr).localname
|
131 |
+
if (
|
132 |
+
attr_name
|
133 |
+
not in self.constraints.allowed_elements[tag_name]
|
134 |
+
and attr_name
|
135 |
+
not in self.constraints.allowed_elements['common']
|
136 |
+
):
|
137 |
+
attrs_to_remove.append(attr)
|
138 |
+
|
139 |
+
for attr in attrs_to_remove:
|
140 |
+
logging.debug(
|
141 |
+
'Attribute "%s" for element "%s" not allowed. Removing.',
|
142 |
+
attr,
|
143 |
+
tag_name,
|
144 |
+
)
|
145 |
+
del element.attrib[attr]
|
146 |
+
|
147 |
+
# Check and remove invalid href attributes
|
148 |
+
for attr, value in element.attrib.items():
|
149 |
+
if etree.QName(attr).localname == 'href' and not value.startswith('#'):
|
150 |
+
logging.debug(
|
151 |
+
'Removing invalid href attribute in element "%s".', tag_name
|
152 |
+
)
|
153 |
+
del element.attrib[attr]
|
154 |
+
|
155 |
+
# Validate path elements to help ensure SVG conversion
|
156 |
+
if tag_name == 'path':
|
157 |
+
d_attribute = element.get('d')
|
158 |
+
if not d_attribute:
|
159 |
+
logging.warning('Path element is missing "d" attribute. Removing path.')
|
160 |
+
elements_to_remove.append(element)
|
161 |
+
continue # Skip further checks for this removed element
|
162 |
+
# Use regex to validate 'd' attribute format
|
163 |
+
path_regex = re.compile(
|
164 |
+
r'^' # Start of string
|
165 |
+
r'(?:' # Non-capturing group for each command + numbers block
|
166 |
+
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
|
167 |
+
r'\s*' # Optional whitespace after command
|
168 |
+
r'(?:' # Non-capturing group for optional numbers
|
169 |
+
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
|
170 |
+
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
|
171 |
+
r')?' # Numbers are optional (e.g. for Z command)
|
172 |
+
r'\s*' # Optional whitespace after numbers/command block
|
173 |
+
r')+' # One or more command blocks
|
174 |
+
r'\s*' # Optional trailing whitespace
|
175 |
+
r'$' # End of string
|
176 |
+
)
|
177 |
+
if not path_regex.match(d_attribute):
|
178 |
+
logging.warning(
|
179 |
+
'Path element has malformed "d" attribute format. Removing path.'
|
180 |
+
)
|
181 |
+
elements_to_remove.append(element)
|
182 |
+
continue
|
183 |
+
logging.debug('Path element "d" attribute validated (regex check).')
|
184 |
+
|
185 |
+
# Remove elements marked for removal
|
186 |
+
for element in elements_to_remove:
|
187 |
+
if element.getparent() is not None:
|
188 |
+
element.getparent().remove(element)
|
189 |
+
logging.debug('Removed element: %s', element.tag)
|
190 |
+
|
191 |
+
try:
|
192 |
+
cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False)
|
193 |
+
return cleaned_svg_string
|
194 |
+
except ValueError as e:
|
195 |
+
logging.error(
|
196 |
+
'SVG could not be sanitized to meet constraints: %s', e
|
197 |
+
)
|
198 |
+
return self.default_svg
|
199 |
+
|
200 |
+
def optimize_svg(self, svg_content):
|
201 |
+
"""
|
202 |
+
Optimize the SVG content using SVGO.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
svg_content (str): The SVG content to optimize.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
str: The optimized SVG content.
|
209 |
+
"""
|
210 |
+
try:
|
211 |
+
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
|
212 |
+
temp_svg_path = temp_svg.name
|
213 |
+
temp_svg.write(svg_content.encode('utf-8'))
|
214 |
+
|
215 |
+
with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_out:
|
216 |
+
temp_out_path = temp_out.name
|
217 |
+
|
218 |
+
subprocess.run(["svgo", temp_svg_path, "-o", temp_out_path], check=True)
|
219 |
+
|
220 |
+
with open(temp_out_path, 'r') as f:
|
221 |
+
optimized_svg = f.read()
|
222 |
+
|
223 |
+
os.unlink(temp_svg_path)
|
224 |
+
os.unlink(temp_out_path)
|
225 |
+
|
226 |
+
return optimized_svg
|
227 |
+
except (FileNotFoundError, subprocess.CalledProcessError):
|
228 |
+
print("Warning: SVGO not found or failed. Returning unoptimized SVG.")
|
229 |
+
return svg_content
|
230 |
+
|
231 |
+
|
232 |
+
# Example usage
|
233 |
+
if __name__ == "__main__":
|
234 |
+
model = MLModel()
|
235 |
+
svg = model.predict("a purple forest at dusk")
|
236 |
+
# Convert SVG to PNG
|
237 |
+
try:
|
238 |
+
# Create a PNG in memory
|
239 |
+
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
|
240 |
+
|
241 |
+
# Save the PNG to a file
|
242 |
+
with open("output.png", "wb") as f:
|
243 |
+
f.write(png_data)
|
244 |
+
print("SVG saved as output.png")
|
245 |
+
except Exception as e:
|
246 |
+
print(f"Error converting SVG to PNG: {e}")
|
naive.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
|
6 |
+
import cairosvg
|
7 |
+
import kagglehub
|
8 |
+
import torch
|
9 |
+
from lxml import etree
|
10 |
+
from unsloth import FastLanguageModel
|
11 |
+
from unsloth.chat_templates import get_chat_template
|
12 |
+
|
13 |
+
svg_constraints = kagglehub.package_import('metric/svg-constraints')
|
14 |
+
|
15 |
+
class NaiveModel:
|
16 |
+
def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"):
|
17 |
+
self.device = device
|
18 |
+
self.max_seq_length = max_seq_length
|
19 |
+
self.load_in_4bit = True
|
20 |
+
|
21 |
+
# Load the Unsloth Phi-4 model
|
22 |
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
23 |
+
model_name=model_name,
|
24 |
+
max_seq_length=self.max_seq_length,
|
25 |
+
load_in_4bit=self.load_in_4bit
|
26 |
+
)
|
27 |
+
|
28 |
+
# Set up chat template
|
29 |
+
self.tokenizer = get_chat_template(
|
30 |
+
self.tokenizer,
|
31 |
+
chat_template="phi-4",
|
32 |
+
)
|
33 |
+
|
34 |
+
# Prepare model for inference
|
35 |
+
FastLanguageModel.for_inference(self.model)
|
36 |
+
|
37 |
+
self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints.
|
38 |
+
<constraints>
|
39 |
+
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
|
40 |
+
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
|
41 |
+
</constraints>
|
42 |
+
|
43 |
+
Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis.
|
44 |
+
|
45 |
+
<description>"A red circle with a blue square inside"</description>
|
46 |
+
```svg
|
47 |
+
<svg viewBox="0 0 256 256" width="256" height="256">
|
48 |
+
<circle cx="50" cy="50" r="40" fill="red"/>
|
49 |
+
<rect x="30" y="30" width="40" height="40" fill="blue"/>
|
50 |
+
</svg>
|
51 |
+
```
|
52 |
+
|
53 |
+
<description>"{}"</description>
|
54 |
+
"""
|
55 |
+
self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
|
56 |
+
self.constraints = svg_constraints.SVGConstraints()
|
57 |
+
self.timeout_seconds = 90
|
58 |
+
|
59 |
+
def predict(self, description: str, max_new_tokens=512) -> str:
|
60 |
+
def generate_svg():
|
61 |
+
try:
|
62 |
+
# Format the prompt
|
63 |
+
prompt = self.prompt_template.format(description)
|
64 |
+
|
65 |
+
# Create messages in the format expected by the chat template
|
66 |
+
messages = [
|
67 |
+
{"role": "user", "content": prompt},
|
68 |
+
]
|
69 |
+
|
70 |
+
# Tokenize the messages
|
71 |
+
inputs = self.tokenizer.apply_chat_template(
|
72 |
+
messages,
|
73 |
+
tokenize=True,
|
74 |
+
add_generation_prompt=True,
|
75 |
+
return_tensors="pt",
|
76 |
+
).to(self.device)
|
77 |
+
|
78 |
+
# Generate the output
|
79 |
+
outputs = self.model.generate(
|
80 |
+
input_ids=inputs,
|
81 |
+
max_new_tokens=max_new_tokens,
|
82 |
+
use_cache=True,
|
83 |
+
temperature=1.0,
|
84 |
+
min_p=0.1,
|
85 |
+
do_sample=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Decode the output
|
89 |
+
output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
90 |
+
|
91 |
+
# Extract only the generated text (skip the prompt)
|
92 |
+
generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else ""
|
93 |
+
|
94 |
+
logging.debug('Output decoded from model: %s', output_decoded)
|
95 |
+
|
96 |
+
matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE)
|
97 |
+
if matches:
|
98 |
+
svg = matches[-1]
|
99 |
+
else:
|
100 |
+
return self.default_svg
|
101 |
+
|
102 |
+
logging.debug('Unprocessed SVG: %s', svg)
|
103 |
+
svg = self.enforce_constraints(svg)
|
104 |
+
logging.debug('Processed SVG: %s', svg)
|
105 |
+
|
106 |
+
# Ensure the generated code can be converted by cairosvg
|
107 |
+
cairosvg.svg2png(bytestring=svg.encode('utf-8'))
|
108 |
+
return svg
|
109 |
+
except Exception as e:
|
110 |
+
logging.error('Exception during SVG generation: %s', e)
|
111 |
+
return self.default_svg
|
112 |
+
|
113 |
+
# Execute SVG generation in a new thread to enforce time constraints
|
114 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
115 |
+
future = executor.submit(generate_svg)
|
116 |
+
try:
|
117 |
+
return future.result(timeout=self.timeout_seconds)
|
118 |
+
except concurrent.futures.TimeoutError:
|
119 |
+
logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
|
120 |
+
return self.default_svg
|
121 |
+
except Exception as e:
|
122 |
+
logging.error(f"An unexpected error occurred: {e}")
|
123 |
+
return self.default_svg
|
124 |
+
|
125 |
+
def enforce_constraints(self, svg_string: str) -> str:
|
126 |
+
"""Enforces constraints on an SVG string, removing disallowed elements
|
127 |
+
and attributes.
|
128 |
+
|
129 |
+
Parameters
|
130 |
+
----------
|
131 |
+
svg_string : str
|
132 |
+
The SVG string to process.
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
str
|
137 |
+
The processed SVG string, or the default SVG if constraints
|
138 |
+
cannot be satisfied.
|
139 |
+
"""
|
140 |
+
logging.info('Sanitizing SVG...')
|
141 |
+
|
142 |
+
try:
|
143 |
+
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
|
144 |
+
root = etree.fromstring(svg_string, parser=parser)
|
145 |
+
except etree.ParseError as e:
|
146 |
+
logging.error('SVG Parse Error: %s. Returning default SVG.', e)
|
147 |
+
logging.error('SVG string: %s', svg_string)
|
148 |
+
return self.default_svg
|
149 |
+
|
150 |
+
elements_to_remove = []
|
151 |
+
for element in root.iter():
|
152 |
+
tag_name = etree.QName(element.tag).localname
|
153 |
+
|
154 |
+
# Remove disallowed elements
|
155 |
+
if tag_name not in self.constraints.allowed_elements:
|
156 |
+
elements_to_remove.append(element)
|
157 |
+
continue # Skip attribute checks for removed elements
|
158 |
+
|
159 |
+
# Remove disallowed attributes
|
160 |
+
attrs_to_remove = []
|
161 |
+
for attr in element.attrib:
|
162 |
+
attr_name = etree.QName(attr).localname
|
163 |
+
if (
|
164 |
+
attr_name
|
165 |
+
not in self.constraints.allowed_elements[tag_name]
|
166 |
+
and attr_name
|
167 |
+
not in self.constraints.allowed_elements['common']
|
168 |
+
):
|
169 |
+
attrs_to_remove.append(attr)
|
170 |
+
|
171 |
+
for attr in attrs_to_remove:
|
172 |
+
logging.debug(
|
173 |
+
'Attribute "%s" for element "%s" not allowed. Removing.',
|
174 |
+
attr,
|
175 |
+
tag_name,
|
176 |
+
)
|
177 |
+
del element.attrib[attr]
|
178 |
+
|
179 |
+
# Check and remove invalid href attributes
|
180 |
+
for attr, value in element.attrib.items():
|
181 |
+
if etree.QName(attr).localname == 'href' and not value.startswith('#'):
|
182 |
+
logging.debug(
|
183 |
+
'Removing invalid href attribute in element "%s".', tag_name
|
184 |
+
)
|
185 |
+
del element.attrib[attr]
|
186 |
+
|
187 |
+
# Validate path elements to help ensure SVG conversion
|
188 |
+
if tag_name == 'path':
|
189 |
+
d_attribute = element.get('d')
|
190 |
+
if not d_attribute:
|
191 |
+
logging.warning('Path element is missing "d" attribute. Removing path.')
|
192 |
+
elements_to_remove.append(element)
|
193 |
+
continue # Skip further checks for this removed element
|
194 |
+
# Use regex to validate 'd' attribute format
|
195 |
+
path_regex = re.compile(
|
196 |
+
r'^' # Start of string
|
197 |
+
r'(?:' # Non-capturing group for each command + numbers block
|
198 |
+
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
|
199 |
+
r'\s*' # Optional whitespace after command
|
200 |
+
r'(?:' # Non-capturing group for optional numbers
|
201 |
+
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
|
202 |
+
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
|
203 |
+
r')?' # Numbers are optional (e.g. for Z command)
|
204 |
+
r'\s*' # Optional whitespace after numbers/command block
|
205 |
+
r')+' # One or more command blocks
|
206 |
+
r'\s*' # Optional trailing whitespace
|
207 |
+
r'$' # End of string
|
208 |
+
)
|
209 |
+
if not path_regex.match(d_attribute):
|
210 |
+
logging.warning(
|
211 |
+
'Path element has malformed "d" attribute format. Removing path.'
|
212 |
+
)
|
213 |
+
elements_to_remove.append(element)
|
214 |
+
continue
|
215 |
+
logging.debug('Path element "d" attribute validated (regex check).')
|
216 |
+
|
217 |
+
# Remove elements marked for removal
|
218 |
+
for element in elements_to_remove:
|
219 |
+
if element.getparent() is not None:
|
220 |
+
element.getparent().remove(element)
|
221 |
+
logging.debug('Removed element: %s', element.tag)
|
222 |
+
|
223 |
+
try:
|
224 |
+
cleaned_svg_string = etree.tostring(root, encoding='unicode')
|
225 |
+
return cleaned_svg_string
|
226 |
+
except ValueError as e:
|
227 |
+
logging.error(
|
228 |
+
'SVG could not be sanitized to meet constraints: %s', e
|
229 |
+
)
|
230 |
+
return self.default_svg
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
model = NaiveModel()
|
235 |
+
svg = model.predict("a purple forest at dusk")
|
236 |
+
# Convert SVG to PNG
|
237 |
+
try:
|
238 |
+
# Create a PNG in memory
|
239 |
+
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
|
240 |
+
|
241 |
+
# Save the PNG to a file
|
242 |
+
with open("output.png", "wb") as f:
|
243 |
+
f.write(png_data)
|
244 |
+
print("SVG saved as output.png")
|
245 |
+
except Exception as e:
|
246 |
+
print(f"Error converting SVG to PNG: {e}")
|
requirements.txt
CHANGED
@@ -19,6 +19,12 @@ dotenv
|
|
19 |
diffusers
|
20 |
safetensors
|
21 |
xformers
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# pip install 'tensorflow[and-cuda]'
|
24 |
# pip install git+https://github.com/openai/CLIP.git
|
|
|
19 |
diffusers
|
20 |
safetensors
|
21 |
xformers
|
22 |
+
unsloth
|
23 |
+
tf-keras
|
24 |
+
vtracer
|
25 |
+
deepspeed
|
26 |
+
torch==2.5.1
|
27 |
+
torchvision==0.20.1
|
28 |
|
29 |
# pip install 'tensorflow[and-cuda]'
|
30 |
# pip install git+https://github.com/openai/CLIP.git
|
starter.ipynb
DELETED
@@ -1,333 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 2,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stderr",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"/home/user/miniconda3/envs/dwl/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"data": {
|
18 |
-
"text/html": [
|
19 |
-
"<div><style>\n",
|
20 |
-
".dataframe > thead > tr,\n",
|
21 |
-
".dataframe > tbody > tr {\n",
|
22 |
-
" text-align: right;\n",
|
23 |
-
" white-space: pre-wrap;\n",
|
24 |
-
"}\n",
|
25 |
-
"</style>\n",
|
26 |
-
"<small>shape: (5, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>description</th></tr><tr><td>str</td><td>str</td></tr></thead><tbody><tr><td>"02d892"</td><td>"a purple forest at dusk"</td></tr><tr><td>"0dcd2e"</td><td>"gray wool coat with a faux furβ¦</td></tr><tr><td>"1e9ac1"</td><td>"a lighthouse overlooking the oβ¦</td></tr><tr><td>"2b25db"</td><td>"burgundy corduroy pants with pβ¦</td></tr><tr><td>"4e6a54"</td><td>"orange corduroy overalls"</td></tr></tbody></table></div>"
|
27 |
-
],
|
28 |
-
"text/plain": [
|
29 |
-
"shape: (5, 2)\n",
|
30 |
-
"ββββββββββ¬ββββββββββββββββββββββββββββββββββ\n",
|
31 |
-
"β id β description β\n",
|
32 |
-
"β --- β --- β\n",
|
33 |
-
"β str β str β\n",
|
34 |
-
"ββββββββββͺββββββββββββββββββββββββββββββββββ‘\n",
|
35 |
-
"β 02d892 β a purple forest at dusk β\n",
|
36 |
-
"β 0dcd2e β gray wool coat with a faux furβ¦ β\n",
|
37 |
-
"β 1e9ac1 β a lighthouse overlooking the oβ¦ β\n",
|
38 |
-
"β 2b25db β burgundy corduroy pants with pβ¦ β\n",
|
39 |
-
"β 4e6a54 β orange corduroy overalls β\n",
|
40 |
-
"ββββββββββ΄ββββββββββββββββββββββββββββββββββ"
|
41 |
-
]
|
42 |
-
},
|
43 |
-
"execution_count": 2,
|
44 |
-
"metadata": {},
|
45 |
-
"output_type": "execute_result"
|
46 |
-
}
|
47 |
-
],
|
48 |
-
"source": [
|
49 |
-
"# We can load and explore the competition's train set to get a feel for the data.\n",
|
50 |
-
"# We're not going to export this cell as it's not needed for our exported inferenceable model.\n",
|
51 |
-
"\n",
|
52 |
-
"import kagglehub\n",
|
53 |
-
"import polars as pl\n",
|
54 |
-
"\n",
|
55 |
-
"train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')\n",
|
56 |
-
"train = pl.read_csv(train_path)\n",
|
57 |
-
"\n",
|
58 |
-
"train.head()"
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"cell_type": "code",
|
63 |
-
"execution_count": 3,
|
64 |
-
"metadata": {},
|
65 |
-
"outputs": [],
|
66 |
-
"source": [
|
67 |
-
"class Model:\n",
|
68 |
-
" def __init__(self):\n",
|
69 |
-
" '''Optional constructor, performs any setup logic, model instantiation, etc.'''\n",
|
70 |
-
" pass\n",
|
71 |
-
" \n",
|
72 |
-
" def predict(self, prompt: str) -> str:\n",
|
73 |
-
" '''Generates SVG which produces an image described by the prompt.\n",
|
74 |
-
"\n",
|
75 |
-
" Args:\n",
|
76 |
-
" prompt (str): A prompt describing an image\n",
|
77 |
-
" Returns:\n",
|
78 |
-
" String of valid SVG code.\n",
|
79 |
-
" '''\n",
|
80 |
-
" # Renders a simple circle regardless of input\n",
|
81 |
-
" return '<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\" /></svg>'"
|
82 |
-
]
|
83 |
-
},
|
84 |
-
{
|
85 |
-
"cell_type": "code",
|
86 |
-
"execution_count": 4,
|
87 |
-
"metadata": {},
|
88 |
-
"outputs": [
|
89 |
-
{
|
90 |
-
"name": "stdout",
|
91 |
-
"output_type": "stream",
|
92 |
-
"text": [
|
93 |
-
"<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\" /></svg>\n"
|
94 |
-
]
|
95 |
-
},
|
96 |
-
{
|
97 |
-
"data": {
|
98 |
-
"image/svg+xml": [
|
99 |
-
"<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\"/></svg>"
|
100 |
-
],
|
101 |
-
"text/plain": [
|
102 |
-
"<IPython.core.display.SVG object>"
|
103 |
-
]
|
104 |
-
},
|
105 |
-
"metadata": {},
|
106 |
-
"output_type": "display_data"
|
107 |
-
}
|
108 |
-
],
|
109 |
-
"source": [
|
110 |
-
"from IPython.display import SVG\n",
|
111 |
-
"\n",
|
112 |
-
"model = Model()\n",
|
113 |
-
"svg = model.predict('a goose winning a gold medal')\n",
|
114 |
-
"\n",
|
115 |
-
"print(svg)\n",
|
116 |
-
"display(SVG(svg))"
|
117 |
-
]
|
118 |
-
},
|
119 |
-
{
|
120 |
-
"cell_type": "code",
|
121 |
-
"execution_count": 6,
|
122 |
-
"metadata": {},
|
123 |
-
"outputs": [
|
124 |
-
{
|
125 |
-
"data": {
|
126 |
-
"text/plain": [
|
127 |
-
"['RN50',\n",
|
128 |
-
" 'RN101',\n",
|
129 |
-
" 'RN50x4',\n",
|
130 |
-
" 'RN50x16',\n",
|
131 |
-
" 'RN50x64',\n",
|
132 |
-
" 'ViT-B/32',\n",
|
133 |
-
" 'ViT-B/16',\n",
|
134 |
-
" 'ViT-L/14',\n",
|
135 |
-
" 'ViT-L/14@336px']"
|
136 |
-
]
|
137 |
-
},
|
138 |
-
"execution_count": 6,
|
139 |
-
"metadata": {},
|
140 |
-
"output_type": "execute_result"
|
141 |
-
}
|
142 |
-
],
|
143 |
-
"source": [
|
144 |
-
"import clip\n",
|
145 |
-
"clip.available_models()"
|
146 |
-
]
|
147 |
-
},
|
148 |
-
{
|
149 |
-
"cell_type": "code",
|
150 |
-
"execution_count": 7,
|
151 |
-
"metadata": {},
|
152 |
-
"outputs": [
|
153 |
-
{
|
154 |
-
"name": "stderr",
|
155 |
-
"output_type": "stream",
|
156 |
-
"text": [
|
157 |
-
"2025-04-20 13:55:34.589770: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
158 |
-
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
|
159 |
-
"E0000 00:00:1745171734.600777 13214 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
160 |
-
"E0000 00:00:1745171734.603957 13214 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
161 |
-
"W0000 00:00:1745171734.615566 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
|
162 |
-
"W0000 00:00:1745171734.615584 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
|
163 |
-
"W0000 00:00:1745171734.615585 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
|
164 |
-
"W0000 00:00:1745171734.615586 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
|
165 |
-
"2025-04-20 13:55:34.618659: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
166 |
-
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
167 |
-
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n",
|
168 |
-
"Loading checkpoint shards: 100%|ββββββββββ| 4/4 [00:18<00:00, 4.68s/it]\n"
|
169 |
-
]
|
170 |
-
}
|
171 |
-
],
|
172 |
-
"source": [
|
173 |
-
"import pandas as pd\n",
|
174 |
-
"import importlib\n",
|
175 |
-
"metric = importlib.import_module('metric')\n",
|
176 |
-
"importlib.reload(metric)\n",
|
177 |
-
"\n",
|
178 |
-
"vqa_evaluator = metric.VQAEvaluator()\n",
|
179 |
-
"aesthetic_evaluator = metric.AestheticEvaluator()"
|
180 |
-
]
|
181 |
-
},
|
182 |
-
{
|
183 |
-
"cell_type": "code",
|
184 |
-
"execution_count": 11,
|
185 |
-
"metadata": {},
|
186 |
-
"outputs": [
|
187 |
-
{
|
188 |
-
"name": "stdout",
|
189 |
-
"output_type": "stream",
|
190 |
-
"text": [
|
191 |
-
"VQA Score: 0.9996758976500401\n",
|
192 |
-
"Aesthetic Score: 0.5749330520629883\n",
|
193 |
-
"Final Fidelity Score: 0.8709845773271212\n"
|
194 |
-
]
|
195 |
-
}
|
196 |
-
],
|
197 |
-
"source": [
|
198 |
-
"# score gpt4o generated images\n",
|
199 |
-
"import ast\n",
|
200 |
-
"import numpy as np\n",
|
201 |
-
"from PIL import Image\n",
|
202 |
-
"\n",
|
203 |
-
"# Load the first sample from descriptions.csv\n",
|
204 |
-
"descriptions_df = pd.read_csv('data/descriptions.csv')\n",
|
205 |
-
"first_description = descriptions_df.iloc[1]\n",
|
206 |
-
"\n",
|
207 |
-
"eval_df = pd.read_csv('data/eval.csv')\n",
|
208 |
-
"first_eval = eval_df.iloc[1]\n",
|
209 |
-
"\n",
|
210 |
-
"# Load the image\n",
|
211 |
-
"image_path = 'data/gray_coat.png' # Assuming the image is saved with this name\n",
|
212 |
-
"image = Image.open(image_path)\n",
|
213 |
-
"\n",
|
214 |
-
"# Prepare the inputs for scoring - need to parse the string representations\n",
|
215 |
-
"questions = ast.literal_eval(first_eval['question'])\n",
|
216 |
-
"choices = ast.literal_eval(first_eval['choices'])\n",
|
217 |
-
"answers = ast.literal_eval(first_eval['answer'])\n",
|
218 |
-
"\n",
|
219 |
-
"# Calculate VQA score - don't wrap in additional lists\n",
|
220 |
-
"vqa_score = vqa_evaluator.score(questions, choices, answers, image)\n",
|
221 |
-
"\n",
|
222 |
-
"# Calculate aesthetic score\n",
|
223 |
-
"aesthetic_score = aesthetic_evaluator.score(image)\n",
|
224 |
-
"\n",
|
225 |
-
"# Apply image processing as done in the metric.score function\n",
|
226 |
-
"image_processor = metric.ImageProcessor(image=image, seed=0).apply()\n",
|
227 |
-
"processed_image = image_processor.image.copy()\n",
|
228 |
-
"\n",
|
229 |
-
"# Calculate final fidelity score\n",
|
230 |
-
"instance_score = metric.harmonic_mean(vqa_score, aesthetic_score, beta=0.5)\n",
|
231 |
-
"\n",
|
232 |
-
"print(f\"VQA Score: {vqa_score}\")\n",
|
233 |
-
"print(f\"Aesthetic Score: {aesthetic_score}\")\n",
|
234 |
-
"print(f\"Final Fidelity Score: {instance_score}\")"
|
235 |
-
]
|
236 |
-
},
|
237 |
-
{
|
238 |
-
"cell_type": "code",
|
239 |
-
"execution_count": 13,
|
240 |
-
"metadata": {},
|
241 |
-
"outputs": [
|
242 |
-
{
|
243 |
-
"name": "stdout",
|
244 |
-
"output_type": "stream",
|
245 |
-
"text": [
|
246 |
-
"No duplicate IDs found in data/descriptions.csv\n",
|
247 |
-
"Sorted rows by ID\n",
|
248 |
-
"Fixed and sorted CSV saved to data/descriptions.csv\n",
|
249 |
-
"No duplicate IDs found in data/eval.csv\n",
|
250 |
-
"Sorted data/eval.csv by ID\n"
|
251 |
-
]
|
252 |
-
}
|
253 |
-
],
|
254 |
-
"source": [
|
255 |
-
"# Fix duplicate IDs in descriptions.csv and order rows by id\n",
|
256 |
-
"def fix_duplicate_ids(csv_path):\n",
|
257 |
-
" \"\"\"\n",
|
258 |
-
" Fix duplicate IDs in a CSV file by assigning new unique IDs to duplicates.\n",
|
259 |
-
" Then order rows by ID.\n",
|
260 |
-
" \"\"\"\n",
|
261 |
-
" # Read the CSV file\n",
|
262 |
-
" df = pd.read_csv(csv_path)\n",
|
263 |
-
" \n",
|
264 |
-
" # Check for duplicate IDs\n",
|
265 |
-
" duplicate_mask = df['id'].duplicated(keep='first')\n",
|
266 |
-
" duplicate_count = duplicate_mask.sum()\n",
|
267 |
-
" \n",
|
268 |
-
" if duplicate_count > 0:\n",
|
269 |
-
" print(f\"Found {duplicate_count} duplicate IDs in {csv_path}\")\n",
|
270 |
-
" \n",
|
271 |
-
" # Get the maximum ID value\n",
|
272 |
-
" max_id = df['id'].max()\n",
|
273 |
-
" \n",
|
274 |
-
" # Assign new IDs to duplicates\n",
|
275 |
-
" new_ids = list(range(max_id + 1, max_id + 1 + duplicate_count))\n",
|
276 |
-
" df.loc[duplicate_mask, 'id'] = new_ids\n",
|
277 |
-
" \n",
|
278 |
-
" print(f\"Assigned new IDs to duplicates\")\n",
|
279 |
-
" else:\n",
|
280 |
-
" print(f\"No duplicate IDs found in {csv_path}\")\n",
|
281 |
-
" \n",
|
282 |
-
" # Sort the dataframe by ID\n",
|
283 |
-
" df = df.sort_values(by='id')\n",
|
284 |
-
" print(f\"Sorted rows by ID\")\n",
|
285 |
-
" \n",
|
286 |
-
" # Save the fixed and sorted CSV\n",
|
287 |
-
" df.to_csv(csv_path, index=False)\n",
|
288 |
-
" print(f\"Fixed and sorted CSV saved to {csv_path}\")\n",
|
289 |
-
" \n",
|
290 |
-
" # Return the fixed dataframe\n",
|
291 |
-
" return df\n",
|
292 |
-
"\n",
|
293 |
-
"# Fix descriptions.csv\n",
|
294 |
-
"fixed_descriptions_df = fix_duplicate_ids('data/descriptions.csv')\n",
|
295 |
-
"\n",
|
296 |
-
"# Fix eval.csv if needed\n",
|
297 |
-
"# First check if eval.csv has the same issue\n",
|
298 |
-
"eval_df = pd.read_csv('data/eval.csv')\n",
|
299 |
-
"duplicate_eval_ids = eval_df['id'].duplicated(keep='first').sum()\n",
|
300 |
-
"\n",
|
301 |
-
"if duplicate_eval_ids > 0:\n",
|
302 |
-
" fixed_eval_df = fix_duplicate_ids('data/eval.csv')\n",
|
303 |
-
"else:\n",
|
304 |
-
" print(\"No duplicate IDs found in data/eval.csv\")\n",
|
305 |
-
" # Still sort by ID even if no duplicates\n",
|
306 |
-
" eval_df = eval_df.sort_values(by='id')\n",
|
307 |
-
" eval_df.to_csv('data/eval.csv', index=False)\n",
|
308 |
-
" print(\"Sorted data/eval.csv by ID\")\n"
|
309 |
-
]
|
310 |
-
}
|
311 |
-
],
|
312 |
-
"metadata": {
|
313 |
-
"kernelspec": {
|
314 |
-
"display_name": "dwl",
|
315 |
-
"language": "python",
|
316 |
-
"name": "python3"
|
317 |
-
},
|
318 |
-
"language_info": {
|
319 |
-
"codemirror_mode": {
|
320 |
-
"name": "ipython",
|
321 |
-
"version": 3
|
322 |
-
},
|
323 |
-
"file_extension": ".py",
|
324 |
-
"mimetype": "text/x-python",
|
325 |
-
"name": "python",
|
326 |
-
"nbconvert_exporter": "python",
|
327 |
-
"pygments_lexer": "ipython3",
|
328 |
-
"version": "3.11.11"
|
329 |
-
}
|
330 |
-
},
|
331 |
-
"nbformat": 4,
|
332 |
-
"nbformat_minor": 2
|
333 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
starvector/__init__.py
ADDED
File without changes
|
starvector/adapter.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.init as init
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class Swish(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(Swish, self).__init__()
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
return x * torch.sigmoid(x)
|
11 |
+
|
12 |
+
class Adapter(nn.Module):
|
13 |
+
def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1):
|
14 |
+
super().__init__()
|
15 |
+
self.query_length = query_length
|
16 |
+
self.dropout_prob = dropout_prob
|
17 |
+
self.adapter_norm = adapter_norm
|
18 |
+
|
19 |
+
self.dropout = nn.Dropout(p=self.dropout_prob)
|
20 |
+
|
21 |
+
self.c_fc = nn.Linear(input_size, input_size*2)
|
22 |
+
self.act = Swish()
|
23 |
+
self.c_proj = nn.Linear(input_size*2, output_size)
|
24 |
+
|
25 |
+
if adapter_norm == "layer_norm":
|
26 |
+
self.norm = nn.LayerNorm([self.query_length, output_size])
|
27 |
+
elif adapter_norm == "batch_norm":
|
28 |
+
self.norm = nn.BatchNorm1d(self.query_length)
|
29 |
+
|
30 |
+
self.init_type = init_type.lower()
|
31 |
+
self._initialize_weights()
|
32 |
+
|
33 |
+
def forward(self, hidden_states):
|
34 |
+
hidden_states = self.dropout(hidden_states)
|
35 |
+
hidden_states = self.c_fc(hidden_states)
|
36 |
+
hidden_states = self.act(hidden_states)
|
37 |
+
hidden_states = self.c_proj(hidden_states)
|
38 |
+
hidden_states = self.norm(hidden_states)
|
39 |
+
return hidden_states
|
40 |
+
|
41 |
+
def _initialize_weights(self):
|
42 |
+
for m in self.modules():
|
43 |
+
if isinstance(m, nn.Linear):
|
44 |
+
if self.init_type == "glorot":
|
45 |
+
init.xavier_uniform_(m.weight)
|
46 |
+
if m.bias is not None:
|
47 |
+
init.constant_(m.bias, 0)
|
48 |
+
elif self.init_type == "normal":
|
49 |
+
init.normal_(m.weight, mean=0, std=0.01)
|
50 |
+
if m.bias is not None:
|
51 |
+
init.constant_(m.bias, 0)
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid initialization type specified.")
|
starvector/clip_model.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
from itertools import repeat
|
5 |
+
import collections.abc
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
11 |
+
|
12 |
+
def convert_weights_to_precision(model: nn.Module, precision: torch.dtype):
|
13 |
+
"""Convert applicable model parameters to the specified precision"""
|
14 |
+
|
15 |
+
def _convert_weights_to_precision(l):
|
16 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
17 |
+
l.weight.data = l.weight.data.to(precision)
|
18 |
+
if l.bias is not None:
|
19 |
+
l.bias.data = l.bias.data.to(precision)
|
20 |
+
|
21 |
+
elif isinstance(l, (nn.MultiheadAttention)):
|
22 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
23 |
+
tensor = getattr(l, attr)
|
24 |
+
if tensor is not None:
|
25 |
+
tensor.data = tensor.data.to(precision)
|
26 |
+
else:
|
27 |
+
for _, p in l.named_parameters():
|
28 |
+
p.data = p.data.to(precision)
|
29 |
+
|
30 |
+
model.apply(_convert_weights_to_precision)
|
31 |
+
|
32 |
+
class Bottleneck(nn.Module):
|
33 |
+
expansion = 4
|
34 |
+
|
35 |
+
def __init__(self, inplanes, planes, stride=1):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
39 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
40 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
41 |
+
self.relu1 = nn.ReLU(inplace=True)
|
42 |
+
|
43 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
44 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
45 |
+
self.relu2 = nn.ReLU(inplace=True)
|
46 |
+
|
47 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
48 |
+
|
49 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
50 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
51 |
+
self.relu3 = nn.ReLU(inplace=True)
|
52 |
+
|
53 |
+
self.downsample = None
|
54 |
+
self.stride = stride
|
55 |
+
|
56 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
57 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
58 |
+
self.downsample = nn.Sequential(OrderedDict([
|
59 |
+
("-1", nn.AvgPool2d(stride)),
|
60 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
61 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
62 |
+
]))
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor):
|
65 |
+
identity = x
|
66 |
+
|
67 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
68 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
69 |
+
out = self.avgpool(out)
|
70 |
+
out = self.bn3(self.conv3(out))
|
71 |
+
|
72 |
+
if self.downsample is not None:
|
73 |
+
identity = self.downsample(x)
|
74 |
+
|
75 |
+
out += identity
|
76 |
+
out = self.relu3(out)
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
class AttentionPool2d(nn.Module):
|
81 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
82 |
+
super().__init__()
|
83 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
84 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
85 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
86 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
87 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
88 |
+
self.num_heads = num_heads
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
92 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
93 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
94 |
+
x, _ = F.multi_head_attention_forward(
|
95 |
+
query=x, key=x, value=x,
|
96 |
+
embed_dim_to_check=x.shape[-1],
|
97 |
+
num_heads=self.num_heads,
|
98 |
+
q_proj_weight=self.q_proj.weight,
|
99 |
+
k_proj_weight=self.k_proj.weight,
|
100 |
+
v_proj_weight=self.v_proj.weight,
|
101 |
+
in_proj_weight=None,
|
102 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
103 |
+
bias_k=None,
|
104 |
+
bias_v=None,
|
105 |
+
add_zero_attn=False,
|
106 |
+
dropout_p=0,
|
107 |
+
out_proj_weight=self.c_proj.weight,
|
108 |
+
out_proj_bias=self.c_proj.bias,
|
109 |
+
use_separate_proj_weight=True,
|
110 |
+
training=self.training,
|
111 |
+
need_weights=False
|
112 |
+
)
|
113 |
+
|
114 |
+
return x[0]
|
115 |
+
|
116 |
+
|
117 |
+
class LayerNorm(nn.LayerNorm):
|
118 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor):
|
121 |
+
orig_type = x.dtype
|
122 |
+
layernorm_dtype = self.weight.dtype
|
123 |
+
ret = super().forward(x.type(layernorm_dtype))
|
124 |
+
return ret.type(orig_type)
|
125 |
+
|
126 |
+
class QuickGELU(nn.Module):
|
127 |
+
def forward(self, x: torch.Tensor):
|
128 |
+
return x * torch.sigmoid(1.702 * x)
|
129 |
+
|
130 |
+
class ResidualAttentionBlock(nn.Module):
|
131 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
135 |
+
self.ln_1 = LayerNorm(d_model)
|
136 |
+
self.mlp = nn.Sequential(OrderedDict([
|
137 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
138 |
+
("gelu", QuickGELU()),
|
139 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
140 |
+
]))
|
141 |
+
self.ln_2 = LayerNorm(d_model)
|
142 |
+
self.attn_mask = attn_mask
|
143 |
+
|
144 |
+
if use_grad_checkpointing:
|
145 |
+
self.attn = checkpoint_wrapper(self.attn)
|
146 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
147 |
+
|
148 |
+
def attention(self, x: torch.Tensor):
|
149 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
150 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
151 |
+
|
152 |
+
def forward(self, x: torch.Tensor):
|
153 |
+
x = x + self.attention(self.ln_1(x))
|
154 |
+
x = x + self.mlp(self.ln_2(x))
|
155 |
+
return x
|
156 |
+
|
157 |
+
class Transformer(nn.Module):
|
158 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
159 |
+
super().__init__()
|
160 |
+
self.width = width
|
161 |
+
self.layers = layers
|
162 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor):
|
165 |
+
return self.resblocks(x)
|
166 |
+
|
167 |
+
class VisionTransformer(nn.Module):
|
168 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
|
169 |
+
super().__init__()
|
170 |
+
self.input_resolution = input_resolution
|
171 |
+
self.num_features = width
|
172 |
+
self.num_heads = heads
|
173 |
+
self.num_patches = (input_resolution // patch_size) ** 2
|
174 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
175 |
+
scale = width ** -0.5
|
176 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
177 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
|
178 |
+
self.ln_pre = LayerNorm(width)
|
179 |
+
self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
|
180 |
+
|
181 |
+
def forward(self, x: torch.Tensor):
|
182 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
183 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
184 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
185 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
186 |
+
x = x + self.positional_embedding.to(x.dtype)
|
187 |
+
x = self.ln_pre(x)
|
188 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
189 |
+
x = self.transformer(x)
|
190 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
191 |
+
return x
|
starvector/data/augmentation.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from svgpathtools import (
|
4 |
+
Path, Arc, CubicBezier, QuadraticBezier,
|
5 |
+
svgstr2paths)
|
6 |
+
import os
|
7 |
+
from noise import pnoise1
|
8 |
+
import re
|
9 |
+
import matplotlib.colors as mcolors
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from starvector.data.util import rasterize_svg
|
12 |
+
|
13 |
+
class SVGTransforms:
|
14 |
+
def __init__(self, transformations):
|
15 |
+
self.transformations = transformations
|
16 |
+
self.noise_std = self.transformations.get('noise_std', False)
|
17 |
+
self.noise_type = self.transformations.get('noise_type', False)
|
18 |
+
self.rotate = self.transformations.get('rotate', False)
|
19 |
+
self.shift_re = self.transformations.get('shift_re', False)
|
20 |
+
self.shift_im = self.transformations.get('shift_im', False)
|
21 |
+
self.scale = self.transformations.get('scale', False)
|
22 |
+
self.color_noise = self.transformations.get('color_noise', False)
|
23 |
+
self.p = self.transformations.get('p', 0.5)
|
24 |
+
self.color_change = self.transformations.get('color_change', False)
|
25 |
+
self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000'])
|
26 |
+
|
27 |
+
def sample_transformations(self):
|
28 |
+
if self.rotate:
|
29 |
+
a, b = self.rotate['from'], self.rotate['to']
|
30 |
+
rotation_angle = np.random.uniform(a, b)
|
31 |
+
self.rotation_angle = rotation_angle
|
32 |
+
|
33 |
+
if self.shift_re or self.shift_im:
|
34 |
+
self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to'])
|
35 |
+
self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to'])
|
36 |
+
|
37 |
+
if self.scale:
|
38 |
+
self.scale = np.random.uniform(self.scale['from'], self.scale['to'])
|
39 |
+
|
40 |
+
if self.color_noise:
|
41 |
+
self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to'])
|
42 |
+
|
43 |
+
|
44 |
+
def paths2str(self, groupped_paths, svg_opening_tag='<svg xmlns="http://www.w3.org/2000/svg" version="1.1">'):
|
45 |
+
|
46 |
+
keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry']
|
47 |
+
all_groups_srt = ''
|
48 |
+
for group, elements in groupped_paths.items():
|
49 |
+
group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', [])
|
50 |
+
group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items())
|
51 |
+
path_strings = []
|
52 |
+
path_str = ''
|
53 |
+
for path, attributes in paths_and_attributes:
|
54 |
+
path_attr_str = ''
|
55 |
+
d_str = path.d()
|
56 |
+
|
57 |
+
for key, value in attributes.items():
|
58 |
+
if key not in keys_to_exclude:
|
59 |
+
path_attr_str += f' {key}="{value}"'
|
60 |
+
|
61 |
+
path_strings.append(f'<path d="{d_str}"{path_attr_str} />')
|
62 |
+
path_str = "\n".join(path_strings)
|
63 |
+
if 'no_group'in group:
|
64 |
+
group_str = path_str
|
65 |
+
else:
|
66 |
+
group_str = f'<g {group_attr_str}>\n{path_str}\n</g>\n'
|
67 |
+
all_groups_srt += group_str
|
68 |
+
svg = f'{svg_opening_tag}\n{all_groups_srt}</svg>'
|
69 |
+
return svg
|
70 |
+
|
71 |
+
def add_noise(self, seg):
|
72 |
+
noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to'])
|
73 |
+
if self.noise_type == 'gaussian':
|
74 |
+
noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \
|
75 |
+
1j * np.random.normal(loc=0.0, scale=noise_scale)
|
76 |
+
elif self.noise_type == 'perlin':
|
77 |
+
noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale
|
78 |
+
|
79 |
+
if isinstance(seg, CubicBezier):
|
80 |
+
seg.control1 = seg.control1 + noise_sample
|
81 |
+
seg.control2 = seg.control2 + noise_sample
|
82 |
+
elif isinstance(seg, QuadraticBezier):
|
83 |
+
seg.control = seg.control + noise_sample
|
84 |
+
elif isinstance(seg, Arc):
|
85 |
+
seg.radius = seg.radius + noise_sample
|
86 |
+
|
87 |
+
|
88 |
+
return seg
|
89 |
+
|
90 |
+
def do_rotate(self, path, viewbox_width, viewbox_height):
|
91 |
+
if self.rotate:
|
92 |
+
new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2))
|
93 |
+
return new_path
|
94 |
+
else:
|
95 |
+
return path
|
96 |
+
|
97 |
+
def do_shift(self, path):
|
98 |
+
if self.shift_re or self.shift_im:
|
99 |
+
return path.translated(complex(self.shift_real, self.shift_imag))
|
100 |
+
else:
|
101 |
+
return path
|
102 |
+
|
103 |
+
def do_scale(self, path):
|
104 |
+
if self.scale:
|
105 |
+
return path.scaled(self.scale)
|
106 |
+
else:
|
107 |
+
return path
|
108 |
+
|
109 |
+
def add_color_noise(self, source_color):
|
110 |
+
# Convert color to RGB
|
111 |
+
if source_color.startswith("#"):
|
112 |
+
base_color = mcolors.hex2color(source_color)
|
113 |
+
else:
|
114 |
+
base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF'))
|
115 |
+
|
116 |
+
# Add noise to each RGB component
|
117 |
+
noise = np.random.normal(0, self.color_noise_std, 3)
|
118 |
+
noisy_color = np.clip(np.array(base_color) + noise, 0, 1)
|
119 |
+
|
120 |
+
# Convert the RGB color back to hex
|
121 |
+
hex_color = mcolors.rgb2hex(noisy_color)
|
122 |
+
|
123 |
+
return hex_color
|
124 |
+
|
125 |
+
def do_color_change(self, attr):
|
126 |
+
if 'fill' in attr:
|
127 |
+
if self.color_noise or self.color_change:
|
128 |
+
fill_value = attr['fill']
|
129 |
+
if fill_value == 'none':
|
130 |
+
new_fill_value = 'none'
|
131 |
+
else:
|
132 |
+
if self.color_noise:
|
133 |
+
new_fill_value = self.add_color_noise(fill_value)
|
134 |
+
elif self.color_change:
|
135 |
+
new_fill_value = np.random.choice(self.colors)
|
136 |
+
attr['fill'] = new_fill_value
|
137 |
+
return attr
|
138 |
+
|
139 |
+
def clean_attributes(self, attr):
|
140 |
+
attr_out = {}
|
141 |
+
if 'fill' in attr:
|
142 |
+
attr_out = attr
|
143 |
+
elif 'style' in attr:
|
144 |
+
fill_values = re.findall('fill:[^;]+', attr['style'])
|
145 |
+
if fill_values:
|
146 |
+
fill_value = fill_values[0].replace('fill:', '').strip()
|
147 |
+
attr_out['fill'] = fill_value
|
148 |
+
else:
|
149 |
+
attr_out = attr
|
150 |
+
else:
|
151 |
+
attr_out = attr
|
152 |
+
|
153 |
+
return attr_out
|
154 |
+
|
155 |
+
def get_viewbox_size(self, svg):
|
156 |
+
# Try to extract viewBox attribute
|
157 |
+
match = re.search(r'viewBox="([^"]+)"', svg)
|
158 |
+
if match:
|
159 |
+
viewbox = match.group(1)
|
160 |
+
else:
|
161 |
+
# If viewBox is not found, try to extract width and height attributes
|
162 |
+
match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg)
|
163 |
+
if match:
|
164 |
+
width, height = match.groups()
|
165 |
+
viewbox = f"0 0 {width} {height}"
|
166 |
+
else:
|
167 |
+
viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found
|
168 |
+
|
169 |
+
viewbox = [float(x) for x in viewbox.split()]
|
170 |
+
viewbox_width, viewbox_height = viewbox[2], viewbox[3]
|
171 |
+
return viewbox_width, viewbox_height
|
172 |
+
|
173 |
+
def augment(self, svg):
|
174 |
+
if os.path.isfile(svg):
|
175 |
+
# open svg file
|
176 |
+
with open(svg, 'r') as f:
|
177 |
+
svg = f.read()
|
178 |
+
|
179 |
+
# Sample transformations for this sample
|
180 |
+
self.sample_transformations()
|
181 |
+
|
182 |
+
|
183 |
+
# Parse the SVG content
|
184 |
+
soup = BeautifulSoup(svg, 'xml')
|
185 |
+
|
186 |
+
# Get opening tag
|
187 |
+
svg_opening_tag = re.findall('<svg[^>]+>', svg)[0]
|
188 |
+
|
189 |
+
viewbox_width, viewbox_height = self.get_viewbox_size(svg)
|
190 |
+
|
191 |
+
# Get all svg parents
|
192 |
+
groups = soup.findAll()
|
193 |
+
|
194 |
+
# Create the groups of paths based on their original <g> tag
|
195 |
+
grouped_paths = {}
|
196 |
+
for i, g in enumerate(groups):
|
197 |
+
if g.name == 'g':
|
198 |
+
group_id = group_id = g.get('id') if g.get('id') else f'none_{i}'
|
199 |
+
group_attrs = g.attrs
|
200 |
+
|
201 |
+
elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs':
|
202 |
+
continue
|
203 |
+
|
204 |
+
else:
|
205 |
+
group_id = f'no_group_{i}'
|
206 |
+
group_attrs = {}
|
207 |
+
|
208 |
+
group_svg_string = f'{svg_opening_tag}{str(g)}</svg>'
|
209 |
+
try:
|
210 |
+
paths, attributes = svgstr2paths(group_svg_string)
|
211 |
+
except:
|
212 |
+
return svg, rasterize_svg(svg)
|
213 |
+
if not paths:
|
214 |
+
continue
|
215 |
+
|
216 |
+
paths_and_attributes = []
|
217 |
+
|
218 |
+
# Rotation, shift, scale, noise addition
|
219 |
+
new_paths = []
|
220 |
+
new_attributes = []
|
221 |
+
for path, attribute in zip(paths, attributes):
|
222 |
+
attr = self.clean_attributes(attribute)
|
223 |
+
|
224 |
+
new_path = self.do_rotate(path, viewbox_width, viewbox_height)
|
225 |
+
new_path = self.do_shift(new_path)
|
226 |
+
new_path = self.do_scale(new_path)
|
227 |
+
|
228 |
+
if self.noise_std:
|
229 |
+
# Add noise to path to deform svg
|
230 |
+
noisy_path = []
|
231 |
+
for seg in new_path:
|
232 |
+
noisy_seg = self.add_noise(seg)
|
233 |
+
noisy_path.append(noisy_seg)
|
234 |
+
new_paths.append(Path(*noisy_path))
|
235 |
+
else:
|
236 |
+
new_paths.append(new_path)
|
237 |
+
|
238 |
+
# Color change
|
239 |
+
attr = self.do_color_change(attr)
|
240 |
+
paths_and_attributes.append((new_path, attr))
|
241 |
+
|
242 |
+
grouped_paths[group_id] = {
|
243 |
+
'paths': paths_and_attributes,
|
244 |
+
'attrs': group_attrs
|
245 |
+
}
|
246 |
+
|
247 |
+
svg = self.paths2str(grouped_paths, svg_opening_tag)
|
248 |
+
image = rasterize_svg(svg)
|
249 |
+
|
250 |
+
return svg, image
|
starvector/data/base.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg
|
3 |
+
from starvector.util import instantiate_from_config
|
4 |
+
import numpy as np
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
class SVGDatasetBase(Dataset):
|
8 |
+
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
|
9 |
+
self.split = split
|
10 |
+
self.im_size = im_size
|
11 |
+
|
12 |
+
transforms = kwargs.get('transforms', False)
|
13 |
+
if transforms:
|
14 |
+
self.transforms = instantiate_from_config(transforms)
|
15 |
+
self.p = self.transforms.p
|
16 |
+
else:
|
17 |
+
self.transforms = None
|
18 |
+
self.p = 0.0
|
19 |
+
|
20 |
+
normalization = kwargs.get('normalize', False)
|
21 |
+
if normalization:
|
22 |
+
mean = tuple(normalization.get('mean', None))
|
23 |
+
std = tuple(normalization.get('std', None))
|
24 |
+
else:
|
25 |
+
mean = None
|
26 |
+
std = None
|
27 |
+
|
28 |
+
self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std)
|
29 |
+
self.data = load_dataset(dataset_name, split=split)
|
30 |
+
|
31 |
+
print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split")
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.data_json)
|
35 |
+
|
36 |
+
def get_svg_and_image(self, svg_str, sample_id):
|
37 |
+
do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p])
|
38 |
+
svg, image = None, None
|
39 |
+
|
40 |
+
# Try to augment the image if conditions are met
|
41 |
+
if self.transforms is not None and do_augment:
|
42 |
+
try:
|
43 |
+
svg, image = self.transforms.augment(svg_str)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG")
|
46 |
+
|
47 |
+
# If augmentation failed or wasn't attempted, try to rasterize the SVG
|
48 |
+
if svg is None or image is None:
|
49 |
+
try:
|
50 |
+
svg, image = svg_str, rasterize_svg(svg_str, self.im_size)
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image")
|
53 |
+
svg = use_placeholder()
|
54 |
+
image = rasterize_svg(svg, self.im_size)
|
55 |
+
|
56 |
+
# If the image is completely white, use a placeholder image
|
57 |
+
if np.array(image).mean() == 255.0:
|
58 |
+
print(f"Image is full white, using placeholder image for {sample_id}")
|
59 |
+
svg = use_placeholder()
|
60 |
+
image = rasterize_svg(svg)
|
61 |
+
|
62 |
+
# Process the image
|
63 |
+
if 'siglip' in self.image_processor:
|
64 |
+
image = self.processor(image).pixel_values[0]
|
65 |
+
else:
|
66 |
+
image = self.processor(image)
|
67 |
+
|
68 |
+
return svg, image
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
raise NotImplementedError("This method should be implemented by subclasses")
|
starvector/data/dataset.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
from starvector.data.augmentation import SVGTransforms
|
4 |
+
from starvector.data.util import ImageTrainProcessor
|
5 |
+
from transformers import AutoProcessor
|
6 |
+
|
7 |
+
class SVGDataset(SVGDatasetBase):
|
8 |
+
def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs):
|
9 |
+
super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
|
10 |
+
|
11 |
+
self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
|
12 |
+
select_dataset_name = kwargs.get('select_dataset_name', False)
|
13 |
+
|
14 |
+
if select_dataset_name:
|
15 |
+
self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
|
16 |
+
|
17 |
+
self.num_samples = num_samples
|
18 |
+
if self.num_samples != -1:
|
19 |
+
self.data = self.data.select(range(self.num_samples))
|
20 |
+
|
21 |
+
self.image_processor = kwargs.get('image_processor', None)
|
22 |
+
if 'siglip' in self.image_processor:
|
23 |
+
model_name = {'siglip_512': 'google/siglip-base-patch16-512',
|
24 |
+
'siglip_384': 'google/siglip-large-patch16-384',
|
25 |
+
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
|
26 |
+
self.processor = AutoProcessor.from_pretrained(model_name).image_processor
|
27 |
+
else:
|
28 |
+
self.processor = ImageTrainProcessor(size=self.im_size)
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.data)
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
svg_str = self.data[idx]['Svg']
|
34 |
+
sample_id = self.data[idx]['Filename']
|
35 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
36 |
+
caption = self.data[idx].get('Caption', "")
|
37 |
+
return {
|
38 |
+
'svg': svg,
|
39 |
+
'image': image,
|
40 |
+
'id': sample_id,
|
41 |
+
'caption': caption
|
42 |
+
}
|
starvector/data/emojisvg.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
|
4 |
+
|
5 |
+
class EmojiSVGDataset(SVGDatasetBase):
|
6 |
+
def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs):
|
7 |
+
super().__init__(dataset_name, split, im_size, **kwargs)
|
8 |
+
|
9 |
+
self.num_samples = num_samples
|
10 |
+
if self.num_samples != -1:
|
11 |
+
self.data = self.data.select(range(self.num_samples))
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return len(self.data)
|
15 |
+
|
16 |
+
def __getitem__(self, idx):
|
17 |
+
|
18 |
+
svg_str = self.data[idx]['Svg']
|
19 |
+
sample_id = self.data[idx]['Filename']
|
20 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
21 |
+
caption = self.data[idx].get('Caption', "")
|
22 |
+
return {
|
23 |
+
'svg': svg,
|
24 |
+
'image': image,
|
25 |
+
'id': sample_id,
|
26 |
+
'caption': caption
|
27 |
+
}
|
starvector/data/figrsvg.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
from transformers import AutoProcessor
|
4 |
+
from starvector.data.util import ImageTrainProcessor
|
5 |
+
|
6 |
+
class FigrSVGDataset(SVGDatasetBase):
|
7 |
+
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
|
8 |
+
super().__init__(dataset_name, split, im_size, **kwargs)
|
9 |
+
|
10 |
+
self.num_samples = num_samples
|
11 |
+
if self.num_samples != -1:
|
12 |
+
self.data = self.data.select(range(self.num_samples))
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.data)
|
16 |
+
|
17 |
+
def __getitem__(self, idx):
|
18 |
+
svg_str = self.data[idx]['Svg']
|
19 |
+
sample_id = self.data[idx]['Id']
|
20 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
21 |
+
caption = self.data[idx].get('Caption', "")
|
22 |
+
return {
|
23 |
+
'svg': svg,
|
24 |
+
'image': image,
|
25 |
+
'id': sample_id,
|
26 |
+
'caption': caption
|
27 |
+
}
|
starvector/data/fontsvg.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
from transformers import AutoProcessor
|
4 |
+
from starvector.data.util import ImageTrainProcessor
|
5 |
+
|
6 |
+
class FontSVGDataset(SVGDatasetBase):
|
7 |
+
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
|
8 |
+
super().__init__(dataset_name, split, im_size, **kwargs)
|
9 |
+
|
10 |
+
self.num_samples = num_samples
|
11 |
+
if self.num_samples != -1:
|
12 |
+
self.data = self.data.select(range(self.num_samples))
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.data)
|
16 |
+
|
17 |
+
def __getitem__(self, idx):
|
18 |
+
|
19 |
+
svg_str = self.data[idx]['Svg']
|
20 |
+
sample_id = self.data[idx]['Filename']
|
21 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
22 |
+
caption = self.data[idx].get('Caption', "")
|
23 |
+
return {
|
24 |
+
'svg': svg,
|
25 |
+
'image': image,
|
26 |
+
'id': sample_id,
|
27 |
+
'caption': caption
|
28 |
+
}
|
starvector/data/iconsvg.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
from starvector.data.util import ImageTrainProcessor
|
4 |
+
from transformers import AutoProcessor
|
5 |
+
|
6 |
+
class SVGIconsDataset(SVGDatasetBase):
|
7 |
+
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
|
8 |
+
super().__init__(dataset_name, split, im_size, **kwargs)
|
9 |
+
|
10 |
+
self.num_samples = num_samples
|
11 |
+
if self.num_samples != -1:
|
12 |
+
self.data = self.data.select(range(self.num_samples))
|
13 |
+
|
14 |
+
self.image_processor = kwargs.get('image_processor', None)
|
15 |
+
if 'siglip' in self.image_processor:
|
16 |
+
model_name = {'siglip_512': 'google/siglip-base-patch16-512',
|
17 |
+
'siglip_384': 'google/siglip-large-patch16-384',
|
18 |
+
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
|
19 |
+
self.processor = AutoProcessor.from_pretrained(model_name).image_processor
|
20 |
+
else:
|
21 |
+
self.processor = ImageTrainProcessor(size=self.im_size)
|
22 |
+
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.data)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
|
29 |
+
svg_str = self.data[idx]['Svg']
|
30 |
+
sample_id = self.data[idx]['Filename']
|
31 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
32 |
+
caption = self.data[idx].get('Caption', "")
|
33 |
+
return {
|
34 |
+
'svg': svg,
|
35 |
+
'image': image,
|
36 |
+
'id': sample_id,
|
37 |
+
'caption': caption
|
38 |
+
}
|
starvector/data/stacksvg.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from starvector.data.base import SVGDatasetBase
|
3 |
+
from starvector.data.augmentation import SVGTransforms
|
4 |
+
import random
|
5 |
+
from transformers import AutoProcessor
|
6 |
+
from starvector.data.util import ImageTrainProcessor
|
7 |
+
|
8 |
+
text2svg_captions = [
|
9 |
+
"Draw an SVG of ",
|
10 |
+
"Draw an SVG image of ",
|
11 |
+
"Draw an SVG picture of ",
|
12 |
+
"Generate an SVG of ",
|
13 |
+
"Create an SVG of ",
|
14 |
+
"Design an SVG of ",
|
15 |
+
"Make an SVG of ",
|
16 |
+
]
|
17 |
+
|
18 |
+
class SVGStackDataset(SVGDatasetBase):
|
19 |
+
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
|
20 |
+
super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
|
21 |
+
self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
|
22 |
+
|
23 |
+
# Text2SVG specific
|
24 |
+
self.random_caption = kwargs.get('random_caption', True)
|
25 |
+
select_dataset_name = kwargs.get('select_dataset_name', False)
|
26 |
+
if select_dataset_name:
|
27 |
+
self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
|
28 |
+
|
29 |
+
self.num_samples = num_samples
|
30 |
+
if self.num_samples != -1:
|
31 |
+
self.data = self.data.select(range(self.num_samples))
|
32 |
+
|
33 |
+
self.image_processor = kwargs.get('image_processor', None)
|
34 |
+
if self.image_processor and 'siglip' in self.image_processor:
|
35 |
+
model_name = {'siglip_512': 'google/siglip-base-patch16-512',
|
36 |
+
'siglip_384': 'google/siglip-large-patch16-384',
|
37 |
+
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
|
38 |
+
self.processor = AutoProcessor.from_pretrained(model_name).image_processor
|
39 |
+
else:
|
40 |
+
self.processor = ImageTrainProcessor(size=self.im_size)
|
41 |
+
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.data)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
svg_str = self.data[idx]['Svg']
|
48 |
+
sample_id = self.data[idx]['Filename']
|
49 |
+
svg, image = self.get_svg_and_image(svg_str, sample_id)
|
50 |
+
|
51 |
+
# Randomly choose between 'caption_blip' and 'caption_llava'
|
52 |
+
caption_column = random.choice(['caption_blip2', 'caption_llava'])
|
53 |
+
caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "")
|
54 |
+
return {
|
55 |
+
'svg': svg,
|
56 |
+
'image': image,
|
57 |
+
'id': sample_id,
|
58 |
+
'caption': caption,
|
59 |
+
}
|
starvector/data/util.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
from torchvision.transforms.functional import InterpolationMode, pad
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
import re
|
8 |
+
from svgpathtools import svgstr2paths
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import cairosvg
|
12 |
+
from io import BytesIO
|
13 |
+
import numpy as np
|
14 |
+
import textwrap
|
15 |
+
import os
|
16 |
+
import base64
|
17 |
+
import io
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
CIRCLE_SVG = "<svg><circle cx='50%' cy='50%' r='50%' /></svg>"
|
22 |
+
VOID_SVF = "<svg></svg>"
|
23 |
+
|
24 |
+
def load_transforms():
|
25 |
+
transforms = {
|
26 |
+
'train': None,
|
27 |
+
'eval': None
|
28 |
+
}
|
29 |
+
return transforms
|
30 |
+
|
31 |
+
class ImageBaseProcessor():
|
32 |
+
def __init__(self, mean=None, std=None):
|
33 |
+
if mean is None:
|
34 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
35 |
+
if std is None:
|
36 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
37 |
+
|
38 |
+
self.normalize = transforms.Normalize(mean=mean, std=std)
|
39 |
+
|
40 |
+
class ImageTrainProcessor(ImageBaseProcessor):
|
41 |
+
def __init__(self, mean=None, std=None, size=224, **kwargs):
|
42 |
+
super().__init__(mean, std)
|
43 |
+
|
44 |
+
self.size = size
|
45 |
+
|
46 |
+
self.transform = transforms.Compose([
|
47 |
+
transforms.Lambda(lambda img: self._rgba_to_rgb_white(img) if img.mode == "RGBA" else img),
|
48 |
+
transforms.Lambda(lambda img: self._pad_to_square(img)),
|
49 |
+
transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
self.normalize
|
52 |
+
])
|
53 |
+
|
54 |
+
def __call__(self, item):
|
55 |
+
return self.transform(item)
|
56 |
+
|
57 |
+
def _pad_to_square(self, img):
|
58 |
+
# Calculate padding to make the image square
|
59 |
+
width, height = img.size
|
60 |
+
max_dim = max(width, height)
|
61 |
+
padding = [(max_dim - width) // 2, (max_dim - height) // 2]
|
62 |
+
padding += [max_dim - width - padding[0], max_dim - height - padding[1]]
|
63 |
+
return pad(img, padding, fill=255) # Assuming white padding
|
64 |
+
|
65 |
+
def _rgba_to_rgb_white(self, img):
|
66 |
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
67 |
+
background.paste(img, mask=img.split()[3])
|
68 |
+
return background
|
69 |
+
|
70 |
+
|
71 |
+
def encode_image_base64(pil_image):
|
72 |
+
if pil_image.mode == 'RGBA':
|
73 |
+
pil_image = pil_image.convert('RGB') # Convert RGBA to RGB
|
74 |
+
buffered = io.BytesIO()
|
75 |
+
pil_image.save(buffered, format="JPEG")
|
76 |
+
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
77 |
+
return base64_image
|
78 |
+
|
79 |
+
# -------------- Generation utils --------------
|
80 |
+
def is_valid_svg(svg_text):
|
81 |
+
try:
|
82 |
+
svgstr2paths(svg_text)
|
83 |
+
return True
|
84 |
+
except Exception as e:
|
85 |
+
print(f"Invalid SVG: {str(e)}")
|
86 |
+
return False
|
87 |
+
|
88 |
+
def clean_svg(svg_text, output_width=None, output_height=None):
|
89 |
+
soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml
|
90 |
+
svg_bs4 = soup.prettify() # Prettify to get a string
|
91 |
+
|
92 |
+
# Store the original signal handler
|
93 |
+
import signal
|
94 |
+
original_handler = signal.getsignal(signal.SIGALRM)
|
95 |
+
|
96 |
+
try:
|
97 |
+
# Set a timeout to prevent hanging
|
98 |
+
def timeout_handler(signum, frame):
|
99 |
+
raise TimeoutError("SVG processing timed out")
|
100 |
+
|
101 |
+
# Set timeout
|
102 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
103 |
+
signal.alarm(5)
|
104 |
+
|
105 |
+
# Try direct conversion without BeautifulSoup
|
106 |
+
svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode()
|
107 |
+
|
108 |
+
except TimeoutError:
|
109 |
+
print("SVG conversion timed out, using fallback method")
|
110 |
+
svg_cairo = """<svg></svg>"""
|
111 |
+
finally:
|
112 |
+
# Always cancel the alarm and restore original handler, regardless of success or failure
|
113 |
+
signal.alarm(0)
|
114 |
+
signal.signal(signal.SIGALRM, original_handler)
|
115 |
+
|
116 |
+
svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("<?xml")]) # Remove xml header
|
117 |
+
return svg_clean
|
118 |
+
|
119 |
+
|
120 |
+
def use_placeholder():
|
121 |
+
return VOID_SVF
|
122 |
+
|
123 |
+
def process_and_rasterize_svg(svg_string, resolution=256, dpi = 128, scale=2):
|
124 |
+
try:
|
125 |
+
svgstr2paths(svg_string) # This will raise an exception if the svg is still not valid
|
126 |
+
out_svg = svg_string
|
127 |
+
except:
|
128 |
+
try:
|
129 |
+
svg = clean_svg(svg_string)
|
130 |
+
svgstr2paths(svg) # This will raise an exception if the svg is still not valid
|
131 |
+
out_svg = svg
|
132 |
+
except Exception as e:
|
133 |
+
out_svg = use_placeholder()
|
134 |
+
|
135 |
+
raster_image = rasterize_svg(out_svg, resolution, dpi, scale)
|
136 |
+
return out_svg, raster_image
|
137 |
+
|
138 |
+
def rasterize_svg(svg_string, resolution=224, dpi = 128, scale=2):
|
139 |
+
try:
|
140 |
+
svg_raster_bytes = cairosvg.svg2png(
|
141 |
+
bytestring=svg_string,
|
142 |
+
background_color='white',
|
143 |
+
output_width=resolution,
|
144 |
+
output_height=resolution,
|
145 |
+
dpi=dpi,
|
146 |
+
scale=scale)
|
147 |
+
svg_raster = Image.open(BytesIO(svg_raster_bytes))
|
148 |
+
except:
|
149 |
+
try:
|
150 |
+
svg = clean_svg(svg_string)
|
151 |
+
svg_raster_bytes = cairosvg.svg2png(
|
152 |
+
bytestring=svg,
|
153 |
+
background_color='white',
|
154 |
+
output_width=resolution,
|
155 |
+
output_height=resolution,
|
156 |
+
dpi=dpi,
|
157 |
+
scale=scale)
|
158 |
+
svg_raster = Image.open(BytesIO(svg_raster_bytes))
|
159 |
+
except:
|
160 |
+
svg_raster = Image.new('RGB', (resolution, resolution), color = 'white')
|
161 |
+
return svg_raster
|
162 |
+
|
163 |
+
def find_unclosed_tags(svg_content):
|
164 |
+
all_tags_pattern = r"<(\w+)"
|
165 |
+
self_closing_pattern = r"<\w+[^>]*\/>"
|
166 |
+
all_tags = re.findall(all_tags_pattern, svg_content)
|
167 |
+
self_closing_matches = re.findall(self_closing_pattern, svg_content)
|
168 |
+
self_closing_tags = []
|
169 |
+
|
170 |
+
for match in self_closing_matches:
|
171 |
+
tag = re.search(all_tags_pattern, match)
|
172 |
+
if tag:
|
173 |
+
self_closing_tags.append(tag.group(1))
|
174 |
+
unclosed_tags = []
|
175 |
+
|
176 |
+
for tag in all_tags:
|
177 |
+
if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count('</' + tag + '>'):
|
178 |
+
unclosed_tags.append(tag)
|
179 |
+
unclosed_tags = list(dict.fromkeys(unclosed_tags))
|
180 |
+
|
181 |
+
return unclosed_tags
|
182 |
+
|
183 |
+
|
184 |
+
# -------------- Plotting utils --------------
|
185 |
+
def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path):
|
186 |
+
array1 = np.array(image1).astype(np.float32)
|
187 |
+
array2 = np.array(image2).astype(np.float32)
|
188 |
+
diff = np.abs(array1 - array2).astype(np.uint8)
|
189 |
+
|
190 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
191 |
+
axes[0].imshow(image1)
|
192 |
+
axes[0].set_title('generated_svg')
|
193 |
+
axes[0].axis('off')
|
194 |
+
axes[1].imshow(image2)
|
195 |
+
axes[1].set_title('gt')
|
196 |
+
axes[1].axis('off')
|
197 |
+
axes[2].imshow(diff)
|
198 |
+
axes[2].set_title('Difference')
|
199 |
+
axes[2].axis('off')
|
200 |
+
plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05)
|
201 |
+
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
|
202 |
+
image = Image.open(out_path)
|
203 |
+
plt.close(fig)
|
204 |
+
return image
|
205 |
+
|
206 |
+
def plot_images_side_by_side(image1, image2, out_path):
|
207 |
+
array1 = np.array(image1).astype(np.float32)
|
208 |
+
array2 = np.array(image2).astype(np.float32)
|
209 |
+
diff = np.abs(array1 - array2).astype(np.uint8)
|
210 |
+
|
211 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
212 |
+
axes[0].imshow(image1)
|
213 |
+
axes[0].set_title('generated_svg')
|
214 |
+
axes[0].axis('off')
|
215 |
+
axes[1].imshow(image2)
|
216 |
+
axes[1].set_title('gt')
|
217 |
+
axes[1].axis('off')
|
218 |
+
axes[2].imshow(diff)
|
219 |
+
axes[2].set_title('Difference')
|
220 |
+
axes[2].axis('off')
|
221 |
+
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
|
222 |
+
image = Image.open(out_path)
|
223 |
+
plt.close(fig)
|
224 |
+
return image
|
225 |
+
|
226 |
+
def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename):
|
227 |
+
# Create a plot with the original image and different temperature results
|
228 |
+
num_temps = len(samples_temp)
|
229 |
+
fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]})
|
230 |
+
|
231 |
+
# Plot the original image
|
232 |
+
gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png')
|
233 |
+
gt_image = Image.open(gt_image_path)
|
234 |
+
axes[0, 0].imshow(gt_image)
|
235 |
+
axes[0, 0].set_title('Original')
|
236 |
+
axes[0, 0].axis('off')
|
237 |
+
axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16)
|
238 |
+
axes[1, 0].axis('off')
|
239 |
+
|
240 |
+
# Plot the generated images for different temperatures and metrics
|
241 |
+
for idx, (temp, sample) in enumerate(samples_temp.items()):
|
242 |
+
gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png')
|
243 |
+
gen_image = Image.open(gen_image_path)
|
244 |
+
axes[0, idx + 1].imshow(gen_image)
|
245 |
+
axes[0, idx + 1].set_title(f'Temp {temp}')
|
246 |
+
axes[0, idx + 1].axis('off')
|
247 |
+
axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}',
|
248 |
+
horizontalalignment='center', verticalalignment='center', fontsize=12)
|
249 |
+
axes[1, idx + 1].axis('off')
|
250 |
+
|
251 |
+
# Save the comparison plot
|
252 |
+
comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png')
|
253 |
+
plt.tight_layout()
|
254 |
+
plt.savefig(comparison_path)
|
255 |
+
plt.close()
|
256 |
+
|
257 |
+
def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path):
|
258 |
+
# First col shows caption, second col shows generated svg, third col shows gt svg
|
259 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
260 |
+
|
261 |
+
# Split the prompt into multiple lines if it exceeds a certain length
|
262 |
+
prompt_lines = textwrap.wrap(prompt, width=30)
|
263 |
+
prompt_text = '\n'.join(prompt_lines)
|
264 |
+
|
265 |
+
# Display the prompt in the first cell
|
266 |
+
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
|
267 |
+
axes[0].axis('off')
|
268 |
+
axes[1].imshow(svg_raster)
|
269 |
+
axes[1].set_title('generated_svg')
|
270 |
+
axes[1].axis('off')
|
271 |
+
axes[2].imshow(gt_svg_raster)
|
272 |
+
axes[2].set_title('gt')
|
273 |
+
axes[2].axis('off')
|
274 |
+
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
|
275 |
+
image = Image.open(out_path)
|
276 |
+
plt.close(fig)
|
277 |
+
return image
|
278 |
+
|
279 |
+
def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path):
|
280 |
+
# First col shows caption, second col shows generated svg, third col shows gt svg
|
281 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
|
282 |
+
|
283 |
+
# Split the prompt into multiple lines if it exceeds a certain length
|
284 |
+
prompt_lines = textwrap.wrap(prompt, width=30)
|
285 |
+
prompt_text = '\n'.join(prompt_lines)
|
286 |
+
|
287 |
+
# Display the prompt in the first cell
|
288 |
+
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
|
289 |
+
axes[0].axis('off')
|
290 |
+
axes[1].imshow(svg_raster)
|
291 |
+
axes[1].set_title('generated_svg')
|
292 |
+
axes[1].axis('off')
|
293 |
+
axes[2].imshow(gt_svg_raster)
|
294 |
+
axes[2].set_title('gt')
|
295 |
+
axes[2].axis('off')
|
296 |
+
plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05)
|
297 |
+
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
|
298 |
+
image = Image.open(out_path)
|
299 |
+
plt.close(fig)
|
300 |
+
return image
|
301 |
+
|
302 |
+
def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename):
|
303 |
+
# Calculate the number of temperature variations
|
304 |
+
num_temps = len(samples_temp)
|
305 |
+
|
306 |
+
# Create a plot with text, the original image, and different temperature results
|
307 |
+
fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6))
|
308 |
+
|
309 |
+
# Split the prompt into multiple lines if it exceeds a certain length
|
310 |
+
prompt_lines = textwrap.wrap(prompt, width=30)
|
311 |
+
prompt_text = '\n'.join(prompt_lines)
|
312 |
+
|
313 |
+
# Display the prompt in the first cell
|
314 |
+
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
|
315 |
+
axes[0].axis('off')
|
316 |
+
|
317 |
+
# Plot the GT (ground truth) image in the second cell
|
318 |
+
gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png')
|
319 |
+
gt_image = Image.open(gt_image_path)
|
320 |
+
axes[1].imshow(gt_image)
|
321 |
+
axes[1].set_title('GT Image')
|
322 |
+
axes[1].axis('off')
|
323 |
+
|
324 |
+
# Plot the generated images for different temperatures and display metrics
|
325 |
+
for idx, (temp, sample) in enumerate(samples_temp.items()):
|
326 |
+
gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png')
|
327 |
+
gen_image = Image.open(gen_image_path)
|
328 |
+
axes[idx + 2].imshow(gen_image)
|
329 |
+
axes[idx + 2].set_title(f'Temp {temp}')
|
330 |
+
axes[idx + 2].axis('off')
|
331 |
+
clip_score = metrics[temp]["clip_score"]
|
332 |
+
axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes)
|
333 |
+
|
334 |
+
# Save the comparison plot
|
335 |
+
comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png')
|
336 |
+
plt.tight_layout()
|
337 |
+
plt.savefig(comparison_path)
|
338 |
+
plt.close()
|
339 |
+
|
340 |
+
return comparison_path
|
341 |
+
|
342 |
+
|
343 |
+
def plot_image_tensor(image):
|
344 |
+
import numpy as np
|
345 |
+
from PIL import Image
|
346 |
+
tensor = image[0].cpu().float()
|
347 |
+
tensor = tensor.permute(1, 2, 0)
|
348 |
+
array = (tensor.numpy() * 255).astype(np.uint8)
|
349 |
+
im = Image.fromarray(array)
|
350 |
+
im.save("tmp/output_image.jpg")
|
351 |
+
|
352 |
+
|
353 |
+
def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'):
|
354 |
+
# Calculate the number of rows required for the grid
|
355 |
+
num_images = len(images)
|
356 |
+
num_rows = (num_images + num_cols - 1) // num_cols
|
357 |
+
|
358 |
+
# Create a new figure
|
359 |
+
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))
|
360 |
+
|
361 |
+
# Loop through the image files and plot them
|
362 |
+
for i, image in enumerate(images):
|
363 |
+
row = i // num_cols
|
364 |
+
col = i % num_cols
|
365 |
+
|
366 |
+
# Open and display the image using Pillow
|
367 |
+
if type(image) == str:
|
368 |
+
img = Image.open(image)
|
369 |
+
else:
|
370 |
+
img = image
|
371 |
+
axes[row, col].imshow(img)
|
372 |
+
# axes[row, col].set_title(os.path.basename(image_file))
|
373 |
+
axes[row, col].axis('off')
|
374 |
+
|
375 |
+
# Remove empty subplots
|
376 |
+
for i in range(num_images, num_rows * num_cols):
|
377 |
+
row = i // num_cols
|
378 |
+
col = i % num_cols
|
379 |
+
fig.delaxes(axes[row, col])
|
380 |
+
|
381 |
+
# Adjust spacing between subplots
|
382 |
+
plt.tight_layout()
|
383 |
+
|
384 |
+
# save image
|
385 |
+
plt.savefig(out_path, dpi=300)
|
386 |
+
image = Image.open(out_path)
|
387 |
+
plt.close(fig)
|
388 |
+
|
389 |
+
return image
|
starvector/image_encoder.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from starvector.model.image_encoder.clip_model import convert_weights_to_precision
|
7 |
+
from starvector.data.util import ImageTrainProcessor
|
8 |
+
|
9 |
+
class ImageEncoder(nn.Module):
|
10 |
+
def __init__(self, config, **kwargs):
|
11 |
+
super(ImageEncoder, self).__init__()
|
12 |
+
|
13 |
+
image_size = config.image_size
|
14 |
+
torch_dtype = kwargs.get('model_precision', config.torch_dtype)
|
15 |
+
self.image_encoder_type = config.image_encoder_type
|
16 |
+
if self.image_encoder_type == 'clip':
|
17 |
+
self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size)
|
18 |
+
convert_weights_to_precision(self, torch_dtype)
|
19 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
20 |
+
|
21 |
+
elif self.image_encoder_type == 'vqgan':
|
22 |
+
self.visual_encoder = self.build_vqgan_encoder()
|
23 |
+
self.ln_vision = None
|
24 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
25 |
+
|
26 |
+
elif self.image_encoder_type == 'convnext':
|
27 |
+
self.visual_encoder = self.build_vqgan_encoder()
|
28 |
+
self.ln_vision = None
|
29 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
30 |
+
|
31 |
+
elif 'siglip' in self.image_encoder_type:
|
32 |
+
if self.image_encoder_type == 'siglip_512':
|
33 |
+
model_name = "google/siglip-base-patch16-512"
|
34 |
+
elif self.image_encoder_type == 'siglip_384':
|
35 |
+
model_name = "google/siglip-large-patch16-384"
|
36 |
+
elif self.image_encoder_type == 'siglip_256':
|
37 |
+
model_name = "google/siglip-base-patch16-256"
|
38 |
+
|
39 |
+
from transformers import AutoProcessor, AutoModel
|
40 |
+
|
41 |
+
self.visual_encoder = AutoModel.from_pretrained(
|
42 |
+
model_name, torch_dtype = torch_dtype
|
43 |
+
).vision_model
|
44 |
+
|
45 |
+
self.processor = AutoProcessor.from_pretrained(
|
46 |
+
model_name, torch_dtype = torch_dtype
|
47 |
+
)
|
48 |
+
|
49 |
+
def build_clip_encoder(self, image_size):
|
50 |
+
from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm
|
51 |
+
visual_encoder = VisionTransformer(
|
52 |
+
input_resolution=image_size,
|
53 |
+
patch_size=14,
|
54 |
+
width=1024,
|
55 |
+
layers=23,
|
56 |
+
heads=16,
|
57 |
+
use_grad_checkpointing=False)
|
58 |
+
|
59 |
+
ln_vision = LayerNorm(visual_encoder.num_features)
|
60 |
+
return visual_encoder, ln_vision
|
61 |
+
|
62 |
+
def build_vqgan_encoder(self):
|
63 |
+
from taming.modules.diffusionmodules.model import Encoder
|
64 |
+
VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md
|
65 |
+
vqgan_chkp_path = VQGAN_CHECKPOINT
|
66 |
+
files_in_directory = os.listdir(vqgan_chkp_path + '/configs')
|
67 |
+
vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0]
|
68 |
+
vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file))
|
69 |
+
visual_encoder = Encoder(**vqgan_config.model.params.ddconfig)
|
70 |
+
|
71 |
+
# Load checkpoint weights
|
72 |
+
checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict']
|
73 |
+
|
74 |
+
# Create a new state_dict with modified keys
|
75 |
+
new_state_dict = {}
|
76 |
+
for key, value in checkpoint.items():
|
77 |
+
if key.startswith('encoder.'):
|
78 |
+
new_key = key[len('encoder.'):]
|
79 |
+
new_state_dict[new_key] = value
|
80 |
+
|
81 |
+
# Load weights
|
82 |
+
visual_encoder.load_state_dict(new_state_dict)
|
83 |
+
return visual_encoder
|
84 |
+
|
85 |
+
def build_convnext_encoder(self):
|
86 |
+
import open_clip
|
87 |
+
model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
|
88 |
+
return model.visual
|
89 |
+
|
90 |
+
def forward(self, image):
|
91 |
+
if self.image_encoder_type == 'clip':
|
92 |
+
embeds = self.visual_encoder(image)
|
93 |
+
out = self.ln_vision(embeds)
|
94 |
+
elif self.image_encoder_type == 'open-clip':
|
95 |
+
out = self.visual_encoder(image)[1]
|
96 |
+
out = self.ln_vision(out)
|
97 |
+
elif self.image_encoder_type == 'vqgan':
|
98 |
+
out = self.visual_encoder(image)
|
99 |
+
size = out.size()
|
100 |
+
out = out.view(size[0], size[1], -1)
|
101 |
+
out = out.permute(0, 2, 1)
|
102 |
+
elif self.image_encoder_type == 'convnext':
|
103 |
+
out = self.visual_encoder.trunk.forward_features(image)
|
104 |
+
size = out.size()
|
105 |
+
out = out.view(size[0], size[1], -1)
|
106 |
+
out = out.permute(0, 2, 1)
|
107 |
+
elif 'siglip' in self.image_encoder_type:
|
108 |
+
out = self.visual_encoder(image)["last_hidden_state"]
|
109 |
+
return out
|
110 |
+
|
111 |
+
def process_images(self, images):
|
112 |
+
if self.image_encoder_type == 'clip':
|
113 |
+
res = []
|
114 |
+
for image in images:
|
115 |
+
res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W
|
116 |
+
return res
|
117 |
+
else:
|
118 |
+
return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0)
|
119 |
+
|
starvector/metrics/base_metric.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from starvector.metrics.util import AverageMeter
|
2 |
+
from tqdm import tqdm
|
3 |
+
import math
|
4 |
+
|
5 |
+
class BaseMetric:
|
6 |
+
def __init__(self):
|
7 |
+
self.meter = AverageMeter()
|
8 |
+
|
9 |
+
def reset(self):
|
10 |
+
self.meter.reset()
|
11 |
+
|
12 |
+
def calculate_score(self, batch, update=True):
|
13 |
+
"""
|
14 |
+
Batch: {"gt_im": [PIL Image], "gen_im": [Image]}
|
15 |
+
"""
|
16 |
+
values = []
|
17 |
+
batch_size = len(next(iter(batch.values())))
|
18 |
+
for index in tqdm(range(batch_size)):
|
19 |
+
kwargs = {}
|
20 |
+
for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]:
|
21 |
+
if key in batch:
|
22 |
+
kwargs[key] = batch[key][index]
|
23 |
+
try:
|
24 |
+
measure = self.metric(**kwargs)
|
25 |
+
except Exception as e:
|
26 |
+
print("Error calculating metric: {}".format(e))
|
27 |
+
continue
|
28 |
+
if math.isnan(measure):
|
29 |
+
continue
|
30 |
+
values.append(measure)
|
31 |
+
|
32 |
+
if not values:
|
33 |
+
print("No valid values found for metric calculation.")
|
34 |
+
return float("nan")
|
35 |
+
|
36 |
+
score = sum(values) / len(values)
|
37 |
+
if update:
|
38 |
+
self.meter.update(score, len(values))
|
39 |
+
return self.meter.avg, values
|
40 |
+
else:
|
41 |
+
return score, values
|
42 |
+
|
43 |
+
def metric(self, **kwargs):
|
44 |
+
"""
|
45 |
+
This method should be overridden by subclasses to provide the specific metric computation.
|
46 |
+
"""
|
47 |
+
raise NotImplementedError("The metric method must be implemented by subclasses.")
|
48 |
+
|
49 |
+
def get_average_score(self):
|
50 |
+
return self.meter.avg
|
51 |
+
|
starvector/metrics/compute_LPIPS.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import ToTensor, Normalize
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from starvector.metrics.base_metric import BaseMetric
|
5 |
+
import lpips
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class LPIPSDistanceCalculator(BaseMetric):
|
10 |
+
def __init__(self, config=None, device='cuda'):
|
11 |
+
super().__init__()
|
12 |
+
self.class_name = self.__class__.__name__
|
13 |
+
self.config = config
|
14 |
+
self.model = lpips.LPIPS(net='vgg').to(device)
|
15 |
+
self.metric = self.LPIPS
|
16 |
+
self.to_tensor = ToTensor()
|
17 |
+
self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
18 |
+
self.device = device
|
19 |
+
|
20 |
+
def LPIPS(self, tensor_image1, tensor_image2):
|
21 |
+
tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device)
|
22 |
+
return self.model(tensor_image1, tensor_image2)
|
23 |
+
|
24 |
+
def to_tensor_transform(self, pil_img):
|
25 |
+
return self.normalize(self.to_tensor(pil_img))
|
26 |
+
|
27 |
+
def collate_fn(self, batch):
|
28 |
+
gt_imgs, gen_imgs = zip(*batch)
|
29 |
+
tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs])
|
30 |
+
tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs])
|
31 |
+
return tensor_gt_imgs, tensor_gen_imgs
|
32 |
+
|
33 |
+
def calculate_score(self, batch, batch_size=8, update=True):
|
34 |
+
gt_images = batch['gt_im']
|
35 |
+
gen_images = batch['gen_im']
|
36 |
+
|
37 |
+
# Create DataLoader with custom collate function
|
38 |
+
data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False)
|
39 |
+
|
40 |
+
values = []
|
41 |
+
for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader):
|
42 |
+
# Compute LPIPS
|
43 |
+
lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch)
|
44 |
+
values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist())
|
45 |
+
|
46 |
+
if not values:
|
47 |
+
print("No valid values found for metric calculation.")
|
48 |
+
return float("nan")
|
49 |
+
|
50 |
+
avg_score = sum(values) / len(values)
|
51 |
+
if update:
|
52 |
+
self.meter.update(avg_score, len(values))
|
53 |
+
return self.meter.avg, values
|
54 |
+
else:
|
55 |
+
return avg_score, values
|
56 |
+
|
starvector/metrics/compute_SSIM.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from starvector.metrics.base_metric import BaseMetric
|
2 |
+
from skimage.metrics import structural_similarity as ssim
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class SSIMDistanceCalculator(BaseMetric):
|
6 |
+
def __init__(self, config=None):
|
7 |
+
super().__init__()
|
8 |
+
self.class_name = self.__class__.__name__
|
9 |
+
self.config = config
|
10 |
+
self.metric = self.compute_SSIM
|
11 |
+
|
12 |
+
def compute_SSIM(self, **kwargs):
|
13 |
+
image1 = kwargs.get('gt_im')
|
14 |
+
image2 = kwargs.get('gen_im')
|
15 |
+
win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy
|
16 |
+
channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1
|
17 |
+
sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter
|
18 |
+
|
19 |
+
# Convert images to numpy arrays if they aren't already
|
20 |
+
img1_np = np.array(image1)
|
21 |
+
img2_np = np.array(image2)
|
22 |
+
|
23 |
+
# Check if images are grayscale or RGB
|
24 |
+
if len(img1_np.shape) == 3 and img1_np.shape[2] == 3:
|
25 |
+
# Compute SSIM for RGB images
|
26 |
+
score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True)
|
27 |
+
else:
|
28 |
+
# Convert to grayscale if not already
|
29 |
+
if len(img1_np.shape) == 3:
|
30 |
+
img1_np = np.mean(img1_np, axis=2)
|
31 |
+
img2_np = np.mean(img2_np, axis=2)
|
32 |
+
|
33 |
+
score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True)
|
34 |
+
|
35 |
+
return score
|
starvector/metrics/compute_clip_score.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import ToTensor
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from starvector.metrics.base_metric import BaseMetric
|
4 |
+
import torch
|
5 |
+
from torchmetrics.multimodal.clip_score import CLIPScore
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torchmetrics.functional.multimodal.clip_score import _clip_score_update
|
10 |
+
|
11 |
+
class CLIPScoreCalculator(BaseMetric):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.class_name = self.__class__.__name__
|
15 |
+
self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32")
|
16 |
+
self.clip_score.to('cuda')
|
17 |
+
|
18 |
+
def CLIP_Score(self, images, captions):
|
19 |
+
all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor)
|
20 |
+
return all_scores
|
21 |
+
|
22 |
+
def collate_fn(self, batch):
|
23 |
+
gen_imgs, captions = zip(*batch)
|
24 |
+
tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs]
|
25 |
+
return tensor_gen_imgs, captions
|
26 |
+
|
27 |
+
def calculate_score(self, batch, batch_size=512, update=True):
|
28 |
+
gen_images = batch['gen_im']
|
29 |
+
captions = batch['caption']
|
30 |
+
|
31 |
+
# Create DataLoader with custom collate function
|
32 |
+
data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
33 |
+
|
34 |
+
all_scores = []
|
35 |
+
for batch_eval in tqdm(data_loader):
|
36 |
+
images, captions = batch_eval
|
37 |
+
images = [img.to('cuda', non_blocking=True) * 255 for img in images]
|
38 |
+
list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist()
|
39 |
+
all_scores.extend(list_scores)
|
40 |
+
|
41 |
+
if not all_scores:
|
42 |
+
print("No valid scores found for metric calculation.")
|
43 |
+
return float("nan"), []
|
44 |
+
|
45 |
+
avg_score = sum(all_scores) / len(all_scores)
|
46 |
+
if update:
|
47 |
+
self.meter.update(avg_score, len(all_scores))
|
48 |
+
return self.meter.avg, all_scores
|
49 |
+
else:
|
50 |
+
return avg_score, all_scores
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
import multiprocessing
|
54 |
+
multiprocessing.set_start_method('spawn')
|
55 |
+
# Rest of your code...
|
starvector/metrics/compute_dino_score.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from starvector.metrics.base_metric import BaseMetric
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import AutoModel, AutoImageProcessor
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
class DINOScoreCalculator(BaseMetric):
|
10 |
+
def __init__(self, config=None, device='cuda'):
|
11 |
+
super().__init__()
|
12 |
+
self.class_name = self.__class__.__name__
|
13 |
+
self.config = config
|
14 |
+
self.model, self.processor = self.get_DINOv2_model("base")
|
15 |
+
self.model = self.model.to(device)
|
16 |
+
self.device = device
|
17 |
+
|
18 |
+
self.metric = self.calculate_DINOv2_similarity_score
|
19 |
+
|
20 |
+
def get_DINOv2_model(self, model_size):
|
21 |
+
if model_size == "small":
|
22 |
+
model_size = "facebook/dinov2-small"
|
23 |
+
elif model_size == "base":
|
24 |
+
model_size = "facebook/dinov2-base"
|
25 |
+
elif model_size == "large":
|
26 |
+
model_size = "facebook/dinov2-large"
|
27 |
+
else:
|
28 |
+
raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}")
|
29 |
+
return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size)
|
30 |
+
|
31 |
+
def process_input(self, image, processor):
|
32 |
+
if isinstance(image, str):
|
33 |
+
image = Image.open(image)
|
34 |
+
if isinstance(image, Image.Image):
|
35 |
+
with torch.no_grad():
|
36 |
+
inputs = processor(images=image, return_tensors="pt").to(self.device)
|
37 |
+
outputs = self.model(**inputs)
|
38 |
+
features = outputs.last_hidden_state.mean(dim=1)
|
39 |
+
elif isinstance(image, torch.Tensor):
|
40 |
+
features = image.unsqueeze(0) if image.dim() == 1 else image
|
41 |
+
else:
|
42 |
+
raise ValueError("Input must be a file path, PIL Image, or tensor of features")
|
43 |
+
return features
|
44 |
+
|
45 |
+
def calculate_DINOv2_similarity_score(self, **kwargs):
|
46 |
+
image1 = kwargs.get('gt_im')
|
47 |
+
image2 = kwargs.get('gen_im')
|
48 |
+
features1 = self.process_input(image1, self.processor)
|
49 |
+
features2 = self.process_input(image2, self.processor)
|
50 |
+
|
51 |
+
cos = nn.CosineSimilarity(dim=1)
|
52 |
+
sim = cos(features1, features2).item()
|
53 |
+
sim = (sim + 1) / 2
|
54 |
+
|
55 |
+
return sim
|
starvector/metrics/compute_fid.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html
|
2 |
+
# from torchmetrics.image.fid import FrechetInceptionDistance
|
3 |
+
from PIL import Image
|
4 |
+
from starvector.metrics.base_metric import BaseMetric
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
import clip
|
8 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
9 |
+
from starvector.metrics.inception import InceptionV3
|
10 |
+
import numpy as np
|
11 |
+
from tqdm import tqdm
|
12 |
+
from scipy import linalg
|
13 |
+
import torchvision.transforms as TF
|
14 |
+
|
15 |
+
class FIDCalculator(BaseMetric):
|
16 |
+
def __init__(self, model_name = 'InceptionV3',):
|
17 |
+
self.class_name = self.__class__.__name__
|
18 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
self.model_name = model_name
|
20 |
+
if self.model_name == 'ViT-B/32':
|
21 |
+
self.dims = 512
|
22 |
+
model, preprocess = clip.load('ViT-B/32')
|
23 |
+
|
24 |
+
elif self.model_name == 'InceptionV3':
|
25 |
+
self.dims = 2048
|
26 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
|
27 |
+
model = InceptionV3([block_idx]).to(self.device)
|
28 |
+
preprocess = TF.Compose([TF.ToTensor()])
|
29 |
+
|
30 |
+
self.model = model.cuda()
|
31 |
+
self.preprocess = preprocess
|
32 |
+
|
33 |
+
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
34 |
+
"""Numpy implementation of the Frechet Distance.
|
35 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
36 |
+
and X_2 ~ N(mu_2, C_2) is
|
37 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
38 |
+
|
39 |
+
Stable version by Dougal J. Sutherland.
|
40 |
+
|
41 |
+
Params:
|
42 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
43 |
+
inception net (like returned by the function 'get_predictions')
|
44 |
+
for generated samples.
|
45 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
46 |
+
representative data set.
|
47 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
48 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
49 |
+
representative data set.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
-- : The Frechet Distance.
|
53 |
+
"""
|
54 |
+
|
55 |
+
mu1 = np.atleast_1d(mu1)
|
56 |
+
mu2 = np.atleast_1d(mu2)
|
57 |
+
|
58 |
+
sigma1 = np.atleast_2d(sigma1)
|
59 |
+
sigma2 = np.atleast_2d(sigma2)
|
60 |
+
|
61 |
+
assert mu1.shape == mu2.shape, \
|
62 |
+
'Training and test mean vectors have different lengths'
|
63 |
+
assert sigma1.shape == sigma2.shape, \
|
64 |
+
'Training and test covariances have different dimensions'
|
65 |
+
|
66 |
+
diff = mu1 - mu2
|
67 |
+
|
68 |
+
# Product might be almost singular
|
69 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
70 |
+
if not np.isfinite(covmean).all():
|
71 |
+
msg = ('fid calculation produces singular product; '
|
72 |
+
'adding %s to diagonal of cov estimates') % eps
|
73 |
+
print(msg)
|
74 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
75 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
76 |
+
|
77 |
+
# Numerical error might give slight imaginary component
|
78 |
+
if np.iscomplexobj(covmean):
|
79 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
80 |
+
m = np.max(np.abs(covmean.imag))
|
81 |
+
raise ValueError('Imaginary component {}'.format(m))
|
82 |
+
covmean = covmean.real
|
83 |
+
|
84 |
+
tr_covmean = np.trace(covmean)
|
85 |
+
|
86 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
87 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
88 |
+
|
89 |
+
def get_activations(self, images):
|
90 |
+
dataset = ImageDataset(images, self.preprocess)
|
91 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4)
|
92 |
+
pred_arr = np.empty((len(images), self.dims))
|
93 |
+
start_idx = 0
|
94 |
+
for batch in tqdm(dataloader):
|
95 |
+
batch = batch.to(self.device)
|
96 |
+
|
97 |
+
with torch.no_grad():
|
98 |
+
if self.model_name == 'ViT-B/32':
|
99 |
+
pred = self.model.encode_image(batch).cpu().numpy()
|
100 |
+
elif self.model_name == 'InceptionV3':
|
101 |
+
pred = self.model(batch)[0]
|
102 |
+
|
103 |
+
# If model output is not scalar, apply global spatial average pooling.
|
104 |
+
# This happens if you choose a dimensionality not equal 2048.
|
105 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
106 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
107 |
+
|
108 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
109 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
110 |
+
start_idx = start_idx + pred.shape[0]
|
111 |
+
|
112 |
+
return pred_arr
|
113 |
+
|
114 |
+
def calculate_activation_statistics(self, images):
|
115 |
+
act = self.get_activations(images)
|
116 |
+
mu = np.mean(act, axis=0)
|
117 |
+
sigma = np.cov(act, rowvar=False)
|
118 |
+
return mu, sigma
|
119 |
+
|
120 |
+
def pil_images_to_tensor(self, images_list):
|
121 |
+
"""Convert a list of PIL Images to a torch.Tensor."""
|
122 |
+
tensors_list = [self.preprocess(img) for img in images_list]
|
123 |
+
return torch.stack(tensors_list).cuda() # BxCxHxW format
|
124 |
+
|
125 |
+
def calculate_score(self, batch):
|
126 |
+
m1, s1 = self.calculate_activation_statistics(batch['gt_im'])
|
127 |
+
m2, s2 = self.calculate_activation_statistics(batch['gen_im'])
|
128 |
+
fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
|
129 |
+
return fid_value
|
130 |
+
|
131 |
+
def reset(self):
|
132 |
+
pass
|
133 |
+
|
134 |
+
class ImageDataset(torch.utils.data.Dataset):
|
135 |
+
def __init__(self, images, processor=None):
|
136 |
+
self.images = images
|
137 |
+
self.processor = processor
|
138 |
+
|
139 |
+
def __len__(self):
|
140 |
+
return len(self.images)
|
141 |
+
|
142 |
+
def __getitem__(self, i):
|
143 |
+
img = self.images[i]
|
144 |
+
img = self.processor(img)
|
145 |
+
return img
|
starvector/metrics/compute_l2.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import ToTensor
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from starvector.metrics.base_metric import BaseMetric
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class L2DistanceCalculator(BaseMetric):
|
7 |
+
def __init__(self, config=None, masked_l2=False):
|
8 |
+
super().__init__()
|
9 |
+
self.class_name = self.__class__.__name__
|
10 |
+
self.config = config
|
11 |
+
self.metric = self.l2_distance
|
12 |
+
self.masked_l2 = masked_l2
|
13 |
+
|
14 |
+
def l2_distance(self, **kwargs):
|
15 |
+
image1 = kwargs.get('gt_im')
|
16 |
+
image2 = kwargs.get('gen_im')
|
17 |
+
image1_tensor = ToTensor()(image1)
|
18 |
+
image2_tensor = ToTensor()(image2)
|
19 |
+
|
20 |
+
if self.masked_l2:
|
21 |
+
# Create binary masks: 0 for white pixels, 1 for non-white pixels
|
22 |
+
mask1 = (image1_tensor != 1).any(dim=0).float()
|
23 |
+
mask2 = (image2_tensor != 1).any(dim=0).float()
|
24 |
+
|
25 |
+
# Create a combined mask for overlapping non-white pixels
|
26 |
+
combined_mask = mask1 * mask2
|
27 |
+
|
28 |
+
# Apply the combined mask to both images
|
29 |
+
image1_tensor = image1_tensor * combined_mask.unsqueeze(0)
|
30 |
+
image2_tensor = image2_tensor * combined_mask.unsqueeze(0)
|
31 |
+
|
32 |
+
# Compute mean squared error
|
33 |
+
mse = F.mse_loss(image1_tensor, image2_tensor)
|
34 |
+
return mse.item()
|
35 |
+
|
36 |
+
|
37 |
+
|
starvector/metrics/count_token_length.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from starvector.metrics.base_metric import BaseMetric
|
4 |
+
from tqdm import tqdm
|
5 |
+
from starvector.metrics.util import AverageMeter
|
6 |
+
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
class CountTokenLength(BaseMetric):
|
10 |
+
def __init__(self, config=None, device='cuda'):
|
11 |
+
super().__init__()
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
|
13 |
+
self.metric = self.calculate_token_length
|
14 |
+
self.meter_gt_tokens = AverageMeter()
|
15 |
+
self.meter_gen_tokens = AverageMeter()
|
16 |
+
self.meter_diff = AverageMeter()
|
17 |
+
|
18 |
+
def calculate_token_length(self, **kwargs):
|
19 |
+
svg = kwargs.get('gt_svg')
|
20 |
+
tokens = self.tokenizer.encode(svg)
|
21 |
+
gen_svg = kwargs.get('gen_svg')
|
22 |
+
gen_tokens = self.tokenizer.encode(gen_svg)
|
23 |
+
diff = len(gen_tokens) - len(tokens)
|
24 |
+
return len(tokens), len(gen_tokens), diff
|
25 |
+
|
26 |
+
def calculate_score(self, batch, update=None):
|
27 |
+
gt_svgs = batch['gt_svg']
|
28 |
+
gen_svgs = batch['gen_svg']
|
29 |
+
values = []
|
30 |
+
for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"):
|
31 |
+
gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg)
|
32 |
+
self.meter_gt_tokens.update(gt_tokens, 1)
|
33 |
+
self.meter_gen_tokens.update(gen_tokens, 1)
|
34 |
+
self.meter_diff.update(diff, 1)
|
35 |
+
values.append({
|
36 |
+
'gt_tokens': gt_tokens,
|
37 |
+
'gen_tokens': gen_tokens,
|
38 |
+
'diff': diff
|
39 |
+
})
|
40 |
+
avg_score = {
|
41 |
+
'gt_tokens': self.meter_gt_tokens.avg,
|
42 |
+
'gen_tokens': self.meter_gen_tokens.avg,
|
43 |
+
'diff': self.meter_diff.avg
|
44 |
+
}
|
45 |
+
if not values:
|
46 |
+
print("No valid values found for metric calculation.")
|
47 |
+
return float("nan")
|
48 |
+
|
49 |
+
return avg_score, values
|
50 |
+
|
51 |
+
def reset(self):
|
52 |
+
self.meter_gt_tokens.reset()
|
53 |
+
self.meter_gen_tokens.reset()
|
54 |
+
self.meter_diff.reset()
|
starvector/metrics/inception.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = _inception_v3(weights='DEFAULT')
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def _inception_v3(*args, **kwargs):
|
167 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
168 |
+
try:
|
169 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
170 |
+
except ValueError:
|
171 |
+
# Just a caution against weird version strings
|
172 |
+
version = (0,)
|
173 |
+
|
174 |
+
# Skips default weight inititialization if supported by torchvision
|
175 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
176 |
+
if version >= (0, 6):
|
177 |
+
kwargs['init_weights'] = False
|
178 |
+
|
179 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
180 |
+
# argument prior to version 0.13.
|
181 |
+
if version < (0, 13) and 'weights' in kwargs:
|
182 |
+
if kwargs['weights'] == 'DEFAULT':
|
183 |
+
kwargs['pretrained'] = True
|
184 |
+
elif kwargs['weights'] is None:
|
185 |
+
kwargs['pretrained'] = False
|
186 |
+
else:
|
187 |
+
raise ValueError(
|
188 |
+
'weights=={} not supported in torchvision {}'.format(
|
189 |
+
kwargs['weights'], torchvision.__version__
|
190 |
+
)
|
191 |
+
)
|
192 |
+
del kwargs['weights']
|
193 |
+
|
194 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
195 |
+
|
196 |
+
|
197 |
+
def fid_inception_v3():
|
198 |
+
"""Build pretrained Inception model for FID computation
|
199 |
+
|
200 |
+
The Inception model for FID computation uses a different set of weights
|
201 |
+
and has a slightly different structure than torchvision's Inception.
|
202 |
+
|
203 |
+
This method first constructs torchvision's Inception and then patches the
|
204 |
+
necessary parts that are different in the FID Inception model.
|
205 |
+
"""
|
206 |
+
inception = _inception_v3(num_classes=1008,
|
207 |
+
aux_logits=False,
|
208 |
+
weights=None)
|
209 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
210 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
211 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
212 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
213 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
214 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
215 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
216 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
217 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
218 |
+
|
219 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
220 |
+
inception.load_state_dict(state_dict)
|
221 |
+
return inception
|
222 |
+
|
223 |
+
|
224 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
225 |
+
"""InceptionA block patched for FID computation"""
|
226 |
+
def __init__(self, in_channels, pool_features):
|
227 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
branch1x1 = self.branch1x1(x)
|
231 |
+
|
232 |
+
branch5x5 = self.branch5x5_1(x)
|
233 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
234 |
+
|
235 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
236 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
237 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
238 |
+
|
239 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
240 |
+
# its average calculation
|
241 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
242 |
+
count_include_pad=False)
|
243 |
+
branch_pool = self.branch_pool(branch_pool)
|
244 |
+
|
245 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
246 |
+
return torch.cat(outputs, 1)
|
247 |
+
|
248 |
+
|
249 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
250 |
+
"""InceptionC block patched for FID computation"""
|
251 |
+
def __init__(self, in_channels, channels_7x7):
|
252 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
branch1x1 = self.branch1x1(x)
|
256 |
+
|
257 |
+
branch7x7 = self.branch7x7_1(x)
|
258 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
259 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
260 |
+
|
261 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
262 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
263 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
264 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
265 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
266 |
+
|
267 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
268 |
+
# its average calculation
|
269 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
270 |
+
count_include_pad=False)
|
271 |
+
branch_pool = self.branch_pool(branch_pool)
|
272 |
+
|
273 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
274 |
+
return torch.cat(outputs, 1)
|
275 |
+
|
276 |
+
|
277 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
278 |
+
"""First InceptionE block patched for FID computation"""
|
279 |
+
def __init__(self, in_channels):
|
280 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
branch1x1 = self.branch1x1(x)
|
284 |
+
|
285 |
+
branch3x3 = self.branch3x3_1(x)
|
286 |
+
branch3x3 = [
|
287 |
+
self.branch3x3_2a(branch3x3),
|
288 |
+
self.branch3x3_2b(branch3x3),
|
289 |
+
]
|
290 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
291 |
+
|
292 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
293 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
294 |
+
branch3x3dbl = [
|
295 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
296 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
297 |
+
]
|
298 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
299 |
+
|
300 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
301 |
+
# its average calculation
|
302 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
303 |
+
count_include_pad=False)
|
304 |
+
branch_pool = self.branch_pool(branch_pool)
|
305 |
+
|
306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
307 |
+
return torch.cat(outputs, 1)
|
308 |
+
|
309 |
+
|
310 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
311 |
+
"""Second InceptionE block patched for FID computation"""
|
312 |
+
def __init__(self, in_channels):
|
313 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
branch1x1 = self.branch1x1(x)
|
317 |
+
|
318 |
+
branch3x3 = self.branch3x3_1(x)
|
319 |
+
branch3x3 = [
|
320 |
+
self.branch3x3_2a(branch3x3),
|
321 |
+
self.branch3x3_2b(branch3x3),
|
322 |
+
]
|
323 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
324 |
+
|
325 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
326 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
327 |
+
branch3x3dbl = [
|
328 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
329 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
330 |
+
]
|
331 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
332 |
+
|
333 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
334 |
+
# pooling. This is likely an error in this specific Inception
|
335 |
+
# implementation, as other Inception models use average pooling here
|
336 |
+
# (which matches the description in the paper).
|
337 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
338 |
+
branch_pool = self.branch_pool(branch_pool)
|
339 |
+
|
340 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
341 |
+
return torch.cat(outputs, 1)
|
starvector/metrics/metrics.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from starvector.metrics.compute_l2 import L2DistanceCalculator
|
2 |
+
from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator
|
3 |
+
from starvector.metrics.compute_SSIM import SSIMDistanceCalculator
|
4 |
+
from starvector.metrics.compute_fid import FIDCalculator
|
5 |
+
from starvector.metrics.compute_clip_score import CLIPScoreCalculator
|
6 |
+
from starvector.data.util import rasterize_svg
|
7 |
+
from starvector.metrics.util import AverageMeter
|
8 |
+
from starvector.metrics.compute_dino_score import DINOScoreCalculator
|
9 |
+
from starvector.metrics.count_token_length import CountTokenLength
|
10 |
+
import os
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
class SVGMetrics:
|
14 |
+
def __init__(self, config=None):
|
15 |
+
self.class_name = self.__class__.__name__
|
16 |
+
|
17 |
+
default_config = {
|
18 |
+
'L2': True,
|
19 |
+
'Masked-L2': False,
|
20 |
+
'LPIPS': False,
|
21 |
+
'SSIM': False,
|
22 |
+
'FID': False,
|
23 |
+
'FID_clip': False,
|
24 |
+
'CLIPScore': False,
|
25 |
+
'CountTokenLength': False,
|
26 |
+
'ratio_post_processed': True,
|
27 |
+
'ratio_non_compiling': True,
|
28 |
+
'DinoScore': True,
|
29 |
+
}
|
30 |
+
self.config = config or default_config
|
31 |
+
|
32 |
+
self.metrics = {
|
33 |
+
'L2': L2DistanceCalculator,
|
34 |
+
'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True),
|
35 |
+
'LPIPS': LPIPSDistanceCalculator,
|
36 |
+
'SSIM': SSIMDistanceCalculator,
|
37 |
+
'FID': lambda: FIDCalculator(model_name='InceptionV3'),
|
38 |
+
'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'),
|
39 |
+
'CLIPScore': CLIPScoreCalculator,
|
40 |
+
'CountTokenLength': CountTokenLength,
|
41 |
+
'ratio_post_processed': AverageMeter,
|
42 |
+
'ratio_non_compiling': AverageMeter,
|
43 |
+
'DinoScore': DINOScoreCalculator,
|
44 |
+
}
|
45 |
+
|
46 |
+
self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)}
|
47 |
+
|
48 |
+
def reset(self):
|
49 |
+
for metric in self.active_metrics.values():
|
50 |
+
metric.reset()
|
51 |
+
|
52 |
+
def batch_contains_raster(self, batch):
|
53 |
+
return "gt_im" in batch and "gen_im" in batch
|
54 |
+
|
55 |
+
def batch_contains_svg(self, batch):
|
56 |
+
return "gt_svg" in batch and "gen_svg" in batch
|
57 |
+
|
58 |
+
def calculate_metrics(self, batch, update=True):
|
59 |
+
if not self.batch_contains_raster(batch):
|
60 |
+
batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]]
|
61 |
+
batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]]
|
62 |
+
|
63 |
+
avg_results_dict = {}
|
64 |
+
all_results_dict = {}
|
65 |
+
|
66 |
+
def get_sample_id(json_item):
|
67 |
+
return json_item.get('outpath_filename') or json_item.get('sample_id')
|
68 |
+
|
69 |
+
# initialize all_results_dict
|
70 |
+
for i, json_item in enumerate(batch['json']):
|
71 |
+
sample_id = get_sample_id(json_item)
|
72 |
+
if sample_id is None:
|
73 |
+
raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]")
|
74 |
+
all_results_dict[sample_id] = {}
|
75 |
+
|
76 |
+
for metric_name, metric in self.active_metrics.items():
|
77 |
+
print(f"Calculating {metric_name}...")
|
78 |
+
|
79 |
+
# Handle metrics that return both average and per-sample results
|
80 |
+
if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']:
|
81 |
+
avg_result, list_result = metric.calculate_score(batch, update=update)
|
82 |
+
avg_results_dict[metric_name] = avg_result
|
83 |
+
|
84 |
+
# Store individual results
|
85 |
+
for i, result in enumerate(list_result):
|
86 |
+
sample_id = get_sample_id(batch['json'][i])
|
87 |
+
all_results_dict[sample_id][metric_name] = result
|
88 |
+
|
89 |
+
# Handle FID metrics that only return average
|
90 |
+
elif metric_name in ['FID', 'FID_clip']:
|
91 |
+
avg_results_dict[metric_name] = metric.calculate_score(batch)
|
92 |
+
|
93 |
+
# Handle other metrics (ratio metrics)
|
94 |
+
else:
|
95 |
+
self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict)
|
96 |
+
|
97 |
+
metric.reset()
|
98 |
+
print("Average results: \n", avg_results_dict)
|
99 |
+
return avg_results_dict, all_results_dict
|
100 |
+
|
101 |
+
def calculate_fid(self, batch):
|
102 |
+
if not self.batch_contains_raster(batch):
|
103 |
+
batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]]
|
104 |
+
batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]]
|
105 |
+
|
106 |
+
return self.active_metrics['FID'].calculate_score(batch).item()
|
107 |
+
|
108 |
+
def get_average_metrics(self):
|
109 |
+
metrics = {}
|
110 |
+
for metric_name, metric in self.active_metrics.items():
|
111 |
+
if hasattr(metric, 'avg'):
|
112 |
+
metrics[metric_name] = metric.avg
|
113 |
+
elif hasattr(metric, 'get_average_score'):
|
114 |
+
metrics[metric_name] = metric.get_average_score()
|
115 |
+
return metrics
|
116 |
+
|
117 |
+
def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict):
|
118 |
+
"""Helper method to handle ratio-based metrics."""
|
119 |
+
metric_key = metric_name.replace('avg_', '').replace('ratio_', '')
|
120 |
+
|
121 |
+
for item in batch['json']:
|
122 |
+
sample_id = get_sample_id(item)
|
123 |
+
value = item[metric_key]
|
124 |
+
all_results_dict[sample_id][metric_name] = value
|
125 |
+
metric.update(value, 1)
|
126 |
+
|
127 |
+
avg_results_dict[metric_name] = metric.avg
|
starvector/metrics/util.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -------------- Metrics --------------
|
3 |
+
class AverageMeter(object):
|
4 |
+
"""Computes and stores the average and current value"""
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
self.reset()
|
8 |
+
|
9 |
+
def reset(self):
|
10 |
+
self.val = 0
|
11 |
+
self.avg = 0
|
12 |
+
self.sum = 0
|
13 |
+
self.count = 0
|
14 |
+
|
15 |
+
def update(self, val, n=1):
|
16 |
+
self.val = val
|
17 |
+
self.sum += val * n
|
18 |
+
self.count += n
|
19 |
+
self.avg = self.sum / self.count
|
20 |
+
|
starvector/model/adapters/adapter.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.init as init
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class Swish(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(Swish, self).__init__()
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
return x * torch.sigmoid(x)
|
11 |
+
|
12 |
+
class Adapter(nn.Module):
|
13 |
+
def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1):
|
14 |
+
super().__init__()
|
15 |
+
self.query_length = query_length
|
16 |
+
self.dropout_prob = dropout_prob
|
17 |
+
self.adapter_norm = adapter_norm
|
18 |
+
|
19 |
+
self.dropout = nn.Dropout(p=self.dropout_prob)
|
20 |
+
|
21 |
+
self.c_fc = nn.Linear(input_size, input_size*2)
|
22 |
+
self.act = Swish()
|
23 |
+
self.c_proj = nn.Linear(input_size*2, output_size)
|
24 |
+
|
25 |
+
if adapter_norm == "layer_norm":
|
26 |
+
self.norm = nn.LayerNorm([self.query_length, output_size])
|
27 |
+
elif adapter_norm == "batch_norm":
|
28 |
+
self.norm = nn.BatchNorm1d(self.query_length)
|
29 |
+
|
30 |
+
self.init_type = init_type.lower()
|
31 |
+
self._initialize_weights()
|
32 |
+
|
33 |
+
def forward(self, hidden_states):
|
34 |
+
hidden_states = self.dropout(hidden_states)
|
35 |
+
hidden_states = self.c_fc(hidden_states)
|
36 |
+
hidden_states = self.act(hidden_states)
|
37 |
+
hidden_states = self.c_proj(hidden_states)
|
38 |
+
hidden_states = self.norm(hidden_states)
|
39 |
+
return hidden_states
|
40 |
+
|
41 |
+
def _initialize_weights(self):
|
42 |
+
for m in self.modules():
|
43 |
+
if isinstance(m, nn.Linear):
|
44 |
+
if self.init_type == "glorot":
|
45 |
+
init.xavier_uniform_(m.weight)
|
46 |
+
if m.bias is not None:
|
47 |
+
init.constant_(m.bias, 0)
|
48 |
+
elif self.init_type == "normal":
|
49 |
+
init.normal_(m.weight, mean=0, std=0.01)
|
50 |
+
if m.bias is not None:
|
51 |
+
init.constant_(m.bias, 0)
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid initialization type specified.")
|
starvector/model/builder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig
|
3 |
+
from starvector.data.base import ImageTrainProcessor
|
4 |
+
from starvector.util import dtype_mapping
|
5 |
+
from transformers import AutoConfig
|
6 |
+
|
7 |
+
def load_pretrained_model(model_path, device="cuda", **kwargs):
|
8 |
+
model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device)
|
9 |
+
tokenizer = model.model.svg_transformer.tokenizer
|
10 |
+
image_processor = ImageTrainProcessor()
|
11 |
+
context_len = model.model.query_length + model.model.max_length
|
12 |
+
return tokenizer, model, image_processor, context_len
|
13 |
+
|
14 |
+
def model_builder(config):
|
15 |
+
model_name = config.model.get("model_name", False)
|
16 |
+
|
17 |
+
args = {
|
18 |
+
"task": config.model.task,
|
19 |
+
"train_image_encoder": config.training.train_image_encoder,
|
20 |
+
"ignore_mismatched_sizes": True,
|
21 |
+
"starcoder_model_name": config.model.starcoder_model_name,
|
22 |
+
"train_LLM": config.training.train_LLM,
|
23 |
+
"torch_dtype": dtype_mapping[config.training.model_precision],
|
24 |
+
"transformer_layer_cls": config.model.get("transformer_layer_cls", False),
|
25 |
+
"use_cache": config.model.use_cache,
|
26 |
+
}
|
27 |
+
if model_name:
|
28 |
+
model = StarVectorForCausalLM.from_pretrained(model_name, **args)
|
29 |
+
else:
|
30 |
+
starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name)
|
31 |
+
|
32 |
+
starvector_config = StarVectorConfig(
|
33 |
+
max_length_train=config.model.max_length,
|
34 |
+
image_encoder_type=config.model.image_encoder_type,
|
35 |
+
use_flash_attn=config.model.use_flash_attn,
|
36 |
+
adapter_norm=config.model.adapter_norm,
|
37 |
+
starcoder_model_name=config.model.starcoder_model_name,
|
38 |
+
torch_dtype=dtype_mapping[config.training.model_precision],
|
39 |
+
num_attention_heads=starcoder_model_config.num_attention_heads,
|
40 |
+
num_hidden_layers=starcoder_model_config.num_hidden_layers,
|
41 |
+
vocab_size=starcoder_model_config.vocab_size,
|
42 |
+
hidden_size=starcoder_model_config.hidden_size,
|
43 |
+
num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None),
|
44 |
+
)
|
45 |
+
model = StarVectorForCausalLM(starvector_config, **args)
|
46 |
+
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
starvector/model/gpt_bigcode/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING
|
16 |
+
|
17 |
+
from transformers.utils import (
|
18 |
+
OptionalDependencyNotAvailable,
|
19 |
+
_LazyModule,
|
20 |
+
is_torch_available,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
_import_structure = {
|
25 |
+
"configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"],
|
26 |
+
}
|
27 |
+
|
28 |
+
try:
|
29 |
+
if not is_torch_available():
|
30 |
+
raise OptionalDependencyNotAvailable()
|
31 |
+
except OptionalDependencyNotAvailable:
|
32 |
+
pass
|
33 |
+
else:
|
34 |
+
_import_structure["modeling_gpt_bigcode"] = [
|
35 |
+
"GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST",
|
36 |
+
"GPTBigCodeForSequenceClassification",
|
37 |
+
"GPTBigCodeForTokenClassification",
|
38 |
+
"GPTBigCodeForCausalLM",
|
39 |
+
"GPTBigCodeModel",
|
40 |
+
"GPTBigCodePreTrainedModel",
|
41 |
+
]
|
42 |
+
|
43 |
+
if TYPE_CHECKING:
|
44 |
+
from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig
|
45 |
+
|
46 |
+
try:
|
47 |
+
if not is_torch_available():
|
48 |
+
raise OptionalDependencyNotAvailable()
|
49 |
+
except OptionalDependencyNotAvailable:
|
50 |
+
pass
|
51 |
+
else:
|
52 |
+
from .modeling_gpt_bigcode import (
|
53 |
+
GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST,
|
54 |
+
GPTBigCodeForCausalLM,
|
55 |
+
GPTBigCodeForSequenceClassification,
|
56 |
+
GPTBigCodeForTokenClassification,
|
57 |
+
GPTBigCodeModel,
|
58 |
+
GPTBigCodePreTrainedModel,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
else:
|
63 |
+
import sys
|
64 |
+
|
65 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
starvector/model/gpt_bigcode/configuration_gpt_bigcode.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The BigCode team and HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" GPTBigCode configuration"""
|
16 |
+
from transformers.configuration_utils import PretrainedConfig
|
17 |
+
from transformers.utils import logging
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
class GPTBigCodeConfig(PretrainedConfig):
|
27 |
+
"""
|
28 |
+
This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a
|
29 |
+
GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a
|
30 |
+
configuration with the defaults will yield a similar configuration to that of the GPTBigCode
|
31 |
+
[gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture.
|
32 |
+
|
33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
34 |
+
documentation from [`PretrainedConfig`] for more information.
|
35 |
+
|
36 |
+
|
37 |
+
Args:
|
38 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
39 |
+
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
40 |
+
`inputs_ids` passed when calling [`GPTBigCodeModel`].
|
41 |
+
n_positions (`int`, *optional*, defaults to 1024):
|
42 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
43 |
+
just in case (e.g., 512 or 1024 or 2048).
|
44 |
+
n_embd (`int`, *optional*, defaults to 768):
|
45 |
+
Dimensionality of the embeddings and hidden states.
|
46 |
+
n_layer (`int`, *optional*, defaults to 12):
|
47 |
+
Number of hidden layers in the Transformer encoder.
|
48 |
+
n_head (`int`, *optional*, defaults to 12):
|
49 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
50 |
+
n_inner (`int`, *optional*, defaults to None):
|
51 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
52 |
+
activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
53 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
|
54 |
+
"gelu_pytorch_tanh"]`.
|
55 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
56 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
57 |
+
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
58 |
+
The dropout ratio for the embeddings.
|
59 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
60 |
+
The dropout ratio for the attention.
|
61 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
62 |
+
The epsilon to use in the layer normalization layers.
|
63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
65 |
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
66 |
+
Scale attention weights by dividing by sqrt(hidden_size)..
|
67 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
69 |
+
attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
|
70 |
+
Whether to call the fused softmax in float32.
|
71 |
+
scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
|
72 |
+
Whether to scale the attention softmax in float32.
|
73 |
+
attention_type (`bool`, *optional*, defaults to `True`):
|
74 |
+
Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
|
75 |
+
Example:
|
76 |
+
|
77 |
+
```python
|
78 |
+
>>> from transformers import GPTBigCodeConfig, GPTBigCodeModel
|
79 |
+
|
80 |
+
>>> # Initializing a GPTBigCode configuration
|
81 |
+
>>> configuration = GPTBigCodeConfig()
|
82 |
+
|
83 |
+
>>> # Initializing a model (with random weights) from the configuration
|
84 |
+
>>> model = GPTBigCodeModel(configuration)
|
85 |
+
|
86 |
+
>>> # Accessing the model configuration
|
87 |
+
>>> configuration = model.config
|
88 |
+
```"""
|
89 |
+
|
90 |
+
model_type = "gpt_bigcode"
|
91 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
92 |
+
attribute_map = {
|
93 |
+
"hidden_size": "n_embd",
|
94 |
+
"max_position_embeddings": "n_positions",
|
95 |
+
"num_attention_heads": "n_head",
|
96 |
+
"num_hidden_layers": "n_layer",
|
97 |
+
}
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
vocab_size=50257,
|
102 |
+
n_positions=1024,
|
103 |
+
n_embd=768,
|
104 |
+
n_layer=12,
|
105 |
+
n_head=12,
|
106 |
+
n_inner=None,
|
107 |
+
activation_function="gelu_pytorch_tanh",
|
108 |
+
resid_pdrop=0.1,
|
109 |
+
embd_pdrop=0.1,
|
110 |
+
attn_pdrop=0.1,
|
111 |
+
layer_norm_epsilon=1e-5,
|
112 |
+
initializer_range=0.02,
|
113 |
+
scale_attn_weights=True,
|
114 |
+
use_cache=True,
|
115 |
+
bos_token_id=50256,
|
116 |
+
eos_token_id=50256,
|
117 |
+
attention_softmax_in_fp32=True,
|
118 |
+
scale_attention_softmax_in_fp32=True,
|
119 |
+
multi_query=True,
|
120 |
+
**kwargs,
|
121 |
+
):
|
122 |
+
self.vocab_size = vocab_size
|
123 |
+
self.n_positions = n_positions
|
124 |
+
self.n_embd = n_embd
|
125 |
+
self.n_layer = n_layer
|
126 |
+
self.n_head = n_head
|
127 |
+
self.n_inner = n_inner
|
128 |
+
self.activation_function = activation_function
|
129 |
+
self.resid_pdrop = resid_pdrop
|
130 |
+
self.embd_pdrop = embd_pdrop
|
131 |
+
self.attn_pdrop = attn_pdrop
|
132 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
133 |
+
self.initializer_range = initializer_range
|
134 |
+
self.scale_attn_weights = scale_attn_weights
|
135 |
+
self.use_cache = use_cache
|
136 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
137 |
+
self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
|
138 |
+
self.multi_query = multi_query
|
139 |
+
|
140 |
+
self.bos_token_id = bos_token_id
|
141 |
+
self.eos_token_id = eos_token_id
|
142 |
+
|
143 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
starvector/model/gpt_bigcode/modeling_gpt_bigcode.py
ADDED
@@ -0,0 +1,1502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""PyTorch GPTBigCode model."""
|
15 |
+
import math
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
28 |
+
CausalLMOutputWithCrossAttentions,
|
29 |
+
SequenceClassifierOutputWithPast,
|
30 |
+
TokenClassifierOutput,
|
31 |
+
)
|
32 |
+
from transformers.modeling_utils import PreTrainedModel
|
33 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
34 |
+
from transformers.utils import (
|
35 |
+
add_code_sample_docstrings,
|
36 |
+
add_start_docstrings,
|
37 |
+
add_start_docstrings_to_model_forward,
|
38 |
+
is_flash_attn_2_available,
|
39 |
+
is_flash_attn_greater_or_equal_2_10,
|
40 |
+
logging,
|
41 |
+
)
|
42 |
+
from starvector.model.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig
|
43 |
+
|
44 |
+
|
45 |
+
if is_flash_attn_2_available():
|
46 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
47 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
|
53 |
+
_CONFIG_FOR_DOC = "GPTBigCodeConfig"
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
# Fused kernels
|
58 |
+
# Use separate functions for each case because conditionals prevent kernel fusion.
|
59 |
+
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
60 |
+
# Is it doable without writing 32 functions?
|
61 |
+
@torch.jit.script
|
62 |
+
def upcast_masked_softmax(
|
63 |
+
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
64 |
+
):
|
65 |
+
input_dtype = x.dtype
|
66 |
+
x = x.to(softmax_dtype) * scale
|
67 |
+
x = torch.where(mask, x, mask_value)
|
68 |
+
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
@torch.jit.script
|
73 |
+
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
|
74 |
+
input_dtype = x.dtype
|
75 |
+
x = x.to(softmax_dtype) * scale
|
76 |
+
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
@torch.jit.script
|
81 |
+
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
|
82 |
+
x = torch.where(mask, x, mask_value)
|
83 |
+
x = torch.nn.functional.softmax(x, dim=-1)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
88 |
+
def _get_unpad_data(attention_mask):
|
89 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
90 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
91 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
92 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
93 |
+
return (
|
94 |
+
indices,
|
95 |
+
cu_seqlens,
|
96 |
+
max_seqlen_in_batch,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class GPTBigCodeAttention(nn.Module):
|
101 |
+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
102 |
+
super().__init__()
|
103 |
+
self.config = config
|
104 |
+
|
105 |
+
self.mask_value = None
|
106 |
+
self.multi_query = config.multi_query
|
107 |
+
self.embed_dim = config.hidden_size
|
108 |
+
self.num_heads = config.num_attention_heads
|
109 |
+
self.head_dim = self.embed_dim // self.num_heads
|
110 |
+
self.kv_heads = 1 if self.multi_query else self.num_heads
|
111 |
+
self.kv_dim = self.kv_heads * self.head_dim
|
112 |
+
self.split_size = self.embed_dim
|
113 |
+
self.is_causal = True
|
114 |
+
|
115 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
116 |
+
raise ValueError(
|
117 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
118 |
+
f" {self.num_heads})."
|
119 |
+
)
|
120 |
+
|
121 |
+
self.scale_attn_weights = config.scale_attn_weights
|
122 |
+
self.is_cross_attention = is_cross_attention
|
123 |
+
|
124 |
+
self.layer_idx = layer_idx
|
125 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
126 |
+
self.scale_attention_softmax_in_fp32 = (
|
127 |
+
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
128 |
+
)
|
129 |
+
self.attn_pdrop = config.attn_pdrop
|
130 |
+
|
131 |
+
if self.is_cross_attention:
|
132 |
+
if self.multi_query:
|
133 |
+
raise NotImplementedError("Multi-Query Attention not supported for cross_attention")
|
134 |
+
|
135 |
+
self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim)
|
136 |
+
self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
|
137 |
+
else:
|
138 |
+
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
|
139 |
+
|
140 |
+
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
141 |
+
|
142 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
143 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
144 |
+
|
145 |
+
def _get_mask_value(self, device, dtype):
|
146 |
+
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
|
147 |
+
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
|
148 |
+
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
|
149 |
+
return self.mask_value
|
150 |
+
|
151 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
152 |
+
dtype = query.dtype
|
153 |
+
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
|
154 |
+
upcast = dtype != softmax_dtype
|
155 |
+
|
156 |
+
unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
|
157 |
+
scale_factor = unscale**-1
|
158 |
+
if self.scale_attn_weights:
|
159 |
+
scale_factor /= self.head_dim**0.5
|
160 |
+
|
161 |
+
# MQA models: (batch_size, query_length, num_heads * head_dim)
|
162 |
+
# MHA models: (batch_size, num_heads, query_length, head_dim)
|
163 |
+
query_shape = query.shape
|
164 |
+
batch_size = query_shape[0]
|
165 |
+
key_length = key.size(-1)
|
166 |
+
if self.multi_query:
|
167 |
+
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
|
168 |
+
# -> (batch_size, query_length, num_heads, key_length)
|
169 |
+
query_length = query_shape[1]
|
170 |
+
attn_shape = (batch_size, query_length, self.num_heads, key_length)
|
171 |
+
attn_view = (batch_size, query_length * self.num_heads, key_length)
|
172 |
+
# No copy needed for MQA 2, or when layer_past is provided.
|
173 |
+
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
|
174 |
+
else:
|
175 |
+
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
|
176 |
+
# -> (batch_size, num_heads, query_length, key_length)
|
177 |
+
query_length = query_shape[2]
|
178 |
+
attn_shape = (batch_size, self.num_heads, query_length, key_length)
|
179 |
+
attn_view = (batch_size * self.num_heads, query_length, key_length)
|
180 |
+
# Always copies
|
181 |
+
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
182 |
+
# No copy when layer_past is provided.
|
183 |
+
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
|
184 |
+
|
185 |
+
attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
|
186 |
+
if query.device.type == "cpu":
|
187 |
+
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
|
188 |
+
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
|
189 |
+
# but the fix has not been released as of pytorch version 2.0.0.
|
190 |
+
attn_weights = torch.zeros_like(attn_weights)
|
191 |
+
beta = 1
|
192 |
+
else:
|
193 |
+
beta = 0
|
194 |
+
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
|
195 |
+
|
196 |
+
if upcast:
|
197 |
+
# Use a fused kernel to prevent a large overhead from casting and scaling.
|
198 |
+
# Sub-optimal when the key length is not a multiple of 8.
|
199 |
+
if attention_mask is None:
|
200 |
+
attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
|
201 |
+
else:
|
202 |
+
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
|
203 |
+
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
|
204 |
+
else:
|
205 |
+
if attention_mask is not None:
|
206 |
+
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
|
207 |
+
|
208 |
+
# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
|
209 |
+
attn_weights = torch.where(attention_mask, attn_weights, mask_value)
|
210 |
+
|
211 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
212 |
+
|
213 |
+
attn_weights = self.attn_dropout(attn_weights)
|
214 |
+
|
215 |
+
# Mask heads if we want to
|
216 |
+
if head_mask is not None:
|
217 |
+
if self.multi_query:
|
218 |
+
head_mask = head_mask.transpose(1, 2)
|
219 |
+
attn_weights = attn_weights * head_mask
|
220 |
+
|
221 |
+
if self.multi_query:
|
222 |
+
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
|
223 |
+
else:
|
224 |
+
attn_output = torch.matmul(attn_weights, value)
|
225 |
+
|
226 |
+
return attn_output, attn_weights
|
227 |
+
|
228 |
+
def forward(
|
229 |
+
self,
|
230 |
+
hidden_states: torch.Tensor,
|
231 |
+
layer_past: Optional[torch.Tensor] = None,
|
232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
233 |
+
head_mask: Optional[torch.Tensor] = None,
|
234 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
235 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
236 |
+
use_cache: Optional[bool] = False,
|
237 |
+
output_attentions: Optional[bool] = False,
|
238 |
+
) -> Union[
|
239 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
240 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
241 |
+
]:
|
242 |
+
if encoder_hidden_states is not None:
|
243 |
+
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
244 |
+
raise ValueError(
|
245 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
246 |
+
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
|
247 |
+
)
|
248 |
+
|
249 |
+
query = self.q_attn(hidden_states)
|
250 |
+
key_value = self.c_attn(encoder_hidden_states)
|
251 |
+
attention_mask = encoder_attention_mask
|
252 |
+
elif self.multi_query:
|
253 |
+
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
|
254 |
+
else:
|
255 |
+
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
256 |
+
# i.e., the memory layout is not the same as GPT2.
|
257 |
+
# This makes the concatenation with past_key_value more efficient.
|
258 |
+
query, key_value = (
|
259 |
+
self.c_attn(hidden_states)
|
260 |
+
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
261 |
+
.transpose(1, 2)
|
262 |
+
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
263 |
+
)
|
264 |
+
|
265 |
+
if layer_past is not None:
|
266 |
+
key_value = torch.cat((layer_past, key_value), dim=-2)
|
267 |
+
present = key_value if use_cache else None
|
268 |
+
|
269 |
+
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
270 |
+
|
271 |
+
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
|
272 |
+
|
273 |
+
if not self.multi_query:
|
274 |
+
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
275 |
+
attn_output = self.c_proj(attn_output)
|
276 |
+
attn_output = self.resid_dropout(attn_output)
|
277 |
+
|
278 |
+
outputs = (attn_output, present)
|
279 |
+
if output_attentions:
|
280 |
+
if self.multi_query:
|
281 |
+
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
282 |
+
attn_weights = attn_weights.transpose(1, 2)
|
283 |
+
outputs += (attn_weights,)
|
284 |
+
|
285 |
+
return outputs # a, present, (attentions)
|
286 |
+
|
287 |
+
|
288 |
+
class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
289 |
+
"""
|
290 |
+
GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
|
291 |
+
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
|
292 |
+
API of flash attention and deal with padding tokens in case the input contains any of them.
|
293 |
+
"""
|
294 |
+
|
295 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
296 |
+
def __init__(self, *args, **kwargs):
|
297 |
+
super().__init__(*args, **kwargs)
|
298 |
+
|
299 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
300 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
301 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
302 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
303 |
+
|
304 |
+
def forward(
|
305 |
+
self,
|
306 |
+
hidden_states: torch.Tensor,
|
307 |
+
layer_past: Optional[torch.Tensor] = None,
|
308 |
+
attention_mask: Optional[torch.Tensor] = None,
|
309 |
+
head_mask: Optional[torch.Tensor] = None,
|
310 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
311 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
312 |
+
use_cache: Optional[bool] = False,
|
313 |
+
output_attentions: Optional[bool] = False,
|
314 |
+
) -> Union[
|
315 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
316 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
317 |
+
]:
|
318 |
+
if encoder_hidden_states is not None:
|
319 |
+
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
320 |
+
raise ValueError(
|
321 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
322 |
+
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
|
323 |
+
)
|
324 |
+
|
325 |
+
query = self.q_attn(hidden_states)
|
326 |
+
key_value = self.c_attn(encoder_hidden_states)
|
327 |
+
attention_mask = encoder_attention_mask
|
328 |
+
elif self.multi_query:
|
329 |
+
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
|
330 |
+
else:
|
331 |
+
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
332 |
+
# i.e., the memory layout is not the same as GPT2.
|
333 |
+
# This makes the concatenation with past_key_value more efficient.
|
334 |
+
query, key_value = (
|
335 |
+
self.c_attn(hidden_states)
|
336 |
+
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
337 |
+
.transpose(1, 2)
|
338 |
+
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
339 |
+
)
|
340 |
+
|
341 |
+
if layer_past is not None:
|
342 |
+
key_value = torch.cat((layer_past, key_value), dim=-2)
|
343 |
+
present = key_value if use_cache else None
|
344 |
+
|
345 |
+
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
346 |
+
|
347 |
+
# Flash attention requires the input to have the shape
|
348 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
349 |
+
if self.multi_query:
|
350 |
+
batch_size, query_length, _ = query.shape
|
351 |
+
query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
|
352 |
+
key = key.unsqueeze(2)
|
353 |
+
value = value.unsqueeze(2)
|
354 |
+
else:
|
355 |
+
query_length = query.shape[2]
|
356 |
+
batch_size, _, tgt, _ = key.shape
|
357 |
+
query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
|
358 |
+
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
|
359 |
+
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
|
360 |
+
|
361 |
+
attn_dropout = self.attn_pdrop if self.training else 0.0
|
362 |
+
|
363 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
364 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
365 |
+
# cast them back in float16 just to be sure everything works as expected.
|
366 |
+
input_dtype = query.dtype
|
367 |
+
if input_dtype == torch.float32:
|
368 |
+
if torch.is_autocast_enabled():
|
369 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
370 |
+
# Handle the case where the model is quantized
|
371 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
372 |
+
target_dtype = self.config._pre_quantization_dtype
|
373 |
+
else:
|
374 |
+
target_dtype = self.c_attn.weight.dtype
|
375 |
+
|
376 |
+
logger.warning_once(
|
377 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
378 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
379 |
+
f" {target_dtype}."
|
380 |
+
)
|
381 |
+
query = query.to(target_dtype)
|
382 |
+
key = key.to(target_dtype)
|
383 |
+
value = value.to(target_dtype)
|
384 |
+
|
385 |
+
attn_output = self._flash_attention_forward(
|
386 |
+
query, key, value, attention_mask, query_length, dropout=attn_dropout
|
387 |
+
)
|
388 |
+
|
389 |
+
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
390 |
+
attn_output = self.c_proj(attn_weights_reshaped)
|
391 |
+
attn_output = self.resid_dropout(attn_output)
|
392 |
+
|
393 |
+
outputs = (attn_output, present)
|
394 |
+
|
395 |
+
if output_attentions:
|
396 |
+
if self.multi_query:
|
397 |
+
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
398 |
+
attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
|
399 |
+
else:
|
400 |
+
attn_weights_reshaped = None
|
401 |
+
|
402 |
+
outputs += (attn_weights_reshaped,)
|
403 |
+
|
404 |
+
return outputs # a, present, (attentions)
|
405 |
+
|
406 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
407 |
+
def _flash_attention_forward(
|
408 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
409 |
+
):
|
410 |
+
"""
|
411 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
412 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
query_states (`torch.Tensor`):
|
416 |
+
Input query states to be passed to Flash Attention API
|
417 |
+
key_states (`torch.Tensor`):
|
418 |
+
Input key states to be passed to Flash Attention API
|
419 |
+
value_states (`torch.Tensor`):
|
420 |
+
Input value states to be passed to Flash Attention API
|
421 |
+
attention_mask (`torch.Tensor`):
|
422 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
423 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
424 |
+
dropout (`float`):
|
425 |
+
Attention dropout
|
426 |
+
softmax_scale (`float`, *optional*):
|
427 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
428 |
+
"""
|
429 |
+
if not self._flash_attn_uses_top_left_mask:
|
430 |
+
causal = self.is_causal
|
431 |
+
else:
|
432 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
433 |
+
causal = self.is_causal and query_length != 1
|
434 |
+
|
435 |
+
# Contains at least one padding token in the sequence
|
436 |
+
if attention_mask is not None:
|
437 |
+
batch_size = query_states.shape[0]
|
438 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
439 |
+
query_states, key_states, value_states, attention_mask, query_length
|
440 |
+
)
|
441 |
+
|
442 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
443 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
444 |
+
|
445 |
+
attn_output_unpad = flash_attn_varlen_func(
|
446 |
+
query_states,
|
447 |
+
key_states,
|
448 |
+
value_states,
|
449 |
+
cu_seqlens_q=cu_seqlens_q,
|
450 |
+
cu_seqlens_k=cu_seqlens_k,
|
451 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
452 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
453 |
+
dropout_p=dropout,
|
454 |
+
softmax_scale=softmax_scale,
|
455 |
+
causal=causal,
|
456 |
+
)
|
457 |
+
|
458 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
459 |
+
else:
|
460 |
+
attn_output = flash_attn_func(
|
461 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
462 |
+
)
|
463 |
+
|
464 |
+
return attn_output
|
465 |
+
|
466 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
467 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
468 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
469 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
470 |
+
|
471 |
+
key_layer = index_first_axis(
|
472 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
473 |
+
)
|
474 |
+
value_layer = index_first_axis(
|
475 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
476 |
+
)
|
477 |
+
if query_length == kv_seq_len:
|
478 |
+
query_layer = index_first_axis(
|
479 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
480 |
+
)
|
481 |
+
cu_seqlens_q = cu_seqlens_k
|
482 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
483 |
+
indices_q = indices_k
|
484 |
+
elif query_length == 1:
|
485 |
+
max_seqlen_in_batch_q = 1
|
486 |
+
cu_seqlens_q = torch.arange(
|
487 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
488 |
+
) # There is a memcpy here, that is very bad.
|
489 |
+
indices_q = cu_seqlens_q[:-1]
|
490 |
+
query_layer = query_layer.squeeze(1)
|
491 |
+
else:
|
492 |
+
# The -q_len: slice assumes left padding.
|
493 |
+
attention_mask = attention_mask[:, -query_length:]
|
494 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
495 |
+
|
496 |
+
return (
|
497 |
+
query_layer,
|
498 |
+
key_layer,
|
499 |
+
value_layer,
|
500 |
+
indices_q,
|
501 |
+
(cu_seqlens_q, cu_seqlens_k),
|
502 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
|
507 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
508 |
+
if head_mask is not None:
|
509 |
+
# The super dispatch is done in the forward.
|
510 |
+
raise ValueError(
|
511 |
+
"PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
|
512 |
+
)
|
513 |
+
|
514 |
+
scale = None
|
515 |
+
if not self.scale_attn_weights:
|
516 |
+
scale = 1
|
517 |
+
|
518 |
+
# MQA models: (batch_size, query_length, num_heads * head_dim)
|
519 |
+
# MHA models: (batch_size, num_heads, query_length, head_dim)
|
520 |
+
query_shape = query.shape
|
521 |
+
batch_size = query_shape[0]
|
522 |
+
key.shape[-2]
|
523 |
+
|
524 |
+
if self.multi_query:
|
525 |
+
query_length = query_shape[1]
|
526 |
+
|
527 |
+
# SDPA requires the dimension [..., sequence_length, head_dim].
|
528 |
+
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
529 |
+
|
530 |
+
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
|
531 |
+
key = key.unsqueeze(1)
|
532 |
+
value = value.unsqueeze(1)
|
533 |
+
|
534 |
+
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
|
535 |
+
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
|
536 |
+
# query = [batch_size, num_heads, query_length, head_dim]
|
537 |
+
# key = [batch_size, 1, past_length, head_dim]
|
538 |
+
# value = [batch_size, 1, past_length, head_dim]
|
539 |
+
#
|
540 |
+
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
|
541 |
+
if is_torch_greater_or_equal_than_2_2:
|
542 |
+
key = key.expand(-1, self.num_heads, -1, -1)
|
543 |
+
value = value.expand(-1, self.num_heads, -1, -1)
|
544 |
+
else:
|
545 |
+
query_length = query_shape[-1]
|
546 |
+
|
547 |
+
# See the comment above.
|
548 |
+
if query.device.type == "cuda" and attention_mask is not None:
|
549 |
+
query = query.contiguous()
|
550 |
+
key = key.contiguous()
|
551 |
+
value = value.contiguous()
|
552 |
+
|
553 |
+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
|
554 |
+
query,
|
555 |
+
key,
|
556 |
+
value,
|
557 |
+
attn_mask=attention_mask,
|
558 |
+
dropout_p=self.attn_pdrop if self.training else 0.0,
|
559 |
+
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
|
560 |
+
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
561 |
+
scale=scale,
|
562 |
+
)
|
563 |
+
|
564 |
+
if self.multi_query:
|
565 |
+
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
|
566 |
+
sdpa_result = sdpa_result.transpose(1, 2)
|
567 |
+
|
568 |
+
# Reshape is kind of expensive here, as it does a memory copy,
|
569 |
+
# but I did not manage to make away without it (logits do not match when using view)
|
570 |
+
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
|
571 |
+
sdpa_result = sdpa_result.reshape(query_shape)
|
572 |
+
|
573 |
+
return sdpa_result, None
|
574 |
+
|
575 |
+
def forward(
|
576 |
+
self,
|
577 |
+
hidden_states: torch.Tensor,
|
578 |
+
layer_past: Optional[torch.Tensor] = None,
|
579 |
+
attention_mask: Optional[torch.Tensor] = None,
|
580 |
+
head_mask: Optional[torch.Tensor] = None,
|
581 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
582 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
583 |
+
use_cache: Optional[bool] = False,
|
584 |
+
output_attentions: Optional[bool] = False,
|
585 |
+
) -> Union[
|
586 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
587 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
588 |
+
]:
|
589 |
+
if encoder_hidden_states is not None:
|
590 |
+
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
591 |
+
raise ValueError(
|
592 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
593 |
+
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
|
594 |
+
)
|
595 |
+
|
596 |
+
query = self.q_attn(hidden_states)
|
597 |
+
key_value = self.c_attn(encoder_hidden_states)
|
598 |
+
attention_mask = encoder_attention_mask
|
599 |
+
elif self.multi_query:
|
600 |
+
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
|
601 |
+
else:
|
602 |
+
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
603 |
+
# i.e., the memory layout is not the same as GPT2.
|
604 |
+
# This makes the concatenation with past_key_value more efficient.
|
605 |
+
query, key_value = (
|
606 |
+
self.c_attn(hidden_states)
|
607 |
+
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
608 |
+
.transpose(1, 2)
|
609 |
+
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
610 |
+
)
|
611 |
+
|
612 |
+
if layer_past is not None:
|
613 |
+
key_value = torch.cat((layer_past, key_value), dim=-2)
|
614 |
+
present = key_value if use_cache else None
|
615 |
+
|
616 |
+
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
617 |
+
|
618 |
+
if not output_attentions and head_mask is None:
|
619 |
+
# Difference with the original implementation: there is no need to transpose the key here,
|
620 |
+
# as SDPA expects seq_length to be at index -2 for the key as well
|
621 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
622 |
+
else:
|
623 |
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
624 |
+
logger.warning_once(
|
625 |
+
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
|
626 |
+
' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
627 |
+
)
|
628 |
+
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
|
629 |
+
|
630 |
+
if not self.multi_query:
|
631 |
+
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
632 |
+
attn_output = self.c_proj(attn_output)
|
633 |
+
attn_output = self.resid_dropout(attn_output)
|
634 |
+
|
635 |
+
outputs = (attn_output, present)
|
636 |
+
if output_attentions:
|
637 |
+
if self.multi_query:
|
638 |
+
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
639 |
+
attn_weights = attn_weights.transpose(1, 2)
|
640 |
+
outputs += (attn_weights,)
|
641 |
+
|
642 |
+
return outputs
|
643 |
+
|
644 |
+
|
645 |
+
class GPTBigCodeMLP(nn.Module):
|
646 |
+
def __init__(self, intermediate_size, config):
|
647 |
+
super().__init__()
|
648 |
+
embed_dim = config.hidden_size
|
649 |
+
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
650 |
+
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
651 |
+
self.act = ACT2FN[config.activation_function]
|
652 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
653 |
+
|
654 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
|
655 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
656 |
+
hidden_states = self.c_fc(hidden_states)
|
657 |
+
hidden_states = self.act(hidden_states)
|
658 |
+
hidden_states = self.c_proj(hidden_states)
|
659 |
+
hidden_states = self.dropout(hidden_states)
|
660 |
+
return hidden_states
|
661 |
+
|
662 |
+
|
663 |
+
GPTBIGCODE_ATTENTION_CLASSES = {
|
664 |
+
"eager": GPTBigCodeAttention,
|
665 |
+
"flash_attention_2": GPTBigCodeFlashAttention2,
|
666 |
+
"sdpa": GPTBigCodeSdpaAttention,
|
667 |
+
}
|
668 |
+
|
669 |
+
|
670 |
+
class GPTBigCodeBlock(nn.Module):
|
671 |
+
def __init__(self, config, layer_idx=None):
|
672 |
+
super().__init__()
|
673 |
+
hidden_size = config.hidden_size
|
674 |
+
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
675 |
+
|
676 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
677 |
+
|
678 |
+
self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
679 |
+
|
680 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
681 |
+
|
682 |
+
if config.add_cross_attention:
|
683 |
+
if config.multi_query:
|
684 |
+
raise NotImplementedError("Cross-attention not implemented for MQA")
|
685 |
+
|
686 |
+
self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
|
687 |
+
config, is_cross_attention=True, layer_idx=layer_idx
|
688 |
+
)
|
689 |
+
|
690 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
691 |
+
|
692 |
+
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
|
693 |
+
|
694 |
+
def forward(
|
695 |
+
self,
|
696 |
+
hidden_states: Optional[Tuple[torch.Tensor]],
|
697 |
+
layer_past: Optional[torch.Tensor] = None,
|
698 |
+
attention_mask: Optional[torch.Tensor] = None,
|
699 |
+
head_mask: Optional[torch.Tensor] = None,
|
700 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
701 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
702 |
+
use_cache: Optional[bool] = False,
|
703 |
+
output_attentions: Optional[bool] = False,
|
704 |
+
) -> Union[
|
705 |
+
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
706 |
+
]:
|
707 |
+
residual = hidden_states
|
708 |
+
hidden_states = self.ln_1(hidden_states)
|
709 |
+
attn_outputs = self.attn(
|
710 |
+
hidden_states,
|
711 |
+
layer_past=layer_past,
|
712 |
+
attention_mask=attention_mask,
|
713 |
+
head_mask=head_mask,
|
714 |
+
use_cache=use_cache,
|
715 |
+
output_attentions=output_attentions,
|
716 |
+
)
|
717 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
718 |
+
outputs = attn_outputs[1:]
|
719 |
+
# residual connection
|
720 |
+
hidden_states = attn_output + residual
|
721 |
+
|
722 |
+
if encoder_hidden_states is not None:
|
723 |
+
# add one self-attention block for cross-attention
|
724 |
+
if not hasattr(self, "crossattention"):
|
725 |
+
raise ValueError(
|
726 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
727 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
728 |
+
)
|
729 |
+
residual = hidden_states
|
730 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
731 |
+
cross_attn_outputs = self.crossattention(
|
732 |
+
hidden_states,
|
733 |
+
attention_mask=attention_mask,
|
734 |
+
head_mask=head_mask,
|
735 |
+
encoder_hidden_states=encoder_hidden_states,
|
736 |
+
encoder_attention_mask=encoder_attention_mask,
|
737 |
+
output_attentions=output_attentions,
|
738 |
+
)
|
739 |
+
attn_output = cross_attn_outputs[0]
|
740 |
+
# residual connection
|
741 |
+
hidden_states = residual + attn_output
|
742 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
743 |
+
|
744 |
+
residual = hidden_states
|
745 |
+
hidden_states = self.ln_2(hidden_states)
|
746 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
747 |
+
# residual connection
|
748 |
+
hidden_states = residual + feed_forward_hidden_states
|
749 |
+
|
750 |
+
if use_cache:
|
751 |
+
outputs = (hidden_states,) + outputs
|
752 |
+
else:
|
753 |
+
outputs = (hidden_states,) + outputs[1:]
|
754 |
+
|
755 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
756 |
+
|
757 |
+
|
758 |
+
class GPTBigCodePreTrainedModel(PreTrainedModel):
|
759 |
+
"""
|
760 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
761 |
+
models.
|
762 |
+
"""
|
763 |
+
|
764 |
+
config_class = GPTBigCodeConfig
|
765 |
+
base_model_prefix = "transformer"
|
766 |
+
supports_gradient_checkpointing = True
|
767 |
+
_no_split_modules = ["GPTBigCodeBlock"]
|
768 |
+
_skip_keys_device_placement = "past_key_values"
|
769 |
+
_supports_flash_attn_2 = True
|
770 |
+
_supports_sdpa = True
|
771 |
+
|
772 |
+
def __init__(self, *inputs, **kwargs):
|
773 |
+
super().__init__(*inputs, **kwargs)
|
774 |
+
|
775 |
+
def _init_weights(self, module):
|
776 |
+
"""Initialize the weights."""
|
777 |
+
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
|
778 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
779 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
780 |
+
# > the weights of residual layers at initialization by a factor of 1/βN where N is the # of residual layers.
|
781 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
782 |
+
#
|
783 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
784 |
+
module.c_proj.weight.data.normal_(
|
785 |
+
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
786 |
+
)
|
787 |
+
module.c_proj._is_hf_initialized = True
|
788 |
+
elif isinstance(module, nn.Linear):
|
789 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
790 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
791 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
792 |
+
if module.bias is not None:
|
793 |
+
module.bias.data.zero_()
|
794 |
+
elif isinstance(module, nn.Embedding):
|
795 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
796 |
+
if module.padding_idx is not None:
|
797 |
+
module.weight.data[module.padding_idx].zero_()
|
798 |
+
elif isinstance(module, nn.LayerNorm):
|
799 |
+
module.bias.data.zero_()
|
800 |
+
module.weight.data.fill_(1.0)
|
801 |
+
|
802 |
+
|
803 |
+
GPT_BIGCODE_START_DOCSTRING = r"""
|
804 |
+
|
805 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
806 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
807 |
+
etc.)
|
808 |
+
|
809 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
810 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
811 |
+
and behavior.
|
812 |
+
|
813 |
+
Parameters:
|
814 |
+
config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model.
|
815 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
816 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
817 |
+
"""
|
818 |
+
|
819 |
+
GPT_BIGCODE_INPUTS_DOCSTRING = r"""
|
820 |
+
Args:
|
821 |
+
input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
|
822 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
823 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
824 |
+
sequence tokens in the vocabulary.
|
825 |
+
|
826 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
827 |
+
`input_ids`.
|
828 |
+
|
829 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
830 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
831 |
+
|
832 |
+
[What are input IDs?](../glossary#input-ids)
|
833 |
+
past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
|
834 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
835 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
836 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
837 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
838 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
839 |
+
|
840 |
+
- 1 for tokens that are **not masked**,
|
841 |
+
- 0 for tokens that are **masked**.
|
842 |
+
|
843 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
844 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
845 |
+
`len(past_key_values) + len(input_ids)`
|
846 |
+
|
847 |
+
[What are attention masks?](../glossary#attention-mask)
|
848 |
+
token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
849 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
850 |
+
1]`:
|
851 |
+
|
852 |
+
- 0 corresponds to a *sentence A* token,
|
853 |
+
- 1 corresponds to a *sentence B* token.
|
854 |
+
|
855 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
856 |
+
position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
857 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
858 |
+
config.max_position_embeddings - 1]`.
|
859 |
+
|
860 |
+
[What are position IDs?](../glossary#position-ids)
|
861 |
+
head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
862 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
863 |
+
|
864 |
+
- 1 indicates the head is **not masked**,
|
865 |
+
- 0 indicates the head is **masked**.
|
866 |
+
|
867 |
+
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
868 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
869 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
870 |
+
model's internal embedding lookup matrix.
|
871 |
+
|
872 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
873 |
+
`past_key_values`).
|
874 |
+
use_cache (`bool`, *optional*):
|
875 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
876 |
+
`past_key_values`).
|
877 |
+
output_attentions (`bool`, *optional*):
|
878 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
879 |
+
tensors for more detail.
|
880 |
+
output_hidden_states (`bool`, *optional*):
|
881 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
882 |
+
more detail.
|
883 |
+
return_dict (`bool`, *optional*):
|
884 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
885 |
+
"""
|
886 |
+
|
887 |
+
|
888 |
+
@add_start_docstrings(
|
889 |
+
"The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
|
890 |
+
GPT_BIGCODE_START_DOCSTRING,
|
891 |
+
)
|
892 |
+
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
893 |
+
def __init__(self, config):
|
894 |
+
super().__init__(config)
|
895 |
+
self.multi_query = config.multi_query
|
896 |
+
self.embed_dim = config.hidden_size
|
897 |
+
|
898 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
899 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
900 |
+
|
901 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
902 |
+
self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
903 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
904 |
+
|
905 |
+
max_positions = config.max_position_embeddings
|
906 |
+
self.register_buffer(
|
907 |
+
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
|
908 |
+
)
|
909 |
+
|
910 |
+
self.gradient_checkpointing = False
|
911 |
+
|
912 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
913 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
914 |
+
|
915 |
+
# Initialize weights and apply final processing
|
916 |
+
self.post_init()
|
917 |
+
|
918 |
+
def get_input_embeddings(self):
|
919 |
+
return self.wte
|
920 |
+
|
921 |
+
def set_input_embeddings(self, new_embeddings):
|
922 |
+
self.wte = new_embeddings
|
923 |
+
|
924 |
+
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
925 |
+
@add_code_sample_docstrings(
|
926 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
927 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
928 |
+
config_class=_CONFIG_FOR_DOC,
|
929 |
+
)
|
930 |
+
def forward(
|
931 |
+
self,
|
932 |
+
input_ids: Optional[torch.Tensor] = None,
|
933 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
934 |
+
attention_mask: Optional[torch.Tensor] = None,
|
935 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
936 |
+
position_ids: Optional[torch.Tensor] = None,
|
937 |
+
head_mask: Optional[torch.Tensor] = None,
|
938 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
939 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
940 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
941 |
+
use_cache: Optional[bool] = None,
|
942 |
+
output_attentions: Optional[bool] = None,
|
943 |
+
output_hidden_states: Optional[bool] = None,
|
944 |
+
return_dict: Optional[bool] = None,
|
945 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
946 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
947 |
+
output_hidden_states = (
|
948 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
949 |
+
)
|
950 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
951 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
952 |
+
|
953 |
+
if input_ids is not None and inputs_embeds is not None:
|
954 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
955 |
+
elif input_ids is not None:
|
956 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
957 |
+
input_shape = input_ids.size()
|
958 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
959 |
+
batch_size = input_ids.shape[0]
|
960 |
+
elif inputs_embeds is not None:
|
961 |
+
input_shape = inputs_embeds.size()[:-1]
|
962 |
+
batch_size = inputs_embeds.shape[0]
|
963 |
+
else:
|
964 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
965 |
+
|
966 |
+
if batch_size <= 0:
|
967 |
+
raise ValueError("batch_size has to be defined and > 0")
|
968 |
+
|
969 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
970 |
+
|
971 |
+
if token_type_ids is not None:
|
972 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
973 |
+
|
974 |
+
if past_key_values is None:
|
975 |
+
past_length = 0
|
976 |
+
past_key_values = tuple([None] * len(self.h))
|
977 |
+
else:
|
978 |
+
past_length = past_key_values[0].size(-2)
|
979 |
+
|
980 |
+
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
|
981 |
+
# create position_ids on the fly for batch generation
|
982 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
983 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
984 |
+
if past_length > 0:
|
985 |
+
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
|
986 |
+
elif position_ids is None:
|
987 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
988 |
+
position_ids = position_ids.unsqueeze(0)
|
989 |
+
|
990 |
+
# Self-attention mask.
|
991 |
+
query_length = input_shape[-1]
|
992 |
+
key_length = past_length + query_length
|
993 |
+
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
994 |
+
|
995 |
+
if self._use_flash_attention_2:
|
996 |
+
# 2d mask is passed through the layers
|
997 |
+
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
|
998 |
+
encoder_attention_mask = (
|
999 |
+
encoder_attention_mask.bool()
|
1000 |
+
if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
|
1001 |
+
else None
|
1002 |
+
)
|
1003 |
+
else:
|
1004 |
+
# 4d mask is passed through the layers
|
1005 |
+
if attention_mask is not None:
|
1006 |
+
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
|
1007 |
+
dtype=torch.bool, device=self_attention_mask.device
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
# MQA models: (batch_size, query_length, n_heads, key_length)
|
1011 |
+
# MHA models: (batch_size, n_heads, query_length, key_length)
|
1012 |
+
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
|
1013 |
+
|
1014 |
+
if self._use_sdpa and head_mask is None and not output_attentions:
|
1015 |
+
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
|
1016 |
+
dtype = self.wte.weight.dtype
|
1017 |
+
min_dtype = torch.finfo(dtype).min
|
1018 |
+
self_attention_mask = torch.where(
|
1019 |
+
self_attention_mask,
|
1020 |
+
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
|
1021 |
+
torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device),
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1025 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1026 |
+
if self.multi_query:
|
1027 |
+
# gpt_bigcode using MQA has the bad taste to use a causal mask with shape
|
1028 |
+
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
|
1029 |
+
self_attention_mask = self_attention_mask.transpose(1, 2)
|
1030 |
+
|
1031 |
+
if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda":
|
1032 |
+
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
1033 |
+
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
1034 |
+
self_attention_mask = AttentionMaskConverter._unmask_unattended(
|
1035 |
+
self_attention_mask, min_dtype=min_dtype
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
attention_mask = self_attention_mask
|
1039 |
+
|
1040 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
1041 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
1042 |
+
if (
|
1043 |
+
self.config.add_cross_attention
|
1044 |
+
and encoder_hidden_states is not None
|
1045 |
+
and encoder_attention_mask is not None
|
1046 |
+
):
|
1047 |
+
if encoder_attention_mask.dim() == 2:
|
1048 |
+
encoder_attention_mask.unsqueeze(1)
|
1049 |
+
assert encoder_attention_mask.dim() == 3
|
1050 |
+
encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
|
1051 |
+
else:
|
1052 |
+
encoder_attention_mask = None
|
1053 |
+
|
1054 |
+
# Prepare head mask if needed
|
1055 |
+
# 1.0 in head_mask indicate we keep the head
|
1056 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1057 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
1058 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
1059 |
+
|
1060 |
+
if inputs_embeds is None:
|
1061 |
+
inputs_embeds = self.wte(input_ids)
|
1062 |
+
position_embeds = self.wpe(position_ids)
|
1063 |
+
hidden_states = inputs_embeds + position_embeds
|
1064 |
+
|
1065 |
+
if token_type_ids is not None:
|
1066 |
+
token_type_embeds = self.wte(token_type_ids)
|
1067 |
+
hidden_states = hidden_states + token_type_embeds
|
1068 |
+
|
1069 |
+
hidden_states = self.drop(hidden_states)
|
1070 |
+
|
1071 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
1072 |
+
|
1073 |
+
presents = [] if use_cache else None
|
1074 |
+
all_self_attentions = () if output_attentions else None
|
1075 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
1076 |
+
all_hidden_states = () if output_hidden_states else None
|
1077 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
1078 |
+
if output_hidden_states:
|
1079 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1080 |
+
|
1081 |
+
if self.gradient_checkpointing and self.training:
|
1082 |
+
outputs = self._gradient_checkpointing_func(
|
1083 |
+
block.__call__,
|
1084 |
+
hidden_states,
|
1085 |
+
None,
|
1086 |
+
attention_mask,
|
1087 |
+
head_mask[i],
|
1088 |
+
encoder_hidden_states,
|
1089 |
+
encoder_attention_mask,
|
1090 |
+
use_cache,
|
1091 |
+
output_attentions,
|
1092 |
+
)
|
1093 |
+
else:
|
1094 |
+
outputs = block(
|
1095 |
+
hidden_states,
|
1096 |
+
layer_past=layer_past,
|
1097 |
+
attention_mask=attention_mask,
|
1098 |
+
head_mask=head_mask[i],
|
1099 |
+
encoder_hidden_states=encoder_hidden_states,
|
1100 |
+
encoder_attention_mask=encoder_attention_mask,
|
1101 |
+
use_cache=use_cache,
|
1102 |
+
output_attentions=output_attentions,
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
hidden_states = outputs[0]
|
1106 |
+
if use_cache:
|
1107 |
+
presents.append(outputs[1])
|
1108 |
+
|
1109 |
+
if output_attentions:
|
1110 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
1111 |
+
if self.config.add_cross_attention:
|
1112 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
1113 |
+
|
1114 |
+
hidden_states = self.ln_f(hidden_states)
|
1115 |
+
|
1116 |
+
hidden_states = hidden_states.view(output_shape)
|
1117 |
+
# Add last hidden state
|
1118 |
+
if output_hidden_states:
|
1119 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1120 |
+
|
1121 |
+
if not return_dict:
|
1122 |
+
return tuple(
|
1123 |
+
v
|
1124 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
1125 |
+
if v is not None
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1129 |
+
last_hidden_state=hidden_states,
|
1130 |
+
past_key_values=presents,
|
1131 |
+
hidden_states=all_hidden_states,
|
1132 |
+
attentions=all_self_attentions,
|
1133 |
+
cross_attentions=all_cross_attentions,
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
|
1137 |
+
@add_start_docstrings(
|
1138 |
+
"""
|
1139 |
+
The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
1140 |
+
embeddings).
|
1141 |
+
""",
|
1142 |
+
GPT_BIGCODE_START_DOCSTRING,
|
1143 |
+
)
|
1144 |
+
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
1145 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1146 |
+
|
1147 |
+
def __init__(self, config):
|
1148 |
+
super().__init__(config)
|
1149 |
+
self.transformer = GPTBigCodeModel(config)
|
1150 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1151 |
+
|
1152 |
+
# Initialize weights and apply final processing
|
1153 |
+
self.post_init()
|
1154 |
+
|
1155 |
+
def get_output_embeddings(self):
|
1156 |
+
return self.lm_head
|
1157 |
+
|
1158 |
+
def set_output_embeddings(self, new_embeddings):
|
1159 |
+
self.lm_head = new_embeddings
|
1160 |
+
|
1161 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
1162 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1163 |
+
# Omit tokens covered by past_key_values
|
1164 |
+
if past_key_values:
|
1165 |
+
if self.config.multi_query:
|
1166 |
+
past_length = past_key_values[0].shape[1]
|
1167 |
+
else:
|
1168 |
+
past_length = past_key_values[0].shape[2]
|
1169 |
+
|
1170 |
+
# Some generation methods already pass only the last input ID
|
1171 |
+
if input_ids.shape[1] > past_length:
|
1172 |
+
remove_prefix_length = past_length
|
1173 |
+
else:
|
1174 |
+
# Default to old behavior: keep only final ID
|
1175 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
1176 |
+
|
1177 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
1178 |
+
if token_type_ids is not None:
|
1179 |
+
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
1180 |
+
|
1181 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1182 |
+
position_ids = kwargs.get("position_ids", None)
|
1183 |
+
|
1184 |
+
if attention_mask is not None and position_ids is None:
|
1185 |
+
# create position_ids on the fly for batch generation
|
1186 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1187 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1188 |
+
if past_key_values:
|
1189 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1190 |
+
else:
|
1191 |
+
position_ids = None
|
1192 |
+
|
1193 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1194 |
+
if inputs_embeds is not None and past_key_values is None:
|
1195 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1196 |
+
else:
|
1197 |
+
model_inputs = {"input_ids": input_ids}
|
1198 |
+
|
1199 |
+
model_inputs.update(
|
1200 |
+
{
|
1201 |
+
"past_key_values": past_key_values,
|
1202 |
+
"use_cache": kwargs.get("use_cache"),
|
1203 |
+
"position_ids": position_ids,
|
1204 |
+
"attention_mask": attention_mask,
|
1205 |
+
"token_type_ids": token_type_ids,
|
1206 |
+
}
|
1207 |
+
)
|
1208 |
+
return model_inputs
|
1209 |
+
|
1210 |
+
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
1211 |
+
@add_code_sample_docstrings(
|
1212 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1213 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1214 |
+
config_class=_CONFIG_FOR_DOC,
|
1215 |
+
)
|
1216 |
+
def forward(
|
1217 |
+
self,
|
1218 |
+
input_ids: Optional[torch.Tensor] = None,
|
1219 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1220 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1221 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1222 |
+
position_ids: Optional[torch.Tensor] = None,
|
1223 |
+
head_mask: Optional[torch.Tensor] = None,
|
1224 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1225 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1226 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1227 |
+
labels: Optional[torch.Tensor] = None,
|
1228 |
+
use_cache: Optional[bool] = None,
|
1229 |
+
output_attentions: Optional[bool] = None,
|
1230 |
+
output_hidden_states: Optional[bool] = None,
|
1231 |
+
return_dict: Optional[bool] = None,
|
1232 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
1233 |
+
r"""
|
1234 |
+
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1235 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1236 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1237 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1238 |
+
"""
|
1239 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1240 |
+
|
1241 |
+
transformer_outputs = self.transformer(
|
1242 |
+
input_ids,
|
1243 |
+
past_key_values=past_key_values,
|
1244 |
+
attention_mask=attention_mask,
|
1245 |
+
token_type_ids=token_type_ids,
|
1246 |
+
position_ids=position_ids,
|
1247 |
+
head_mask=head_mask,
|
1248 |
+
inputs_embeds=inputs_embeds,
|
1249 |
+
encoder_hidden_states=encoder_hidden_states,
|
1250 |
+
encoder_attention_mask=encoder_attention_mask,
|
1251 |
+
use_cache=use_cache,
|
1252 |
+
output_attentions=output_attentions,
|
1253 |
+
output_hidden_states=output_hidden_states,
|
1254 |
+
return_dict=return_dict,
|
1255 |
+
)
|
1256 |
+
hidden_states = transformer_outputs[0]
|
1257 |
+
|
1258 |
+
lm_logits = self.lm_head(hidden_states)
|
1259 |
+
|
1260 |
+
loss = None
|
1261 |
+
if labels is not None:
|
1262 |
+
# Shift so that tokens < n predict n
|
1263 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1264 |
+
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
1265 |
+
# Flatten the tokens
|
1266 |
+
loss_fct = CrossEntropyLoss()
|
1267 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1268 |
+
|
1269 |
+
if not return_dict:
|
1270 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1271 |
+
return ((loss,) + output) if loss is not None else output
|
1272 |
+
|
1273 |
+
return CausalLMOutputWithCrossAttentions(
|
1274 |
+
loss=loss,
|
1275 |
+
logits=lm_logits,
|
1276 |
+
past_key_values=transformer_outputs.past_key_values,
|
1277 |
+
hidden_states=transformer_outputs.hidden_states,
|
1278 |
+
attentions=transformer_outputs.attentions,
|
1279 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1280 |
+
)
|
1281 |
+
|
1282 |
+
@staticmethod
|
1283 |
+
def _reorder_cache(
|
1284 |
+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
1285 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
1286 |
+
"""
|
1287 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1288 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1289 |
+
beam_idx at every generation step.
|
1290 |
+
"""
|
1291 |
+
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|
1292 |
+
|
1293 |
+
|
1294 |
+
@add_start_docstrings(
|
1295 |
+
"""
|
1296 |
+
The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
|
1297 |
+
|
1298 |
+
[`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
1299 |
+
models (e.g. GPT-1) do.
|
1300 |
+
|
1301 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1302 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1303 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1304 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1305 |
+
each row of the batch).
|
1306 |
+
""",
|
1307 |
+
GPT_BIGCODE_START_DOCSTRING,
|
1308 |
+
)
|
1309 |
+
class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
|
1310 |
+
def __init__(self, config):
|
1311 |
+
super().__init__(config)
|
1312 |
+
self.num_labels = config.num_labels
|
1313 |
+
self.transformer = GPTBigCodeModel(config)
|
1314 |
+
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1315 |
+
|
1316 |
+
# Initialize weights and apply final processing
|
1317 |
+
self.post_init()
|
1318 |
+
|
1319 |
+
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
1320 |
+
def forward(
|
1321 |
+
self,
|
1322 |
+
input_ids: Optional[torch.Tensor] = None,
|
1323 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1324 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1325 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1326 |
+
position_ids: Optional[torch.Tensor] = None,
|
1327 |
+
head_mask: Optional[torch.Tensor] = None,
|
1328 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1329 |
+
labels: Optional[torch.Tensor] = None,
|
1330 |
+
use_cache: Optional[bool] = None,
|
1331 |
+
output_attentions: Optional[bool] = None,
|
1332 |
+
output_hidden_states: Optional[bool] = None,
|
1333 |
+
return_dict: Optional[bool] = None,
|
1334 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1335 |
+
r"""
|
1336 |
+
labels (`torch.Tensor` of shape `(batch_size,)`, *optional*):
|
1337 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1338 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1339 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1340 |
+
"""
|
1341 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1342 |
+
|
1343 |
+
transformer_outputs = self.transformer(
|
1344 |
+
input_ids,
|
1345 |
+
past_key_values=past_key_values,
|
1346 |
+
attention_mask=attention_mask,
|
1347 |
+
token_type_ids=token_type_ids,
|
1348 |
+
position_ids=position_ids,
|
1349 |
+
head_mask=head_mask,
|
1350 |
+
inputs_embeds=inputs_embeds,
|
1351 |
+
use_cache=use_cache,
|
1352 |
+
output_attentions=output_attentions,
|
1353 |
+
output_hidden_states=output_hidden_states,
|
1354 |
+
return_dict=return_dict,
|
1355 |
+
)
|
1356 |
+
hidden_states = transformer_outputs[0]
|
1357 |
+
logits = self.score(hidden_states)
|
1358 |
+
|
1359 |
+
if input_ids is not None:
|
1360 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
1361 |
+
else:
|
1362 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
1363 |
+
|
1364 |
+
assert (
|
1365 |
+
self.config.pad_token_id is not None or batch_size == 1
|
1366 |
+
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
1367 |
+
if self.config.pad_token_id is None:
|
1368 |
+
sequence_lengths = -1
|
1369 |
+
else:
|
1370 |
+
if input_ids is not None:
|
1371 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1372 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1373 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1374 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1375 |
+
else:
|
1376 |
+
sequence_lengths = -1
|
1377 |
+
logger.warning(
|
1378 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
1379 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
1380 |
+
)
|
1381 |
+
|
1382 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1383 |
+
|
1384 |
+
loss = None
|
1385 |
+
if labels is not None:
|
1386 |
+
labels = labels.to(logits.device)
|
1387 |
+
|
1388 |
+
if self.config.problem_type is None:
|
1389 |
+
if self.num_labels == 1:
|
1390 |
+
self.config.problem_type = "regression"
|
1391 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1392 |
+
self.config.problem_type = "single_label_classification"
|
1393 |
+
else:
|
1394 |
+
self.config.problem_type = "multi_label_classification"
|
1395 |
+
|
1396 |
+
if self.config.problem_type == "regression":
|
1397 |
+
loss_fct = MSELoss()
|
1398 |
+
if self.num_labels == 1:
|
1399 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1400 |
+
else:
|
1401 |
+
loss = loss_fct(pooled_logits, labels)
|
1402 |
+
elif self.config.problem_type == "single_label_classification":
|
1403 |
+
loss_fct = CrossEntropyLoss()
|
1404 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1405 |
+
elif self.config.problem_type == "multi_label_classification":
|
1406 |
+
loss_fct = BCEWithLogitsLoss()
|
1407 |
+
loss = loss_fct(pooled_logits, labels)
|
1408 |
+
if not return_dict:
|
1409 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1410 |
+
return ((loss,) + output) if loss is not None else output
|
1411 |
+
|
1412 |
+
return SequenceClassifierOutputWithPast(
|
1413 |
+
loss=loss,
|
1414 |
+
logits=pooled_logits,
|
1415 |
+
past_key_values=transformer_outputs.past_key_values,
|
1416 |
+
hidden_states=transformer_outputs.hidden_states,
|
1417 |
+
attentions=transformer_outputs.attentions,
|
1418 |
+
)
|
1419 |
+
|
1420 |
+
|
1421 |
+
@add_start_docstrings(
|
1422 |
+
"""
|
1423 |
+
GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
1424 |
+
for Named-Entity-Recognition (NER) tasks.
|
1425 |
+
""",
|
1426 |
+
GPT_BIGCODE_START_DOCSTRING,
|
1427 |
+
)
|
1428 |
+
class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
|
1429 |
+
def __init__(self, config):
|
1430 |
+
super().__init__(config)
|
1431 |
+
self.num_labels = config.num_labels
|
1432 |
+
|
1433 |
+
self.transformer = GPTBigCodeModel(config)
|
1434 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1435 |
+
classifier_dropout = config.classifier_dropout
|
1436 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1437 |
+
classifier_dropout = config.hidden_dropout
|
1438 |
+
else:
|
1439 |
+
classifier_dropout = 0.1
|
1440 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1441 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1442 |
+
|
1443 |
+
# Initialize weights and apply final processing
|
1444 |
+
self.post_init()
|
1445 |
+
|
1446 |
+
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
1447 |
+
def forward(
|
1448 |
+
self,
|
1449 |
+
input_ids: Optional[torch.Tensor] = None,
|
1450 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1451 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1452 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1453 |
+
position_ids: Optional[torch.Tensor] = None,
|
1454 |
+
head_mask: Optional[torch.Tensor] = None,
|
1455 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1456 |
+
labels: Optional[torch.Tensor] = None,
|
1457 |
+
use_cache: Optional[bool] = None,
|
1458 |
+
output_attentions: Optional[bool] = None,
|
1459 |
+
output_hidden_states: Optional[bool] = None,
|
1460 |
+
return_dict: Optional[bool] = None,
|
1461 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1462 |
+
r"""
|
1463 |
+
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1464 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1465 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1466 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1467 |
+
"""
|
1468 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1469 |
+
|
1470 |
+
transformer_outputs = self.transformer(
|
1471 |
+
input_ids,
|
1472 |
+
past_key_values=past_key_values,
|
1473 |
+
attention_mask=attention_mask,
|
1474 |
+
token_type_ids=token_type_ids,
|
1475 |
+
position_ids=position_ids,
|
1476 |
+
head_mask=head_mask,
|
1477 |
+
inputs_embeds=inputs_embeds,
|
1478 |
+
use_cache=use_cache,
|
1479 |
+
output_attentions=output_attentions,
|
1480 |
+
output_hidden_states=output_hidden_states,
|
1481 |
+
return_dict=return_dict,
|
1482 |
+
)
|
1483 |
+
|
1484 |
+
hidden_states = transformer_outputs[0]
|
1485 |
+
hidden_states = self.dropout(hidden_states)
|
1486 |
+
logits = self.classifier(hidden_states)
|
1487 |
+
|
1488 |
+
loss = None
|
1489 |
+
if labels is not None:
|
1490 |
+
loss_fct = CrossEntropyLoss()
|
1491 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device))
|
1492 |
+
|
1493 |
+
if not return_dict:
|
1494 |
+
output = (logits,) + transformer_outputs[2:]
|
1495 |
+
return ((loss,) + output) if loss is not None else output
|
1496 |
+
|
1497 |
+
return TokenClassifierOutput(
|
1498 |
+
loss=loss,
|
1499 |
+
logits=logits,
|
1500 |
+
hidden_states=transformer_outputs.hidden_states,
|
1501 |
+
attentions=transformer_outputs.attentions,
|
1502 |
+
)
|
starvector/model/image_encoder/clip_model.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
from itertools import repeat
|
5 |
+
import collections.abc
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
11 |
+
|
12 |
+
def convert_weights_to_precision(model: nn.Module, precision: torch.dtype):
|
13 |
+
"""Convert applicable model parameters to the specified precision"""
|
14 |
+
|
15 |
+
def _convert_weights_to_precision(l):
|
16 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
17 |
+
l.weight.data = l.weight.data.to(precision)
|
18 |
+
if l.bias is not None:
|
19 |
+
l.bias.data = l.bias.data.to(precision)
|
20 |
+
|
21 |
+
elif isinstance(l, (nn.MultiheadAttention)):
|
22 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
23 |
+
tensor = getattr(l, attr)
|
24 |
+
if tensor is not None:
|
25 |
+
tensor.data = tensor.data.to(precision)
|
26 |
+
else:
|
27 |
+
for _, p in l.named_parameters():
|
28 |
+
p.data = p.data.to(precision)
|
29 |
+
|
30 |
+
model.apply(_convert_weights_to_precision)
|
31 |
+
|
32 |
+
class Bottleneck(nn.Module):
|
33 |
+
expansion = 4
|
34 |
+
|
35 |
+
def __init__(self, inplanes, planes, stride=1):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
39 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
40 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
41 |
+
self.relu1 = nn.ReLU(inplace=True)
|
42 |
+
|
43 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
44 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
45 |
+
self.relu2 = nn.ReLU(inplace=True)
|
46 |
+
|
47 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
48 |
+
|
49 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
50 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
51 |
+
self.relu3 = nn.ReLU(inplace=True)
|
52 |
+
|
53 |
+
self.downsample = None
|
54 |
+
self.stride = stride
|
55 |
+
|
56 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
57 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
58 |
+
self.downsample = nn.Sequential(OrderedDict([
|
59 |
+
("-1", nn.AvgPool2d(stride)),
|
60 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
61 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
62 |
+
]))
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor):
|
65 |
+
identity = x
|
66 |
+
|
67 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
68 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
69 |
+
out = self.avgpool(out)
|
70 |
+
out = self.bn3(self.conv3(out))
|
71 |
+
|
72 |
+
if self.downsample is not None:
|
73 |
+
identity = self.downsample(x)
|
74 |
+
|
75 |
+
out += identity
|
76 |
+
out = self.relu3(out)
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
class AttentionPool2d(nn.Module):
|
81 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
82 |
+
super().__init__()
|
83 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
84 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
85 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
86 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
87 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
88 |
+
self.num_heads = num_heads
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
92 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
93 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
94 |
+
x, _ = F.multi_head_attention_forward(
|
95 |
+
query=x, key=x, value=x,
|
96 |
+
embed_dim_to_check=x.shape[-1],
|
97 |
+
num_heads=self.num_heads,
|
98 |
+
q_proj_weight=self.q_proj.weight,
|
99 |
+
k_proj_weight=self.k_proj.weight,
|
100 |
+
v_proj_weight=self.v_proj.weight,
|
101 |
+
in_proj_weight=None,
|
102 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
103 |
+
bias_k=None,
|
104 |
+
bias_v=None,
|
105 |
+
add_zero_attn=False,
|
106 |
+
dropout_p=0,
|
107 |
+
out_proj_weight=self.c_proj.weight,
|
108 |
+
out_proj_bias=self.c_proj.bias,
|
109 |
+
use_separate_proj_weight=True,
|
110 |
+
training=self.training,
|
111 |
+
need_weights=False
|
112 |
+
)
|
113 |
+
|
114 |
+
return x[0]
|
115 |
+
|
116 |
+
|
117 |
+
class LayerNorm(nn.LayerNorm):
|
118 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor):
|
121 |
+
orig_type = x.dtype
|
122 |
+
layernorm_dtype = self.weight.dtype
|
123 |
+
ret = super().forward(x.type(layernorm_dtype))
|
124 |
+
return ret.type(orig_type)
|
125 |
+
|
126 |
+
class QuickGELU(nn.Module):
|
127 |
+
def forward(self, x: torch.Tensor):
|
128 |
+
return x * torch.sigmoid(1.702 * x)
|
129 |
+
|
130 |
+
class ResidualAttentionBlock(nn.Module):
|
131 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
135 |
+
self.ln_1 = LayerNorm(d_model)
|
136 |
+
self.mlp = nn.Sequential(OrderedDict([
|
137 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
138 |
+
("gelu", QuickGELU()),
|
139 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
140 |
+
]))
|
141 |
+
self.ln_2 = LayerNorm(d_model)
|
142 |
+
self.attn_mask = attn_mask
|
143 |
+
|
144 |
+
if use_grad_checkpointing:
|
145 |
+
self.attn = checkpoint_wrapper(self.attn)
|
146 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
147 |
+
|
148 |
+
def attention(self, x: torch.Tensor):
|
149 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
150 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
151 |
+
|
152 |
+
def forward(self, x: torch.Tensor):
|
153 |
+
x = x + self.attention(self.ln_1(x))
|
154 |
+
x = x + self.mlp(self.ln_2(x))
|
155 |
+
return x
|
156 |
+
|
157 |
+
class Transformer(nn.Module):
|
158 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
159 |
+
super().__init__()
|
160 |
+
self.width = width
|
161 |
+
self.layers = layers
|
162 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor):
|
165 |
+
return self.resblocks(x)
|
166 |
+
|
167 |
+
class VisionTransformer(nn.Module):
|
168 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
|
169 |
+
super().__init__()
|
170 |
+
self.input_resolution = input_resolution
|
171 |
+
self.num_features = width
|
172 |
+
self.num_heads = heads
|
173 |
+
self.num_patches = (input_resolution // patch_size) ** 2
|
174 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
175 |
+
scale = width ** -0.5
|
176 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
177 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
|
178 |
+
self.ln_pre = LayerNorm(width)
|
179 |
+
self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
|
180 |
+
|
181 |
+
def forward(self, x: torch.Tensor):
|
182 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
183 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
184 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
185 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
186 |
+
x = x + self.positional_embedding.to(x.dtype)
|
187 |
+
x = self.ln_pre(x)
|
188 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
189 |
+
x = self.transformer(x)
|
190 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
191 |
+
return x
|
starvector/model/image_encoder/image_encoder.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from starvector.model.image_encoder.clip_model import convert_weights_to_precision
|
7 |
+
from starvector.data.util import ImageTrainProcessor
|
8 |
+
|
9 |
+
class ImageEncoder(nn.Module):
|
10 |
+
def __init__(self, config, **kwargs):
|
11 |
+
super(ImageEncoder, self).__init__()
|
12 |
+
|
13 |
+
image_size = config.image_size
|
14 |
+
torch_dtype = kwargs.get('model_precision', config.torch_dtype)
|
15 |
+
# torch_dtype = torch.float32
|
16 |
+
self.image_encoder_type = config.image_encoder_type
|
17 |
+
if self.image_encoder_type == 'clip':
|
18 |
+
self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size)
|
19 |
+
convert_weights_to_precision(self, torch_dtype)
|
20 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
21 |
+
|
22 |
+
elif self.image_encoder_type == 'vqgan':
|
23 |
+
self.visual_encoder = self.build_vqgan_encoder()
|
24 |
+
self.ln_vision = None
|
25 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
26 |
+
|
27 |
+
elif self.image_encoder_type == 'convnext':
|
28 |
+
self.visual_encoder = self.build_vqgan_encoder()
|
29 |
+
self.ln_vision = None
|
30 |
+
self.processor = ImageTrainProcessor(size=config.image_size)
|
31 |
+
|
32 |
+
elif 'siglip' in self.image_encoder_type:
|
33 |
+
if self.image_encoder_type == 'siglip_512':
|
34 |
+
model_name = "google/siglip-base-patch16-512"
|
35 |
+
elif self.image_encoder_type == 'siglip_384':
|
36 |
+
model_name = "google/siglip-large-patch16-384"
|
37 |
+
elif self.image_encoder_type == 'siglip_256':
|
38 |
+
model_name = "google/siglip-base-patch16-256"
|
39 |
+
|
40 |
+
from transformers import AutoProcessor, AutoModel
|
41 |
+
|
42 |
+
self.visual_encoder = AutoModel.from_pretrained(
|
43 |
+
model_name, torch_dtype = torch_dtype
|
44 |
+
).vision_model
|
45 |
+
|
46 |
+
self.processor = AutoProcessor.from_pretrained(
|
47 |
+
model_name, torch_dtype = torch_dtype
|
48 |
+
)
|
49 |
+
|
50 |
+
def build_clip_encoder(self, image_size):
|
51 |
+
from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm
|
52 |
+
visual_encoder = VisionTransformer(
|
53 |
+
input_resolution=image_size,
|
54 |
+
patch_size=14,
|
55 |
+
width=1024,
|
56 |
+
layers=23,
|
57 |
+
heads=16,
|
58 |
+
use_grad_checkpointing=False)
|
59 |
+
|
60 |
+
ln_vision = LayerNorm(visual_encoder.num_features)
|
61 |
+
return visual_encoder, ln_vision
|
62 |
+
|
63 |
+
def build_vqgan_encoder(self):
|
64 |
+
from taming.modules.diffusionmodules.model import Encoder
|
65 |
+
VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md
|
66 |
+
vqgan_chkp_path = VQGAN_CHECKPOINT
|
67 |
+
files_in_directory = os.listdir(vqgan_chkp_path + '/configs')
|
68 |
+
vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0]
|
69 |
+
vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file))
|
70 |
+
visual_encoder = Encoder(**vqgan_config.model.params.ddconfig)
|
71 |
+
|
72 |
+
# Load checkpoint weights
|
73 |
+
checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict']
|
74 |
+
|
75 |
+
# Create a new state_dict with modified keys
|
76 |
+
new_state_dict = {}
|
77 |
+
for key, value in checkpoint.items():
|
78 |
+
if key.startswith('encoder.'):
|
79 |
+
new_key = key[len('encoder.'):]
|
80 |
+
new_state_dict[new_key] = value
|
81 |
+
|
82 |
+
# Load weights
|
83 |
+
visual_encoder.load_state_dict(new_state_dict)
|
84 |
+
return visual_encoder
|
85 |
+
|
86 |
+
def build_convnext_encoder(self):
|
87 |
+
import open_clip
|
88 |
+
model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
|
89 |
+
return model.visual
|
90 |
+
|
91 |
+
def forward(self, image):
|
92 |
+
if self.image_encoder_type == 'clip':
|
93 |
+
embeds = self.visual_encoder(image)
|
94 |
+
out = self.ln_vision(embeds)
|
95 |
+
elif self.image_encoder_type == 'open-clip':
|
96 |
+
out = self.visual_encoder(image)[1]
|
97 |
+
out = self.ln_vision(out)
|
98 |
+
elif self.image_encoder_type == 'vqgan':
|
99 |
+
out = self.visual_encoder(image)
|
100 |
+
size = out.size()
|
101 |
+
out = out.view(size[0], size[1], -1)
|
102 |
+
out = out.permute(0, 2, 1)
|
103 |
+
elif self.image_encoder_type == 'convnext':
|
104 |
+
out = self.visual_encoder.trunk.forward_features(image)
|
105 |
+
size = out.size()
|
106 |
+
out = out.view(size[0], size[1], -1)
|
107 |
+
out = out.permute(0, 2, 1)
|
108 |
+
elif 'siglip' in self.image_encoder_type:
|
109 |
+
out = self.visual_encoder(image)["last_hidden_state"]
|
110 |
+
return out
|
111 |
+
|
112 |
+
def process_images(self, images):
|
113 |
+
if self.image_encoder_type == 'clip':
|
114 |
+
res = []
|
115 |
+
for image in images:
|
116 |
+
res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W
|
117 |
+
return res
|
118 |
+
else:
|
119 |
+
return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0)
|
120 |
+
|
starvector/model/llm/starcoder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import (
|
3 |
+
AutoConfig,
|
4 |
+
AutoModelForCausalLM,
|
5 |
+
AutoTokenizer,
|
6 |
+
)
|
7 |
+
|
8 |
+
class StarCoderModel(nn.Module):
|
9 |
+
def __init__(self, config, **kwargs):
|
10 |
+
super(StarCoderModel, self).__init__()
|
11 |
+
|
12 |
+
self.init_tokenizer(config.starcoder_model_name)
|
13 |
+
|
14 |
+
self.max_length = config.max_length
|
15 |
+
model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True)
|
16 |
+
kwargs = {}
|
17 |
+
kwargs['trust_remote_code'] = True
|
18 |
+
kwargs['torch_dtype'] = config.torch_dtype
|
19 |
+
|
20 |
+
# Configure special tokens for generation
|
21 |
+
model_config.eos_token_id = self.tokenizer.eos_token_id
|
22 |
+
model_config.pad_token_id = self.tokenizer.pad_token_id
|
23 |
+
model_config.bos_token_id = self.tokenizer.bos_token_id
|
24 |
+
try:
|
25 |
+
model_config.flash_attention = config.use_flash_attn
|
26 |
+
model_config._attn_implementation = "flash_attention_2"
|
27 |
+
except ImportError:
|
28 |
+
config.use_flash_attn = False
|
29 |
+
|
30 |
+
# model = GPTBigCodeForCausalLM(config=model_config)
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs)
|
32 |
+
model.resize_token_embeddings(len(self.tokenizer))
|
33 |
+
self.transformer = model
|
34 |
+
|
35 |
+
# Prompt the model after image
|
36 |
+
self.prompt = '<svg'
|
37 |
+
|
38 |
+
def init_tokenizer(self, model_name):
|
39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
40 |
+
# Incude padding and eos tokens in the vocabulary
|
41 |
+
if self.tokenizer.eos_token_id is None:
|
42 |
+
self.tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
43 |
+
if self.tokenizer.pad_token_id is None:
|
44 |
+
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
45 |
+
|
46 |
+
self.svg_start_token = "<svg-start>"
|
47 |
+
self.image_start_token = "<image-start>"
|
48 |
+
self.text_start_token = "<caption-start>"
|
49 |
+
|
50 |
+
self.tokenizer.add_tokens([self.svg_start_token, self.image_start_token, self.text_start_token])
|
51 |
+
self.svg_start_token_id = self.tokenizer.encode(self.svg_start_token)[0]
|
starvector/model/llm/starcoder2.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import (
|
3 |
+
AutoConfig,
|
4 |
+
AutoModelForCausalLM,
|
5 |
+
AutoTokenizer,
|
6 |
+
)
|
7 |
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
8 |
+
from functools import partial
|
9 |
+
from starvector.train.util import get_module_class_from_name
|
10 |
+
import torch
|
11 |
+
|
12 |
+
class StarCoderModel(nn.Module):
|
13 |
+
def __init__(self, config, **kwargs):
|
14 |
+
super(StarCoderModel, self).__init__()
|
15 |
+
|
16 |
+
self.init_tokenizer(config.starcoder_model_name)
|
17 |
+
|
18 |
+
self.max_length = config.max_length
|
19 |
+
model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True)
|
20 |
+
model_config.use_cache = config.use_cache
|
21 |
+
model_config.use_bfloat16 = True
|
22 |
+
model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
config.starcoder_model_name,
|
24 |
+
config=model_config,
|
25 |
+
attn_implementation="flash_attention_2",
|
26 |
+
torch_dtype=torch.bfloat16,
|
27 |
+
trust_remote_code=True)
|
28 |
+
model.resize_token_embeddings(len(self.tokenizer))
|
29 |
+
self.transformer = model
|
30 |
+
|
31 |
+
# Prompt the model after image
|
32 |
+
self.prompt = '<svg'
|
33 |
+
|
34 |
+
transformer_layer_cls = kwargs.get('transformer_layer_cls', 'Starcoder2DecoderLayer')
|
35 |
+
self.transformer_layer_cls = get_module_class_from_name(self, transformer_layer_cls)
|
36 |
+
|
37 |
+
def init_tokenizer(self, model_name):
|
38 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
39 |
+
# Incude padding and eos tokens in the vocabulary
|
40 |
+
if self.tokenizer.eos_token_id is None:
|
41 |
+
self.tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
42 |
+
if self.tokenizer.pad_token_id is None:
|
43 |
+
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
44 |
+
|
45 |
+
self.svg_start_token = "<svg-start>"
|
46 |
+
self.svg_end_token = "<svg-end>"
|
47 |
+
self.image_start_token = "<image-start>"
|
48 |
+
self.text_start_token = "<caption-start>"
|
49 |
+
|
50 |
+
self.tokenizer.add_tokens([self.svg_start_token, self.image_start_token, self.text_start_token, self.svg_end_token])
|
51 |
+
self.svg_start_token_id = self.tokenizer.encode(self.svg_start_token)[0]
|
52 |
+
self.svg_end_token_id = self.tokenizer.encode(self.svg_end_token)[0]
|
53 |
+
self.tokenizer.padding_side = "left"
|
54 |
+
|
55 |
+
def get_fsdp_wrapping_policy(self):
|
56 |
+
"""Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
|
57 |
+
transformer_block_policy = partial(
|
58 |
+
transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
|
59 |
+
)
|
60 |
+
|
61 |
+
return transformer_block_policy
|
starvector/model/models/starvector_base.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from starvector.model.adapters.adapter import Adapter
|
5 |
+
from starvector.model.image_encoder.image_encoder import ImageEncoder
|
6 |
+
from starvector.util import print_trainable_parameters
|
7 |
+
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
8 |
+
|
9 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
10 |
+
|
11 |
+
def __init__(self, stops=[]):
|
12 |
+
super().__init__() # Correct super() call
|
13 |
+
self.stops = stops
|
14 |
+
|
15 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
16 |
+
# Check if any of the stop sequences are in the input_ids
|
17 |
+
for stop_ids in self.stops:
|
18 |
+
if input_ids[0][-len(stop_ids):].tolist() == stop_ids:
|
19 |
+
return True
|
20 |
+
return False
|
21 |
+
|
22 |
+
class StarVectorBase(nn.Module, ABC):
|
23 |
+
def __init__(self, config, **kwargs):
|
24 |
+
super().__init__()
|
25 |
+
# Task-specific layers
|
26 |
+
self.task = kwargs.get('task', 'im2svg')
|
27 |
+
self.model_precision = kwargs.get('model_precision', config.torch_dtype)
|
28 |
+
# Build Code LLM (StarCoder)
|
29 |
+
self.svg_transformer = self._get_svg_transformer(config, **kwargs)
|
30 |
+
|
31 |
+
if self.use_image_encoder():
|
32 |
+
# Build Image Encoder
|
33 |
+
self.image_encoder = ImageEncoder(config, **kwargs)
|
34 |
+
|
35 |
+
# Build Adapter
|
36 |
+
self.image_projection = self.get_adapter(config, **kwargs).to(dtype=self.model_precision)
|
37 |
+
else:
|
38 |
+
self.query_length = 0
|
39 |
+
|
40 |
+
self.max_length = config.max_length_train - self.query_length - 4 # for added special tokens
|
41 |
+
|
42 |
+
self.train_image_encoder = kwargs.get('train_image_encoder', False)
|
43 |
+
self.train_LLM = kwargs.get('train_LLM', False)
|
44 |
+
self.train_connector = kwargs.get('train_connector', False)
|
45 |
+
|
46 |
+
# Freeze parameters
|
47 |
+
self.freze_parameters(self.train_image_encoder, self.train_LLM, self.train_connector)
|
48 |
+
print_trainable_parameters(self)
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def _get_svg_transformer(self, config, **kwargs):
|
52 |
+
"""Get SVG transformer model - implementation differs between versions"""
|
53 |
+
pass
|
54 |
+
|
55 |
+
def freze_parameters(self, train_image_encoder, train_LLM, train_connector):
|
56 |
+
"""V2 implementation of parameter freezing"""
|
57 |
+
if self.use_image_encoder():
|
58 |
+
for _, param in self.image_encoder.named_parameters():
|
59 |
+
param.requires_grad = train_image_encoder
|
60 |
+
|
61 |
+
# adapter trainable
|
62 |
+
for _, param in self.image_projection.named_parameters():
|
63 |
+
param.requires_grad = train_connector
|
64 |
+
|
65 |
+
for _, param in self.svg_transformer.named_parameters():
|
66 |
+
param.requires_grad = train_LLM
|
67 |
+
|
68 |
+
def use_image_encoder(self):
|
69 |
+
"""Determine if image encoder should be used based on task"""
|
70 |
+
return self.task == 'im2svg'
|
71 |
+
|
72 |
+
def get_adapter(self, config, **kwargs):
|
73 |
+
"""Get adapter layer for image projection"""
|
74 |
+
vision_hidden_size, self.query_length = self.get_hidden_size_and_query_length(config.image_encoder_type)
|
75 |
+
llm_hidden_size = self.svg_transformer.transformer.config.hidden_size
|
76 |
+
image_projection = Adapter(
|
77 |
+
vision_hidden_size,
|
78 |
+
llm_hidden_size,
|
79 |
+
adapter_norm=config.adapter_norm,
|
80 |
+
query_length=self.query_length,
|
81 |
+
dropout_prob=kwargs.get('dropout', 0.1)
|
82 |
+
)
|
83 |
+
return image_projection
|
84 |
+
|
85 |
+
def get_hidden_size_and_query_length(self, image_encoder_type):
|
86 |
+
"""Get hidden size and query length based on encoder type"""
|
87 |
+
if image_encoder_type == 'clip':
|
88 |
+
hidden_size = self.image_encoder.visual_encoder.num_features
|
89 |
+
query_length = 257
|
90 |
+
elif image_encoder_type == 'open-clip':
|
91 |
+
hidden_size = self.image_encoder.visual_encoder.transformer.width
|
92 |
+
query_length = 256
|
93 |
+
elif image_encoder_type == 'vqgan':
|
94 |
+
hidden_size = 256
|
95 |
+
query_length = 196
|
96 |
+
elif image_encoder_type == 'convnext':
|
97 |
+
hidden_size = 1024
|
98 |
+
query_length = 49
|
99 |
+
elif 'siglip' in image_encoder_type:
|
100 |
+
hidden_size = self.image_encoder.visual_encoder.head.mlp.fc2.out_features
|
101 |
+
if '512' in image_encoder_type:
|
102 |
+
query_length = 1024
|
103 |
+
elif '384' in image_encoder_type:
|
104 |
+
query_length = 576
|
105 |
+
|
106 |
+
return hidden_size, query_length
|
107 |
+
|
108 |
+
def _tokenize(self, text, max_length, device, add_special_tokens=True):
|
109 |
+
"""Common tokenization logic"""
|
110 |
+
tokens = self.svg_transformer.tokenizer(
|
111 |
+
text,
|
112 |
+
truncation=True,
|
113 |
+
add_special_tokens=add_special_tokens,
|
114 |
+
padding='longest',
|
115 |
+
max_length=max_length,
|
116 |
+
return_tensors="pt"
|
117 |
+
).to(device)
|
118 |
+
return tokens
|
119 |
+
|
120 |
+
def _create_targets(self, tokens):
|
121 |
+
"""Create targets with padding mask"""
|
122 |
+
target_mask = (tokens.input_ids == self.svg_transformer.tokenizer.pad_token_id)
|
123 |
+
return tokens.input_ids.masked_fill(target_mask, -100)
|
124 |
+
|
125 |
+
@abstractmethod
|
126 |
+
def _get_embeddings(self, input_ids):
|
127 |
+
"""Get embeddings from input ids - implementation differs between v1 and v2"""
|
128 |
+
pass
|
129 |
+
|
130 |
+
def embed_text_to_svg(self, batch, device):
|
131 |
+
"""Common text to SVG embedding logic"""
|
132 |
+
captions = batch["caption"]
|
133 |
+
svgs = batch["svg"]
|
134 |
+
samples = [captions[i] + self.svg_transformer.svg_start_token + svgs[i] + self.svg_transformer.tokenizer.eos_token
|
135 |
+
for i in range(len(captions))]
|
136 |
+
|
137 |
+
tokens = self._tokenize(samples, self.max_length, device)
|
138 |
+
targets = self._create_targets(tokens)
|
139 |
+
inputs_embeds = self._get_embeddings(tokens.input_ids)
|
140 |
+
|
141 |
+
return inputs_embeds, tokens.attention_mask, targets
|
142 |
+
|
143 |
+
def get_image_embeddings(self, batch, device):
|
144 |
+
"""Get image embeddings"""
|
145 |
+
image = batch["image"].to(dtype=self.model_precision)
|
146 |
+
embedded_image = self.image_encoder(image)
|
147 |
+
conditioning_embeds = self.image_projection(embedded_image)
|
148 |
+
return conditioning_embeds
|
149 |
+
|
150 |
+
def embed_im_to_svg(self, batch, device):
|
151 |
+
"""Common image to SVG embedding logic"""
|
152 |
+
# Process image
|
153 |
+
image = batch["image"].to(dtype=self.model_precision)
|
154 |
+
embedded_image = self.image_encoder(image)
|
155 |
+
conditioning_embeds = self.image_projection(embedded_image)
|
156 |
+
conditioning_embeds_att = torch.ones(conditioning_embeds.size()[:-1], dtype=torch.long).to(device)
|
157 |
+
|
158 |
+
# Get SVG text with appropriate end tokens (implemented by subclasses)
|
159 |
+
svg_text = self._get_svg_text(batch["svg"])
|
160 |
+
|
161 |
+
svg_tokens = self._tokenize(svg_text, self.max_length, device)
|
162 |
+
svg_tokens_embeds = self._get_embeddings(svg_tokens.input_ids)
|
163 |
+
|
164 |
+
inputs_embeds = torch.cat([conditioning_embeds, svg_tokens_embeds], dim=1)
|
165 |
+
|
166 |
+
svg_targets = self._create_targets(svg_tokens)
|
167 |
+
empty_targets = torch.ones(conditioning_embeds_att.size(), dtype=torch.long).to(device).fill_(-100)
|
168 |
+
targets = torch.cat([empty_targets, svg_targets], dim=1)
|
169 |
+
|
170 |
+
attention_mask = torch.cat([conditioning_embeds_att, svg_tokens.attention_mask], dim=1)
|
171 |
+
|
172 |
+
return inputs_embeds, attention_mask, targets
|
173 |
+
|
174 |
+
def forward(self, batch):
|
175 |
+
"""Forward pass"""
|
176 |
+
device = batch["image"].device
|
177 |
+
task = self.task
|
178 |
+
|
179 |
+
# Depending
|
180 |
+
if task == 'text2svg':
|
181 |
+
inputs_embeds, attention_mask, targets = self.embed_text_to_svg(batch, device)
|
182 |
+
elif task == 'im2svg':
|
183 |
+
inputs_embeds, attention_mask, targets = self.embed_im_to_svg(batch, device)
|
184 |
+
|
185 |
+
outputs = self.svg_transformer.transformer(
|
186 |
+
inputs_embeds=inputs_embeds,
|
187 |
+
attention_mask=attention_mask,
|
188 |
+
labels=targets,
|
189 |
+
return_dict=True,
|
190 |
+
output_hidden_states=True,
|
191 |
+
use_cache=False,
|
192 |
+
)
|
193 |
+
loss = outputs.loss
|
194 |
+
return loss
|
195 |
+
|
196 |
+
|
197 |
+
@abstractmethod
|
198 |
+
def _get_svg_text(self, svg_list):
|
199 |
+
"""Get SVG text with appropriate end tokens - implementation differs between v1 and v2"""
|
200 |
+
pass
|
201 |
+
|
202 |
+
|
203 |
+
def _prepare_generation_inputs(self, batch, prompt, device):
|
204 |
+
"""Common preparation for generation inputs"""
|
205 |
+
image = batch["image"]
|
206 |
+
image = image.to(device).to(self.model_precision)
|
207 |
+
|
208 |
+
embedded_image = self.image_encoder(image)
|
209 |
+
embedded_image = self.image_projection(embedded_image)
|
210 |
+
embedded_att = torch.ones(embedded_image.size()[:-1], dtype=torch.long).to(device)
|
211 |
+
|
212 |
+
if prompt is None:
|
213 |
+
prompt = self.svg_transformer.prompt
|
214 |
+
prompt = [prompt] * image.size(0)
|
215 |
+
|
216 |
+
prompt_tokens = self._tokenize(prompt, None, device, add_special_tokens=False)
|
217 |
+
attention_mask = torch.cat([embedded_att, prompt_tokens.attention_mask], dim=1)
|
218 |
+
inputs_embeds = self._get_embeddings(prompt_tokens.input_ids)
|
219 |
+
inputs_embeds = torch.cat([embedded_image, inputs_embeds], dim=1)
|
220 |
+
|
221 |
+
return inputs_embeds, attention_mask, prompt_tokens
|
222 |
+
|
223 |
+
def _get_generation_kwargs(self, base_kwargs):
|
224 |
+
"""Common generation kwargs preparation"""
|
225 |
+
# Get token IDs for "</svg>"
|
226 |
+
end_sequence = self.svg_transformer.tokenizer("</svg>", add_special_tokens=False)['input_ids']
|
227 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[end_sequence])])
|
228 |
+
return {
|
229 |
+
'inputs_embeds': base_kwargs['inputs_embeds'],
|
230 |
+
'attention_mask': base_kwargs['attention_mask'],
|
231 |
+
'do_sample': base_kwargs.get('use_nucleus_sampling', True),
|
232 |
+
'top_p': base_kwargs.get('top_p', 0.9),
|
233 |
+
'temperature': base_kwargs.get('temperature', 1),
|
234 |
+
'num_beams': base_kwargs.get('num_beams', 2),
|
235 |
+
'max_length': base_kwargs.get('max_length', 30),
|
236 |
+
'min_length': base_kwargs.get('min_length', 1),
|
237 |
+
'repetition_penalty': base_kwargs.get('repetition_penalty', 1.0),
|
238 |
+
'length_penalty': base_kwargs.get('length_penalty', 1.0),
|
239 |
+
'use_cache': base_kwargs.get('use_cache', True),
|
240 |
+
'stopping_criteria': stopping_criteria
|
241 |
+
}
|
242 |
+
|
243 |
+
def generate_im2svg(self, batch, **kwargs):
|
244 |
+
"""Base implementation of image to SVG generation"""
|
245 |
+
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
|
246 |
+
batch, kwargs.get('prompt'), batch["image"].device
|
247 |
+
)
|
248 |
+
|
249 |
+
generation_kwargs = self._get_generation_kwargs(
|
250 |
+
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
|
251 |
+
)
|
252 |
+
# Let subclasses override these defaults if needed
|
253 |
+
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
|
254 |
+
|
255 |
+
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
|
256 |
+
outputs = torch.cat([prompt_tokens.input_ids, outputs], dim=1)
|
257 |
+
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
258 |
+
|
259 |
+
return raw_svg
|
260 |
+
|
261 |
+
def generate_im2svg_grpo(self, batch, **kwargs):
|
262 |
+
"""Base implementation of image to SVG generation"""
|
263 |
+
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
|
264 |
+
batch, kwargs.get('prompt'), batch["image"].device
|
265 |
+
)
|
266 |
+
|
267 |
+
generation_kwargs = self._get_generation_kwargs(
|
268 |
+
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
|
269 |
+
)
|
270 |
+
# Let subclasses override these defaults if needed
|
271 |
+
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
|
272 |
+
|
273 |
+
num_return_sequences = kwargs.get('num_return_sequences', 1)
|
274 |
+
if num_return_sequences > 1:
|
275 |
+
generation_kwargs['num_return_sequences'] = num_return_sequences
|
276 |
+
generation_kwargs['num_beams'] = 1
|
277 |
+
|
278 |
+
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
|
279 |
+
outputs = torch.cat([prompt_tokens.input_ids.repeat(num_return_sequences, 1), outputs], dim=1)
|
280 |
+
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
281 |
+
|
282 |
+
return {
|
283 |
+
"raw_svg": raw_svg,
|
284 |
+
"outputs": outputs,
|
285 |
+
"inputs_embeds": inputs_embeds,
|
286 |
+
}
|
287 |
+
|
288 |
+
|
289 |
+
def _get_im2svg_specific_kwargs(self, kwargs):
|
290 |
+
"""Default implementation of im2svg specific generation kwargs.
|
291 |
+
Subclasses can override this to customize generation behavior."""
|
292 |
+
return {
|
293 |
+
'early_stopping': True,
|
294 |
+
'pad_token_id': self.svg_transformer.tokenizer.pad_token_id
|
295 |
+
}
|
296 |
+
|
297 |
+
def generate_text2svg(self, batch, **kwargs):
|
298 |
+
"""Base implementation of text to SVG generation"""
|
299 |
+
device = batch["image"].device
|
300 |
+
prompt = batch["caption"]
|
301 |
+
|
302 |
+
prompt_tokens = self._tokenize(
|
303 |
+
prompt,
|
304 |
+
max_length=kwargs.get('max_length', 30),
|
305 |
+
device=device,
|
306 |
+
add_special_tokens=False
|
307 |
+
)
|
308 |
+
|
309 |
+
trigger_token = self._tokenize(
|
310 |
+
[self.svg_transformer.svg_start_token for _ in batch["caption"]],
|
311 |
+
max_length=None,
|
312 |
+
device=device,
|
313 |
+
add_special_tokens=False
|
314 |
+
)
|
315 |
+
|
316 |
+
input_tokens = torch.cat([prompt_tokens.input_ids, trigger_token.input_ids], dim=1)
|
317 |
+
attention_mask = torch.cat([prompt_tokens.attention_mask, trigger_token.attention_mask], dim=1)
|
318 |
+
inputs_embeds = self._get_embeddings(input_tokens)
|
319 |
+
max_length = kwargs.get('max_length', 30) - input_tokens.size(1)
|
320 |
+
|
321 |
+
generation_kwargs = self._get_generation_kwargs(
|
322 |
+
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask},
|
323 |
+
input_tokens.size(1)
|
324 |
+
)
|
325 |
+
# Let subclasses override these defaults if needed
|
326 |
+
generation_kwargs.update(self._get_text2svg_specific_kwargs(kwargs))
|
327 |
+
generation_kwargs['max_length'] = max_length
|
328 |
+
|
329 |
+
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
|
330 |
+
return outputs
|
331 |
+
|
332 |
+
def _get_text2svg_specific_kwargs(self, kwargs):
|
333 |
+
"""Default implementation of text2svg specific generation kwargs.
|
334 |
+
Subclasses can override this to customize generation behavior."""
|
335 |
+
return {
|
336 |
+
'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
|
337 |
+
'early_stopping': True,
|
338 |
+
'length_penalty': kwargs.get('length_penalty', 1.0)
|
339 |
+
}
|
starvector/model/models/starvector_v1.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from starvector.model.models.starvector_base import StarVectorBase
|
4 |
+
from transformers import AutoProcessor
|
5 |
+
|
6 |
+
class StarVectorStarCoder(StarVectorBase):
|
7 |
+
def __init__(self, config, **kwargs):
|
8 |
+
super().__init__(config, **kwargs)
|
9 |
+
|
10 |
+
self.processor = AutoProcessor.from_pretrained(config._name_or_path)
|
11 |
+
|
12 |
+
def _get_svg_transformer(self, config, **kwargs):
|
13 |
+
from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1)
|
14 |
+
return StarCoderModel(config, **kwargs)
|
15 |
+
|
16 |
+
def _get_embeddings(self, input_ids):
|
17 |
+
"""V1 specific embedding method"""
|
18 |
+
return self.svg_transformer.transformer.transformer.wte(input_ids)
|
19 |
+
|
20 |
+
def _get_svg_text(self, svg_list):
|
21 |
+
"""V1 specific SVG text preparation"""
|
22 |
+
return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list]
|
starvector/model/models/starvector_v2.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
|
4 |
+
from functools import partial
|
5 |
+
from starvector.model.models.starvector_base import StarVectorBase
|
6 |
+
from transformers import AutoImageProcessor
|
7 |
+
|
8 |
+
class StarVectorStarCoder2(StarVectorBase):
|
9 |
+
def __init__(self, config, **kwargs):
|
10 |
+
super().__init__(config, **kwargs)
|
11 |
+
|
12 |
+
self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
|
13 |
+
|
14 |
+
def _get_svg_transformer(self, config, **kwargs):
|
15 |
+
from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2
|
16 |
+
return StarCoderModel(config, **kwargs)
|
17 |
+
|
18 |
+
|
19 |
+
def get_fsdp_wrapping_policy(self):
|
20 |
+
"""V2 specific FSDP wrapping policy"""
|
21 |
+
from starvector.model.image_encoder.image_encoder import ImageEncoder
|
22 |
+
|
23 |
+
image_encoder_wrapping_policy = partial(
|
24 |
+
_module_wrap_policy,
|
25 |
+
module_classes={ImageEncoder},
|
26 |
+
)
|
27 |
+
|
28 |
+
llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy()
|
29 |
+
from starvector.model.adapters.adapter import Adapter
|
30 |
+
|
31 |
+
adapter_wrapping_policy = partial(
|
32 |
+
_module_wrap_policy,
|
33 |
+
module_classes={Adapter},
|
34 |
+
)
|
35 |
+
|
36 |
+
return partial(
|
37 |
+
_or_policy,
|
38 |
+
policies=[
|
39 |
+
image_encoder_wrapping_policy,
|
40 |
+
llm_fsdp_wrapping_policy,
|
41 |
+
adapter_wrapping_policy,
|
42 |
+
],
|
43 |
+
)
|
44 |
+
|
45 |
+
def _get_embeddings(self, input_ids):
|
46 |
+
"""V2 specific embedding method"""
|
47 |
+
return self.svg_transformer.transformer.model.embed_tokens(input_ids)
|
48 |
+
|
49 |
+
def _get_svg_text(self, svg_list):
|
50 |
+
"""V2 specific SVG text preparation"""
|
51 |
+
return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list]
|
52 |
+
|
53 |
+
def _get_im2svg_specific_kwargs(self, kwargs):
|
54 |
+
"""V2 specific generation kwargs"""
|
55 |
+
return {
|
56 |
+
# 'eos_token_id': self.svg_transformer.svg_end_token_id,
|
57 |
+
}
|
58 |
+
|
59 |
+
def _get_text2svg_specific_kwargs(self, kwargs):
|
60 |
+
"""V2 specific text2svg generation kwargs"""
|
61 |
+
return {
|
62 |
+
'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
|
63 |
+
}
|
starvector/model/starvector_arch.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
PretrainedConfig,
|
3 |
+
PreTrainedModel
|
4 |
+
)
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from transformers.processing_utils import ProcessorMixin
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode, pad
|
13 |
+
from transformers.feature_extraction_sequence_utils import BatchFeature
|
14 |
+
from transformers import AutoProcessor
|
15 |
+
|
16 |
+
class SimpleStarVectorProcessor(ProcessorMixin):
|
17 |
+
attributes = ["tokenizer"] # Only include tokenizer in attributes
|
18 |
+
valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs
|
19 |
+
image_processor_class = "AutoImageProcessor"
|
20 |
+
tokenizer_class = "AutoTokenizer"
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
tokenizer=None, # Make tokenizer the first argument
|
24 |
+
size=224,
|
25 |
+
mean=None,
|
26 |
+
std=None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
if mean is None:
|
30 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
31 |
+
if std is None:
|
32 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
33 |
+
|
34 |
+
# Store these as instance variables
|
35 |
+
self.mean = mean
|
36 |
+
self.std = std
|
37 |
+
self.size = size
|
38 |
+
self.normalize = transforms.Normalize(mean=mean, std=std)
|
39 |
+
|
40 |
+
self.transform = transforms.Compose([
|
41 |
+
transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img),
|
42 |
+
transforms.Lambda(lambda img: self._pad_to_square(img)),
|
43 |
+
transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
self.normalize
|
46 |
+
])
|
47 |
+
|
48 |
+
# Initialize parent class with tokenizer
|
49 |
+
super().__init__(tokenizer=tokenizer)
|
50 |
+
|
51 |
+
|
52 |
+
def __call__(self, images=None, text=None, max_length=None, **kwargs) -> BatchFeature:
|
53 |
+
"""
|
54 |
+
Process images and/or text inputs.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
images: Optional image input(s)
|
58 |
+
text: Optional text input(s)
|
59 |
+
**kwargs: Additional arguments
|
60 |
+
"""
|
61 |
+
if images is None and text is None:
|
62 |
+
raise ValueError("You have to specify at least one of `images` or `text`.")
|
63 |
+
|
64 |
+
image_inputs = {}
|
65 |
+
if images is not None:
|
66 |
+
if isinstance(images, (list, tuple)):
|
67 |
+
images_ = torch.stack([self.transform(img) for img in images])
|
68 |
+
else:
|
69 |
+
images_ = self.transform(images)
|
70 |
+
image_inputs = {"pixel_values": images_}
|
71 |
+
|
72 |
+
text_inputs = {}
|
73 |
+
if text is not None:
|
74 |
+
text_inputs = self.tokenizer(
|
75 |
+
text, truncation=True,
|
76 |
+
add_special_tokens=True,
|
77 |
+
padding='longest',
|
78 |
+
max_length=max_length,
|
79 |
+
return_tensors="pt"
|
80 |
+
)
|
81 |
+
|
82 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
83 |
+
|
84 |
+
def _pad_to_square(self, img):
|
85 |
+
# Calculate padding to make the image square
|
86 |
+
width, height = img.size
|
87 |
+
max_dim = max(width, height)
|
88 |
+
padding = [(max_dim - width) // 2, (max_dim - height) // 2]
|
89 |
+
padding += [max_dim - width - padding[0], max_dim - height - padding[1]]
|
90 |
+
return pad(img, padding, fill=255) # Assuming white padding
|
91 |
+
|
92 |
+
|
93 |
+
AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor)
|
94 |
+
|
95 |
+
|
96 |
+
class StarVectorConfig(PretrainedConfig):
|
97 |
+
model_type = "starvector"
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
starcoder_model_name: str = "bigcode/starcoderbase-1b",
|
102 |
+
image_encoder_type: str = "clip",
|
103 |
+
adapter_norm: str = "layer_norm",
|
104 |
+
image_size: int = 224,
|
105 |
+
max_length: int = 8192,
|
106 |
+
max_length_train: int = 8192,
|
107 |
+
use_flash_attn: bool = True,
|
108 |
+
use_cache: bool = True,
|
109 |
+
num_attention_heads: int = 16,
|
110 |
+
num_hidden_layers: int = 24,
|
111 |
+
vocab_size: int = 49152,
|
112 |
+
hidden_size: int = 2048,
|
113 |
+
num_kv_heads: int = 4,
|
114 |
+
torch_dtype: str = "bfloat16",
|
115 |
+
**kwargs,
|
116 |
+
):
|
117 |
+
kwargs["torch_dtype"] = torch_dtype
|
118 |
+
self.starcoder_model_name = starcoder_model_name
|
119 |
+
self.image_encoder_type = image_encoder_type
|
120 |
+
self.adapter_norm = adapter_norm
|
121 |
+
self.image_size = image_size
|
122 |
+
self.max_length = max_length
|
123 |
+
self.max_length_train = max_length_train
|
124 |
+
self.use_flash_attn = use_flash_attn
|
125 |
+
self.use_cache = use_cache
|
126 |
+
self.num_attention_heads = num_attention_heads
|
127 |
+
self.num_hidden_layers = num_hidden_layers
|
128 |
+
self.vocab_size = vocab_size
|
129 |
+
self.hidden_size = hidden_size
|
130 |
+
self.num_kv_heads = num_kv_heads
|
131 |
+
super().__init__(**kwargs)
|
132 |
+
|
133 |
+
class StarVectorForCausalLM(PreTrainedModel):
|
134 |
+
config_class = StarVectorConfig
|
135 |
+
_no_split_modules = []
|
136 |
+
|
137 |
+
def __init__(self, config: StarVectorConfig, **kwargs):
|
138 |
+
super().__init__(config)
|
139 |
+
starcoder_model_name = config.starcoder_model_name
|
140 |
+
if 'starcoder2' in starcoder_model_name:
|
141 |
+
from starvector.model.models.starvector_v2 import StarVectorStarCoder2
|
142 |
+
self.model = StarVectorStarCoder2(config=config, **kwargs)
|
143 |
+
else:
|
144 |
+
from starvector.model.models.starvector_v1 import StarVectorStarCoder
|
145 |
+
self.model = StarVectorStarCoder(config=config, **kwargs)
|
146 |
+
|
147 |
+
|
148 |
+
@property
|
149 |
+
def supports_gradient_checkpointing(self):
|
150 |
+
# If the underlying transformer (e.g., the one in StarCoderModel)
|
151 |
+
# supports gradient checkpointing, delegate to it.
|
152 |
+
if hasattr(self.model, 'svg_transformer'):
|
153 |
+
return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False)
|
154 |
+
return False
|
155 |
+
|
156 |
+
def gradient_checkpointing_enable(self):
|
157 |
+
# Optionally, forward this call to the internal transformer.
|
158 |
+
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
159 |
+
self.model.svg_transformer.gradient_checkpointing_enable()
|
160 |
+
|
161 |
+
def forward(self, vision_embeds, input_ids, num_generations, attention_mask, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
162 |
+
completion_embeds = self.model._get_embeddings(input_ids)
|
163 |
+
inputs_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
|
164 |
+
|
165 |
+
transformer_outputs = self.model.svg_transformer.transformer.transformer(
|
166 |
+
inputs_embeds=inputs_embeds,
|
167 |
+
attention_mask=attention_mask,
|
168 |
+
)
|
169 |
+
hidden_states = transformer_outputs[0]
|
170 |
+
|
171 |
+
if num_logits_to_keep > 0:
|
172 |
+
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
173 |
+
else:
|
174 |
+
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states)
|
175 |
+
|
176 |
+
loss = None
|
177 |
+
return CausalLMOutputWithCrossAttentions(
|
178 |
+
loss=loss,
|
179 |
+
logits=lm_logits,
|
180 |
+
past_key_values=transformer_outputs.past_key_values,
|
181 |
+
hidden_states=transformer_outputs.hidden_states,
|
182 |
+
attentions=transformer_outputs.attentions,
|
183 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
184 |
+
)
|
185 |
+
|
186 |
+
def generate_im2svg(self, batch, **kwargs):
|
187 |
+
return self.model.generate_im2svg(batch, **kwargs)
|
188 |
+
|
189 |
+
def generate_im2text(self, batch, **kwargs):
|
190 |
+
return self.model.generate_im2text(batch, **kwargs)
|
191 |
+
|
192 |
+
def process_images(self, images):
|
193 |
+
return self.model.image_encoder.process_images(images)
|
194 |
+
|
starvector/serve/__init__.py
ADDED
File without changes
|
starvector/serve/constants.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
14 |
+
|
15 |
+
CLIP_QUERY_LENGTH = 257
|
16 |
+
|
starvector/serve/controller.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from typing import List, Union
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.responses import StreamingResponse
|
17 |
+
import numpy as np
|
18 |
+
import requests
|
19 |
+
import uvicorn
|
20 |
+
|
21 |
+
from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
+
from starvector.serve.util import build_logger, server_error_msg
|
23 |
+
|
24 |
+
logger = build_logger("controller", "controller.log")
|
25 |
+
|
26 |
+
class DispatchMethod(Enum):
|
27 |
+
LOTTERY = auto()
|
28 |
+
SHORTEST_QUEUE = auto()
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_str(cls, name):
|
32 |
+
if name == "lottery":
|
33 |
+
return cls.LOTTERY
|
34 |
+
elif name == "shortest_queue":
|
35 |
+
return cls.SHORTEST_QUEUE
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Invalid dispatch method")
|
38 |
+
|
39 |
+
|
40 |
+
@dataclasses.dataclass
|
41 |
+
class WorkerInfo:
|
42 |
+
model_names: List[str]
|
43 |
+
speed: int
|
44 |
+
queue_length: int
|
45 |
+
check_heart_beat: bool
|
46 |
+
last_heart_beat: str
|
47 |
+
|
48 |
+
|
49 |
+
def heart_beat_controller(controller):
|
50 |
+
while True:
|
51 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
52 |
+
controller.remove_stable_workers_by_expiration()
|
53 |
+
|
54 |
+
|
55 |
+
class Controller:
|
56 |
+
def __init__(self, dispatch_method: str):
|
57 |
+
# Dict[str -> WorkerInfo]
|
58 |
+
self.worker_info = {}
|
59 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
60 |
+
|
61 |
+
self.heart_beat_thread = threading.Thread(
|
62 |
+
target=heart_beat_controller, args=(self,))
|
63 |
+
self.heart_beat_thread.start()
|
64 |
+
|
65 |
+
logger.info("Init controller")
|
66 |
+
|
67 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
68 |
+
worker_status: dict):
|
69 |
+
if worker_name not in self.worker_info:
|
70 |
+
logger.info(f"Register a new worker: {worker_name}")
|
71 |
+
else:
|
72 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
73 |
+
|
74 |
+
if not worker_status:
|
75 |
+
worker_status = self.get_worker_status(worker_name)
|
76 |
+
if not worker_status:
|
77 |
+
return False
|
78 |
+
|
79 |
+
self.worker_info[worker_name] = WorkerInfo(
|
80 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
81 |
+
check_heart_beat, time.time())
|
82 |
+
|
83 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
84 |
+
return True
|
85 |
+
|
86 |
+
def get_worker_status(self, worker_name: str):
|
87 |
+
try:
|
88 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
89 |
+
except requests.exceptions.RequestException as e:
|
90 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
91 |
+
return None
|
92 |
+
|
93 |
+
if r.status_code != 200:
|
94 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
95 |
+
return None
|
96 |
+
|
97 |
+
return r.json()
|
98 |
+
|
99 |
+
def remove_worker(self, worker_name: str):
|
100 |
+
del self.worker_info[worker_name]
|
101 |
+
|
102 |
+
def refresh_all_workers(self):
|
103 |
+
old_info = dict(self.worker_info)
|
104 |
+
self.worker_info = {}
|
105 |
+
|
106 |
+
for w_name, w_info in old_info.items():
|
107 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
108 |
+
logger.info(f"Remove stale worker: {w_name}")
|
109 |
+
|
110 |
+
def list_models(self):
|
111 |
+
model_names = set()
|
112 |
+
|
113 |
+
for w_name, w_info in self.worker_info.items():
|
114 |
+
model_names.update(w_info.model_names)
|
115 |
+
|
116 |
+
return list(model_names)
|
117 |
+
|
118 |
+
def get_worker_address(self, model_name: str):
|
119 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
120 |
+
worker_names = []
|
121 |
+
worker_speeds = []
|
122 |
+
for w_name, w_info in self.worker_info.items():
|
123 |
+
if model_name in w_info.model_names:
|
124 |
+
worker_names.append(w_name)
|
125 |
+
worker_speeds.append(w_info.speed)
|
126 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
127 |
+
norm = np.sum(worker_speeds)
|
128 |
+
if norm < 1e-4:
|
129 |
+
return ""
|
130 |
+
worker_speeds = worker_speeds / norm
|
131 |
+
if True: # Directly return address
|
132 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
133 |
+
p=worker_speeds)
|
134 |
+
worker_name = worker_names[pt]
|
135 |
+
return worker_name
|
136 |
+
|
137 |
+
# Check status before returning
|
138 |
+
while True:
|
139 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
140 |
+
p=worker_speeds)
|
141 |
+
worker_name = worker_names[pt]
|
142 |
+
|
143 |
+
if self.get_worker_status(worker_name):
|
144 |
+
break
|
145 |
+
else:
|
146 |
+
self.remove_worker(worker_name)
|
147 |
+
worker_speeds[pt] = 0
|
148 |
+
norm = np.sum(worker_speeds)
|
149 |
+
if norm < 1e-4:
|
150 |
+
return ""
|
151 |
+
worker_speeds = worker_speeds / norm
|
152 |
+
continue
|
153 |
+
return worker_name
|
154 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
155 |
+
worker_names = []
|
156 |
+
worker_qlen = []
|
157 |
+
for w_name, w_info in self.worker_info.items():
|
158 |
+
if model_name in w_info.model_names:
|
159 |
+
worker_names.append(w_name)
|
160 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
161 |
+
if len(worker_names) == 0:
|
162 |
+
return ""
|
163 |
+
min_index = np.argmin(worker_qlen)
|
164 |
+
w_name = worker_names[min_index]
|
165 |
+
self.worker_info[w_name].queue_length += 1
|
166 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
167 |
+
return w_name
|
168 |
+
else:
|
169 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
170 |
+
|
171 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
172 |
+
if worker_name not in self.worker_info:
|
173 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
174 |
+
return False
|
175 |
+
|
176 |
+
self.worker_info[worker_name].queue_length = queue_length
|
177 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
178 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
179 |
+
return True
|
180 |
+
|
181 |
+
def remove_stable_workers_by_expiration(self):
|
182 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
183 |
+
to_delete = []
|
184 |
+
for worker_name, w_info in self.worker_info.items():
|
185 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
186 |
+
to_delete.append(worker_name)
|
187 |
+
|
188 |
+
for worker_name in to_delete:
|
189 |
+
self.remove_worker(worker_name)
|
190 |
+
|
191 |
+
def worker_api_generate_stream(self, params):
|
192 |
+
worker_addr = self.get_worker_address(params["model"])
|
193 |
+
if not worker_addr:
|
194 |
+
logger.info(f"no worker: {params['model']}")
|
195 |
+
ret = {
|
196 |
+
"text": server_error_msg,
|
197 |
+
"error_code": 2,
|
198 |
+
}
|
199 |
+
yield json.dumps(ret).encode() + b"\0"
|
200 |
+
|
201 |
+
try:
|
202 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
203 |
+
json=params, stream=True, timeout=5)
|
204 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
205 |
+
if chunk:
|
206 |
+
yield chunk + b"\0"
|
207 |
+
except requests.exceptions.RequestException as e:
|
208 |
+
logger.info(f"worker timeout: {worker_addr}")
|
209 |
+
ret = {
|
210 |
+
"text": server_error_msg,
|
211 |
+
"error_code": 3,
|
212 |
+
}
|
213 |
+
yield json.dumps(ret).encode() + b"\0"
|
214 |
+
|
215 |
+
|
216 |
+
# Let the controller act as a worker to achieve hierarchical
|
217 |
+
# management. This can be used to connect isolated sub networks.
|
218 |
+
def worker_api_get_status(self):
|
219 |
+
model_names = set()
|
220 |
+
speed = 0
|
221 |
+
queue_length = 0
|
222 |
+
|
223 |
+
for w_name in self.worker_info:
|
224 |
+
worker_status = self.get_worker_status(w_name)
|
225 |
+
if worker_status is not None:
|
226 |
+
model_names.update(worker_status["model_names"])
|
227 |
+
speed += worker_status["speed"]
|
228 |
+
queue_length += worker_status["queue_length"]
|
229 |
+
|
230 |
+
return {
|
231 |
+
"model_names": list(model_names),
|
232 |
+
"speed": speed,
|
233 |
+
"queue_length": queue_length,
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
app = FastAPI()
|
238 |
+
|
239 |
+
@app.post("/register_worker")
|
240 |
+
async def register_worker(request: Request):
|
241 |
+
data = await request.json()
|
242 |
+
controller.register_worker(
|
243 |
+
data["worker_name"], data["check_heart_beat"],
|
244 |
+
data.get("worker_status", None))
|
245 |
+
|
246 |
+
@app.post("/refresh_all_workers")
|
247 |
+
async def refresh_all_workers():
|
248 |
+
models = controller.refresh_all_workers()
|
249 |
+
|
250 |
+
|
251 |
+
@app.post("/list_models")
|
252 |
+
async def list_models():
|
253 |
+
models = controller.list_models()
|
254 |
+
return {"models": models}
|
255 |
+
|
256 |
+
|
257 |
+
@app.post("/get_worker_address")
|
258 |
+
async def get_worker_address(request: Request):
|
259 |
+
data = await request.json()
|
260 |
+
addr = controller.get_worker_address(data["model"])
|
261 |
+
return {"address": addr}
|
262 |
+
|
263 |
+
@app.post("/receive_heart_beat")
|
264 |
+
async def receive_heart_beat(request: Request):
|
265 |
+
data = await request.json()
|
266 |
+
exist = controller.receive_heart_beat(
|
267 |
+
data["worker_name"], data["queue_length"])
|
268 |
+
return {"exist": exist}
|
269 |
+
|
270 |
+
|
271 |
+
@app.post("/worker_generate_stream")
|
272 |
+
async def worker_api_generate_stream(request: Request):
|
273 |
+
params = await request.json()
|
274 |
+
generator = controller.worker_api_generate_stream(params)
|
275 |
+
return StreamingResponse(generator)
|
276 |
+
|
277 |
+
|
278 |
+
@app.post("/worker_get_status")
|
279 |
+
async def worker_api_get_status(request: Request):
|
280 |
+
return controller.worker_api_get_status()
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
parser = argparse.ArgumentParser()
|
285 |
+
parser.add_argument("--host", type=str, default="localhost")
|
286 |
+
parser.add_argument("--port", type=int, default=21001)
|
287 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
288 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
289 |
+
args = parser.parse_args()
|
290 |
+
logger.info(f"args: {args}")
|
291 |
+
|
292 |
+
controller = Controller(args.dispatch_method)
|
293 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
starvector/serve/conversation.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from typing import List
|
3 |
+
from PIL import Image
|
4 |
+
import concurrent.futures
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
import cairosvg
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
@dataclasses.dataclass
|
10 |
+
class Conversation:
|
11 |
+
"""A class that keeps all conversation history."""
|
12 |
+
system: str
|
13 |
+
image_prompt: str
|
14 |
+
roles: List[str]
|
15 |
+
messages: List[List[str]]
|
16 |
+
offset: int
|
17 |
+
version: str = "Unknown"
|
18 |
+
stop_sampling: bool = False
|
19 |
+
skip_next: bool = False
|
20 |
+
display_images: bool = False
|
21 |
+
task: str = "Im2SVG"
|
22 |
+
|
23 |
+
def set_task(self, task):
|
24 |
+
self.task = task
|
25 |
+
|
26 |
+
def get_image_prompt(self):
|
27 |
+
return self.image_prompt
|
28 |
+
|
29 |
+
def get_images(self, return_pil=False):
|
30 |
+
images = []
|
31 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
32 |
+
if i % 2 == 0:
|
33 |
+
if type(msg) is tuple:
|
34 |
+
import base64
|
35 |
+
from io import BytesIO
|
36 |
+
from PIL import Image
|
37 |
+
image, image_process_mode = msg
|
38 |
+
if image_process_mode == "Pad":
|
39 |
+
def expand2square(pil_img, background_color=(255, 255, 255)):
|
40 |
+
width, height = pil_img.size
|
41 |
+
if width == height:
|
42 |
+
return pil_img
|
43 |
+
elif width > height:
|
44 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
45 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
46 |
+
return result
|
47 |
+
else:
|
48 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
49 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
50 |
+
return result
|
51 |
+
image = expand2square(image)
|
52 |
+
elif image_process_mode in ["Default", "Crop"]:
|
53 |
+
pass
|
54 |
+
elif image_process_mode == "Resize":
|
55 |
+
image = image.resize((224, 224))
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
58 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
59 |
+
aspect_ratio = max_hw / min_hw
|
60 |
+
max_len, min_len = 800, 400
|
61 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
62 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
63 |
+
W, H = image.size
|
64 |
+
if longest_edge != max(image.size):
|
65 |
+
if H > W:
|
66 |
+
H, W = longest_edge, shortest_edge
|
67 |
+
else:
|
68 |
+
H, W = shortest_edge, longest_edge
|
69 |
+
image = image.resize((W, H))
|
70 |
+
if return_pil:
|
71 |
+
images.append(image)
|
72 |
+
else:
|
73 |
+
buffered = BytesIO()
|
74 |
+
image.save(buffered, format="PNG")
|
75 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
76 |
+
images.append(img_b64_str)
|
77 |
+
return images
|
78 |
+
|
79 |
+
def append_message(self, role, message):
|
80 |
+
self.messages.append([role, message])
|
81 |
+
|
82 |
+
def download_files(self):
|
83 |
+
svg_string = self.messages[-1][-1][:-1]
|
84 |
+
image = self.render_svg(svg_string)
|
85 |
+
svg_out = clean_svg(svg_string)
|
86 |
+
|
87 |
+
return image, svg_out
|
88 |
+
|
89 |
+
def rasterize_svg(self, svg_string, resolution=224, dpi = 128, scale=2):
|
90 |
+
try:
|
91 |
+
svg_raster_bytes = cairosvg.svg2png(
|
92 |
+
bytestring=svg_string,
|
93 |
+
background_color='white',
|
94 |
+
output_width=resolution,
|
95 |
+
output_height=resolution,
|
96 |
+
dpi=dpi,
|
97 |
+
scale=scale)
|
98 |
+
svg_raster = Image.open(BytesIO(svg_raster_bytes))
|
99 |
+
except:
|
100 |
+
try:
|
101 |
+
svg = self.clean_svg(svg_string)
|
102 |
+
svg_raster_bytes = cairosvg.svg2png(
|
103 |
+
bytestring=svg,
|
104 |
+
background_color='white',
|
105 |
+
output_width=resolution,
|
106 |
+
output_height=resolution,
|
107 |
+
dpi=dpi,
|
108 |
+
scale=scale)
|
109 |
+
svg_raster = Image.open(BytesIO(svg_raster_bytes))
|
110 |
+
except:
|
111 |
+
svg_raster = Image.new('RGB', (resolution, resolution), color = 'white')
|
112 |
+
return svg_raster
|
113 |
+
|
114 |
+
def clean_svg(self, svg_text, output_width=None, output_height=None):
|
115 |
+
soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml
|
116 |
+
svg_bs4 = soup.prettify() # Prettify to get a string
|
117 |
+
svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode()
|
118 |
+
svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("<?xml")]) # Remove xml header
|
119 |
+
return svg_clean
|
120 |
+
|
121 |
+
def render_svg(self, svg_string):
|
122 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
123 |
+
future = executor.submit(self.rasterize_svg, svg_string, resolution = 512)
|
124 |
+
try:
|
125 |
+
result = future.result(timeout=0.1) # Specify the timeout duration in seconds
|
126 |
+
except concurrent.futures.TimeoutError:
|
127 |
+
print("Timeout occurred!")
|
128 |
+
result = None
|
129 |
+
return result
|
130 |
+
|
131 |
+
def to_gradio_svg_render(self):
|
132 |
+
svg_string = self.messages[-1][-1][:-1]
|
133 |
+
result = self.render_svg(svg_string)
|
134 |
+
return result
|
135 |
+
|
136 |
+
def to_gradio_svg_code(self):
|
137 |
+
ret = []
|
138 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
139 |
+
if i % 2 == 0:
|
140 |
+
if type(msg) is tuple:
|
141 |
+
import base64
|
142 |
+
from io import BytesIO
|
143 |
+
image, image_process_mode = msg
|
144 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
145 |
+
aspect_ratio = max_hw / min_hw
|
146 |
+
max_len, min_len = 800, 400
|
147 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
148 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
149 |
+
W, H = image.size
|
150 |
+
if H > W:
|
151 |
+
H, W = longest_edge, shortest_edge
|
152 |
+
else:
|
153 |
+
H, W = shortest_edge, longest_edge
|
154 |
+
image = image.resize((W, H))
|
155 |
+
buffered = BytesIO()
|
156 |
+
image.save(buffered, format="JPEG")
|
157 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
158 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
159 |
+
msg = img_str
|
160 |
+
ret.append([msg, None])
|
161 |
+
else:
|
162 |
+
ret.append([msg, None])
|
163 |
+
else:
|
164 |
+
ret[-1][-1] = msg
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def copy(self):
|
168 |
+
return Conversation(
|
169 |
+
system=self.system,
|
170 |
+
image_prompt=self.image_prompt,
|
171 |
+
roles=self.roles,
|
172 |
+
messages=[[x, y] for x, y in self.messages],
|
173 |
+
offset=self.offset,
|
174 |
+
version=self.version
|
175 |
+
|
176 |
+
)
|
177 |
+
def dict(self):
|
178 |
+
if len(self.get_images()) > 0:
|
179 |
+
return {
|
180 |
+
"system": self.system,
|
181 |
+
"image_prompt": self.image_prompt,
|
182 |
+
"roles": self.roles,
|
183 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
184 |
+
"offset": self.offset,
|
185 |
+
}
|
186 |
+
return {
|
187 |
+
"system": self.system,
|
188 |
+
"image_prompt": self.image_prompt,
|
189 |
+
"roles": self.roles,
|
190 |
+
"messages": self.messages,
|
191 |
+
"offset": self.offset,
|
192 |
+
}
|
193 |
+
|
194 |
+
starvector_v1 = Conversation(
|
195 |
+
system="StarVector",
|
196 |
+
# prompt='<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 32 32" version="1.1">',
|
197 |
+
image_prompt='<svg',
|
198 |
+
roles=("Human", "StarVector"),
|
199 |
+
version="v1",
|
200 |
+
messages=(
|
201 |
+
),
|
202 |
+
offset=0,
|
203 |
+
task="Im2SVG",
|
204 |
+
)
|
205 |
+
default_conversation = starvector_v1
|
206 |
+
conv_templates = {
|
207 |
+
"default": default_conversation,
|
208 |
+
}
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
print(default_conversation.get_image_prompt())
|
starvector/serve/gradio_demo_with_updated_gradio.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
from starvector.serve.conversation import default_conversation
|
9 |
+
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
+
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
+
|
12 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
+
headers = {"User-Agent": "StarVector Client"}
|
14 |
+
|
15 |
+
no_change_btn = gr.Button()
|
16 |
+
enable_btn = gr.Button(interactive=True)
|
17 |
+
disable_btn = gr.Button(interactive=False)
|
18 |
+
|
19 |
+
priority = {
|
20 |
+
"starvector-1.4b": "aaaaaaa",
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_conv_log_filename():
|
24 |
+
t = datetime.datetime.now()
|
25 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
+
return name
|
27 |
+
|
28 |
+
def get_model_list():
|
29 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
+
assert ret.status_code == 200
|
31 |
+
ret = requests.post(args.controller_url + "/list_models")
|
32 |
+
models = ret.json()["models"]
|
33 |
+
models.sort(key=lambda x: priority.get(x, x))
|
34 |
+
logger.info(f"Models: {models}")
|
35 |
+
return models
|
36 |
+
|
37 |
+
get_window_url_params = """
|
38 |
+
function() {
|
39 |
+
const params = new URLSearchParams(window.location.search);
|
40 |
+
url_params = Object.fromEntries(params);
|
41 |
+
console.log(url_params);
|
42 |
+
return url_params;
|
43 |
+
}
|
44 |
+
"""
|
45 |
+
|
46 |
+
def load_demo(url_params, request: gr.Request):
|
47 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
48 |
+
|
49 |
+
dropdown_update = gr.Dropdown(visible=True)
|
50 |
+
if "model" in url_params:
|
51 |
+
model = url_params["model"]
|
52 |
+
if model in models:
|
53 |
+
dropdown_update = gr.Dropdown(value=model, visible=True)
|
54 |
+
|
55 |
+
state = default_conversation.copy()
|
56 |
+
return state, dropdown_update
|
57 |
+
|
58 |
+
|
59 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
60 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
61 |
+
models = get_model_list()
|
62 |
+
state = default_conversation.copy()
|
63 |
+
dropdown_update = gr.Dropdown(
|
64 |
+
choices=models,
|
65 |
+
value=models[0] if len(models) > 0 else ""
|
66 |
+
)
|
67 |
+
return state, dropdown_update
|
68 |
+
|
69 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
70 |
+
with open(get_conv_log_filename(), "a") as fout:
|
71 |
+
data = {
|
72 |
+
"tstamp": round(time.time(), 4),
|
73 |
+
"type": vote_type,
|
74 |
+
"model": model_selector,
|
75 |
+
"state": state.dict(),
|
76 |
+
"ip": request.client.host,
|
77 |
+
}
|
78 |
+
fout.write(json.dumps(data) + "\n")
|
79 |
+
|
80 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
81 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
82 |
+
vote_last_response(state, "upvote", model_selector, request)
|
83 |
+
return ("",) + (disable_btn,) * 3
|
84 |
+
|
85 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
86 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
87 |
+
vote_last_response(state, "downvote", model_selector, request)
|
88 |
+
return ("",) + (disable_btn,) * 3
|
89 |
+
|
90 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
91 |
+
logger.info(f"flag. ip: {request.client.host}")
|
92 |
+
vote_last_response(state, "flag", model_selector, request)
|
93 |
+
return ("",) + (disable_btn,) * 3
|
94 |
+
|
95 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
96 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
97 |
+
state.messages[-1][-1] = None
|
98 |
+
prev_human_msg = state.messages[-2]
|
99 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
100 |
+
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
101 |
+
state.skip_next = False
|
102 |
+
return (state, None, None, None) + (disable_btn,) * 6
|
103 |
+
|
104 |
+
def clear_history(request: gr.Request):
|
105 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
106 |
+
state = default_conversation.copy()
|
107 |
+
return (state, None, None) + (disable_btn,) * 6
|
108 |
+
|
109 |
+
def send_image(state, image, image_process_mode, request: gr.Request):
|
110 |
+
logger.info(f"send_image. ip: {request.client.host}.")
|
111 |
+
state.stop_sampling = False
|
112 |
+
if image is None:
|
113 |
+
state.skip_next = True
|
114 |
+
return (state, None, None, image) + (no_change_btn,) * 6
|
115 |
+
|
116 |
+
if image is not None:
|
117 |
+
text = (image, image_process_mode)
|
118 |
+
state.append_message(state.roles[0], text)
|
119 |
+
state.append_message(state.roles[1], "β")
|
120 |
+
state.skip_next = False
|
121 |
+
msg = state.to_gradio_svg_code()[0][1]
|
122 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
|
123 |
+
|
124 |
+
def stop_sampling(state, image, request: gr.Request):
|
125 |
+
logger.info(f"stop_sampling. ip: {request.client.host}")
|
126 |
+
state.stop_sampling = True
|
127 |
+
return (state, None, None, image) + (disable_btn,) * 6
|
128 |
+
|
129 |
+
def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
130 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
131 |
+
start_tstamp = time.time()
|
132 |
+
model_name = model_selector
|
133 |
+
|
134 |
+
if state.skip_next:
|
135 |
+
# This generate call is skipped due to invalid inputs
|
136 |
+
yield (state, None, None) + (no_change_btn,) * 6
|
137 |
+
return
|
138 |
+
|
139 |
+
# Query worker address
|
140 |
+
controller_url = args.controller_url
|
141 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
142 |
+
json={"model": model_name})
|
143 |
+
worker_addr = ret.json()["address"]
|
144 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
145 |
+
|
146 |
+
# No available worker
|
147 |
+
if worker_addr == "":
|
148 |
+
state.messages[-1][-1] = server_error_msg
|
149 |
+
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
150 |
+
return
|
151 |
+
|
152 |
+
# Construct prompt
|
153 |
+
prompt = state.get_prompt()
|
154 |
+
|
155 |
+
# Make requests
|
156 |
+
pload = {
|
157 |
+
"model": model_name,
|
158 |
+
"prompt": prompt,
|
159 |
+
"num_beams": int(num_beams),
|
160 |
+
"temperature": float(temperature),
|
161 |
+
"len_penalty": float(len_penalty),
|
162 |
+
"top_p": float(top_p),
|
163 |
+
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
164 |
+
}
|
165 |
+
logger.info(f"==== request ====\n{pload}")
|
166 |
+
|
167 |
+
pload['images'] = state.get_images()
|
168 |
+
|
169 |
+
state.messages[-1][-1] = "β"
|
170 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
|
171 |
+
|
172 |
+
try:
|
173 |
+
# Stream output
|
174 |
+
if state.stop_sampling:
|
175 |
+
state.messages[1][-1] = "β"
|
176 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
177 |
+
return
|
178 |
+
|
179 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
180 |
+
headers=headers, json=pload, stream=True, timeout=100)
|
181 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
182 |
+
if chunk:
|
183 |
+
data = json.loads(chunk.decode())
|
184 |
+
if data["error_code"] == 0:
|
185 |
+
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
186 |
+
output = data["text"].strip()
|
187 |
+
state.messages[-1][-1] = output + "β"
|
188 |
+
st = state.to_gradio_svg_code()
|
189 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
|
190 |
+
else:
|
191 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
192 |
+
state.messages[-1][-1] = output
|
193 |
+
|
194 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
195 |
+
return
|
196 |
+
time.sleep(0.03)
|
197 |
+
except requests.exceptions.RequestException as e:
|
198 |
+
state.messages[-1][-1] = server_error_msg
|
199 |
+
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
200 |
+
return
|
201 |
+
|
202 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
|
203 |
+
|
204 |
+
finish_tstamp = time.time()
|
205 |
+
logger.info(f"{output}")
|
206 |
+
|
207 |
+
with open(get_conv_log_filename(), "a") as fout:
|
208 |
+
data = {
|
209 |
+
"tstamp": round(finish_tstamp, 4),
|
210 |
+
"type": "chat",
|
211 |
+
"model": model_name,
|
212 |
+
"start": round(start_tstamp, 4),
|
213 |
+
"finish": round(finish_tstamp, 4),
|
214 |
+
"svg": state.messages[-1][-1],
|
215 |
+
"ip": request.client.host,
|
216 |
+
}
|
217 |
+
fout.write(json.dumps(data) + "\n")
|
218 |
+
|
219 |
+
title_markdown = ("""
|
220 |
+
# π« StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
221 |
+
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | π [[StarVector](https://arxiv.org/abs/2312.11556)]
|
222 |
+
""")
|
223 |
+
|
224 |
+
sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
|
225 |
+
tos_markdown = ("""
|
226 |
+
### Terms of use
|
227 |
+
By using this service, users are required to agree to the following terms:
|
228 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
229 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
230 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
231 |
+
""")
|
232 |
+
|
233 |
+
|
234 |
+
learn_more_markdown = ("""
|
235 |
+
### License
|
236 |
+
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
237 |
+
""")
|
238 |
+
|
239 |
+
block_css = """
|
240 |
+
|
241 |
+
#buttons button {
|
242 |
+
min-width: min(120px,100%);
|
243 |
+
}
|
244 |
+
|
245 |
+
.gradio-container{
|
246 |
+
max-width: 1200px!important
|
247 |
+
}
|
248 |
+
|
249 |
+
#svg_render{
|
250 |
+
padding: 20px !important;
|
251 |
+
}
|
252 |
+
|
253 |
+
#svg_code{
|
254 |
+
height: 200px !important;
|
255 |
+
overflow: scroll !important;
|
256 |
+
white-space: unset !important;
|
257 |
+
flex-shrink: unset !important;
|
258 |
+
}
|
259 |
+
|
260 |
+
|
261 |
+
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
262 |
+
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
263 |
+
"""
|
264 |
+
|
265 |
+
def build_demo(embed_mode, concurrency_count=10):
|
266 |
+
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
267 |
+
state = gr.State()
|
268 |
+
if not embed_mode:
|
269 |
+
gr.Markdown(title_markdown)
|
270 |
+
gr.Markdown(sub_title_markdown)
|
271 |
+
with gr.Row():
|
272 |
+
with gr.Column(scale=3):
|
273 |
+
with gr.Row(elem_id="model_selector_row"):
|
274 |
+
model_selector = gr.Dropdown(
|
275 |
+
choices=models,
|
276 |
+
value=models[0] if len(models) > 0 else "",
|
277 |
+
interactive=True,
|
278 |
+
show_label=False,
|
279 |
+
container=False)
|
280 |
+
imagebox = gr.Image(type="pil")
|
281 |
+
image_process_mode = gr.Radio(
|
282 |
+
["Resize", "Pad", "Default"],
|
283 |
+
value="Pad",
|
284 |
+
label="Preprocess for non-square image", visible=False)
|
285 |
+
|
286 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
287 |
+
gr.Examples(examples=[
|
288 |
+
[f"{cur_dir}/examples/sample-4.png"],
|
289 |
+
[f"{cur_dir}/examples/sample-7.png"],
|
290 |
+
[f"{cur_dir}/examples/sample-16.png"],
|
291 |
+
[f"{cur_dir}/examples/sample-17.png"],
|
292 |
+
[f"{cur_dir}/examples/sample-18.png"],
|
293 |
+
[f"{cur_dir}/examples/sample-0.png"],
|
294 |
+
[f"{cur_dir}/examples/sample-1.png"],
|
295 |
+
[f"{cur_dir}/examples/sample-6.png"],
|
296 |
+
], inputs=[imagebox])
|
297 |
+
|
298 |
+
with gr.Column(scale=1, min_width=50):
|
299 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
300 |
+
|
301 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
302 |
+
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
303 |
+
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
|
304 |
+
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
305 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
|
306 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
|
307 |
+
|
308 |
+
with gr.Column(scale=8):
|
309 |
+
with gr.Row():
|
310 |
+
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
|
311 |
+
with gr.Row():
|
312 |
+
gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
|
313 |
+
with gr.Row(elem_id="buttons") as button_row:
|
314 |
+
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
315 |
+
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
316 |
+
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
317 |
+
stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False, visible=False)
|
318 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=False, visible=False)
|
319 |
+
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
320 |
+
|
321 |
+
if not embed_mode:
|
322 |
+
gr.Markdown(tos_markdown)
|
323 |
+
gr.Markdown(learn_more_markdown)
|
324 |
+
url_params = gr.JSON(visible=False)
|
325 |
+
|
326 |
+
# Register listeners
|
327 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
|
328 |
+
upvote_btn.click(
|
329 |
+
upvote_last_response,
|
330 |
+
[state, model_selector],
|
331 |
+
[upvote_btn, downvote_btn, flag_btn],
|
332 |
+
queue=False
|
333 |
+
)
|
334 |
+
downvote_btn.click(
|
335 |
+
downvote_last_response,
|
336 |
+
[state, model_selector],
|
337 |
+
[upvote_btn, downvote_btn, flag_btn],
|
338 |
+
queue=False
|
339 |
+
)
|
340 |
+
flag_btn.click(
|
341 |
+
flag_last_response,
|
342 |
+
[state, model_selector],
|
343 |
+
[upvote_btn, downvote_btn, flag_btn],
|
344 |
+
queue=False
|
345 |
+
)
|
346 |
+
|
347 |
+
regenerate_btn.click(
|
348 |
+
regenerate,
|
349 |
+
[state, image_process_mode],
|
350 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
351 |
+
queue=False
|
352 |
+
).then(
|
353 |
+
http_bot,
|
354 |
+
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
355 |
+
[state, svg_code, svg_render] + btn_list,
|
356 |
+
concurrency_limit=concurrency_count
|
357 |
+
)
|
358 |
+
|
359 |
+
submit_btn.click(
|
360 |
+
send_image,
|
361 |
+
[state, imagebox, image_process_mode],
|
362 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
363 |
+
queue=False
|
364 |
+
).then(
|
365 |
+
http_bot,
|
366 |
+
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
367 |
+
[state, svg_code, svg_render] + btn_list,
|
368 |
+
concurrency_limit=concurrency_count
|
369 |
+
)
|
370 |
+
|
371 |
+
clear_btn.click(
|
372 |
+
clear_history,
|
373 |
+
None,
|
374 |
+
[state, svg_code, svg_render] + btn_list,
|
375 |
+
queue=False
|
376 |
+
)
|
377 |
+
|
378 |
+
stop_btn.click(
|
379 |
+
stop_sampling,
|
380 |
+
[state, imagebox],
|
381 |
+
[state, imagebox] + btn_list,
|
382 |
+
queue=False
|
383 |
+
).then(
|
384 |
+
clear_history,
|
385 |
+
None,
|
386 |
+
[state, svg_code, svg_render] + btn_list,
|
387 |
+
queue=False
|
388 |
+
)
|
389 |
+
|
390 |
+
if args.model_list_mode == "once":
|
391 |
+
demo.load(
|
392 |
+
load_demo,
|
393 |
+
[url_params],
|
394 |
+
[state, model_selector],
|
395 |
+
_js=get_window_url_params,
|
396 |
+
)
|
397 |
+
elif args.model_list_mode == "reload":
|
398 |
+
demo.load(
|
399 |
+
load_demo_refresh_model_list,
|
400 |
+
None,
|
401 |
+
[state, model_selector],
|
402 |
+
queue=False
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
406 |
+
|
407 |
+
return demo
|
408 |
+
|
409 |
+
if __name__ == "__main__":
|
410 |
+
parser = argparse.ArgumentParser()
|
411 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
412 |
+
parser.add_argument("--port", type=int)
|
413 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
414 |
+
parser.add_argument("--concurrency-count", type=int, default=15)
|
415 |
+
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
|
416 |
+
parser.add_argument("--share", action="store_true")
|
417 |
+
parser.add_argument("--moderate", action="store_true")
|
418 |
+
parser.add_argument("--embed", action="store_true")
|
419 |
+
args = parser.parse_args()
|
420 |
+
logger.info(f"args: {args}")
|
421 |
+
|
422 |
+
models = get_model_list()
|
423 |
+
|
424 |
+
logger.info(args)
|
425 |
+
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
|
426 |
+
demo.queue(
|
427 |
+
api_open=False
|
428 |
+
).launch(
|
429 |
+
server_name=args.host,
|
430 |
+
server_port=args.port,
|
431 |
+
share=args.share
|
432 |
+
)
|
starvector/serve/gradio_web_server.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
from starvector.serve.conversation import default_conversation
|
9 |
+
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
+
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
+
|
12 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
+
headers = {"User-Agent": "StarVector Client"}
|
14 |
+
|
15 |
+
no_change_btn = gr.Button.update()
|
16 |
+
enable_btn = gr.Button.update(interactive=True)
|
17 |
+
disable_btn = gr.Button.update(interactive=False)
|
18 |
+
|
19 |
+
priority = {
|
20 |
+
"starvector-1b-im2svg": "aaaaaaa",
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_conv_log_filename():
|
24 |
+
t = datetime.datetime.now()
|
25 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
+
return name
|
27 |
+
|
28 |
+
def get_model_list():
|
29 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
+
assert ret.status_code == 200
|
31 |
+
ret = requests.post(args.controller_url + "/list_models")
|
32 |
+
models = ret.json()["models"]
|
33 |
+
models.sort(key=lambda x: priority.get(x, x))
|
34 |
+
logger.info(f"Models: {models}")
|
35 |
+
return models
|
36 |
+
|
37 |
+
def load_demo(url_params, request: gr.Request):
|
38 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
39 |
+
|
40 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
41 |
+
if "model" in url_params:
|
42 |
+
model = url_params["model"]
|
43 |
+
if model in models:
|
44 |
+
dropdown_update = gr.Dropdown.update(
|
45 |
+
value=model, visible=True)
|
46 |
+
|
47 |
+
state = default_conversation.copy()
|
48 |
+
return state, dropdown_update
|
49 |
+
|
50 |
+
mapping_model_task = {
|
51 |
+
'Image2SVG': 'im2svg',
|
52 |
+
'Text2SVG': 'text2svg'
|
53 |
+
}
|
54 |
+
|
55 |
+
def get_models_dropdown_from_task(task):
|
56 |
+
models = get_model_list()
|
57 |
+
models = [model for model in models if mapping_model_task[task] in model]
|
58 |
+
dropdown_update = gr.Dropdown.update(
|
59 |
+
choices=models,
|
60 |
+
value=models[0] if len(models) > 0 else ""
|
61 |
+
)
|
62 |
+
return dropdown_update
|
63 |
+
|
64 |
+
|
65 |
+
def load_demo_refresh_model_list(task, request: gr.Request):
|
66 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
67 |
+
dropdown_update = get_models_dropdown_from_task(task)
|
68 |
+
state = default_conversation.copy()
|
69 |
+
return state, dropdown_update
|
70 |
+
|
71 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
72 |
+
with open(get_conv_log_filename(), "a") as fout:
|
73 |
+
data = {
|
74 |
+
"tstamp": round(time.time(), 4),
|
75 |
+
"type": vote_type,
|
76 |
+
"model": model_selector,
|
77 |
+
"state": state.dict(),
|
78 |
+
"ip": request.client.host,
|
79 |
+
}
|
80 |
+
fout.write(json.dumps(data) + "\n")
|
81 |
+
|
82 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
83 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
84 |
+
vote_last_response(state, "upvote", model_selector, request)
|
85 |
+
return ("",) + (disable_btn,) * 7
|
86 |
+
|
87 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
88 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
89 |
+
vote_last_response(state, "downvote", model_selector, request)
|
90 |
+
return ("",) + (disable_btn,) * 7
|
91 |
+
|
92 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
93 |
+
logger.info(f"flag. ip: {request.client.host}")
|
94 |
+
vote_last_response(state, "flag", model_selector, request)
|
95 |
+
return ("",) + (disable_btn,) * 7
|
96 |
+
|
97 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
98 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
99 |
+
state.messages[-1][-1] = None
|
100 |
+
prev_human_msg = state.messages[-2]
|
101 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
102 |
+
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
103 |
+
state.skip_next = False
|
104 |
+
return (state, None, None, None) + (disable_btn,) * 7
|
105 |
+
|
106 |
+
def clear_history(request: gr.Request):
|
107 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
108 |
+
state = default_conversation.copy()
|
109 |
+
return (state, None, None) + (disable_btn,) * 7
|
110 |
+
|
111 |
+
def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
|
112 |
+
logger.info(f"send_data. ip: {request.client.host}.")
|
113 |
+
if task == 'Image2SVG':
|
114 |
+
if image is None:
|
115 |
+
state.skip_next = True
|
116 |
+
return (state, None, None, image) + (no_change_btn,) * 7
|
117 |
+
|
118 |
+
if image is not None:
|
119 |
+
image_message = (image, image_process_mode)
|
120 |
+
state.append_message(state.roles[0], image_message)
|
121 |
+
state.append_message(state.roles[1], "β")
|
122 |
+
state.skip_next = False
|
123 |
+
msg = state.to_gradio_svg_code()[0][1]
|
124 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
125 |
+
else:
|
126 |
+
if text_caption is None:
|
127 |
+
state.skip_next = True
|
128 |
+
return (state, None, None, image) + (no_change_btn,) * 7
|
129 |
+
|
130 |
+
state.append_message(state.roles[0], text_caption)
|
131 |
+
state.append_message(state.roles[1], "β")
|
132 |
+
state.skip_next = False
|
133 |
+
msg = state.to_gradio_svg_code()[0][1]
|
134 |
+
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
135 |
+
|
136 |
+
def download_files(state, request: gr.Request):
|
137 |
+
logger.info(f"download_files. ip: {request.client.host}")
|
138 |
+
svg_str, image = state.download_files()
|
139 |
+
|
140 |
+
# TODO: Figure out how to download the SVG in the users browser, idk how to do it now
|
141 |
+
|
142 |
+
def update_task(task):
|
143 |
+
dropdown_update = get_models_dropdown_from_task(task)
|
144 |
+
|
145 |
+
if task == "Text2SVG":
|
146 |
+
return 1.0, 0.9, 0.95, dropdown_update
|
147 |
+
else:
|
148 |
+
return 0.6, 0.9, 0.95, dropdown_update
|
149 |
+
|
150 |
+
|
151 |
+
def stop_sampling(state, image, request: gr.Request):
|
152 |
+
logger.info(f"stop_sampling. ip: {request.client.host}")
|
153 |
+
state.stop_sampling = True
|
154 |
+
return (state, None, None, image) + (disable_btn,) * 7
|
155 |
+
|
156 |
+
def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
157 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
+
start_tstamp = time.time()
|
159 |
+
model_name = model_selector
|
160 |
+
|
161 |
+
if state.skip_next:
|
162 |
+
# This generate call is skipped due to invalid inputs
|
163 |
+
yield (state, None, None) + (no_change_btn,) * 7
|
164 |
+
return
|
165 |
+
|
166 |
+
# Query worker address
|
167 |
+
controller_url = args.controller_url
|
168 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
169 |
+
json={"model": model_name})
|
170 |
+
worker_addr = ret.json()["address"]
|
171 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
172 |
+
|
173 |
+
# No available worker
|
174 |
+
if worker_addr == "":
|
175 |
+
state.messages[-1][-1] = server_error_msg
|
176 |
+
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
177 |
+
return
|
178 |
+
|
179 |
+
# Construct prompt
|
180 |
+
if task_selector == "Image2SVG":
|
181 |
+
prompt = state.get_image_prompt()
|
182 |
+
else:
|
183 |
+
prompt = text_caption
|
184 |
+
|
185 |
+
# Make requests
|
186 |
+
pload = {
|
187 |
+
"model": model_name,
|
188 |
+
"prompt": prompt,
|
189 |
+
"num_beams": int(num_beams),
|
190 |
+
"temperature": float(temperature),
|
191 |
+
"len_penalty": float(len_penalty),
|
192 |
+
"top_p": float(top_p),
|
193 |
+
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
194 |
+
}
|
195 |
+
logger.info(f"==== request ====\n{pload}")
|
196 |
+
|
197 |
+
pload['images'] = state.get_images()
|
198 |
+
|
199 |
+
state.messages[-1][-1] = "β"
|
200 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
201 |
+
|
202 |
+
try:
|
203 |
+
# Stream output
|
204 |
+
if state.stop_sampling:
|
205 |
+
state.messages[1][-1] = "β"
|
206 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
|
207 |
+
return
|
208 |
+
|
209 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
210 |
+
headers=headers, json=pload, stream=True, timeout=10)
|
211 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
212 |
+
if chunk:
|
213 |
+
data = json.loads(chunk.decode())
|
214 |
+
if data["error_code"] == 0:
|
215 |
+
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
216 |
+
output = data["text"].strip()
|
217 |
+
state.messages[-1][-1] = output + "β"
|
218 |
+
st = state.to_gradio_svg_code()
|
219 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
|
220 |
+
else:
|
221 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
222 |
+
state.messages[-1][-1] = output
|
223 |
+
|
224 |
+
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
225 |
+
return
|
226 |
+
time.sleep(0.03)
|
227 |
+
except requests.exceptions.RequestException as e:
|
228 |
+
state.messages[-1][-1] = server_error_msg
|
229 |
+
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
230 |
+
return
|
231 |
+
|
232 |
+
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
|
233 |
+
|
234 |
+
finish_tstamp = time.time()
|
235 |
+
logger.info(f"{output}")
|
236 |
+
|
237 |
+
with open(get_conv_log_filename(), "a") as fout:
|
238 |
+
data = {
|
239 |
+
"tstamp": round(finish_tstamp, 4),
|
240 |
+
"type": "chat",
|
241 |
+
"model": model_name,
|
242 |
+
"start": round(start_tstamp, 4),
|
243 |
+
"finish": round(finish_tstamp, 4),
|
244 |
+
"svg": state.messages[-1][-1],
|
245 |
+
"ip": request.client.host,
|
246 |
+
}
|
247 |
+
fout.write(json.dumps(data) + "\n")
|
248 |
+
|
249 |
+
title_markdown = ("""
|
250 |
+
# π« StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
251 |
+
|
252 |
+
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | π [[StarVector](https://arxiv.org/abs/2312.11556)]""")
|
253 |
+
|
254 |
+
sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
|
255 |
+
**Note**: The current model works on vector-like images like icons and or vector-like designs.""")
|
256 |
+
tos_markdown = ("""
|
257 |
+
### Terms of use
|
258 |
+
By using this service, users are required to agree to the following terms:
|
259 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
260 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
261 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
262 |
+
""")
|
263 |
+
|
264 |
+
learn_more_markdown = ("""
|
265 |
+
### License
|
266 |
+
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
267 |
+
""")
|
268 |
+
|
269 |
+
block_css = """
|
270 |
+
|
271 |
+
#buttons button {
|
272 |
+
min-width: min(120px,100%);
|
273 |
+
}
|
274 |
+
|
275 |
+
.gradio-container{
|
276 |
+
max-width: 1200px!important
|
277 |
+
}
|
278 |
+
|
279 |
+
.ΝΌ1 .cm-content {
|
280 |
+
white-space: unset !important;
|
281 |
+
flex-shrink: unset !important;
|
282 |
+
}
|
283 |
+
|
284 |
+
.ΝΌ2p .cm-scroller {
|
285 |
+
max-height: 200px;
|
286 |
+
overflow: scroll;
|
287 |
+
}
|
288 |
+
|
289 |
+
#svg_render{
|
290 |
+
padding: 20px !important;
|
291 |
+
}
|
292 |
+
|
293 |
+
#submit_btn{
|
294 |
+
max-height: 40px;
|
295 |
+
}
|
296 |
+
|
297 |
+
.selector{
|
298 |
+
max-height: 100px;
|
299 |
+
}
|
300 |
+
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
301 |
+
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
302 |
+
"""
|
303 |
+
def build_demo(embed_mode):
|
304 |
+
svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
|
305 |
+
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
|
306 |
+
|
307 |
+
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
308 |
+
state = gr.State()
|
309 |
+
if not embed_mode:
|
310 |
+
gr.Markdown(title_markdown)
|
311 |
+
gr.Markdown(sub_title_markdown)
|
312 |
+
with gr.Row():
|
313 |
+
with gr.Column(scale=4):
|
314 |
+
task_selector = gr.Dropdown(
|
315 |
+
choices=["Image2SVG", "Text2SVG"],
|
316 |
+
value="Image2SVG",
|
317 |
+
label="Task",
|
318 |
+
interactive=True,
|
319 |
+
show_label=True,
|
320 |
+
container=True,
|
321 |
+
elem_id="task_selector",
|
322 |
+
elem_classes=["selector"],
|
323 |
+
)
|
324 |
+
model_selector = gr.Dropdown(
|
325 |
+
choices=models,
|
326 |
+
value=models[0] if len(models) > 0 else "",
|
327 |
+
label="Model",
|
328 |
+
interactive=True,
|
329 |
+
show_label=True,
|
330 |
+
container=True,
|
331 |
+
elem_classes=["selector"],
|
332 |
+
)
|
333 |
+
|
334 |
+
imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
|
335 |
+
image_process_mode = gr.Radio(
|
336 |
+
["Resize", "Pad", "Default"],
|
337 |
+
value="Pad",
|
338 |
+
label="Preprocess for non-square image", visible=False)
|
339 |
+
|
340 |
+
# Text input
|
341 |
+
text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
|
342 |
+
|
343 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
344 |
+
gr.Examples(examples=[
|
345 |
+
[f"{cur_dir}/examples/sample-4.png"],
|
346 |
+
[f"{cur_dir}/examples/sample-7.png"],
|
347 |
+
[f"{cur_dir}/examples/sample-16.png"],
|
348 |
+
[f"{cur_dir}/examples/sample-17.png"],
|
349 |
+
[f"{cur_dir}/examples/sample-18.png"],
|
350 |
+
[f"{cur_dir}/examples/sample-0.png"],
|
351 |
+
[f"{cur_dir}/examples/sample-1.png"],
|
352 |
+
[f"{cur_dir}/examples/sample-6.png"],
|
353 |
+
], inputs=[imagebox], elem_id="examples")
|
354 |
+
|
355 |
+
submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
|
356 |
+
|
357 |
+
with gr.Accordion("Parameters", open=False):
|
358 |
+
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
359 |
+
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
|
360 |
+
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
361 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
|
362 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",)
|
363 |
+
|
364 |
+
with gr.Column(scale=9):
|
365 |
+
with gr.Row():
|
366 |
+
svg_code.render()
|
367 |
+
with gr.Row():
|
368 |
+
svg_render.render()
|
369 |
+
|
370 |
+
with gr.Row(elem_id="buttons") as button_row:
|
371 |
+
upvote_btn = gr.Button(value="π Upvote", interactive=False)
|
372 |
+
downvote_btn = gr.Button(value="π Downvote", interactive=False)
|
373 |
+
flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
|
374 |
+
stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False, visible=False)
|
375 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=False, visible=False)
|
376 |
+
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
|
377 |
+
download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
|
378 |
+
|
379 |
+
if not embed_mode:
|
380 |
+
gr.Markdown(tos_markdown)
|
381 |
+
gr.Markdown(learn_more_markdown)
|
382 |
+
url_params = gr.JSON(visible=False)
|
383 |
+
|
384 |
+
# Register listeners
|
385 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
|
386 |
+
upvote_btn.click(
|
387 |
+
upvote_last_response,
|
388 |
+
[state, model_selector],
|
389 |
+
[upvote_btn, downvote_btn, flag_btn],
|
390 |
+
queue=False
|
391 |
+
)
|
392 |
+
downvote_btn.click(
|
393 |
+
downvote_last_response,
|
394 |
+
[state, model_selector],
|
395 |
+
[upvote_btn, downvote_btn, flag_btn],
|
396 |
+
queue=False
|
397 |
+
)
|
398 |
+
flag_btn.click(
|
399 |
+
flag_last_response,
|
400 |
+
[state, model_selector],
|
401 |
+
[upvote_btn, downvote_btn, flag_btn],
|
402 |
+
queue=False
|
403 |
+
)
|
404 |
+
|
405 |
+
regenerate_btn.click(
|
406 |
+
regenerate,
|
407 |
+
[state, image_process_mode],
|
408 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
409 |
+
queue=False
|
410 |
+
).then(
|
411 |
+
http_bot,
|
412 |
+
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
413 |
+
[state, svg_code, svg_render] + btn_list)
|
414 |
+
|
415 |
+
submit_btn.click(
|
416 |
+
send_data,
|
417 |
+
[state, imagebox, image_process_mode, text_caption, task_selector],
|
418 |
+
[state, svg_code, svg_render, imagebox] + btn_list,
|
419 |
+
queue=False
|
420 |
+
).then(
|
421 |
+
http_bot,
|
422 |
+
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
423 |
+
[state, svg_code, svg_render] + btn_list
|
424 |
+
)
|
425 |
+
|
426 |
+
clear_btn.click(
|
427 |
+
clear_history,
|
428 |
+
None,
|
429 |
+
[state, svg_code, svg_render] + btn_list,
|
430 |
+
queue=False
|
431 |
+
)
|
432 |
+
|
433 |
+
stop_btn.click(
|
434 |
+
stop_sampling,
|
435 |
+
[state, imagebox],
|
436 |
+
[state, imagebox] + btn_list,
|
437 |
+
queue=False
|
438 |
+
).then(
|
439 |
+
clear_history,
|
440 |
+
None,
|
441 |
+
[state, svg_code, svg_render] + btn_list,
|
442 |
+
queue=False
|
443 |
+
)
|
444 |
+
|
445 |
+
download_btn.click(
|
446 |
+
download_files,
|
447 |
+
[state],
|
448 |
+
None,
|
449 |
+
queue=False
|
450 |
+
)
|
451 |
+
task_selector.change(
|
452 |
+
update_task,
|
453 |
+
inputs=[task_selector],
|
454 |
+
outputs=[len_penalty, temperature, top_p, model_selector],
|
455 |
+
queue=False,
|
456 |
+
_js="""
|
457 |
+
function(task) {
|
458 |
+
var imageBoxElement = document.getElementById("imagebox");
|
459 |
+
var textCaptionElement = document.getElementById("text_caption");
|
460 |
+
var examplesElement = document.getElementById("examples");
|
461 |
+
if (task === "Text2SVG") {
|
462 |
+
imageBoxElement.style.display = "none";
|
463 |
+
textCaptionElement.style.display = "block";
|
464 |
+
examplesElement.style.display = "none";
|
465 |
+
} else if (task === "Image2SVG") {
|
466 |
+
imageBoxElement.style.display = "block";
|
467 |
+
textCaptionElement.style.display = "none";
|
468 |
+
examplesElement.style.display = "block";
|
469 |
+
}
|
470 |
+
return task;
|
471 |
+
}
|
472 |
+
"""
|
473 |
+
)
|
474 |
+
|
475 |
+
if args.model_list_mode == "once":
|
476 |
+
demo.load(
|
477 |
+
load_demo,
|
478 |
+
[url_params, task_selector],
|
479 |
+
[state, model_selector],
|
480 |
+
_js="""
|
481 |
+
function() {
|
482 |
+
const params = new URLSearchParams(window.location.search);
|
483 |
+
url_params = Object.fromEntries(params);
|
484 |
+
console.log(url_params);
|
485 |
+
return url_params;
|
486 |
+
|
487 |
+
}
|
488 |
+
""",
|
489 |
+
queue=False
|
490 |
+
)
|
491 |
+
elif args.model_list_mode == "reload":
|
492 |
+
demo.load(
|
493 |
+
load_demo_refresh_model_list,
|
494 |
+
[task_selector],
|
495 |
+
[state, model_selector],
|
496 |
+
_js="""
|
497 |
+
function(task) {
|
498 |
+
var textCaptionElement = document.getElementById("text_caption");
|
499 |
+
var autoScrollBottom = true;
|
500 |
+
textCaptionElement.style.display = "none";
|
501 |
+
function updateScroll(){
|
502 |
+
if (autoScrollBottom) {
|
503 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
504 |
+
element.scrollTop = element.scrollHeight;
|
505 |
+
}
|
506 |
+
}
|
507 |
+
function handleScroll() {
|
508 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
509 |
+
//if (element.scrollHeight - element.scrollTop === element.clientHeight) {
|
510 |
+
if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
|
511 |
+
// User has scrolled to the bottom, enable auto-scrolling
|
512 |
+
autoScrollBottom = true;
|
513 |
+
console.log("bottom");
|
514 |
+
} else {
|
515 |
+
console.log("not bottom");
|
516 |
+
// User has scrolled away from the bottom, disable auto-scrolling
|
517 |
+
autoScrollBottom = false;
|
518 |
+
}
|
519 |
+
}
|
520 |
+
setInterval(updateScroll,500);
|
521 |
+
var element = document.getElementsByClassName("cm-scroller")[0];
|
522 |
+
element.addEventListener("scroll", handleScroll);
|
523 |
+
|
524 |
+
return task;
|
525 |
+
}
|
526 |
+
|
527 |
+
""",
|
528 |
+
queue=False,
|
529 |
+
)
|
530 |
+
|
531 |
+
else:
|
532 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
533 |
+
|
534 |
+
return demo
|
535 |
+
|
536 |
+
if __name__ == "__main__":
|
537 |
+
|
538 |
+
parser = argparse.ArgumentParser()
|
539 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
540 |
+
parser.add_argument("--port", type=int)
|
541 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
542 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
543 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
544 |
+
choices=["once", "reload"])
|
545 |
+
parser.add_argument("--share", action="store_true")
|
546 |
+
parser.add_argument("--moderate", action="store_true")
|
547 |
+
parser.add_argument("--embed", action="store_true")
|
548 |
+
args = parser.parse_args()
|
549 |
+
logger.info(f"args: {args}")
|
550 |
+
|
551 |
+
models = get_model_list()
|
552 |
+
|
553 |
+
logger.info(args)
|
554 |
+
demo = build_demo(args.embed)
|
555 |
+
demo.queue(
|
556 |
+
concurrency_count=args.concurrency_count,
|
557 |
+
api_open=False
|
558 |
+
).launch(
|
559 |
+
server_name=args.host,
|
560 |
+
server_port=args.port,
|
561 |
+
share=args.share
|
562 |
+
)
|