Jinglong Xiong commited on
Commit
6642f4e
Β·
1 Parent(s): bb5a422

add models

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. dl.py +203 -0
  3. gen_image.py +2 -2
  4. ml.py +246 -0
  5. naive.py +246 -0
  6. requirements.txt +6 -0
  7. starter.ipynb +0 -333
  8. starvector/__init__.py +0 -0
  9. starvector/adapter.py +53 -0
  10. starvector/clip_model.py +191 -0
  11. starvector/data/augmentation.py +250 -0
  12. starvector/data/base.py +71 -0
  13. starvector/data/dataset.py +42 -0
  14. starvector/data/emojisvg.py +27 -0
  15. starvector/data/figrsvg.py +27 -0
  16. starvector/data/fontsvg.py +28 -0
  17. starvector/data/iconsvg.py +38 -0
  18. starvector/data/stacksvg.py +59 -0
  19. starvector/data/util.py +389 -0
  20. starvector/image_encoder.py +119 -0
  21. starvector/metrics/base_metric.py +51 -0
  22. starvector/metrics/compute_LPIPS.py +56 -0
  23. starvector/metrics/compute_SSIM.py +35 -0
  24. starvector/metrics/compute_clip_score.py +55 -0
  25. starvector/metrics/compute_dino_score.py +55 -0
  26. starvector/metrics/compute_fid.py +145 -0
  27. starvector/metrics/compute_l2.py +37 -0
  28. starvector/metrics/count_token_length.py +54 -0
  29. starvector/metrics/inception.py +341 -0
  30. starvector/metrics/metrics.py +127 -0
  31. starvector/metrics/util.py +20 -0
  32. starvector/model/adapters/adapter.py +53 -0
  33. starvector/model/builder.py +49 -0
  34. starvector/model/gpt_bigcode/__init__.py +65 -0
  35. starvector/model/gpt_bigcode/configuration_gpt_bigcode.py +143 -0
  36. starvector/model/gpt_bigcode/modeling_gpt_bigcode.py +1502 -0
  37. starvector/model/image_encoder/clip_model.py +191 -0
  38. starvector/model/image_encoder/image_encoder.py +120 -0
  39. starvector/model/llm/starcoder.py +51 -0
  40. starvector/model/llm/starcoder2.py +61 -0
  41. starvector/model/models/starvector_base.py +339 -0
  42. starvector/model/models/starvector_v1.py +22 -0
  43. starvector/model/models/starvector_v2.py +63 -0
  44. starvector/model/starvector_arch.py +194 -0
  45. starvector/serve/__init__.py +0 -0
  46. starvector/serve/constants.py +16 -0
  47. starvector/serve/controller.py +293 -0
  48. starvector/serve/conversation.py +211 -0
  49. starvector/serve/gradio_demo_with_updated_gradio.py +432 -0
  50. starvector/serve/gradio_web_server.py +562 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  star-vector/
2
  SVGDreamer/
3
  *.parquet
 
1
+ unsloth_compiled_cache/
2
+ *.ipynb
3
  star-vector/
4
  SVGDreamer/
5
  *.parquet
dl.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import cairosvg
4
+ import torch
5
+ from transformers import AutoModelForCausalLM
6
+ from lxml import etree
7
+ import kagglehub
8
+ from gen_image import ImageGenerator
9
+ from starvector.data.util import process_and_rasterize_svg
10
+
11
+ svg_constraints = kagglehub.package_import('metric/svg-constraints')
12
+
13
+ class DLModel:
14
+ def __init__(self, model_id="starvector/starvector-8b-im2svg", device="cuda"):
15
+ """
16
+ Initialize the SVG generation pipeline using StarVector.
17
+
18
+ Args:
19
+ model_id (str): The model identifier for the StarVector model.
20
+ device (str): The device to run the model on, either "cuda" or "cpu".
21
+ """
22
+ self.image_generator = ImageGenerator(model_id="stabilityai/stable-diffusion-2-1-base", device=device)
23
+ self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
24
+ self.constraints = svg_constraints.SVGConstraints()
25
+ self.timeout_seconds = 90
26
+
27
+ # Load StarVector model
28
+ self.device = device
29
+ self.starvector = AutoModelForCausalLM.from_pretrained(
30
+ model_id,
31
+ torch_dtype=torch.float16,
32
+ trust_remote_code=True
33
+ )
34
+ self.processor = self.starvector.model.processor
35
+ self.starvector.to(device)
36
+ self.starvector.eval()
37
+
38
+ def predict(self, description):
39
+ """
40
+ Generate an SVG from a text description.
41
+
42
+ Args:
43
+ description (str): The text description to generate an image from.
44
+
45
+ Returns:
46
+ str: The generated SVG content.
47
+ """
48
+ try:
49
+ # Step 1: Generate image using diffusion model
50
+ images = self.image_generator.generate(description)
51
+ image = images[0]
52
+
53
+ # Save the generated image
54
+ image_path = "diff_image.png"
55
+ image.save(image_path)
56
+ logging.info(f"Intermediate image saved to {image_path}")
57
+
58
+ # Step 2: Convert image to SVG using StarVector
59
+ processed_image = self.processor(image, return_tensors="pt")['pixel_values'].to(self.device)
60
+ if not processed_image.shape[0] == 1:
61
+ processed_image = processed_image.squeeze(0)
62
+
63
+ batch = {"image": processed_image}
64
+ with torch.no_grad():
65
+ raw_svg = self.starvector.generate_im2svg(batch, max_length=4000)[0]
66
+ raw_svg, _ = process_and_rasterize_svg(raw_svg)
67
+
68
+ if 'viewBox' not in raw_svg:
69
+ raw_svg = raw_svg.replace('<svg', f'<svg viewBox="0 0 384 384"')
70
+
71
+ # Step 3: Enforce constraints
72
+ svg_content = self.enforce_constraints(raw_svg)
73
+
74
+ return svg_content
75
+ except Exception as e:
76
+ logging.error(f"Error generating SVG: {e}")
77
+ return self.default_svg
78
+
79
+ def enforce_constraints(self, svg_string: str) -> str:
80
+ """Enforces constraints on an SVG string, removing disallowed elements
81
+ and attributes.
82
+
83
+ Parameters
84
+ ----------
85
+ svg_string : str
86
+ The SVG string to process.
87
+
88
+ Returns
89
+ -------
90
+ str
91
+ The processed SVG string, or the default SVG if constraints
92
+ cannot be satisfied.
93
+ """
94
+ logging.info('Sanitizing SVG...')
95
+
96
+ try:
97
+ # Remove XML declaration if it exists
98
+ svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip()
99
+
100
+ parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
101
+ root = etree.fromstring(svg_string, parser=parser)
102
+ except etree.ParseError as e:
103
+ logging.error('SVG Parse Error: %s. Returning default SVG.', e)
104
+ logging.error('SVG string: %s', svg_string)
105
+ return self.default_svg
106
+
107
+ elements_to_remove = []
108
+ for element in root.iter():
109
+ tag_name = etree.QName(element.tag).localname
110
+
111
+ # Remove disallowed elements
112
+ if tag_name not in self.constraints.allowed_elements:
113
+ elements_to_remove.append(element)
114
+ continue # Skip attribute checks for removed elements
115
+
116
+ # Remove disallowed attributes
117
+ attrs_to_remove = []
118
+ for attr in element.attrib:
119
+ attr_name = etree.QName(attr).localname
120
+ if (
121
+ attr_name
122
+ not in self.constraints.allowed_elements[tag_name]
123
+ and attr_name
124
+ not in self.constraints.allowed_elements['common']
125
+ ):
126
+ attrs_to_remove.append(attr)
127
+
128
+ for attr in attrs_to_remove:
129
+ logging.debug(
130
+ 'Attribute "%s" for element "%s" not allowed. Removing.',
131
+ attr,
132
+ tag_name,
133
+ )
134
+ del element.attrib[attr]
135
+
136
+ # Check and remove invalid href attributes
137
+ for attr, value in element.attrib.items():
138
+ if etree.QName(attr).localname == 'href' and not value.startswith('#'):
139
+ logging.debug(
140
+ 'Removing invalid href attribute in element "%s".', tag_name
141
+ )
142
+ del element.attrib[attr]
143
+
144
+ # Validate path elements to help ensure SVG conversion
145
+ if tag_name == 'path':
146
+ d_attribute = element.get('d')
147
+ if not d_attribute:
148
+ logging.warning('Path element is missing "d" attribute. Removing path.')
149
+ elements_to_remove.append(element)
150
+ continue # Skip further checks for this removed element
151
+ # Use regex to validate 'd' attribute format
152
+ path_regex = re.compile(
153
+ r'^' # Start of string
154
+ r'(?:' # Non-capturing group for each command + numbers block
155
+ r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
156
+ r'\s*' # Optional whitespace after command
157
+ r'(?:' # Non-capturing group for optional numbers
158
+ r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
159
+ r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
160
+ r')?' # Numbers are optional (e.g. for Z command)
161
+ r'\s*' # Optional whitespace after numbers/command block
162
+ r')+' # One or more command blocks
163
+ r'\s*' # Optional trailing whitespace
164
+ r'$' # End of string
165
+ )
166
+ if not path_regex.match(d_attribute):
167
+ logging.warning(
168
+ 'Path element has malformed "d" attribute format. Removing path.'
169
+ )
170
+ elements_to_remove.append(element)
171
+ continue
172
+ logging.debug('Path element "d" attribute validated (regex check).')
173
+
174
+ # Remove elements marked for removal
175
+ for element in elements_to_remove:
176
+ if element.getparent() is not None:
177
+ element.getparent().remove(element)
178
+ logging.debug('Removed element: %s', element.tag)
179
+
180
+ try:
181
+ cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False)
182
+ return cleaned_svg_string
183
+ except ValueError as e:
184
+ logging.error(
185
+ 'SVG could not be sanitized to meet constraints: %s', e
186
+ )
187
+ return self.default_svg
188
+
189
+ # Example usage
190
+ if __name__ == "__main__":
191
+ model = DLModel()
192
+ svg = model.predict("a purple forest at dusk")
193
+ # Convert SVG to PNG
194
+ try:
195
+ # Create a PNG in memory
196
+ png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
197
+
198
+ # Save the PNG to a file
199
+ with open("output.png", "wb") as f:
200
+ f.write(png_data)
201
+ print("SVG saved as output.png")
202
+ except Exception as e:
203
+ print(f"Error converting SVG to PNG: {e}")
gen_image.py CHANGED
@@ -35,7 +35,7 @@ class ImageGenerator:
35
  num_images (int, optional): Number of images to generate.
36
 
37
  Returns:
38
- PIL.Image.Image: The generated image.
39
  """
40
  prompt = f"{prompt}, {self.positive_prompt}"
41
  if negative_prompt is None:
@@ -51,7 +51,7 @@ class ImageGenerator:
51
  for i, image in enumerate(images):
52
  image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}")
53
 
54
- return image
55
 
56
  # Example usage
57
  if __name__ == "__main__":
 
35
  num_images (int, optional): Number of images to generate.
36
 
37
  Returns:
38
+ list[PIL.Image.Image]: The generated images.
39
  """
40
  prompt = f"{prompt}, {self.positive_prompt}"
41
  if negative_prompt is None:
 
51
  for i, image in enumerate(images):
52
  image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}")
53
 
54
+ return images
55
 
56
  # Example usage
57
  if __name__ == "__main__":
ml.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import logging
4
+ import re
5
+ import subprocess
6
+ import cairosvg
7
+ from lxml import etree
8
+ import kagglehub
9
+ from gen_image import ImageGenerator
10
+ import vtracer
11
+
12
+ svg_constraints = kagglehub.package_import('metric/svg-constraints')
13
+
14
+ class MLModel:
15
+ def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"):
16
+ """
17
+ Initialize the SVG generation pipeline.
18
+
19
+ Args:
20
+ model_id (str): The model identifier for the stable diffusion model.
21
+ device (str): The device to run the model on, either "cuda" or "cpu".
22
+ """
23
+ self.image_generator = ImageGenerator(model_id=model_id, device=device)
24
+ self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
25
+ self.constraints = svg_constraints.SVGConstraints()
26
+ self.timeout_seconds = 90
27
+
28
+ def predict(self, description, simplify=True, color_precision=6,
29
+ gradient_step=10, filter_speckle=4, path_precision=8):
30
+ """
31
+ Generate an SVG from a text description.
32
+
33
+ Args:
34
+ description (str): The text description to generate an image from.
35
+ simplify (bool): Whether to simplify the SVG paths.
36
+ color_precision (int): Color quantization precision.
37
+ gradient_step (int): Gradient step for color quantization (not used by vtracer).
38
+ filter_speckle (int): Filter speckle size.
39
+ path_precision (int): Path fitting precision.
40
+
41
+ Returns:
42
+ str: The generated SVG content.
43
+ """
44
+ try:
45
+ # Step 1: Generate image using diffusion model
46
+ images = self.image_generator.generate(description)
47
+ image = images[0]
48
+
49
+ # Step 2: Save image to a temporary file
50
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
51
+ temp_img_path = temp_img.name
52
+ image.save(temp_img_path)
53
+
54
+ # Step 3: Convert image to SVG using vtracer
55
+ with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
56
+ temp_svg_path = temp_svg.name
57
+
58
+ # Process the image with vtracer using parameters directly
59
+ vtracer.convert_image_to_svg_py(
60
+ temp_img_path,
61
+ temp_svg_path,
62
+ colormode='color',
63
+ hierarchical='stacked' if simplify else 'cutout',
64
+ mode='spline',
65
+ filter_speckle=filter_speckle,
66
+ color_precision=color_precision,
67
+ path_precision=path_precision,
68
+ corner_threshold=60,
69
+ length_threshold=4.0,
70
+ max_iterations=10,
71
+ splice_threshold=45
72
+ )
73
+
74
+ # Step 4: Read the generated SVG
75
+ with open(temp_svg_path, 'r') as f:
76
+ svg_content = f.read()
77
+
78
+ # Clean up temporary files
79
+ os.unlink(temp_img_path)
80
+ os.unlink(temp_svg_path)
81
+
82
+ # Step 5: Enforce constraints
83
+ svg_content = self.enforce_constraints(svg_content)
84
+
85
+ return svg_content
86
+ except Exception as e:
87
+ logging.error(f"Error generating SVG: {e}")
88
+ return self.default_svg
89
+
90
+ def enforce_constraints(self, svg_string: str) -> str:
91
+ """Enforces constraints on an SVG string, removing disallowed elements
92
+ and attributes.
93
+
94
+ Parameters
95
+ ----------
96
+ svg_string : str
97
+ The SVG string to process.
98
+
99
+ Returns
100
+ -------
101
+ str
102
+ The processed SVG string, or the default SVG if constraints
103
+ cannot be satisfied.
104
+ """
105
+ logging.info('Sanitizing SVG...')
106
+
107
+ try:
108
+ # Remove XML declaration if it exists
109
+ svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip()
110
+
111
+ parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
112
+ root = etree.fromstring(svg_string, parser=parser)
113
+ except etree.ParseError as e:
114
+ logging.error('SVG Parse Error: %s. Returning default SVG.', e)
115
+ logging.error('SVG string: %s', svg_string)
116
+ return self.default_svg
117
+
118
+ elements_to_remove = []
119
+ for element in root.iter():
120
+ tag_name = etree.QName(element.tag).localname
121
+
122
+ # Remove disallowed elements
123
+ if tag_name not in self.constraints.allowed_elements:
124
+ elements_to_remove.append(element)
125
+ continue # Skip attribute checks for removed elements
126
+
127
+ # Remove disallowed attributes
128
+ attrs_to_remove = []
129
+ for attr in element.attrib:
130
+ attr_name = etree.QName(attr).localname
131
+ if (
132
+ attr_name
133
+ not in self.constraints.allowed_elements[tag_name]
134
+ and attr_name
135
+ not in self.constraints.allowed_elements['common']
136
+ ):
137
+ attrs_to_remove.append(attr)
138
+
139
+ for attr in attrs_to_remove:
140
+ logging.debug(
141
+ 'Attribute "%s" for element "%s" not allowed. Removing.',
142
+ attr,
143
+ tag_name,
144
+ )
145
+ del element.attrib[attr]
146
+
147
+ # Check and remove invalid href attributes
148
+ for attr, value in element.attrib.items():
149
+ if etree.QName(attr).localname == 'href' and not value.startswith('#'):
150
+ logging.debug(
151
+ 'Removing invalid href attribute in element "%s".', tag_name
152
+ )
153
+ del element.attrib[attr]
154
+
155
+ # Validate path elements to help ensure SVG conversion
156
+ if tag_name == 'path':
157
+ d_attribute = element.get('d')
158
+ if not d_attribute:
159
+ logging.warning('Path element is missing "d" attribute. Removing path.')
160
+ elements_to_remove.append(element)
161
+ continue # Skip further checks for this removed element
162
+ # Use regex to validate 'd' attribute format
163
+ path_regex = re.compile(
164
+ r'^' # Start of string
165
+ r'(?:' # Non-capturing group for each command + numbers block
166
+ r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
167
+ r'\s*' # Optional whitespace after command
168
+ r'(?:' # Non-capturing group for optional numbers
169
+ r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
170
+ r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
171
+ r')?' # Numbers are optional (e.g. for Z command)
172
+ r'\s*' # Optional whitespace after numbers/command block
173
+ r')+' # One or more command blocks
174
+ r'\s*' # Optional trailing whitespace
175
+ r'$' # End of string
176
+ )
177
+ if not path_regex.match(d_attribute):
178
+ logging.warning(
179
+ 'Path element has malformed "d" attribute format. Removing path.'
180
+ )
181
+ elements_to_remove.append(element)
182
+ continue
183
+ logging.debug('Path element "d" attribute validated (regex check).')
184
+
185
+ # Remove elements marked for removal
186
+ for element in elements_to_remove:
187
+ if element.getparent() is not None:
188
+ element.getparent().remove(element)
189
+ logging.debug('Removed element: %s', element.tag)
190
+
191
+ try:
192
+ cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False)
193
+ return cleaned_svg_string
194
+ except ValueError as e:
195
+ logging.error(
196
+ 'SVG could not be sanitized to meet constraints: %s', e
197
+ )
198
+ return self.default_svg
199
+
200
+ def optimize_svg(self, svg_content):
201
+ """
202
+ Optimize the SVG content using SVGO.
203
+
204
+ Args:
205
+ svg_content (str): The SVG content to optimize.
206
+
207
+ Returns:
208
+ str: The optimized SVG content.
209
+ """
210
+ try:
211
+ with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg:
212
+ temp_svg_path = temp_svg.name
213
+ temp_svg.write(svg_content.encode('utf-8'))
214
+
215
+ with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_out:
216
+ temp_out_path = temp_out.name
217
+
218
+ subprocess.run(["svgo", temp_svg_path, "-o", temp_out_path], check=True)
219
+
220
+ with open(temp_out_path, 'r') as f:
221
+ optimized_svg = f.read()
222
+
223
+ os.unlink(temp_svg_path)
224
+ os.unlink(temp_out_path)
225
+
226
+ return optimized_svg
227
+ except (FileNotFoundError, subprocess.CalledProcessError):
228
+ print("Warning: SVGO not found or failed. Returning unoptimized SVG.")
229
+ return svg_content
230
+
231
+
232
+ # Example usage
233
+ if __name__ == "__main__":
234
+ model = MLModel()
235
+ svg = model.predict("a purple forest at dusk")
236
+ # Convert SVG to PNG
237
+ try:
238
+ # Create a PNG in memory
239
+ png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
240
+
241
+ # Save the PNG to a file
242
+ with open("output.png", "wb") as f:
243
+ f.write(png_data)
244
+ print("SVG saved as output.png")
245
+ except Exception as e:
246
+ print(f"Error converting SVG to PNG: {e}")
naive.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent
2
+ import io
3
+ import logging
4
+ import re
5
+
6
+ import cairosvg
7
+ import kagglehub
8
+ import torch
9
+ from lxml import etree
10
+ from unsloth import FastLanguageModel
11
+ from unsloth.chat_templates import get_chat_template
12
+
13
+ svg_constraints = kagglehub.package_import('metric/svg-constraints')
14
+
15
+ class NaiveModel:
16
+ def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"):
17
+ self.device = device
18
+ self.max_seq_length = max_seq_length
19
+ self.load_in_4bit = True
20
+
21
+ # Load the Unsloth Phi-4 model
22
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name=model_name,
24
+ max_seq_length=self.max_seq_length,
25
+ load_in_4bit=self.load_in_4bit
26
+ )
27
+
28
+ # Set up chat template
29
+ self.tokenizer = get_chat_template(
30
+ self.tokenizer,
31
+ chat_template="phi-4",
32
+ )
33
+
34
+ # Prepare model for inference
35
+ FastLanguageModel.for_inference(self.model)
36
+
37
+ self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints.
38
+ <constraints>
39
+ * **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
40
+ * **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
41
+ </constraints>
42
+
43
+ Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis.
44
+
45
+ <description>"A red circle with a blue square inside"</description>
46
+ ```svg
47
+ <svg viewBox="0 0 256 256" width="256" height="256">
48
+ <circle cx="50" cy="50" r="40" fill="red"/>
49
+ <rect x="30" y="30" width="40" height="40" fill="blue"/>
50
+ </svg>
51
+ ```
52
+
53
+ <description>"{}"</description>
54
+ """
55
+ self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
56
+ self.constraints = svg_constraints.SVGConstraints()
57
+ self.timeout_seconds = 90
58
+
59
+ def predict(self, description: str, max_new_tokens=512) -> str:
60
+ def generate_svg():
61
+ try:
62
+ # Format the prompt
63
+ prompt = self.prompt_template.format(description)
64
+
65
+ # Create messages in the format expected by the chat template
66
+ messages = [
67
+ {"role": "user", "content": prompt},
68
+ ]
69
+
70
+ # Tokenize the messages
71
+ inputs = self.tokenizer.apply_chat_template(
72
+ messages,
73
+ tokenize=True,
74
+ add_generation_prompt=True,
75
+ return_tensors="pt",
76
+ ).to(self.device)
77
+
78
+ # Generate the output
79
+ outputs = self.model.generate(
80
+ input_ids=inputs,
81
+ max_new_tokens=max_new_tokens,
82
+ use_cache=True,
83
+ temperature=1.0,
84
+ min_p=0.1,
85
+ do_sample=True,
86
+ )
87
+
88
+ # Decode the output
89
+ output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+
91
+ # Extract only the generated text (skip the prompt)
92
+ generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else ""
93
+
94
+ logging.debug('Output decoded from model: %s', output_decoded)
95
+
96
+ matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE)
97
+ if matches:
98
+ svg = matches[-1]
99
+ else:
100
+ return self.default_svg
101
+
102
+ logging.debug('Unprocessed SVG: %s', svg)
103
+ svg = self.enforce_constraints(svg)
104
+ logging.debug('Processed SVG: %s', svg)
105
+
106
+ # Ensure the generated code can be converted by cairosvg
107
+ cairosvg.svg2png(bytestring=svg.encode('utf-8'))
108
+ return svg
109
+ except Exception as e:
110
+ logging.error('Exception during SVG generation: %s', e)
111
+ return self.default_svg
112
+
113
+ # Execute SVG generation in a new thread to enforce time constraints
114
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
115
+ future = executor.submit(generate_svg)
116
+ try:
117
+ return future.result(timeout=self.timeout_seconds)
118
+ except concurrent.futures.TimeoutError:
119
+ logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
120
+ return self.default_svg
121
+ except Exception as e:
122
+ logging.error(f"An unexpected error occurred: {e}")
123
+ return self.default_svg
124
+
125
+ def enforce_constraints(self, svg_string: str) -> str:
126
+ """Enforces constraints on an SVG string, removing disallowed elements
127
+ and attributes.
128
+
129
+ Parameters
130
+ ----------
131
+ svg_string : str
132
+ The SVG string to process.
133
+
134
+ Returns
135
+ -------
136
+ str
137
+ The processed SVG string, or the default SVG if constraints
138
+ cannot be satisfied.
139
+ """
140
+ logging.info('Sanitizing SVG...')
141
+
142
+ try:
143
+ parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
144
+ root = etree.fromstring(svg_string, parser=parser)
145
+ except etree.ParseError as e:
146
+ logging.error('SVG Parse Error: %s. Returning default SVG.', e)
147
+ logging.error('SVG string: %s', svg_string)
148
+ return self.default_svg
149
+
150
+ elements_to_remove = []
151
+ for element in root.iter():
152
+ tag_name = etree.QName(element.tag).localname
153
+
154
+ # Remove disallowed elements
155
+ if tag_name not in self.constraints.allowed_elements:
156
+ elements_to_remove.append(element)
157
+ continue # Skip attribute checks for removed elements
158
+
159
+ # Remove disallowed attributes
160
+ attrs_to_remove = []
161
+ for attr in element.attrib:
162
+ attr_name = etree.QName(attr).localname
163
+ if (
164
+ attr_name
165
+ not in self.constraints.allowed_elements[tag_name]
166
+ and attr_name
167
+ not in self.constraints.allowed_elements['common']
168
+ ):
169
+ attrs_to_remove.append(attr)
170
+
171
+ for attr in attrs_to_remove:
172
+ logging.debug(
173
+ 'Attribute "%s" for element "%s" not allowed. Removing.',
174
+ attr,
175
+ tag_name,
176
+ )
177
+ del element.attrib[attr]
178
+
179
+ # Check and remove invalid href attributes
180
+ for attr, value in element.attrib.items():
181
+ if etree.QName(attr).localname == 'href' and not value.startswith('#'):
182
+ logging.debug(
183
+ 'Removing invalid href attribute in element "%s".', tag_name
184
+ )
185
+ del element.attrib[attr]
186
+
187
+ # Validate path elements to help ensure SVG conversion
188
+ if tag_name == 'path':
189
+ d_attribute = element.get('d')
190
+ if not d_attribute:
191
+ logging.warning('Path element is missing "d" attribute. Removing path.')
192
+ elements_to_remove.append(element)
193
+ continue # Skip further checks for this removed element
194
+ # Use regex to validate 'd' attribute format
195
+ path_regex = re.compile(
196
+ r'^' # Start of string
197
+ r'(?:' # Non-capturing group for each command + numbers block
198
+ r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters)
199
+ r'\s*' # Optional whitespace after command
200
+ r'(?:' # Non-capturing group for optional numbers
201
+ r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number
202
+ r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s)
203
+ r')?' # Numbers are optional (e.g. for Z command)
204
+ r'\s*' # Optional whitespace after numbers/command block
205
+ r')+' # One or more command blocks
206
+ r'\s*' # Optional trailing whitespace
207
+ r'$' # End of string
208
+ )
209
+ if not path_regex.match(d_attribute):
210
+ logging.warning(
211
+ 'Path element has malformed "d" attribute format. Removing path.'
212
+ )
213
+ elements_to_remove.append(element)
214
+ continue
215
+ logging.debug('Path element "d" attribute validated (regex check).')
216
+
217
+ # Remove elements marked for removal
218
+ for element in elements_to_remove:
219
+ if element.getparent() is not None:
220
+ element.getparent().remove(element)
221
+ logging.debug('Removed element: %s', element.tag)
222
+
223
+ try:
224
+ cleaned_svg_string = etree.tostring(root, encoding='unicode')
225
+ return cleaned_svg_string
226
+ except ValueError as e:
227
+ logging.error(
228
+ 'SVG could not be sanitized to meet constraints: %s', e
229
+ )
230
+ return self.default_svg
231
+
232
+
233
+ if __name__ == "__main__":
234
+ model = NaiveModel()
235
+ svg = model.predict("a purple forest at dusk")
236
+ # Convert SVG to PNG
237
+ try:
238
+ # Create a PNG in memory
239
+ png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
240
+
241
+ # Save the PNG to a file
242
+ with open("output.png", "wb") as f:
243
+ f.write(png_data)
244
+ print("SVG saved as output.png")
245
+ except Exception as e:
246
+ print(f"Error converting SVG to PNG: {e}")
requirements.txt CHANGED
@@ -19,6 +19,12 @@ dotenv
19
  diffusers
20
  safetensors
21
  xformers
 
 
 
 
 
 
22
 
23
  # pip install 'tensorflow[and-cuda]'
24
  # pip install git+https://github.com/openai/CLIP.git
 
19
  diffusers
20
  safetensors
21
  xformers
22
+ unsloth
23
+ tf-keras
24
+ vtracer
25
+ deepspeed
26
+ torch==2.5.1
27
+ torchvision==0.20.1
28
 
29
  # pip install 'tensorflow[and-cuda]'
30
  # pip install git+https://github.com/openai/CLIP.git
starter.ipynb DELETED
@@ -1,333 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stderr",
10
- "output_type": "stream",
11
- "text": [
12
- "/home/user/miniconda3/envs/dwl/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
- " from .autonotebook import tqdm as notebook_tqdm\n"
14
- ]
15
- },
16
- {
17
- "data": {
18
- "text/html": [
19
- "<div><style>\n",
20
- ".dataframe > thead > tr,\n",
21
- ".dataframe > tbody > tr {\n",
22
- " text-align: right;\n",
23
- " white-space: pre-wrap;\n",
24
- "}\n",
25
- "</style>\n",
26
- "<small>shape: (5, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>id</th><th>description</th></tr><tr><td>str</td><td>str</td></tr></thead><tbody><tr><td>&quot;02d892&quot;</td><td>&quot;a purple forest at dusk&quot;</td></tr><tr><td>&quot;0dcd2e&quot;</td><td>&quot;gray wool coat with a faux fur…</td></tr><tr><td>&quot;1e9ac1&quot;</td><td>&quot;a lighthouse overlooking the o…</td></tr><tr><td>&quot;2b25db&quot;</td><td>&quot;burgundy corduroy pants with p…</td></tr><tr><td>&quot;4e6a54&quot;</td><td>&quot;orange corduroy overalls&quot;</td></tr></tbody></table></div>"
27
- ],
28
- "text/plain": [
29
- "shape: (5, 2)\n",
30
- "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n",
31
- "β”‚ id ┆ description β”‚\n",
32
- "β”‚ --- ┆ --- β”‚\n",
33
- "β”‚ str ┆ str β”‚\n",
34
- "β•žβ•β•β•β•β•β•β•β•β•ͺ═════════════════════════════════║\n",
35
- "β”‚ 02d892 ┆ a purple forest at dusk β”‚\n",
36
- "β”‚ 0dcd2e ┆ gray wool coat with a faux fur… β”‚\n",
37
- "β”‚ 1e9ac1 ┆ a lighthouse overlooking the o… β”‚\n",
38
- "β”‚ 2b25db ┆ burgundy corduroy pants with p… β”‚\n",
39
- "β”‚ 4e6a54 ┆ orange corduroy overalls β”‚\n",
40
- "β””β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜"
41
- ]
42
- },
43
- "execution_count": 2,
44
- "metadata": {},
45
- "output_type": "execute_result"
46
- }
47
- ],
48
- "source": [
49
- "# We can load and explore the competition's train set to get a feel for the data.\n",
50
- "# We're not going to export this cell as it's not needed for our exported inferenceable model.\n",
51
- "\n",
52
- "import kagglehub\n",
53
- "import polars as pl\n",
54
- "\n",
55
- "train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')\n",
56
- "train = pl.read_csv(train_path)\n",
57
- "\n",
58
- "train.head()"
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 3,
64
- "metadata": {},
65
- "outputs": [],
66
- "source": [
67
- "class Model:\n",
68
- " def __init__(self):\n",
69
- " '''Optional constructor, performs any setup logic, model instantiation, etc.'''\n",
70
- " pass\n",
71
- " \n",
72
- " def predict(self, prompt: str) -> str:\n",
73
- " '''Generates SVG which produces an image described by the prompt.\n",
74
- "\n",
75
- " Args:\n",
76
- " prompt (str): A prompt describing an image\n",
77
- " Returns:\n",
78
- " String of valid SVG code.\n",
79
- " '''\n",
80
- " # Renders a simple circle regardless of input\n",
81
- " return '<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\" /></svg>'"
82
- ]
83
- },
84
- {
85
- "cell_type": "code",
86
- "execution_count": 4,
87
- "metadata": {},
88
- "outputs": [
89
- {
90
- "name": "stdout",
91
- "output_type": "stream",
92
- "text": [
93
- "<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\" /></svg>\n"
94
- ]
95
- },
96
- {
97
- "data": {
98
- "image/svg+xml": [
99
- "<svg width=\"100\" height=\"100\" viewBox=\"0 0 100 100\"><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\"/></svg>"
100
- ],
101
- "text/plain": [
102
- "<IPython.core.display.SVG object>"
103
- ]
104
- },
105
- "metadata": {},
106
- "output_type": "display_data"
107
- }
108
- ],
109
- "source": [
110
- "from IPython.display import SVG\n",
111
- "\n",
112
- "model = Model()\n",
113
- "svg = model.predict('a goose winning a gold medal')\n",
114
- "\n",
115
- "print(svg)\n",
116
- "display(SVG(svg))"
117
- ]
118
- },
119
- {
120
- "cell_type": "code",
121
- "execution_count": 6,
122
- "metadata": {},
123
- "outputs": [
124
- {
125
- "data": {
126
- "text/plain": [
127
- "['RN50',\n",
128
- " 'RN101',\n",
129
- " 'RN50x4',\n",
130
- " 'RN50x16',\n",
131
- " 'RN50x64',\n",
132
- " 'ViT-B/32',\n",
133
- " 'ViT-B/16',\n",
134
- " 'ViT-L/14',\n",
135
- " 'ViT-L/14@336px']"
136
- ]
137
- },
138
- "execution_count": 6,
139
- "metadata": {},
140
- "output_type": "execute_result"
141
- }
142
- ],
143
- "source": [
144
- "import clip\n",
145
- "clip.available_models()"
146
- ]
147
- },
148
- {
149
- "cell_type": "code",
150
- "execution_count": 7,
151
- "metadata": {},
152
- "outputs": [
153
- {
154
- "name": "stderr",
155
- "output_type": "stream",
156
- "text": [
157
- "2025-04-20 13:55:34.589770: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
158
- "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
159
- "E0000 00:00:1745171734.600777 13214 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
160
- "E0000 00:00:1745171734.603957 13214 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
161
- "W0000 00:00:1745171734.615566 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
162
- "W0000 00:00:1745171734.615584 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
163
- "W0000 00:00:1745171734.615585 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
164
- "W0000 00:00:1745171734.615586 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
165
- "2025-04-20 13:55:34.618659: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
166
- "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
167
- "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n",
168
- "Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:18<00:00, 4.68s/it]\n"
169
- ]
170
- }
171
- ],
172
- "source": [
173
- "import pandas as pd\n",
174
- "import importlib\n",
175
- "metric = importlib.import_module('metric')\n",
176
- "importlib.reload(metric)\n",
177
- "\n",
178
- "vqa_evaluator = metric.VQAEvaluator()\n",
179
- "aesthetic_evaluator = metric.AestheticEvaluator()"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": 11,
185
- "metadata": {},
186
- "outputs": [
187
- {
188
- "name": "stdout",
189
- "output_type": "stream",
190
- "text": [
191
- "VQA Score: 0.9996758976500401\n",
192
- "Aesthetic Score: 0.5749330520629883\n",
193
- "Final Fidelity Score: 0.8709845773271212\n"
194
- ]
195
- }
196
- ],
197
- "source": [
198
- "# score gpt4o generated images\n",
199
- "import ast\n",
200
- "import numpy as np\n",
201
- "from PIL import Image\n",
202
- "\n",
203
- "# Load the first sample from descriptions.csv\n",
204
- "descriptions_df = pd.read_csv('data/descriptions.csv')\n",
205
- "first_description = descriptions_df.iloc[1]\n",
206
- "\n",
207
- "eval_df = pd.read_csv('data/eval.csv')\n",
208
- "first_eval = eval_df.iloc[1]\n",
209
- "\n",
210
- "# Load the image\n",
211
- "image_path = 'data/gray_coat.png' # Assuming the image is saved with this name\n",
212
- "image = Image.open(image_path)\n",
213
- "\n",
214
- "# Prepare the inputs for scoring - need to parse the string representations\n",
215
- "questions = ast.literal_eval(first_eval['question'])\n",
216
- "choices = ast.literal_eval(first_eval['choices'])\n",
217
- "answers = ast.literal_eval(first_eval['answer'])\n",
218
- "\n",
219
- "# Calculate VQA score - don't wrap in additional lists\n",
220
- "vqa_score = vqa_evaluator.score(questions, choices, answers, image)\n",
221
- "\n",
222
- "# Calculate aesthetic score\n",
223
- "aesthetic_score = aesthetic_evaluator.score(image)\n",
224
- "\n",
225
- "# Apply image processing as done in the metric.score function\n",
226
- "image_processor = metric.ImageProcessor(image=image, seed=0).apply()\n",
227
- "processed_image = image_processor.image.copy()\n",
228
- "\n",
229
- "# Calculate final fidelity score\n",
230
- "instance_score = metric.harmonic_mean(vqa_score, aesthetic_score, beta=0.5)\n",
231
- "\n",
232
- "print(f\"VQA Score: {vqa_score}\")\n",
233
- "print(f\"Aesthetic Score: {aesthetic_score}\")\n",
234
- "print(f\"Final Fidelity Score: {instance_score}\")"
235
- ]
236
- },
237
- {
238
- "cell_type": "code",
239
- "execution_count": 13,
240
- "metadata": {},
241
- "outputs": [
242
- {
243
- "name": "stdout",
244
- "output_type": "stream",
245
- "text": [
246
- "No duplicate IDs found in data/descriptions.csv\n",
247
- "Sorted rows by ID\n",
248
- "Fixed and sorted CSV saved to data/descriptions.csv\n",
249
- "No duplicate IDs found in data/eval.csv\n",
250
- "Sorted data/eval.csv by ID\n"
251
- ]
252
- }
253
- ],
254
- "source": [
255
- "# Fix duplicate IDs in descriptions.csv and order rows by id\n",
256
- "def fix_duplicate_ids(csv_path):\n",
257
- " \"\"\"\n",
258
- " Fix duplicate IDs in a CSV file by assigning new unique IDs to duplicates.\n",
259
- " Then order rows by ID.\n",
260
- " \"\"\"\n",
261
- " # Read the CSV file\n",
262
- " df = pd.read_csv(csv_path)\n",
263
- " \n",
264
- " # Check for duplicate IDs\n",
265
- " duplicate_mask = df['id'].duplicated(keep='first')\n",
266
- " duplicate_count = duplicate_mask.sum()\n",
267
- " \n",
268
- " if duplicate_count > 0:\n",
269
- " print(f\"Found {duplicate_count} duplicate IDs in {csv_path}\")\n",
270
- " \n",
271
- " # Get the maximum ID value\n",
272
- " max_id = df['id'].max()\n",
273
- " \n",
274
- " # Assign new IDs to duplicates\n",
275
- " new_ids = list(range(max_id + 1, max_id + 1 + duplicate_count))\n",
276
- " df.loc[duplicate_mask, 'id'] = new_ids\n",
277
- " \n",
278
- " print(f\"Assigned new IDs to duplicates\")\n",
279
- " else:\n",
280
- " print(f\"No duplicate IDs found in {csv_path}\")\n",
281
- " \n",
282
- " # Sort the dataframe by ID\n",
283
- " df = df.sort_values(by='id')\n",
284
- " print(f\"Sorted rows by ID\")\n",
285
- " \n",
286
- " # Save the fixed and sorted CSV\n",
287
- " df.to_csv(csv_path, index=False)\n",
288
- " print(f\"Fixed and sorted CSV saved to {csv_path}\")\n",
289
- " \n",
290
- " # Return the fixed dataframe\n",
291
- " return df\n",
292
- "\n",
293
- "# Fix descriptions.csv\n",
294
- "fixed_descriptions_df = fix_duplicate_ids('data/descriptions.csv')\n",
295
- "\n",
296
- "# Fix eval.csv if needed\n",
297
- "# First check if eval.csv has the same issue\n",
298
- "eval_df = pd.read_csv('data/eval.csv')\n",
299
- "duplicate_eval_ids = eval_df['id'].duplicated(keep='first').sum()\n",
300
- "\n",
301
- "if duplicate_eval_ids > 0:\n",
302
- " fixed_eval_df = fix_duplicate_ids('data/eval.csv')\n",
303
- "else:\n",
304
- " print(\"No duplicate IDs found in data/eval.csv\")\n",
305
- " # Still sort by ID even if no duplicates\n",
306
- " eval_df = eval_df.sort_values(by='id')\n",
307
- " eval_df.to_csv('data/eval.csv', index=False)\n",
308
- " print(\"Sorted data/eval.csv by ID\")\n"
309
- ]
310
- }
311
- ],
312
- "metadata": {
313
- "kernelspec": {
314
- "display_name": "dwl",
315
- "language": "python",
316
- "name": "python3"
317
- },
318
- "language_info": {
319
- "codemirror_mode": {
320
- "name": "ipython",
321
- "version": 3
322
- },
323
- "file_extension": ".py",
324
- "mimetype": "text/x-python",
325
- "name": "python",
326
- "nbconvert_exporter": "python",
327
- "pygments_lexer": "ipython3",
328
- "version": "3.11.11"
329
- }
330
- },
331
- "nbformat": 4,
332
- "nbformat_minor": 2
333
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
starvector/__init__.py ADDED
File without changes
starvector/adapter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.init as init
3
+ import torch
4
+
5
+ class Swish(nn.Module):
6
+ def __init__(self):
7
+ super(Swish, self).__init__()
8
+
9
+ def forward(self, x):
10
+ return x * torch.sigmoid(x)
11
+
12
+ class Adapter(nn.Module):
13
+ def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1):
14
+ super().__init__()
15
+ self.query_length = query_length
16
+ self.dropout_prob = dropout_prob
17
+ self.adapter_norm = adapter_norm
18
+
19
+ self.dropout = nn.Dropout(p=self.dropout_prob)
20
+
21
+ self.c_fc = nn.Linear(input_size, input_size*2)
22
+ self.act = Swish()
23
+ self.c_proj = nn.Linear(input_size*2, output_size)
24
+
25
+ if adapter_norm == "layer_norm":
26
+ self.norm = nn.LayerNorm([self.query_length, output_size])
27
+ elif adapter_norm == "batch_norm":
28
+ self.norm = nn.BatchNorm1d(self.query_length)
29
+
30
+ self.init_type = init_type.lower()
31
+ self._initialize_weights()
32
+
33
+ def forward(self, hidden_states):
34
+ hidden_states = self.dropout(hidden_states)
35
+ hidden_states = self.c_fc(hidden_states)
36
+ hidden_states = self.act(hidden_states)
37
+ hidden_states = self.c_proj(hidden_states)
38
+ hidden_states = self.norm(hidden_states)
39
+ return hidden_states
40
+
41
+ def _initialize_weights(self):
42
+ for m in self.modules():
43
+ if isinstance(m, nn.Linear):
44
+ if self.init_type == "glorot":
45
+ init.xavier_uniform_(m.weight)
46
+ if m.bias is not None:
47
+ init.constant_(m.bias, 0)
48
+ elif self.init_type == "normal":
49
+ init.normal_(m.weight, mean=0, std=0.01)
50
+ if m.bias is not None:
51
+ init.constant_(m.bias, 0)
52
+ else:
53
+ raise ValueError("Invalid initialization type specified.")
starvector/clip_model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py
2
+
3
+ from collections import OrderedDict
4
+ from itertools import repeat
5
+ import collections.abc
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
11
+
12
+ def convert_weights_to_precision(model: nn.Module, precision: torch.dtype):
13
+ """Convert applicable model parameters to the specified precision"""
14
+
15
+ def _convert_weights_to_precision(l):
16
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
17
+ l.weight.data = l.weight.data.to(precision)
18
+ if l.bias is not None:
19
+ l.bias.data = l.bias.data.to(precision)
20
+
21
+ elif isinstance(l, (nn.MultiheadAttention)):
22
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
23
+ tensor = getattr(l, attr)
24
+ if tensor is not None:
25
+ tensor.data = tensor.data.to(precision)
26
+ else:
27
+ for _, p in l.named_parameters():
28
+ p.data = p.data.to(precision)
29
+
30
+ model.apply(_convert_weights_to_precision)
31
+
32
+ class Bottleneck(nn.Module):
33
+ expansion = 4
34
+
35
+ def __init__(self, inplanes, planes, stride=1):
36
+ super().__init__()
37
+
38
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
39
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
40
+ self.bn1 = nn.BatchNorm2d(planes)
41
+ self.relu1 = nn.ReLU(inplace=True)
42
+
43
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
44
+ self.bn2 = nn.BatchNorm2d(planes)
45
+ self.relu2 = nn.ReLU(inplace=True)
46
+
47
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
48
+
49
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
50
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
51
+ self.relu3 = nn.ReLU(inplace=True)
52
+
53
+ self.downsample = None
54
+ self.stride = stride
55
+
56
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
57
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
58
+ self.downsample = nn.Sequential(OrderedDict([
59
+ ("-1", nn.AvgPool2d(stride)),
60
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
61
+ ("1", nn.BatchNorm2d(planes * self.expansion))
62
+ ]))
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ identity = x
66
+
67
+ out = self.relu1(self.bn1(self.conv1(x)))
68
+ out = self.relu2(self.bn2(self.conv2(out)))
69
+ out = self.avgpool(out)
70
+ out = self.bn3(self.conv3(out))
71
+
72
+ if self.downsample is not None:
73
+ identity = self.downsample(x)
74
+
75
+ out += identity
76
+ out = self.relu3(out)
77
+ return out
78
+
79
+
80
+ class AttentionPool2d(nn.Module):
81
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
82
+ super().__init__()
83
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
84
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
85
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
86
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
87
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
88
+ self.num_heads = num_heads
89
+
90
+ def forward(self, x):
91
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
92
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
93
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
94
+ x, _ = F.multi_head_attention_forward(
95
+ query=x, key=x, value=x,
96
+ embed_dim_to_check=x.shape[-1],
97
+ num_heads=self.num_heads,
98
+ q_proj_weight=self.q_proj.weight,
99
+ k_proj_weight=self.k_proj.weight,
100
+ v_proj_weight=self.v_proj.weight,
101
+ in_proj_weight=None,
102
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
103
+ bias_k=None,
104
+ bias_v=None,
105
+ add_zero_attn=False,
106
+ dropout_p=0,
107
+ out_proj_weight=self.c_proj.weight,
108
+ out_proj_bias=self.c_proj.bias,
109
+ use_separate_proj_weight=True,
110
+ training=self.training,
111
+ need_weights=False
112
+ )
113
+
114
+ return x[0]
115
+
116
+
117
+ class LayerNorm(nn.LayerNorm):
118
+ """Subclass torch's LayerNorm to handle fp16."""
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ orig_type = x.dtype
122
+ layernorm_dtype = self.weight.dtype
123
+ ret = super().forward(x.type(layernorm_dtype))
124
+ return ret.type(orig_type)
125
+
126
+ class QuickGELU(nn.Module):
127
+ def forward(self, x: torch.Tensor):
128
+ return x * torch.sigmoid(1.702 * x)
129
+
130
+ class ResidualAttentionBlock(nn.Module):
131
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
132
+ super().__init__()
133
+
134
+ self.attn = nn.MultiheadAttention(d_model, n_head)
135
+ self.ln_1 = LayerNorm(d_model)
136
+ self.mlp = nn.Sequential(OrderedDict([
137
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
138
+ ("gelu", QuickGELU()),
139
+ ("c_proj", nn.Linear(d_model * 4, d_model))
140
+ ]))
141
+ self.ln_2 = LayerNorm(d_model)
142
+ self.attn_mask = attn_mask
143
+
144
+ if use_grad_checkpointing:
145
+ self.attn = checkpoint_wrapper(self.attn)
146
+ self.mlp = checkpoint_wrapper(self.mlp)
147
+
148
+ def attention(self, x: torch.Tensor):
149
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
150
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ x = x + self.attention(self.ln_1(x))
154
+ x = x + self.mlp(self.ln_2(x))
155
+ return x
156
+
157
+ class Transformer(nn.Module):
158
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
159
+ super().__init__()
160
+ self.width = width
161
+ self.layers = layers
162
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
163
+
164
+ def forward(self, x: torch.Tensor):
165
+ return self.resblocks(x)
166
+
167
+ class VisionTransformer(nn.Module):
168
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
169
+ super().__init__()
170
+ self.input_resolution = input_resolution
171
+ self.num_features = width
172
+ self.num_heads = heads
173
+ self.num_patches = (input_resolution // patch_size) ** 2
174
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
175
+ scale = width ** -0.5
176
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
177
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
178
+ self.ln_pre = LayerNorm(width)
179
+ self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
180
+
181
+ def forward(self, x: torch.Tensor):
182
+ x = self.conv1(x) # shape = [*, width, grid, grid]
183
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
184
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
185
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
186
+ x = x + self.positional_embedding.to(x.dtype)
187
+ x = self.ln_pre(x)
188
+ x = x.permute(1, 0, 2) # NLD -> LND
189
+ x = self.transformer(x)
190
+ x = x.permute(1, 0, 2) # LND -> NLD
191
+ return x
starvector/data/augmentation.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from svgpathtools import (
4
+ Path, Arc, CubicBezier, QuadraticBezier,
5
+ svgstr2paths)
6
+ import os
7
+ from noise import pnoise1
8
+ import re
9
+ import matplotlib.colors as mcolors
10
+ from bs4 import BeautifulSoup
11
+ from starvector.data.util import rasterize_svg
12
+
13
+ class SVGTransforms:
14
+ def __init__(self, transformations):
15
+ self.transformations = transformations
16
+ self.noise_std = self.transformations.get('noise_std', False)
17
+ self.noise_type = self.transformations.get('noise_type', False)
18
+ self.rotate = self.transformations.get('rotate', False)
19
+ self.shift_re = self.transformations.get('shift_re', False)
20
+ self.shift_im = self.transformations.get('shift_im', False)
21
+ self.scale = self.transformations.get('scale', False)
22
+ self.color_noise = self.transformations.get('color_noise', False)
23
+ self.p = self.transformations.get('p', 0.5)
24
+ self.color_change = self.transformations.get('color_change', False)
25
+ self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000'])
26
+
27
+ def sample_transformations(self):
28
+ if self.rotate:
29
+ a, b = self.rotate['from'], self.rotate['to']
30
+ rotation_angle = np.random.uniform(a, b)
31
+ self.rotation_angle = rotation_angle
32
+
33
+ if self.shift_re or self.shift_im:
34
+ self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to'])
35
+ self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to'])
36
+
37
+ if self.scale:
38
+ self.scale = np.random.uniform(self.scale['from'], self.scale['to'])
39
+
40
+ if self.color_noise:
41
+ self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to'])
42
+
43
+
44
+ def paths2str(self, groupped_paths, svg_opening_tag='<svg xmlns="http://www.w3.org/2000/svg" version="1.1">'):
45
+
46
+ keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry']
47
+ all_groups_srt = ''
48
+ for group, elements in groupped_paths.items():
49
+ group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', [])
50
+ group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items())
51
+ path_strings = []
52
+ path_str = ''
53
+ for path, attributes in paths_and_attributes:
54
+ path_attr_str = ''
55
+ d_str = path.d()
56
+
57
+ for key, value in attributes.items():
58
+ if key not in keys_to_exclude:
59
+ path_attr_str += f' {key}="{value}"'
60
+
61
+ path_strings.append(f'<path d="{d_str}"{path_attr_str} />')
62
+ path_str = "\n".join(path_strings)
63
+ if 'no_group'in group:
64
+ group_str = path_str
65
+ else:
66
+ group_str = f'<g {group_attr_str}>\n{path_str}\n</g>\n'
67
+ all_groups_srt += group_str
68
+ svg = f'{svg_opening_tag}\n{all_groups_srt}</svg>'
69
+ return svg
70
+
71
+ def add_noise(self, seg):
72
+ noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to'])
73
+ if self.noise_type == 'gaussian':
74
+ noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \
75
+ 1j * np.random.normal(loc=0.0, scale=noise_scale)
76
+ elif self.noise_type == 'perlin':
77
+ noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale
78
+
79
+ if isinstance(seg, CubicBezier):
80
+ seg.control1 = seg.control1 + noise_sample
81
+ seg.control2 = seg.control2 + noise_sample
82
+ elif isinstance(seg, QuadraticBezier):
83
+ seg.control = seg.control + noise_sample
84
+ elif isinstance(seg, Arc):
85
+ seg.radius = seg.radius + noise_sample
86
+
87
+
88
+ return seg
89
+
90
+ def do_rotate(self, path, viewbox_width, viewbox_height):
91
+ if self.rotate:
92
+ new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2))
93
+ return new_path
94
+ else:
95
+ return path
96
+
97
+ def do_shift(self, path):
98
+ if self.shift_re or self.shift_im:
99
+ return path.translated(complex(self.shift_real, self.shift_imag))
100
+ else:
101
+ return path
102
+
103
+ def do_scale(self, path):
104
+ if self.scale:
105
+ return path.scaled(self.scale)
106
+ else:
107
+ return path
108
+
109
+ def add_color_noise(self, source_color):
110
+ # Convert color to RGB
111
+ if source_color.startswith("#"):
112
+ base_color = mcolors.hex2color(source_color)
113
+ else:
114
+ base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF'))
115
+
116
+ # Add noise to each RGB component
117
+ noise = np.random.normal(0, self.color_noise_std, 3)
118
+ noisy_color = np.clip(np.array(base_color) + noise, 0, 1)
119
+
120
+ # Convert the RGB color back to hex
121
+ hex_color = mcolors.rgb2hex(noisy_color)
122
+
123
+ return hex_color
124
+
125
+ def do_color_change(self, attr):
126
+ if 'fill' in attr:
127
+ if self.color_noise or self.color_change:
128
+ fill_value = attr['fill']
129
+ if fill_value == 'none':
130
+ new_fill_value = 'none'
131
+ else:
132
+ if self.color_noise:
133
+ new_fill_value = self.add_color_noise(fill_value)
134
+ elif self.color_change:
135
+ new_fill_value = np.random.choice(self.colors)
136
+ attr['fill'] = new_fill_value
137
+ return attr
138
+
139
+ def clean_attributes(self, attr):
140
+ attr_out = {}
141
+ if 'fill' in attr:
142
+ attr_out = attr
143
+ elif 'style' in attr:
144
+ fill_values = re.findall('fill:[^;]+', attr['style'])
145
+ if fill_values:
146
+ fill_value = fill_values[0].replace('fill:', '').strip()
147
+ attr_out['fill'] = fill_value
148
+ else:
149
+ attr_out = attr
150
+ else:
151
+ attr_out = attr
152
+
153
+ return attr_out
154
+
155
+ def get_viewbox_size(self, svg):
156
+ # Try to extract viewBox attribute
157
+ match = re.search(r'viewBox="([^"]+)"', svg)
158
+ if match:
159
+ viewbox = match.group(1)
160
+ else:
161
+ # If viewBox is not found, try to extract width and height attributes
162
+ match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg)
163
+ if match:
164
+ width, height = match.groups()
165
+ viewbox = f"0 0 {width} {height}"
166
+ else:
167
+ viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found
168
+
169
+ viewbox = [float(x) for x in viewbox.split()]
170
+ viewbox_width, viewbox_height = viewbox[2], viewbox[3]
171
+ return viewbox_width, viewbox_height
172
+
173
+ def augment(self, svg):
174
+ if os.path.isfile(svg):
175
+ # open svg file
176
+ with open(svg, 'r') as f:
177
+ svg = f.read()
178
+
179
+ # Sample transformations for this sample
180
+ self.sample_transformations()
181
+
182
+
183
+ # Parse the SVG content
184
+ soup = BeautifulSoup(svg, 'xml')
185
+
186
+ # Get opening tag
187
+ svg_opening_tag = re.findall('<svg[^>]+>', svg)[0]
188
+
189
+ viewbox_width, viewbox_height = self.get_viewbox_size(svg)
190
+
191
+ # Get all svg parents
192
+ groups = soup.findAll()
193
+
194
+ # Create the groups of paths based on their original <g> tag
195
+ grouped_paths = {}
196
+ for i, g in enumerate(groups):
197
+ if g.name == 'g':
198
+ group_id = group_id = g.get('id') if g.get('id') else f'none_{i}'
199
+ group_attrs = g.attrs
200
+
201
+ elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs':
202
+ continue
203
+
204
+ else:
205
+ group_id = f'no_group_{i}'
206
+ group_attrs = {}
207
+
208
+ group_svg_string = f'{svg_opening_tag}{str(g)}</svg>'
209
+ try:
210
+ paths, attributes = svgstr2paths(group_svg_string)
211
+ except:
212
+ return svg, rasterize_svg(svg)
213
+ if not paths:
214
+ continue
215
+
216
+ paths_and_attributes = []
217
+
218
+ # Rotation, shift, scale, noise addition
219
+ new_paths = []
220
+ new_attributes = []
221
+ for path, attribute in zip(paths, attributes):
222
+ attr = self.clean_attributes(attribute)
223
+
224
+ new_path = self.do_rotate(path, viewbox_width, viewbox_height)
225
+ new_path = self.do_shift(new_path)
226
+ new_path = self.do_scale(new_path)
227
+
228
+ if self.noise_std:
229
+ # Add noise to path to deform svg
230
+ noisy_path = []
231
+ for seg in new_path:
232
+ noisy_seg = self.add_noise(seg)
233
+ noisy_path.append(noisy_seg)
234
+ new_paths.append(Path(*noisy_path))
235
+ else:
236
+ new_paths.append(new_path)
237
+
238
+ # Color change
239
+ attr = self.do_color_change(attr)
240
+ paths_and_attributes.append((new_path, attr))
241
+
242
+ grouped_paths[group_id] = {
243
+ 'paths': paths_and_attributes,
244
+ 'attrs': group_attrs
245
+ }
246
+
247
+ svg = self.paths2str(grouped_paths, svg_opening_tag)
248
+ image = rasterize_svg(svg)
249
+
250
+ return svg, image
starvector/data/base.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg
3
+ from starvector.util import instantiate_from_config
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+
7
+ class SVGDatasetBase(Dataset):
8
+ def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
9
+ self.split = split
10
+ self.im_size = im_size
11
+
12
+ transforms = kwargs.get('transforms', False)
13
+ if transforms:
14
+ self.transforms = instantiate_from_config(transforms)
15
+ self.p = self.transforms.p
16
+ else:
17
+ self.transforms = None
18
+ self.p = 0.0
19
+
20
+ normalization = kwargs.get('normalize', False)
21
+ if normalization:
22
+ mean = tuple(normalization.get('mean', None))
23
+ std = tuple(normalization.get('std', None))
24
+ else:
25
+ mean = None
26
+ std = None
27
+
28
+ self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std)
29
+ self.data = load_dataset(dataset_name, split=split)
30
+
31
+ print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split")
32
+
33
+ def __len__(self):
34
+ return len(self.data_json)
35
+
36
+ def get_svg_and_image(self, svg_str, sample_id):
37
+ do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p])
38
+ svg, image = None, None
39
+
40
+ # Try to augment the image if conditions are met
41
+ if self.transforms is not None and do_augment:
42
+ try:
43
+ svg, image = self.transforms.augment(svg_str)
44
+ except Exception as e:
45
+ print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG")
46
+
47
+ # If augmentation failed or wasn't attempted, try to rasterize the SVG
48
+ if svg is None or image is None:
49
+ try:
50
+ svg, image = svg_str, rasterize_svg(svg_str, self.im_size)
51
+ except Exception as e:
52
+ print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image")
53
+ svg = use_placeholder()
54
+ image = rasterize_svg(svg, self.im_size)
55
+
56
+ # If the image is completely white, use a placeholder image
57
+ if np.array(image).mean() == 255.0:
58
+ print(f"Image is full white, using placeholder image for {sample_id}")
59
+ svg = use_placeholder()
60
+ image = rasterize_svg(svg)
61
+
62
+ # Process the image
63
+ if 'siglip' in self.image_processor:
64
+ image = self.processor(image).pixel_values[0]
65
+ else:
66
+ image = self.processor(image)
67
+
68
+ return svg, image
69
+
70
+ def __getitem__(self, idx):
71
+ raise NotImplementedError("This method should be implemented by subclasses")
starvector/data/dataset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+ from starvector.data.augmentation import SVGTransforms
4
+ from starvector.data.util import ImageTrainProcessor
5
+ from transformers import AutoProcessor
6
+
7
+ class SVGDataset(SVGDatasetBase):
8
+ def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs):
9
+ super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
10
+
11
+ self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
12
+ select_dataset_name = kwargs.get('select_dataset_name', False)
13
+
14
+ if select_dataset_name:
15
+ self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
16
+
17
+ self.num_samples = num_samples
18
+ if self.num_samples != -1:
19
+ self.data = self.data.select(range(self.num_samples))
20
+
21
+ self.image_processor = kwargs.get('image_processor', None)
22
+ if 'siglip' in self.image_processor:
23
+ model_name = {'siglip_512': 'google/siglip-base-patch16-512',
24
+ 'siglip_384': 'google/siglip-large-patch16-384',
25
+ 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
26
+ self.processor = AutoProcessor.from_pretrained(model_name).image_processor
27
+ else:
28
+ self.processor = ImageTrainProcessor(size=self.im_size)
29
+ def __len__(self):
30
+ return len(self.data)
31
+
32
+ def __getitem__(self, idx):
33
+ svg_str = self.data[idx]['Svg']
34
+ sample_id = self.data[idx]['Filename']
35
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
36
+ caption = self.data[idx].get('Caption', "")
37
+ return {
38
+ 'svg': svg,
39
+ 'image': image,
40
+ 'id': sample_id,
41
+ 'caption': caption
42
+ }
starvector/data/emojisvg.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+
4
+
5
+ class EmojiSVGDataset(SVGDatasetBase):
6
+ def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs):
7
+ super().__init__(dataset_name, split, im_size, **kwargs)
8
+
9
+ self.num_samples = num_samples
10
+ if self.num_samples != -1:
11
+ self.data = self.data.select(range(self.num_samples))
12
+
13
+ def __len__(self):
14
+ return len(self.data)
15
+
16
+ def __getitem__(self, idx):
17
+
18
+ svg_str = self.data[idx]['Svg']
19
+ sample_id = self.data[idx]['Filename']
20
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
21
+ caption = self.data[idx].get('Caption', "")
22
+ return {
23
+ 'svg': svg,
24
+ 'image': image,
25
+ 'id': sample_id,
26
+ 'caption': caption
27
+ }
starvector/data/figrsvg.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+ from transformers import AutoProcessor
4
+ from starvector.data.util import ImageTrainProcessor
5
+
6
+ class FigrSVGDataset(SVGDatasetBase):
7
+ def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
8
+ super().__init__(dataset_name, split, im_size, **kwargs)
9
+
10
+ self.num_samples = num_samples
11
+ if self.num_samples != -1:
12
+ self.data = self.data.select(range(self.num_samples))
13
+
14
+ def __len__(self):
15
+ return len(self.data)
16
+
17
+ def __getitem__(self, idx):
18
+ svg_str = self.data[idx]['Svg']
19
+ sample_id = self.data[idx]['Id']
20
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
21
+ caption = self.data[idx].get('Caption', "")
22
+ return {
23
+ 'svg': svg,
24
+ 'image': image,
25
+ 'id': sample_id,
26
+ 'caption': caption
27
+ }
starvector/data/fontsvg.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+ from transformers import AutoProcessor
4
+ from starvector.data.util import ImageTrainProcessor
5
+
6
+ class FontSVGDataset(SVGDatasetBase):
7
+ def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
8
+ super().__init__(dataset_name, split, im_size, **kwargs)
9
+
10
+ self.num_samples = num_samples
11
+ if self.num_samples != -1:
12
+ self.data = self.data.select(range(self.num_samples))
13
+
14
+ def __len__(self):
15
+ return len(self.data)
16
+
17
+ def __getitem__(self, idx):
18
+
19
+ svg_str = self.data[idx]['Svg']
20
+ sample_id = self.data[idx]['Filename']
21
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
22
+ caption = self.data[idx].get('Caption', "")
23
+ return {
24
+ 'svg': svg,
25
+ 'image': image,
26
+ 'id': sample_id,
27
+ 'caption': caption
28
+ }
starvector/data/iconsvg.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+ from starvector.data.util import ImageTrainProcessor
4
+ from transformers import AutoProcessor
5
+
6
+ class SVGIconsDataset(SVGDatasetBase):
7
+ def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
8
+ super().__init__(dataset_name, split, im_size, **kwargs)
9
+
10
+ self.num_samples = num_samples
11
+ if self.num_samples != -1:
12
+ self.data = self.data.select(range(self.num_samples))
13
+
14
+ self.image_processor = kwargs.get('image_processor', None)
15
+ if 'siglip' in self.image_processor:
16
+ model_name = {'siglip_512': 'google/siglip-base-patch16-512',
17
+ 'siglip_384': 'google/siglip-large-patch16-384',
18
+ 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
19
+ self.processor = AutoProcessor.from_pretrained(model_name).image_processor
20
+ else:
21
+ self.processor = ImageTrainProcessor(size=self.im_size)
22
+
23
+
24
+ def __len__(self):
25
+ return len(self.data)
26
+
27
+ def __getitem__(self, idx):
28
+
29
+ svg_str = self.data[idx]['Svg']
30
+ sample_id = self.data[idx]['Filename']
31
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
32
+ caption = self.data[idx].get('Caption', "")
33
+ return {
34
+ 'svg': svg,
35
+ 'image': image,
36
+ 'id': sample_id,
37
+ 'caption': caption
38
+ }
starvector/data/stacksvg.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from starvector.data.base import SVGDatasetBase
3
+ from starvector.data.augmentation import SVGTransforms
4
+ import random
5
+ from transformers import AutoProcessor
6
+ from starvector.data.util import ImageTrainProcessor
7
+
8
+ text2svg_captions = [
9
+ "Draw an SVG of ",
10
+ "Draw an SVG image of ",
11
+ "Draw an SVG picture of ",
12
+ "Generate an SVG of ",
13
+ "Create an SVG of ",
14
+ "Design an SVG of ",
15
+ "Make an SVG of ",
16
+ ]
17
+
18
+ class SVGStackDataset(SVGDatasetBase):
19
+ def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
20
+ super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
21
+ self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
22
+
23
+ # Text2SVG specific
24
+ self.random_caption = kwargs.get('random_caption', True)
25
+ select_dataset_name = kwargs.get('select_dataset_name', False)
26
+ if select_dataset_name:
27
+ self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
28
+
29
+ self.num_samples = num_samples
30
+ if self.num_samples != -1:
31
+ self.data = self.data.select(range(self.num_samples))
32
+
33
+ self.image_processor = kwargs.get('image_processor', None)
34
+ if self.image_processor and 'siglip' in self.image_processor:
35
+ model_name = {'siglip_512': 'google/siglip-base-patch16-512',
36
+ 'siglip_384': 'google/siglip-large-patch16-384',
37
+ 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
38
+ self.processor = AutoProcessor.from_pretrained(model_name).image_processor
39
+ else:
40
+ self.processor = ImageTrainProcessor(size=self.im_size)
41
+
42
+
43
+ def __len__(self):
44
+ return len(self.data)
45
+
46
+ def __getitem__(self, idx):
47
+ svg_str = self.data[idx]['Svg']
48
+ sample_id = self.data[idx]['Filename']
49
+ svg, image = self.get_svg_and_image(svg_str, sample_id)
50
+
51
+ # Randomly choose between 'caption_blip' and 'caption_llava'
52
+ caption_column = random.choice(['caption_blip2', 'caption_llava'])
53
+ caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "")
54
+ return {
55
+ 'svg': svg,
56
+ 'image': image,
57
+ 'id': sample_id,
58
+ 'caption': caption,
59
+ }
starvector/data/util.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from torchvision.transforms.functional import InterpolationMode, pad
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from bs4 import BeautifulSoup
7
+ import re
8
+ from svgpathtools import svgstr2paths
9
+ import numpy as np
10
+ from PIL import Image
11
+ import cairosvg
12
+ from io import BytesIO
13
+ import numpy as np
14
+ import textwrap
15
+ import os
16
+ import base64
17
+ import io
18
+
19
+
20
+
21
+ CIRCLE_SVG = "<svg><circle cx='50%' cy='50%' r='50%' /></svg>"
22
+ VOID_SVF = "<svg></svg>"
23
+
24
+ def load_transforms():
25
+ transforms = {
26
+ 'train': None,
27
+ 'eval': None
28
+ }
29
+ return transforms
30
+
31
+ class ImageBaseProcessor():
32
+ def __init__(self, mean=None, std=None):
33
+ if mean is None:
34
+ mean = (0.48145466, 0.4578275, 0.40821073)
35
+ if std is None:
36
+ std = (0.26862954, 0.26130258, 0.27577711)
37
+
38
+ self.normalize = transforms.Normalize(mean=mean, std=std)
39
+
40
+ class ImageTrainProcessor(ImageBaseProcessor):
41
+ def __init__(self, mean=None, std=None, size=224, **kwargs):
42
+ super().__init__(mean, std)
43
+
44
+ self.size = size
45
+
46
+ self.transform = transforms.Compose([
47
+ transforms.Lambda(lambda img: self._rgba_to_rgb_white(img) if img.mode == "RGBA" else img),
48
+ transforms.Lambda(lambda img: self._pad_to_square(img)),
49
+ transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC),
50
+ transforms.ToTensor(),
51
+ self.normalize
52
+ ])
53
+
54
+ def __call__(self, item):
55
+ return self.transform(item)
56
+
57
+ def _pad_to_square(self, img):
58
+ # Calculate padding to make the image square
59
+ width, height = img.size
60
+ max_dim = max(width, height)
61
+ padding = [(max_dim - width) // 2, (max_dim - height) // 2]
62
+ padding += [max_dim - width - padding[0], max_dim - height - padding[1]]
63
+ return pad(img, padding, fill=255) # Assuming white padding
64
+
65
+ def _rgba_to_rgb_white(self, img):
66
+ background = Image.new("RGB", img.size, (255, 255, 255))
67
+ background.paste(img, mask=img.split()[3])
68
+ return background
69
+
70
+
71
+ def encode_image_base64(pil_image):
72
+ if pil_image.mode == 'RGBA':
73
+ pil_image = pil_image.convert('RGB') # Convert RGBA to RGB
74
+ buffered = io.BytesIO()
75
+ pil_image.save(buffered, format="JPEG")
76
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
77
+ return base64_image
78
+
79
+ # -------------- Generation utils --------------
80
+ def is_valid_svg(svg_text):
81
+ try:
82
+ svgstr2paths(svg_text)
83
+ return True
84
+ except Exception as e:
85
+ print(f"Invalid SVG: {str(e)}")
86
+ return False
87
+
88
+ def clean_svg(svg_text, output_width=None, output_height=None):
89
+ soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml
90
+ svg_bs4 = soup.prettify() # Prettify to get a string
91
+
92
+ # Store the original signal handler
93
+ import signal
94
+ original_handler = signal.getsignal(signal.SIGALRM)
95
+
96
+ try:
97
+ # Set a timeout to prevent hanging
98
+ def timeout_handler(signum, frame):
99
+ raise TimeoutError("SVG processing timed out")
100
+
101
+ # Set timeout
102
+ signal.signal(signal.SIGALRM, timeout_handler)
103
+ signal.alarm(5)
104
+
105
+ # Try direct conversion without BeautifulSoup
106
+ svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode()
107
+
108
+ except TimeoutError:
109
+ print("SVG conversion timed out, using fallback method")
110
+ svg_cairo = """<svg></svg>"""
111
+ finally:
112
+ # Always cancel the alarm and restore original handler, regardless of success or failure
113
+ signal.alarm(0)
114
+ signal.signal(signal.SIGALRM, original_handler)
115
+
116
+ svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("<?xml")]) # Remove xml header
117
+ return svg_clean
118
+
119
+
120
+ def use_placeholder():
121
+ return VOID_SVF
122
+
123
+ def process_and_rasterize_svg(svg_string, resolution=256, dpi = 128, scale=2):
124
+ try:
125
+ svgstr2paths(svg_string) # This will raise an exception if the svg is still not valid
126
+ out_svg = svg_string
127
+ except:
128
+ try:
129
+ svg = clean_svg(svg_string)
130
+ svgstr2paths(svg) # This will raise an exception if the svg is still not valid
131
+ out_svg = svg
132
+ except Exception as e:
133
+ out_svg = use_placeholder()
134
+
135
+ raster_image = rasterize_svg(out_svg, resolution, dpi, scale)
136
+ return out_svg, raster_image
137
+
138
+ def rasterize_svg(svg_string, resolution=224, dpi = 128, scale=2):
139
+ try:
140
+ svg_raster_bytes = cairosvg.svg2png(
141
+ bytestring=svg_string,
142
+ background_color='white',
143
+ output_width=resolution,
144
+ output_height=resolution,
145
+ dpi=dpi,
146
+ scale=scale)
147
+ svg_raster = Image.open(BytesIO(svg_raster_bytes))
148
+ except:
149
+ try:
150
+ svg = clean_svg(svg_string)
151
+ svg_raster_bytes = cairosvg.svg2png(
152
+ bytestring=svg,
153
+ background_color='white',
154
+ output_width=resolution,
155
+ output_height=resolution,
156
+ dpi=dpi,
157
+ scale=scale)
158
+ svg_raster = Image.open(BytesIO(svg_raster_bytes))
159
+ except:
160
+ svg_raster = Image.new('RGB', (resolution, resolution), color = 'white')
161
+ return svg_raster
162
+
163
+ def find_unclosed_tags(svg_content):
164
+ all_tags_pattern = r"<(\w+)"
165
+ self_closing_pattern = r"<\w+[^>]*\/>"
166
+ all_tags = re.findall(all_tags_pattern, svg_content)
167
+ self_closing_matches = re.findall(self_closing_pattern, svg_content)
168
+ self_closing_tags = []
169
+
170
+ for match in self_closing_matches:
171
+ tag = re.search(all_tags_pattern, match)
172
+ if tag:
173
+ self_closing_tags.append(tag.group(1))
174
+ unclosed_tags = []
175
+
176
+ for tag in all_tags:
177
+ if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count('</' + tag + '>'):
178
+ unclosed_tags.append(tag)
179
+ unclosed_tags = list(dict.fromkeys(unclosed_tags))
180
+
181
+ return unclosed_tags
182
+
183
+
184
+ # -------------- Plotting utils --------------
185
+ def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path):
186
+ array1 = np.array(image1).astype(np.float32)
187
+ array2 = np.array(image2).astype(np.float32)
188
+ diff = np.abs(array1 - array2).astype(np.uint8)
189
+
190
+ fig, axes = plt.subplots(1, 3, figsize=(10, 5))
191
+ axes[0].imshow(image1)
192
+ axes[0].set_title('generated_svg')
193
+ axes[0].axis('off')
194
+ axes[1].imshow(image2)
195
+ axes[1].set_title('gt')
196
+ axes[1].axis('off')
197
+ axes[2].imshow(diff)
198
+ axes[2].set_title('Difference')
199
+ axes[2].axis('off')
200
+ plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05)
201
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
202
+ image = Image.open(out_path)
203
+ plt.close(fig)
204
+ return image
205
+
206
+ def plot_images_side_by_side(image1, image2, out_path):
207
+ array1 = np.array(image1).astype(np.float32)
208
+ array2 = np.array(image2).astype(np.float32)
209
+ diff = np.abs(array1 - array2).astype(np.uint8)
210
+
211
+ fig, axes = plt.subplots(1, 3, figsize=(10, 5))
212
+ axes[0].imshow(image1)
213
+ axes[0].set_title('generated_svg')
214
+ axes[0].axis('off')
215
+ axes[1].imshow(image2)
216
+ axes[1].set_title('gt')
217
+ axes[1].axis('off')
218
+ axes[2].imshow(diff)
219
+ axes[2].set_title('Difference')
220
+ axes[2].axis('off')
221
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
222
+ image = Image.open(out_path)
223
+ plt.close(fig)
224
+ return image
225
+
226
+ def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename):
227
+ # Create a plot with the original image and different temperature results
228
+ num_temps = len(samples_temp)
229
+ fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]})
230
+
231
+ # Plot the original image
232
+ gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png')
233
+ gt_image = Image.open(gt_image_path)
234
+ axes[0, 0].imshow(gt_image)
235
+ axes[0, 0].set_title('Original')
236
+ axes[0, 0].axis('off')
237
+ axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16)
238
+ axes[1, 0].axis('off')
239
+
240
+ # Plot the generated images for different temperatures and metrics
241
+ for idx, (temp, sample) in enumerate(samples_temp.items()):
242
+ gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png')
243
+ gen_image = Image.open(gen_image_path)
244
+ axes[0, idx + 1].imshow(gen_image)
245
+ axes[0, idx + 1].set_title(f'Temp {temp}')
246
+ axes[0, idx + 1].axis('off')
247
+ axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}',
248
+ horizontalalignment='center', verticalalignment='center', fontsize=12)
249
+ axes[1, idx + 1].axis('off')
250
+
251
+ # Save the comparison plot
252
+ comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png')
253
+ plt.tight_layout()
254
+ plt.savefig(comparison_path)
255
+ plt.close()
256
+
257
+ def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path):
258
+ # First col shows caption, second col shows generated svg, third col shows gt svg
259
+ fig, axes = plt.subplots(1, 3, figsize=(10, 5))
260
+
261
+ # Split the prompt into multiple lines if it exceeds a certain length
262
+ prompt_lines = textwrap.wrap(prompt, width=30)
263
+ prompt_text = '\n'.join(prompt_lines)
264
+
265
+ # Display the prompt in the first cell
266
+ axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
267
+ axes[0].axis('off')
268
+ axes[1].imshow(svg_raster)
269
+ axes[1].set_title('generated_svg')
270
+ axes[1].axis('off')
271
+ axes[2].imshow(gt_svg_raster)
272
+ axes[2].set_title('gt')
273
+ axes[2].axis('off')
274
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
275
+ image = Image.open(out_path)
276
+ plt.close(fig)
277
+ return image
278
+
279
+ def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path):
280
+ # First col shows caption, second col shows generated svg, third col shows gt svg
281
+ fig, axes = plt.subplots(1, 3, figsize=(10, 5))
282
+
283
+ # Split the prompt into multiple lines if it exceeds a certain length
284
+ prompt_lines = textwrap.wrap(prompt, width=30)
285
+ prompt_text = '\n'.join(prompt_lines)
286
+
287
+ # Display the prompt in the first cell
288
+ axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
289
+ axes[0].axis('off')
290
+ axes[1].imshow(svg_raster)
291
+ axes[1].set_title('generated_svg')
292
+ axes[1].axis('off')
293
+ axes[2].imshow(gt_svg_raster)
294
+ axes[2].set_title('gt')
295
+ axes[2].axis('off')
296
+ plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05)
297
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1)
298
+ image = Image.open(out_path)
299
+ plt.close(fig)
300
+ return image
301
+
302
+ def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename):
303
+ # Calculate the number of temperature variations
304
+ num_temps = len(samples_temp)
305
+
306
+ # Create a plot with text, the original image, and different temperature results
307
+ fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6))
308
+
309
+ # Split the prompt into multiple lines if it exceeds a certain length
310
+ prompt_lines = textwrap.wrap(prompt, width=30)
311
+ prompt_text = '\n'.join(prompt_lines)
312
+
313
+ # Display the prompt in the first cell
314
+ axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True)
315
+ axes[0].axis('off')
316
+
317
+ # Plot the GT (ground truth) image in the second cell
318
+ gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png')
319
+ gt_image = Image.open(gt_image_path)
320
+ axes[1].imshow(gt_image)
321
+ axes[1].set_title('GT Image')
322
+ axes[1].axis('off')
323
+
324
+ # Plot the generated images for different temperatures and display metrics
325
+ for idx, (temp, sample) in enumerate(samples_temp.items()):
326
+ gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png')
327
+ gen_image = Image.open(gen_image_path)
328
+ axes[idx + 2].imshow(gen_image)
329
+ axes[idx + 2].set_title(f'Temp {temp}')
330
+ axes[idx + 2].axis('off')
331
+ clip_score = metrics[temp]["clip_score"]
332
+ axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes)
333
+
334
+ # Save the comparison plot
335
+ comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png')
336
+ plt.tight_layout()
337
+ plt.savefig(comparison_path)
338
+ plt.close()
339
+
340
+ return comparison_path
341
+
342
+
343
+ def plot_image_tensor(image):
344
+ import numpy as np
345
+ from PIL import Image
346
+ tensor = image[0].cpu().float()
347
+ tensor = tensor.permute(1, 2, 0)
348
+ array = (tensor.numpy() * 255).astype(np.uint8)
349
+ im = Image.fromarray(array)
350
+ im.save("tmp/output_image.jpg")
351
+
352
+
353
+ def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'):
354
+ # Calculate the number of rows required for the grid
355
+ num_images = len(images)
356
+ num_rows = (num_images + num_cols - 1) // num_cols
357
+
358
+ # Create a new figure
359
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))
360
+
361
+ # Loop through the image files and plot them
362
+ for i, image in enumerate(images):
363
+ row = i // num_cols
364
+ col = i % num_cols
365
+
366
+ # Open and display the image using Pillow
367
+ if type(image) == str:
368
+ img = Image.open(image)
369
+ else:
370
+ img = image
371
+ axes[row, col].imshow(img)
372
+ # axes[row, col].set_title(os.path.basename(image_file))
373
+ axes[row, col].axis('off')
374
+
375
+ # Remove empty subplots
376
+ for i in range(num_images, num_rows * num_cols):
377
+ row = i // num_cols
378
+ col = i % num_cols
379
+ fig.delaxes(axes[row, col])
380
+
381
+ # Adjust spacing between subplots
382
+ plt.tight_layout()
383
+
384
+ # save image
385
+ plt.savefig(out_path, dpi=300)
386
+ image = Image.open(out_path)
387
+ plt.close(fig)
388
+
389
+ return image
starvector/image_encoder.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from omegaconf import OmegaConf
6
+ from starvector.model.image_encoder.clip_model import convert_weights_to_precision
7
+ from starvector.data.util import ImageTrainProcessor
8
+
9
+ class ImageEncoder(nn.Module):
10
+ def __init__(self, config, **kwargs):
11
+ super(ImageEncoder, self).__init__()
12
+
13
+ image_size = config.image_size
14
+ torch_dtype = kwargs.get('model_precision', config.torch_dtype)
15
+ self.image_encoder_type = config.image_encoder_type
16
+ if self.image_encoder_type == 'clip':
17
+ self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size)
18
+ convert_weights_to_precision(self, torch_dtype)
19
+ self.processor = ImageTrainProcessor(size=config.image_size)
20
+
21
+ elif self.image_encoder_type == 'vqgan':
22
+ self.visual_encoder = self.build_vqgan_encoder()
23
+ self.ln_vision = None
24
+ self.processor = ImageTrainProcessor(size=config.image_size)
25
+
26
+ elif self.image_encoder_type == 'convnext':
27
+ self.visual_encoder = self.build_vqgan_encoder()
28
+ self.ln_vision = None
29
+ self.processor = ImageTrainProcessor(size=config.image_size)
30
+
31
+ elif 'siglip' in self.image_encoder_type:
32
+ if self.image_encoder_type == 'siglip_512':
33
+ model_name = "google/siglip-base-patch16-512"
34
+ elif self.image_encoder_type == 'siglip_384':
35
+ model_name = "google/siglip-large-patch16-384"
36
+ elif self.image_encoder_type == 'siglip_256':
37
+ model_name = "google/siglip-base-patch16-256"
38
+
39
+ from transformers import AutoProcessor, AutoModel
40
+
41
+ self.visual_encoder = AutoModel.from_pretrained(
42
+ model_name, torch_dtype = torch_dtype
43
+ ).vision_model
44
+
45
+ self.processor = AutoProcessor.from_pretrained(
46
+ model_name, torch_dtype = torch_dtype
47
+ )
48
+
49
+ def build_clip_encoder(self, image_size):
50
+ from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm
51
+ visual_encoder = VisionTransformer(
52
+ input_resolution=image_size,
53
+ patch_size=14,
54
+ width=1024,
55
+ layers=23,
56
+ heads=16,
57
+ use_grad_checkpointing=False)
58
+
59
+ ln_vision = LayerNorm(visual_encoder.num_features)
60
+ return visual_encoder, ln_vision
61
+
62
+ def build_vqgan_encoder(self):
63
+ from taming.modules.diffusionmodules.model import Encoder
64
+ VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md
65
+ vqgan_chkp_path = VQGAN_CHECKPOINT
66
+ files_in_directory = os.listdir(vqgan_chkp_path + '/configs')
67
+ vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0]
68
+ vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file))
69
+ visual_encoder = Encoder(**vqgan_config.model.params.ddconfig)
70
+
71
+ # Load checkpoint weights
72
+ checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict']
73
+
74
+ # Create a new state_dict with modified keys
75
+ new_state_dict = {}
76
+ for key, value in checkpoint.items():
77
+ if key.startswith('encoder.'):
78
+ new_key = key[len('encoder.'):]
79
+ new_state_dict[new_key] = value
80
+
81
+ # Load weights
82
+ visual_encoder.load_state_dict(new_state_dict)
83
+ return visual_encoder
84
+
85
+ def build_convnext_encoder(self):
86
+ import open_clip
87
+ model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
88
+ return model.visual
89
+
90
+ def forward(self, image):
91
+ if self.image_encoder_type == 'clip':
92
+ embeds = self.visual_encoder(image)
93
+ out = self.ln_vision(embeds)
94
+ elif self.image_encoder_type == 'open-clip':
95
+ out = self.visual_encoder(image)[1]
96
+ out = self.ln_vision(out)
97
+ elif self.image_encoder_type == 'vqgan':
98
+ out = self.visual_encoder(image)
99
+ size = out.size()
100
+ out = out.view(size[0], size[1], -1)
101
+ out = out.permute(0, 2, 1)
102
+ elif self.image_encoder_type == 'convnext':
103
+ out = self.visual_encoder.trunk.forward_features(image)
104
+ size = out.size()
105
+ out = out.view(size[0], size[1], -1)
106
+ out = out.permute(0, 2, 1)
107
+ elif 'siglip' in self.image_encoder_type:
108
+ out = self.visual_encoder(image)["last_hidden_state"]
109
+ return out
110
+
111
+ def process_images(self, images):
112
+ if self.image_encoder_type == 'clip':
113
+ res = []
114
+ for image in images:
115
+ res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W
116
+ return res
117
+ else:
118
+ return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0)
119
+
starvector/metrics/base_metric.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starvector.metrics.util import AverageMeter
2
+ from tqdm import tqdm
3
+ import math
4
+
5
+ class BaseMetric:
6
+ def __init__(self):
7
+ self.meter = AverageMeter()
8
+
9
+ def reset(self):
10
+ self.meter.reset()
11
+
12
+ def calculate_score(self, batch, update=True):
13
+ """
14
+ Batch: {"gt_im": [PIL Image], "gen_im": [Image]}
15
+ """
16
+ values = []
17
+ batch_size = len(next(iter(batch.values())))
18
+ for index in tqdm(range(batch_size)):
19
+ kwargs = {}
20
+ for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]:
21
+ if key in batch:
22
+ kwargs[key] = batch[key][index]
23
+ try:
24
+ measure = self.metric(**kwargs)
25
+ except Exception as e:
26
+ print("Error calculating metric: {}".format(e))
27
+ continue
28
+ if math.isnan(measure):
29
+ continue
30
+ values.append(measure)
31
+
32
+ if not values:
33
+ print("No valid values found for metric calculation.")
34
+ return float("nan")
35
+
36
+ score = sum(values) / len(values)
37
+ if update:
38
+ self.meter.update(score, len(values))
39
+ return self.meter.avg, values
40
+ else:
41
+ return score, values
42
+
43
+ def metric(self, **kwargs):
44
+ """
45
+ This method should be overridden by subclasses to provide the specific metric computation.
46
+ """
47
+ raise NotImplementedError("The metric method must be implemented by subclasses.")
48
+
49
+ def get_average_score(self):
50
+ return self.meter.avg
51
+
starvector/metrics/compute_LPIPS.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ToTensor, Normalize
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from starvector.metrics.base_metric import BaseMetric
5
+ import lpips
6
+ from tqdm import tqdm
7
+
8
+
9
+ class LPIPSDistanceCalculator(BaseMetric):
10
+ def __init__(self, config=None, device='cuda'):
11
+ super().__init__()
12
+ self.class_name = self.__class__.__name__
13
+ self.config = config
14
+ self.model = lpips.LPIPS(net='vgg').to(device)
15
+ self.metric = self.LPIPS
16
+ self.to_tensor = ToTensor()
17
+ self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ self.device = device
19
+
20
+ def LPIPS(self, tensor_image1, tensor_image2):
21
+ tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device)
22
+ return self.model(tensor_image1, tensor_image2)
23
+
24
+ def to_tensor_transform(self, pil_img):
25
+ return self.normalize(self.to_tensor(pil_img))
26
+
27
+ def collate_fn(self, batch):
28
+ gt_imgs, gen_imgs = zip(*batch)
29
+ tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs])
30
+ tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs])
31
+ return tensor_gt_imgs, tensor_gen_imgs
32
+
33
+ def calculate_score(self, batch, batch_size=8, update=True):
34
+ gt_images = batch['gt_im']
35
+ gen_images = batch['gen_im']
36
+
37
+ # Create DataLoader with custom collate function
38
+ data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False)
39
+
40
+ values = []
41
+ for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader):
42
+ # Compute LPIPS
43
+ lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch)
44
+ values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist())
45
+
46
+ if not values:
47
+ print("No valid values found for metric calculation.")
48
+ return float("nan")
49
+
50
+ avg_score = sum(values) / len(values)
51
+ if update:
52
+ self.meter.update(avg_score, len(values))
53
+ return self.meter.avg, values
54
+ else:
55
+ return avg_score, values
56
+
starvector/metrics/compute_SSIM.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starvector.metrics.base_metric import BaseMetric
2
+ from skimage.metrics import structural_similarity as ssim
3
+ import numpy as np
4
+
5
+ class SSIMDistanceCalculator(BaseMetric):
6
+ def __init__(self, config=None):
7
+ super().__init__()
8
+ self.class_name = self.__class__.__name__
9
+ self.config = config
10
+ self.metric = self.compute_SSIM
11
+
12
+ def compute_SSIM(self, **kwargs):
13
+ image1 = kwargs.get('gt_im')
14
+ image2 = kwargs.get('gen_im')
15
+ win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy
16
+ channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1
17
+ sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter
18
+
19
+ # Convert images to numpy arrays if they aren't already
20
+ img1_np = np.array(image1)
21
+ img2_np = np.array(image2)
22
+
23
+ # Check if images are grayscale or RGB
24
+ if len(img1_np.shape) == 3 and img1_np.shape[2] == 3:
25
+ # Compute SSIM for RGB images
26
+ score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True)
27
+ else:
28
+ # Convert to grayscale if not already
29
+ if len(img1_np.shape) == 3:
30
+ img1_np = np.mean(img1_np, axis=2)
31
+ img2_np = np.mean(img2_np, axis=2)
32
+
33
+ score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True)
34
+
35
+ return score
starvector/metrics/compute_clip_score.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ToTensor
2
+ import torch.nn.functional as F
3
+ from starvector.metrics.base_metric import BaseMetric
4
+ import torch
5
+ from torchmetrics.multimodal.clip_score import CLIPScore
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ import torchvision.transforms as transforms
9
+ from torchmetrics.functional.multimodal.clip_score import _clip_score_update
10
+
11
+ class CLIPScoreCalculator(BaseMetric):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.class_name = self.__class__.__name__
15
+ self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32")
16
+ self.clip_score.to('cuda')
17
+
18
+ def CLIP_Score(self, images, captions):
19
+ all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor)
20
+ return all_scores
21
+
22
+ def collate_fn(self, batch):
23
+ gen_imgs, captions = zip(*batch)
24
+ tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs]
25
+ return tensor_gen_imgs, captions
26
+
27
+ def calculate_score(self, batch, batch_size=512, update=True):
28
+ gen_images = batch['gen_im']
29
+ captions = batch['caption']
30
+
31
+ # Create DataLoader with custom collate function
32
+ data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
33
+
34
+ all_scores = []
35
+ for batch_eval in tqdm(data_loader):
36
+ images, captions = batch_eval
37
+ images = [img.to('cuda', non_blocking=True) * 255 for img in images]
38
+ list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist()
39
+ all_scores.extend(list_scores)
40
+
41
+ if not all_scores:
42
+ print("No valid scores found for metric calculation.")
43
+ return float("nan"), []
44
+
45
+ avg_score = sum(all_scores) / len(all_scores)
46
+ if update:
47
+ self.meter.update(avg_score, len(all_scores))
48
+ return self.meter.avg, all_scores
49
+ else:
50
+ return avg_score, all_scores
51
+
52
+ if __name__ == '__main__':
53
+ import multiprocessing
54
+ multiprocessing.set_start_method('spawn')
55
+ # Rest of your code...
starvector/metrics/compute_dino_score.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from starvector.metrics.base_metric import BaseMetric
4
+ from tqdm import tqdm
5
+ from transformers import AutoModel, AutoImageProcessor
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+
9
+ class DINOScoreCalculator(BaseMetric):
10
+ def __init__(self, config=None, device='cuda'):
11
+ super().__init__()
12
+ self.class_name = self.__class__.__name__
13
+ self.config = config
14
+ self.model, self.processor = self.get_DINOv2_model("base")
15
+ self.model = self.model.to(device)
16
+ self.device = device
17
+
18
+ self.metric = self.calculate_DINOv2_similarity_score
19
+
20
+ def get_DINOv2_model(self, model_size):
21
+ if model_size == "small":
22
+ model_size = "facebook/dinov2-small"
23
+ elif model_size == "base":
24
+ model_size = "facebook/dinov2-base"
25
+ elif model_size == "large":
26
+ model_size = "facebook/dinov2-large"
27
+ else:
28
+ raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}")
29
+ return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size)
30
+
31
+ def process_input(self, image, processor):
32
+ if isinstance(image, str):
33
+ image = Image.open(image)
34
+ if isinstance(image, Image.Image):
35
+ with torch.no_grad():
36
+ inputs = processor(images=image, return_tensors="pt").to(self.device)
37
+ outputs = self.model(**inputs)
38
+ features = outputs.last_hidden_state.mean(dim=1)
39
+ elif isinstance(image, torch.Tensor):
40
+ features = image.unsqueeze(0) if image.dim() == 1 else image
41
+ else:
42
+ raise ValueError("Input must be a file path, PIL Image, or tensor of features")
43
+ return features
44
+
45
+ def calculate_DINOv2_similarity_score(self, **kwargs):
46
+ image1 = kwargs.get('gt_im')
47
+ image2 = kwargs.get('gen_im')
48
+ features1 = self.process_input(image1, self.processor)
49
+ features2 = self.process_input(image2, self.processor)
50
+
51
+ cos = nn.CosineSimilarity(dim=1)
52
+ sim = cos(features1, features2).item()
53
+ sim = (sim + 1) / 2
54
+
55
+ return sim
starvector/metrics/compute_fid.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html
2
+ # from torchmetrics.image.fid import FrechetInceptionDistance
3
+ from PIL import Image
4
+ from starvector.metrics.base_metric import BaseMetric
5
+ import torch
6
+ from torchvision import transforms
7
+ import clip
8
+ from torch.nn.functional import adaptive_avg_pool2d
9
+ from starvector.metrics.inception import InceptionV3
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from scipy import linalg
13
+ import torchvision.transforms as TF
14
+
15
+ class FIDCalculator(BaseMetric):
16
+ def __init__(self, model_name = 'InceptionV3',):
17
+ self.class_name = self.__class__.__name__
18
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ self.model_name = model_name
20
+ if self.model_name == 'ViT-B/32':
21
+ self.dims = 512
22
+ model, preprocess = clip.load('ViT-B/32')
23
+
24
+ elif self.model_name == 'InceptionV3':
25
+ self.dims = 2048
26
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
27
+ model = InceptionV3([block_idx]).to(self.device)
28
+ preprocess = TF.Compose([TF.ToTensor()])
29
+
30
+ self.model = model.cuda()
31
+ self.preprocess = preprocess
32
+
33
+ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
34
+ """Numpy implementation of the Frechet Distance.
35
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
36
+ and X_2 ~ N(mu_2, C_2) is
37
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
38
+
39
+ Stable version by Dougal J. Sutherland.
40
+
41
+ Params:
42
+ -- mu1 : Numpy array containing the activations of a layer of the
43
+ inception net (like returned by the function 'get_predictions')
44
+ for generated samples.
45
+ -- mu2 : The sample mean over activations, precalculated on an
46
+ representative data set.
47
+ -- sigma1: The covariance matrix over activations for generated samples.
48
+ -- sigma2: The covariance matrix over activations, precalculated on an
49
+ representative data set.
50
+
51
+ Returns:
52
+ -- : The Frechet Distance.
53
+ """
54
+
55
+ mu1 = np.atleast_1d(mu1)
56
+ mu2 = np.atleast_1d(mu2)
57
+
58
+ sigma1 = np.atleast_2d(sigma1)
59
+ sigma2 = np.atleast_2d(sigma2)
60
+
61
+ assert mu1.shape == mu2.shape, \
62
+ 'Training and test mean vectors have different lengths'
63
+ assert sigma1.shape == sigma2.shape, \
64
+ 'Training and test covariances have different dimensions'
65
+
66
+ diff = mu1 - mu2
67
+
68
+ # Product might be almost singular
69
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
70
+ if not np.isfinite(covmean).all():
71
+ msg = ('fid calculation produces singular product; '
72
+ 'adding %s to diagonal of cov estimates') % eps
73
+ print(msg)
74
+ offset = np.eye(sigma1.shape[0]) * eps
75
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
76
+
77
+ # Numerical error might give slight imaginary component
78
+ if np.iscomplexobj(covmean):
79
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
80
+ m = np.max(np.abs(covmean.imag))
81
+ raise ValueError('Imaginary component {}'.format(m))
82
+ covmean = covmean.real
83
+
84
+ tr_covmean = np.trace(covmean)
85
+
86
+ return (diff.dot(diff) + np.trace(sigma1)
87
+ + np.trace(sigma2) - 2 * tr_covmean)
88
+
89
+ def get_activations(self, images):
90
+ dataset = ImageDataset(images, self.preprocess)
91
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4)
92
+ pred_arr = np.empty((len(images), self.dims))
93
+ start_idx = 0
94
+ for batch in tqdm(dataloader):
95
+ batch = batch.to(self.device)
96
+
97
+ with torch.no_grad():
98
+ if self.model_name == 'ViT-B/32':
99
+ pred = self.model.encode_image(batch).cpu().numpy()
100
+ elif self.model_name == 'InceptionV3':
101
+ pred = self.model(batch)[0]
102
+
103
+ # If model output is not scalar, apply global spatial average pooling.
104
+ # This happens if you choose a dimensionality not equal 2048.
105
+ if pred.size(2) != 1 or pred.size(3) != 1:
106
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
107
+
108
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
109
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
110
+ start_idx = start_idx + pred.shape[0]
111
+
112
+ return pred_arr
113
+
114
+ def calculate_activation_statistics(self, images):
115
+ act = self.get_activations(images)
116
+ mu = np.mean(act, axis=0)
117
+ sigma = np.cov(act, rowvar=False)
118
+ return mu, sigma
119
+
120
+ def pil_images_to_tensor(self, images_list):
121
+ """Convert a list of PIL Images to a torch.Tensor."""
122
+ tensors_list = [self.preprocess(img) for img in images_list]
123
+ return torch.stack(tensors_list).cuda() # BxCxHxW format
124
+
125
+ def calculate_score(self, batch):
126
+ m1, s1 = self.calculate_activation_statistics(batch['gt_im'])
127
+ m2, s2 = self.calculate_activation_statistics(batch['gen_im'])
128
+ fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
129
+ return fid_value
130
+
131
+ def reset(self):
132
+ pass
133
+
134
+ class ImageDataset(torch.utils.data.Dataset):
135
+ def __init__(self, images, processor=None):
136
+ self.images = images
137
+ self.processor = processor
138
+
139
+ def __len__(self):
140
+ return len(self.images)
141
+
142
+ def __getitem__(self, i):
143
+ img = self.images[i]
144
+ img = self.processor(img)
145
+ return img
starvector/metrics/compute_l2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ToTensor
2
+ import torch.nn.functional as F
3
+ from starvector.metrics.base_metric import BaseMetric
4
+ import torch
5
+
6
+ class L2DistanceCalculator(BaseMetric):
7
+ def __init__(self, config=None, masked_l2=False):
8
+ super().__init__()
9
+ self.class_name = self.__class__.__name__
10
+ self.config = config
11
+ self.metric = self.l2_distance
12
+ self.masked_l2 = masked_l2
13
+
14
+ def l2_distance(self, **kwargs):
15
+ image1 = kwargs.get('gt_im')
16
+ image2 = kwargs.get('gen_im')
17
+ image1_tensor = ToTensor()(image1)
18
+ image2_tensor = ToTensor()(image2)
19
+
20
+ if self.masked_l2:
21
+ # Create binary masks: 0 for white pixels, 1 for non-white pixels
22
+ mask1 = (image1_tensor != 1).any(dim=0).float()
23
+ mask2 = (image2_tensor != 1).any(dim=0).float()
24
+
25
+ # Create a combined mask for overlapping non-white pixels
26
+ combined_mask = mask1 * mask2
27
+
28
+ # Apply the combined mask to both images
29
+ image1_tensor = image1_tensor * combined_mask.unsqueeze(0)
30
+ image2_tensor = image2_tensor * combined_mask.unsqueeze(0)
31
+
32
+ # Compute mean squared error
33
+ mse = F.mse_loss(image1_tensor, image2_tensor)
34
+ return mse.item()
35
+
36
+
37
+
starvector/metrics/count_token_length.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from starvector.metrics.base_metric import BaseMetric
4
+ from tqdm import tqdm
5
+ from starvector.metrics.util import AverageMeter
6
+
7
+ from transformers import AutoTokenizer
8
+
9
+ class CountTokenLength(BaseMetric):
10
+ def __init__(self, config=None, device='cuda'):
11
+ super().__init__()
12
+ self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
13
+ self.metric = self.calculate_token_length
14
+ self.meter_gt_tokens = AverageMeter()
15
+ self.meter_gen_tokens = AverageMeter()
16
+ self.meter_diff = AverageMeter()
17
+
18
+ def calculate_token_length(self, **kwargs):
19
+ svg = kwargs.get('gt_svg')
20
+ tokens = self.tokenizer.encode(svg)
21
+ gen_svg = kwargs.get('gen_svg')
22
+ gen_tokens = self.tokenizer.encode(gen_svg)
23
+ diff = len(gen_tokens) - len(tokens)
24
+ return len(tokens), len(gen_tokens), diff
25
+
26
+ def calculate_score(self, batch, update=None):
27
+ gt_svgs = batch['gt_svg']
28
+ gen_svgs = batch['gen_svg']
29
+ values = []
30
+ for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"):
31
+ gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg)
32
+ self.meter_gt_tokens.update(gt_tokens, 1)
33
+ self.meter_gen_tokens.update(gen_tokens, 1)
34
+ self.meter_diff.update(diff, 1)
35
+ values.append({
36
+ 'gt_tokens': gt_tokens,
37
+ 'gen_tokens': gen_tokens,
38
+ 'diff': diff
39
+ })
40
+ avg_score = {
41
+ 'gt_tokens': self.meter_gt_tokens.avg,
42
+ 'gen_tokens': self.meter_gen_tokens.avg,
43
+ 'diff': self.meter_diff.avg
44
+ }
45
+ if not values:
46
+ print("No valid values found for metric calculation.")
47
+ return float("nan")
48
+
49
+ return avg_score, values
50
+
51
+ def reset(self):
52
+ self.meter_gt_tokens.reset()
53
+ self.meter_gen_tokens.reset()
54
+ self.meter_diff.reset()
starvector/metrics/inception.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(weights='DEFAULT')
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`"""
168
+ try:
169
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
170
+ except ValueError:
171
+ # Just a caution against weird version strings
172
+ version = (0,)
173
+
174
+ # Skips default weight inititialization if supported by torchvision
175
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
176
+ if version >= (0, 6):
177
+ kwargs['init_weights'] = False
178
+
179
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
180
+ # argument prior to version 0.13.
181
+ if version < (0, 13) and 'weights' in kwargs:
182
+ if kwargs['weights'] == 'DEFAULT':
183
+ kwargs['pretrained'] = True
184
+ elif kwargs['weights'] is None:
185
+ kwargs['pretrained'] = False
186
+ else:
187
+ raise ValueError(
188
+ 'weights=={} not supported in torchvision {}'.format(
189
+ kwargs['weights'], torchvision.__version__
190
+ )
191
+ )
192
+ del kwargs['weights']
193
+
194
+ return torchvision.models.inception_v3(*args, **kwargs)
195
+
196
+
197
+ def fid_inception_v3():
198
+ """Build pretrained Inception model for FID computation
199
+
200
+ The Inception model for FID computation uses a different set of weights
201
+ and has a slightly different structure than torchvision's Inception.
202
+
203
+ This method first constructs torchvision's Inception and then patches the
204
+ necessary parts that are different in the FID Inception model.
205
+ """
206
+ inception = _inception_v3(num_classes=1008,
207
+ aux_logits=False,
208
+ weights=None)
209
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
210
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
211
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
212
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
213
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
214
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
215
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
216
+ inception.Mixed_7b = FIDInceptionE_1(1280)
217
+ inception.Mixed_7c = FIDInceptionE_2(2048)
218
+
219
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
220
+ inception.load_state_dict(state_dict)
221
+ return inception
222
+
223
+
224
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
225
+ """InceptionA block patched for FID computation"""
226
+ def __init__(self, in_channels, pool_features):
227
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
228
+
229
+ def forward(self, x):
230
+ branch1x1 = self.branch1x1(x)
231
+
232
+ branch5x5 = self.branch5x5_1(x)
233
+ branch5x5 = self.branch5x5_2(branch5x5)
234
+
235
+ branch3x3dbl = self.branch3x3dbl_1(x)
236
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
237
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
238
+
239
+ # Patch: Tensorflow's average pool does not use the padded zero's in
240
+ # its average calculation
241
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
242
+ count_include_pad=False)
243
+ branch_pool = self.branch_pool(branch_pool)
244
+
245
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
246
+ return torch.cat(outputs, 1)
247
+
248
+
249
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
250
+ """InceptionC block patched for FID computation"""
251
+ def __init__(self, in_channels, channels_7x7):
252
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
253
+
254
+ def forward(self, x):
255
+ branch1x1 = self.branch1x1(x)
256
+
257
+ branch7x7 = self.branch7x7_1(x)
258
+ branch7x7 = self.branch7x7_2(branch7x7)
259
+ branch7x7 = self.branch7x7_3(branch7x7)
260
+
261
+ branch7x7dbl = self.branch7x7dbl_1(x)
262
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
263
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
264
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
265
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
266
+
267
+ # Patch: Tensorflow's average pool does not use the padded zero's in
268
+ # its average calculation
269
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
270
+ count_include_pad=False)
271
+ branch_pool = self.branch_pool(branch_pool)
272
+
273
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
274
+ return torch.cat(outputs, 1)
275
+
276
+
277
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
278
+ """First InceptionE block patched for FID computation"""
279
+ def __init__(self, in_channels):
280
+ super(FIDInceptionE_1, self).__init__(in_channels)
281
+
282
+ def forward(self, x):
283
+ branch1x1 = self.branch1x1(x)
284
+
285
+ branch3x3 = self.branch3x3_1(x)
286
+ branch3x3 = [
287
+ self.branch3x3_2a(branch3x3),
288
+ self.branch3x3_2b(branch3x3),
289
+ ]
290
+ branch3x3 = torch.cat(branch3x3, 1)
291
+
292
+ branch3x3dbl = self.branch3x3dbl_1(x)
293
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
294
+ branch3x3dbl = [
295
+ self.branch3x3dbl_3a(branch3x3dbl),
296
+ self.branch3x3dbl_3b(branch3x3dbl),
297
+ ]
298
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
299
+
300
+ # Patch: Tensorflow's average pool does not use the padded zero's in
301
+ # its average calculation
302
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
303
+ count_include_pad=False)
304
+ branch_pool = self.branch_pool(branch_pool)
305
+
306
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
+ return torch.cat(outputs, 1)
308
+
309
+
310
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
311
+ """Second InceptionE block patched for FID computation"""
312
+ def __init__(self, in_channels):
313
+ super(FIDInceptionE_2, self).__init__(in_channels)
314
+
315
+ def forward(self, x):
316
+ branch1x1 = self.branch1x1(x)
317
+
318
+ branch3x3 = self.branch3x3_1(x)
319
+ branch3x3 = [
320
+ self.branch3x3_2a(branch3x3),
321
+ self.branch3x3_2b(branch3x3),
322
+ ]
323
+ branch3x3 = torch.cat(branch3x3, 1)
324
+
325
+ branch3x3dbl = self.branch3x3dbl_1(x)
326
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
327
+ branch3x3dbl = [
328
+ self.branch3x3dbl_3a(branch3x3dbl),
329
+ self.branch3x3dbl_3b(branch3x3dbl),
330
+ ]
331
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
332
+
333
+ # Patch: The FID Inception model uses max pooling instead of average
334
+ # pooling. This is likely an error in this specific Inception
335
+ # implementation, as other Inception models use average pooling here
336
+ # (which matches the description in the paper).
337
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
338
+ branch_pool = self.branch_pool(branch_pool)
339
+
340
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
341
+ return torch.cat(outputs, 1)
starvector/metrics/metrics.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starvector.metrics.compute_l2 import L2DistanceCalculator
2
+ from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator
3
+ from starvector.metrics.compute_SSIM import SSIMDistanceCalculator
4
+ from starvector.metrics.compute_fid import FIDCalculator
5
+ from starvector.metrics.compute_clip_score import CLIPScoreCalculator
6
+ from starvector.data.util import rasterize_svg
7
+ from starvector.metrics.util import AverageMeter
8
+ from starvector.metrics.compute_dino_score import DINOScoreCalculator
9
+ from starvector.metrics.count_token_length import CountTokenLength
10
+ import os
11
+ from tqdm import tqdm
12
+
13
+ class SVGMetrics:
14
+ def __init__(self, config=None):
15
+ self.class_name = self.__class__.__name__
16
+
17
+ default_config = {
18
+ 'L2': True,
19
+ 'Masked-L2': False,
20
+ 'LPIPS': False,
21
+ 'SSIM': False,
22
+ 'FID': False,
23
+ 'FID_clip': False,
24
+ 'CLIPScore': False,
25
+ 'CountTokenLength': False,
26
+ 'ratio_post_processed': True,
27
+ 'ratio_non_compiling': True,
28
+ 'DinoScore': True,
29
+ }
30
+ self.config = config or default_config
31
+
32
+ self.metrics = {
33
+ 'L2': L2DistanceCalculator,
34
+ 'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True),
35
+ 'LPIPS': LPIPSDistanceCalculator,
36
+ 'SSIM': SSIMDistanceCalculator,
37
+ 'FID': lambda: FIDCalculator(model_name='InceptionV3'),
38
+ 'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'),
39
+ 'CLIPScore': CLIPScoreCalculator,
40
+ 'CountTokenLength': CountTokenLength,
41
+ 'ratio_post_processed': AverageMeter,
42
+ 'ratio_non_compiling': AverageMeter,
43
+ 'DinoScore': DINOScoreCalculator,
44
+ }
45
+
46
+ self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)}
47
+
48
+ def reset(self):
49
+ for metric in self.active_metrics.values():
50
+ metric.reset()
51
+
52
+ def batch_contains_raster(self, batch):
53
+ return "gt_im" in batch and "gen_im" in batch
54
+
55
+ def batch_contains_svg(self, batch):
56
+ return "gt_svg" in batch and "gen_svg" in batch
57
+
58
+ def calculate_metrics(self, batch, update=True):
59
+ if not self.batch_contains_raster(batch):
60
+ batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]]
61
+ batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]]
62
+
63
+ avg_results_dict = {}
64
+ all_results_dict = {}
65
+
66
+ def get_sample_id(json_item):
67
+ return json_item.get('outpath_filename') or json_item.get('sample_id')
68
+
69
+ # initialize all_results_dict
70
+ for i, json_item in enumerate(batch['json']):
71
+ sample_id = get_sample_id(json_item)
72
+ if sample_id is None:
73
+ raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]")
74
+ all_results_dict[sample_id] = {}
75
+
76
+ for metric_name, metric in self.active_metrics.items():
77
+ print(f"Calculating {metric_name}...")
78
+
79
+ # Handle metrics that return both average and per-sample results
80
+ if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']:
81
+ avg_result, list_result = metric.calculate_score(batch, update=update)
82
+ avg_results_dict[metric_name] = avg_result
83
+
84
+ # Store individual results
85
+ for i, result in enumerate(list_result):
86
+ sample_id = get_sample_id(batch['json'][i])
87
+ all_results_dict[sample_id][metric_name] = result
88
+
89
+ # Handle FID metrics that only return average
90
+ elif metric_name in ['FID', 'FID_clip']:
91
+ avg_results_dict[metric_name] = metric.calculate_score(batch)
92
+
93
+ # Handle other metrics (ratio metrics)
94
+ else:
95
+ self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict)
96
+
97
+ metric.reset()
98
+ print("Average results: \n", avg_results_dict)
99
+ return avg_results_dict, all_results_dict
100
+
101
+ def calculate_fid(self, batch):
102
+ if not self.batch_contains_raster(batch):
103
+ batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]]
104
+ batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]]
105
+
106
+ return self.active_metrics['FID'].calculate_score(batch).item()
107
+
108
+ def get_average_metrics(self):
109
+ metrics = {}
110
+ for metric_name, metric in self.active_metrics.items():
111
+ if hasattr(metric, 'avg'):
112
+ metrics[metric_name] = metric.avg
113
+ elif hasattr(metric, 'get_average_score'):
114
+ metrics[metric_name] = metric.get_average_score()
115
+ return metrics
116
+
117
+ def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict):
118
+ """Helper method to handle ratio-based metrics."""
119
+ metric_key = metric_name.replace('avg_', '').replace('ratio_', '')
120
+
121
+ for item in batch['json']:
122
+ sample_id = get_sample_id(item)
123
+ value = item[metric_key]
124
+ all_results_dict[sample_id][metric_name] = value
125
+ metric.update(value, 1)
126
+
127
+ avg_results_dict[metric_name] = metric.avg
starvector/metrics/util.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -------------- Metrics --------------
3
+ class AverageMeter(object):
4
+ """Computes and stores the average and current value"""
5
+
6
+ def __init__(self):
7
+ self.reset()
8
+
9
+ def reset(self):
10
+ self.val = 0
11
+ self.avg = 0
12
+ self.sum = 0
13
+ self.count = 0
14
+
15
+ def update(self, val, n=1):
16
+ self.val = val
17
+ self.sum += val * n
18
+ self.count += n
19
+ self.avg = self.sum / self.count
20
+
starvector/model/adapters/adapter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.init as init
3
+ import torch
4
+
5
+ class Swish(nn.Module):
6
+ def __init__(self):
7
+ super(Swish, self).__init__()
8
+
9
+ def forward(self, x):
10
+ return x * torch.sigmoid(x)
11
+
12
+ class Adapter(nn.Module):
13
+ def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1):
14
+ super().__init__()
15
+ self.query_length = query_length
16
+ self.dropout_prob = dropout_prob
17
+ self.adapter_norm = adapter_norm
18
+
19
+ self.dropout = nn.Dropout(p=self.dropout_prob)
20
+
21
+ self.c_fc = nn.Linear(input_size, input_size*2)
22
+ self.act = Swish()
23
+ self.c_proj = nn.Linear(input_size*2, output_size)
24
+
25
+ if adapter_norm == "layer_norm":
26
+ self.norm = nn.LayerNorm([self.query_length, output_size])
27
+ elif adapter_norm == "batch_norm":
28
+ self.norm = nn.BatchNorm1d(self.query_length)
29
+
30
+ self.init_type = init_type.lower()
31
+ self._initialize_weights()
32
+
33
+ def forward(self, hidden_states):
34
+ hidden_states = self.dropout(hidden_states)
35
+ hidden_states = self.c_fc(hidden_states)
36
+ hidden_states = self.act(hidden_states)
37
+ hidden_states = self.c_proj(hidden_states)
38
+ hidden_states = self.norm(hidden_states)
39
+ return hidden_states
40
+
41
+ def _initialize_weights(self):
42
+ for m in self.modules():
43
+ if isinstance(m, nn.Linear):
44
+ if self.init_type == "glorot":
45
+ init.xavier_uniform_(m.weight)
46
+ if m.bias is not None:
47
+ init.constant_(m.bias, 0)
48
+ elif self.init_type == "normal":
49
+ init.normal_(m.weight, mean=0, std=0.01)
50
+ if m.bias is not None:
51
+ init.constant_(m.bias, 0)
52
+ else:
53
+ raise ValueError("Invalid initialization type specified.")
starvector/model/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig
3
+ from starvector.data.base import ImageTrainProcessor
4
+ from starvector.util import dtype_mapping
5
+ from transformers import AutoConfig
6
+
7
+ def load_pretrained_model(model_path, device="cuda", **kwargs):
8
+ model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device)
9
+ tokenizer = model.model.svg_transformer.tokenizer
10
+ image_processor = ImageTrainProcessor()
11
+ context_len = model.model.query_length + model.model.max_length
12
+ return tokenizer, model, image_processor, context_len
13
+
14
+ def model_builder(config):
15
+ model_name = config.model.get("model_name", False)
16
+
17
+ args = {
18
+ "task": config.model.task,
19
+ "train_image_encoder": config.training.train_image_encoder,
20
+ "ignore_mismatched_sizes": True,
21
+ "starcoder_model_name": config.model.starcoder_model_name,
22
+ "train_LLM": config.training.train_LLM,
23
+ "torch_dtype": dtype_mapping[config.training.model_precision],
24
+ "transformer_layer_cls": config.model.get("transformer_layer_cls", False),
25
+ "use_cache": config.model.use_cache,
26
+ }
27
+ if model_name:
28
+ model = StarVectorForCausalLM.from_pretrained(model_name, **args)
29
+ else:
30
+ starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name)
31
+
32
+ starvector_config = StarVectorConfig(
33
+ max_length_train=config.model.max_length,
34
+ image_encoder_type=config.model.image_encoder_type,
35
+ use_flash_attn=config.model.use_flash_attn,
36
+ adapter_norm=config.model.adapter_norm,
37
+ starcoder_model_name=config.model.starcoder_model_name,
38
+ torch_dtype=dtype_mapping[config.training.model_precision],
39
+ num_attention_heads=starcoder_model_config.num_attention_heads,
40
+ num_hidden_layers=starcoder_model_config.num_hidden_layers,
41
+ vocab_size=starcoder_model_config.vocab_size,
42
+ hidden_size=starcoder_model_config.hidden_size,
43
+ num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None),
44
+ )
45
+ model = StarVectorForCausalLM(starvector_config, **args)
46
+
47
+ return model
48
+
49
+
starvector/model/gpt_bigcode/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from transformers.utils import (
18
+ OptionalDependencyNotAvailable,
19
+ _LazyModule,
20
+ is_torch_available,
21
+ )
22
+
23
+
24
+ _import_structure = {
25
+ "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"],
26
+ }
27
+
28
+ try:
29
+ if not is_torch_available():
30
+ raise OptionalDependencyNotAvailable()
31
+ except OptionalDependencyNotAvailable:
32
+ pass
33
+ else:
34
+ _import_structure["modeling_gpt_bigcode"] = [
35
+ "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST",
36
+ "GPTBigCodeForSequenceClassification",
37
+ "GPTBigCodeForTokenClassification",
38
+ "GPTBigCodeForCausalLM",
39
+ "GPTBigCodeModel",
40
+ "GPTBigCodePreTrainedModel",
41
+ ]
42
+
43
+ if TYPE_CHECKING:
44
+ from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig
45
+
46
+ try:
47
+ if not is_torch_available():
48
+ raise OptionalDependencyNotAvailable()
49
+ except OptionalDependencyNotAvailable:
50
+ pass
51
+ else:
52
+ from .modeling_gpt_bigcode import (
53
+ GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST,
54
+ GPTBigCodeForCausalLM,
55
+ GPTBigCodeForSequenceClassification,
56
+ GPTBigCodeForTokenClassification,
57
+ GPTBigCodeModel,
58
+ GPTBigCodePreTrainedModel,
59
+ )
60
+
61
+
62
+ else:
63
+ import sys
64
+
65
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
starvector/model/gpt_bigcode/configuration_gpt_bigcode.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The BigCode team and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ GPTBigCode configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+
25
+
26
+ class GPTBigCodeConfig(PretrainedConfig):
27
+ """
28
+ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a
29
+ GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a
30
+ configuration with the defaults will yield a similar configuration to that of the GPTBigCode
31
+ [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 50257):
39
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`GPTBigCodeModel`].
41
+ n_positions (`int`, *optional*, defaults to 1024):
42
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
43
+ just in case (e.g., 512 or 1024 or 2048).
44
+ n_embd (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the embeddings and hidden states.
46
+ n_layer (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ n_head (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ n_inner (`int`, *optional*, defaults to None):
51
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
52
+ activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
53
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
54
+ "gelu_pytorch_tanh"]`.
55
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for the embeddings.
59
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
60
+ The dropout ratio for the attention.
61
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
62
+ The epsilon to use in the layer normalization layers.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
66
+ Scale attention weights by dividing by sqrt(hidden_size)..
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models).
69
+ attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
70
+ Whether to call the fused softmax in float32.
71
+ scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
72
+ Whether to scale the attention softmax in float32.
73
+ attention_type (`bool`, *optional*, defaults to `True`):
74
+ Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
75
+ Example:
76
+
77
+ ```python
78
+ >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel
79
+
80
+ >>> # Initializing a GPTBigCode configuration
81
+ >>> configuration = GPTBigCodeConfig()
82
+
83
+ >>> # Initializing a model (with random weights) from the configuration
84
+ >>> model = GPTBigCodeModel(configuration)
85
+
86
+ >>> # Accessing the model configuration
87
+ >>> configuration = model.config
88
+ ```"""
89
+
90
+ model_type = "gpt_bigcode"
91
+ keys_to_ignore_at_inference = ["past_key_values"]
92
+ attribute_map = {
93
+ "hidden_size": "n_embd",
94
+ "max_position_embeddings": "n_positions",
95
+ "num_attention_heads": "n_head",
96
+ "num_hidden_layers": "n_layer",
97
+ }
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=50257,
102
+ n_positions=1024,
103
+ n_embd=768,
104
+ n_layer=12,
105
+ n_head=12,
106
+ n_inner=None,
107
+ activation_function="gelu_pytorch_tanh",
108
+ resid_pdrop=0.1,
109
+ embd_pdrop=0.1,
110
+ attn_pdrop=0.1,
111
+ layer_norm_epsilon=1e-5,
112
+ initializer_range=0.02,
113
+ scale_attn_weights=True,
114
+ use_cache=True,
115
+ bos_token_id=50256,
116
+ eos_token_id=50256,
117
+ attention_softmax_in_fp32=True,
118
+ scale_attention_softmax_in_fp32=True,
119
+ multi_query=True,
120
+ **kwargs,
121
+ ):
122
+ self.vocab_size = vocab_size
123
+ self.n_positions = n_positions
124
+ self.n_embd = n_embd
125
+ self.n_layer = n_layer
126
+ self.n_head = n_head
127
+ self.n_inner = n_inner
128
+ self.activation_function = activation_function
129
+ self.resid_pdrop = resid_pdrop
130
+ self.embd_pdrop = embd_pdrop
131
+ self.attn_pdrop = attn_pdrop
132
+ self.layer_norm_epsilon = layer_norm_epsilon
133
+ self.initializer_range = initializer_range
134
+ self.scale_attn_weights = scale_attn_weights
135
+ self.use_cache = use_cache
136
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
137
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
138
+ self.multi_query = multi_query
139
+
140
+ self.bos_token_id = bos_token_id
141
+ self.eos_token_id = eos_token_id
142
+
143
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
starvector/model/gpt_bigcode/modeling_gpt_bigcode.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Bigcode team and HuggingFace Inc. team.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch GPTBigCode model."""
15
+ import math
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ CausalLMOutputWithCrossAttentions,
29
+ SequenceClassifierOutputWithPast,
30
+ TokenClassifierOutput,
31
+ )
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
34
+ from transformers.utils import (
35
+ add_code_sample_docstrings,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ is_flash_attn_2_available,
39
+ is_flash_attn_greater_or_equal_2_10,
40
+ logging,
41
+ )
42
+ from starvector.model.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig
43
+
44
+
45
+ if is_flash_attn_2_available():
46
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
47
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
53
+ _CONFIG_FOR_DOC = "GPTBigCodeConfig"
54
+
55
+
56
+
57
+ # Fused kernels
58
+ # Use separate functions for each case because conditionals prevent kernel fusion.
59
+ # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
60
+ # Is it doable without writing 32 functions?
61
+ @torch.jit.script
62
+ def upcast_masked_softmax(
63
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
64
+ ):
65
+ input_dtype = x.dtype
66
+ x = x.to(softmax_dtype) * scale
67
+ x = torch.where(mask, x, mask_value)
68
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
69
+ return x
70
+
71
+
72
+ @torch.jit.script
73
+ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
74
+ input_dtype = x.dtype
75
+ x = x.to(softmax_dtype) * scale
76
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
77
+ return x
78
+
79
+
80
+ @torch.jit.script
81
+ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
82
+ x = torch.where(mask, x, mask_value)
83
+ x = torch.nn.functional.softmax(x, dim=-1)
84
+ return x
85
+
86
+
87
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
88
+ def _get_unpad_data(attention_mask):
89
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
90
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
91
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
92
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
93
+ return (
94
+ indices,
95
+ cu_seqlens,
96
+ max_seqlen_in_batch,
97
+ )
98
+
99
+
100
+ class GPTBigCodeAttention(nn.Module):
101
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
102
+ super().__init__()
103
+ self.config = config
104
+
105
+ self.mask_value = None
106
+ self.multi_query = config.multi_query
107
+ self.embed_dim = config.hidden_size
108
+ self.num_heads = config.num_attention_heads
109
+ self.head_dim = self.embed_dim // self.num_heads
110
+ self.kv_heads = 1 if self.multi_query else self.num_heads
111
+ self.kv_dim = self.kv_heads * self.head_dim
112
+ self.split_size = self.embed_dim
113
+ self.is_causal = True
114
+
115
+ if self.head_dim * self.num_heads != self.embed_dim:
116
+ raise ValueError(
117
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
118
+ f" {self.num_heads})."
119
+ )
120
+
121
+ self.scale_attn_weights = config.scale_attn_weights
122
+ self.is_cross_attention = is_cross_attention
123
+
124
+ self.layer_idx = layer_idx
125
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
126
+ self.scale_attention_softmax_in_fp32 = (
127
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
128
+ )
129
+ self.attn_pdrop = config.attn_pdrop
130
+
131
+ if self.is_cross_attention:
132
+ if self.multi_query:
133
+ raise NotImplementedError("Multi-Query Attention not supported for cross_attention")
134
+
135
+ self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim)
136
+ self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
137
+ else:
138
+ self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
139
+
140
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
141
+
142
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
143
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
144
+
145
+ def _get_mask_value(self, device, dtype):
146
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
147
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
148
+ self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
149
+ return self.mask_value
150
+
151
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
152
+ dtype = query.dtype
153
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
154
+ upcast = dtype != softmax_dtype
155
+
156
+ unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
157
+ scale_factor = unscale**-1
158
+ if self.scale_attn_weights:
159
+ scale_factor /= self.head_dim**0.5
160
+
161
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
162
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
163
+ query_shape = query.shape
164
+ batch_size = query_shape[0]
165
+ key_length = key.size(-1)
166
+ if self.multi_query:
167
+ # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
168
+ # -> (batch_size, query_length, num_heads, key_length)
169
+ query_length = query_shape[1]
170
+ attn_shape = (batch_size, query_length, self.num_heads, key_length)
171
+ attn_view = (batch_size, query_length * self.num_heads, key_length)
172
+ # No copy needed for MQA 2, or when layer_past is provided.
173
+ query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
174
+ else:
175
+ # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
176
+ # -> (batch_size, num_heads, query_length, key_length)
177
+ query_length = query_shape[2]
178
+ attn_shape = (batch_size, self.num_heads, query_length, key_length)
179
+ attn_view = (batch_size * self.num_heads, query_length, key_length)
180
+ # Always copies
181
+ query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
182
+ # No copy when layer_past is provided.
183
+ key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
184
+
185
+ attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
186
+ if query.device.type == "cpu":
187
+ # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
188
+ # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
189
+ # but the fix has not been released as of pytorch version 2.0.0.
190
+ attn_weights = torch.zeros_like(attn_weights)
191
+ beta = 1
192
+ else:
193
+ beta = 0
194
+ attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
195
+
196
+ if upcast:
197
+ # Use a fused kernel to prevent a large overhead from casting and scaling.
198
+ # Sub-optimal when the key length is not a multiple of 8.
199
+ if attention_mask is None:
200
+ attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
201
+ else:
202
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
203
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
204
+ else:
205
+ if attention_mask is not None:
206
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
207
+
208
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
209
+ attn_weights = torch.where(attention_mask, attn_weights, mask_value)
210
+
211
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
212
+
213
+ attn_weights = self.attn_dropout(attn_weights)
214
+
215
+ # Mask heads if we want to
216
+ if head_mask is not None:
217
+ if self.multi_query:
218
+ head_mask = head_mask.transpose(1, 2)
219
+ attn_weights = attn_weights * head_mask
220
+
221
+ if self.multi_query:
222
+ attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
223
+ else:
224
+ attn_output = torch.matmul(attn_weights, value)
225
+
226
+ return attn_output, attn_weights
227
+
228
+ def forward(
229
+ self,
230
+ hidden_states: torch.Tensor,
231
+ layer_past: Optional[torch.Tensor] = None,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ head_mask: Optional[torch.Tensor] = None,
234
+ encoder_hidden_states: Optional[torch.Tensor] = None,
235
+ encoder_attention_mask: Optional[torch.Tensor] = None,
236
+ use_cache: Optional[bool] = False,
237
+ output_attentions: Optional[bool] = False,
238
+ ) -> Union[
239
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
240
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
241
+ ]:
242
+ if encoder_hidden_states is not None:
243
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
244
+ raise ValueError(
245
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
246
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
247
+ )
248
+
249
+ query = self.q_attn(hidden_states)
250
+ key_value = self.c_attn(encoder_hidden_states)
251
+ attention_mask = encoder_attention_mask
252
+ elif self.multi_query:
253
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
254
+ else:
255
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
256
+ # i.e., the memory layout is not the same as GPT2.
257
+ # This makes the concatenation with past_key_value more efficient.
258
+ query, key_value = (
259
+ self.c_attn(hidden_states)
260
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
261
+ .transpose(1, 2)
262
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
263
+ )
264
+
265
+ if layer_past is not None:
266
+ key_value = torch.cat((layer_past, key_value), dim=-2)
267
+ present = key_value if use_cache else None
268
+
269
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
270
+
271
+ attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
272
+
273
+ if not self.multi_query:
274
+ attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
275
+ attn_output = self.c_proj(attn_output)
276
+ attn_output = self.resid_dropout(attn_output)
277
+
278
+ outputs = (attn_output, present)
279
+ if output_attentions:
280
+ if self.multi_query:
281
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
282
+ attn_weights = attn_weights.transpose(1, 2)
283
+ outputs += (attn_weights,)
284
+
285
+ return outputs # a, present, (attentions)
286
+
287
+
288
+ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
289
+ """
290
+ GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
291
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
292
+ API of flash attention and deal with padding tokens in case the input contains any of them.
293
+ """
294
+
295
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
296
+ def __init__(self, *args, **kwargs):
297
+ super().__init__(*args, **kwargs)
298
+
299
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
300
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
301
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
302
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ layer_past: Optional[torch.Tensor] = None,
308
+ attention_mask: Optional[torch.Tensor] = None,
309
+ head_mask: Optional[torch.Tensor] = None,
310
+ encoder_hidden_states: Optional[torch.Tensor] = None,
311
+ encoder_attention_mask: Optional[torch.Tensor] = None,
312
+ use_cache: Optional[bool] = False,
313
+ output_attentions: Optional[bool] = False,
314
+ ) -> Union[
315
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
316
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
317
+ ]:
318
+ if encoder_hidden_states is not None:
319
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
320
+ raise ValueError(
321
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
322
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
323
+ )
324
+
325
+ query = self.q_attn(hidden_states)
326
+ key_value = self.c_attn(encoder_hidden_states)
327
+ attention_mask = encoder_attention_mask
328
+ elif self.multi_query:
329
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
330
+ else:
331
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
332
+ # i.e., the memory layout is not the same as GPT2.
333
+ # This makes the concatenation with past_key_value more efficient.
334
+ query, key_value = (
335
+ self.c_attn(hidden_states)
336
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
337
+ .transpose(1, 2)
338
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
339
+ )
340
+
341
+ if layer_past is not None:
342
+ key_value = torch.cat((layer_past, key_value), dim=-2)
343
+ present = key_value if use_cache else None
344
+
345
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
346
+
347
+ # Flash attention requires the input to have the shape
348
+ # batch_size x seq_length x head_dim x hidden_dim
349
+ if self.multi_query:
350
+ batch_size, query_length, _ = query.shape
351
+ query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
352
+ key = key.unsqueeze(2)
353
+ value = value.unsqueeze(2)
354
+ else:
355
+ query_length = query.shape[2]
356
+ batch_size, _, tgt, _ = key.shape
357
+ query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
358
+ key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
359
+ value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
360
+
361
+ attn_dropout = self.attn_pdrop if self.training else 0.0
362
+
363
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
364
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
365
+ # cast them back in float16 just to be sure everything works as expected.
366
+ input_dtype = query.dtype
367
+ if input_dtype == torch.float32:
368
+ if torch.is_autocast_enabled():
369
+ target_dtype = torch.get_autocast_gpu_dtype()
370
+ # Handle the case where the model is quantized
371
+ elif hasattr(self.config, "_pre_quantization_dtype"):
372
+ target_dtype = self.config._pre_quantization_dtype
373
+ else:
374
+ target_dtype = self.c_attn.weight.dtype
375
+
376
+ logger.warning_once(
377
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
378
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
379
+ f" {target_dtype}."
380
+ )
381
+ query = query.to(target_dtype)
382
+ key = key.to(target_dtype)
383
+ value = value.to(target_dtype)
384
+
385
+ attn_output = self._flash_attention_forward(
386
+ query, key, value, attention_mask, query_length, dropout=attn_dropout
387
+ )
388
+
389
+ attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
390
+ attn_output = self.c_proj(attn_weights_reshaped)
391
+ attn_output = self.resid_dropout(attn_output)
392
+
393
+ outputs = (attn_output, present)
394
+
395
+ if output_attentions:
396
+ if self.multi_query:
397
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
398
+ attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
399
+ else:
400
+ attn_weights_reshaped = None
401
+
402
+ outputs += (attn_weights_reshaped,)
403
+
404
+ return outputs # a, present, (attentions)
405
+
406
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
407
+ def _flash_attention_forward(
408
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
409
+ ):
410
+ """
411
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
412
+ first unpad the input, then computes the attention scores and pad the final attention scores.
413
+
414
+ Args:
415
+ query_states (`torch.Tensor`):
416
+ Input query states to be passed to Flash Attention API
417
+ key_states (`torch.Tensor`):
418
+ Input key states to be passed to Flash Attention API
419
+ value_states (`torch.Tensor`):
420
+ Input value states to be passed to Flash Attention API
421
+ attention_mask (`torch.Tensor`):
422
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
423
+ position of padding tokens and 1 for the position of non-padding tokens.
424
+ dropout (`float`):
425
+ Attention dropout
426
+ softmax_scale (`float`, *optional*):
427
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
428
+ """
429
+ if not self._flash_attn_uses_top_left_mask:
430
+ causal = self.is_causal
431
+ else:
432
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
433
+ causal = self.is_causal and query_length != 1
434
+
435
+ # Contains at least one padding token in the sequence
436
+ if attention_mask is not None:
437
+ batch_size = query_states.shape[0]
438
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
439
+ query_states, key_states, value_states, attention_mask, query_length
440
+ )
441
+
442
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
443
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
444
+
445
+ attn_output_unpad = flash_attn_varlen_func(
446
+ query_states,
447
+ key_states,
448
+ value_states,
449
+ cu_seqlens_q=cu_seqlens_q,
450
+ cu_seqlens_k=cu_seqlens_k,
451
+ max_seqlen_q=max_seqlen_in_batch_q,
452
+ max_seqlen_k=max_seqlen_in_batch_k,
453
+ dropout_p=dropout,
454
+ softmax_scale=softmax_scale,
455
+ causal=causal,
456
+ )
457
+
458
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
459
+ else:
460
+ attn_output = flash_attn_func(
461
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
462
+ )
463
+
464
+ return attn_output
465
+
466
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
467
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
468
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
469
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
470
+
471
+ key_layer = index_first_axis(
472
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
473
+ )
474
+ value_layer = index_first_axis(
475
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
476
+ )
477
+ if query_length == kv_seq_len:
478
+ query_layer = index_first_axis(
479
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
480
+ )
481
+ cu_seqlens_q = cu_seqlens_k
482
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
483
+ indices_q = indices_k
484
+ elif query_length == 1:
485
+ max_seqlen_in_batch_q = 1
486
+ cu_seqlens_q = torch.arange(
487
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
488
+ ) # There is a memcpy here, that is very bad.
489
+ indices_q = cu_seqlens_q[:-1]
490
+ query_layer = query_layer.squeeze(1)
491
+ else:
492
+ # The -q_len: slice assumes left padding.
493
+ attention_mask = attention_mask[:, -query_length:]
494
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
495
+
496
+ return (
497
+ query_layer,
498
+ key_layer,
499
+ value_layer,
500
+ indices_q,
501
+ (cu_seqlens_q, cu_seqlens_k),
502
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
503
+ )
504
+
505
+
506
+ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
507
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
508
+ if head_mask is not None:
509
+ # The super dispatch is done in the forward.
510
+ raise ValueError(
511
+ "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
512
+ )
513
+
514
+ scale = None
515
+ if not self.scale_attn_weights:
516
+ scale = 1
517
+
518
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
519
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
520
+ query_shape = query.shape
521
+ batch_size = query_shape[0]
522
+ key.shape[-2]
523
+
524
+ if self.multi_query:
525
+ query_length = query_shape[1]
526
+
527
+ # SDPA requires the dimension [..., sequence_length, head_dim].
528
+ query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
529
+
530
+ # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
531
+ key = key.unsqueeze(1)
532
+ value = value.unsqueeze(1)
533
+
534
+ # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
535
+ # and flash attention backend (No available kernel. Aborting execution.) from the shapes
536
+ # query = [batch_size, num_heads, query_length, head_dim]
537
+ # key = [batch_size, 1, past_length, head_dim]
538
+ # value = [batch_size, 1, past_length, head_dim]
539
+ #
540
+ # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
541
+ if is_torch_greater_or_equal_than_2_2:
542
+ key = key.expand(-1, self.num_heads, -1, -1)
543
+ value = value.expand(-1, self.num_heads, -1, -1)
544
+ else:
545
+ query_length = query_shape[-1]
546
+
547
+ # See the comment above.
548
+ if query.device.type == "cuda" and attention_mask is not None:
549
+ query = query.contiguous()
550
+ key = key.contiguous()
551
+ value = value.contiguous()
552
+
553
+ sdpa_result = torch.nn.functional.scaled_dot_product_attention(
554
+ query,
555
+ key,
556
+ value,
557
+ attn_mask=attention_mask,
558
+ dropout_p=self.attn_pdrop if self.training else 0.0,
559
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
560
+ is_causal=self.is_causal and attention_mask is None and query_length > 1,
561
+ scale=scale,
562
+ )
563
+
564
+ if self.multi_query:
565
+ # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
566
+ sdpa_result = sdpa_result.transpose(1, 2)
567
+
568
+ # Reshape is kind of expensive here, as it does a memory copy,
569
+ # but I did not manage to make away without it (logits do not match when using view)
570
+ # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
571
+ sdpa_result = sdpa_result.reshape(query_shape)
572
+
573
+ return sdpa_result, None
574
+
575
+ def forward(
576
+ self,
577
+ hidden_states: torch.Tensor,
578
+ layer_past: Optional[torch.Tensor] = None,
579
+ attention_mask: Optional[torch.Tensor] = None,
580
+ head_mask: Optional[torch.Tensor] = None,
581
+ encoder_hidden_states: Optional[torch.Tensor] = None,
582
+ encoder_attention_mask: Optional[torch.Tensor] = None,
583
+ use_cache: Optional[bool] = False,
584
+ output_attentions: Optional[bool] = False,
585
+ ) -> Union[
586
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
587
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
588
+ ]:
589
+ if encoder_hidden_states is not None:
590
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
591
+ raise ValueError(
592
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
593
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
594
+ )
595
+
596
+ query = self.q_attn(hidden_states)
597
+ key_value = self.c_attn(encoder_hidden_states)
598
+ attention_mask = encoder_attention_mask
599
+ elif self.multi_query:
600
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
601
+ else:
602
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
603
+ # i.e., the memory layout is not the same as GPT2.
604
+ # This makes the concatenation with past_key_value more efficient.
605
+ query, key_value = (
606
+ self.c_attn(hidden_states)
607
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
608
+ .transpose(1, 2)
609
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
610
+ )
611
+
612
+ if layer_past is not None:
613
+ key_value = torch.cat((layer_past, key_value), dim=-2)
614
+ present = key_value if use_cache else None
615
+
616
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
617
+
618
+ if not output_attentions and head_mask is None:
619
+ # Difference with the original implementation: there is no need to transpose the key here,
620
+ # as SDPA expects seq_length to be at index -2 for the key as well
621
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
622
+ else:
623
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
624
+ logger.warning_once(
625
+ "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
626
+ ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
627
+ )
628
+ attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
629
+
630
+ if not self.multi_query:
631
+ attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
632
+ attn_output = self.c_proj(attn_output)
633
+ attn_output = self.resid_dropout(attn_output)
634
+
635
+ outputs = (attn_output, present)
636
+ if output_attentions:
637
+ if self.multi_query:
638
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
639
+ attn_weights = attn_weights.transpose(1, 2)
640
+ outputs += (attn_weights,)
641
+
642
+ return outputs
643
+
644
+
645
+ class GPTBigCodeMLP(nn.Module):
646
+ def __init__(self, intermediate_size, config):
647
+ super().__init__()
648
+ embed_dim = config.hidden_size
649
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
650
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
651
+ self.act = ACT2FN[config.activation_function]
652
+ self.dropout = nn.Dropout(config.resid_pdrop)
653
+
654
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
655
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
656
+ hidden_states = self.c_fc(hidden_states)
657
+ hidden_states = self.act(hidden_states)
658
+ hidden_states = self.c_proj(hidden_states)
659
+ hidden_states = self.dropout(hidden_states)
660
+ return hidden_states
661
+
662
+
663
+ GPTBIGCODE_ATTENTION_CLASSES = {
664
+ "eager": GPTBigCodeAttention,
665
+ "flash_attention_2": GPTBigCodeFlashAttention2,
666
+ "sdpa": GPTBigCodeSdpaAttention,
667
+ }
668
+
669
+
670
+ class GPTBigCodeBlock(nn.Module):
671
+ def __init__(self, config, layer_idx=None):
672
+ super().__init__()
673
+ hidden_size = config.hidden_size
674
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
675
+
676
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
677
+
678
+ self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
679
+
680
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
681
+
682
+ if config.add_cross_attention:
683
+ if config.multi_query:
684
+ raise NotImplementedError("Cross-attention not implemented for MQA")
685
+
686
+ self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
687
+ config, is_cross_attention=True, layer_idx=layer_idx
688
+ )
689
+
690
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
691
+
692
+ self.mlp = GPTBigCodeMLP(self.inner_dim, config)
693
+
694
+ def forward(
695
+ self,
696
+ hidden_states: Optional[Tuple[torch.Tensor]],
697
+ layer_past: Optional[torch.Tensor] = None,
698
+ attention_mask: Optional[torch.Tensor] = None,
699
+ head_mask: Optional[torch.Tensor] = None,
700
+ encoder_hidden_states: Optional[torch.Tensor] = None,
701
+ encoder_attention_mask: Optional[torch.Tensor] = None,
702
+ use_cache: Optional[bool] = False,
703
+ output_attentions: Optional[bool] = False,
704
+ ) -> Union[
705
+ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
706
+ ]:
707
+ residual = hidden_states
708
+ hidden_states = self.ln_1(hidden_states)
709
+ attn_outputs = self.attn(
710
+ hidden_states,
711
+ layer_past=layer_past,
712
+ attention_mask=attention_mask,
713
+ head_mask=head_mask,
714
+ use_cache=use_cache,
715
+ output_attentions=output_attentions,
716
+ )
717
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
718
+ outputs = attn_outputs[1:]
719
+ # residual connection
720
+ hidden_states = attn_output + residual
721
+
722
+ if encoder_hidden_states is not None:
723
+ # add one self-attention block for cross-attention
724
+ if not hasattr(self, "crossattention"):
725
+ raise ValueError(
726
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
727
+ "cross-attention layers by setting `config.add_cross_attention=True`"
728
+ )
729
+ residual = hidden_states
730
+ hidden_states = self.ln_cross_attn(hidden_states)
731
+ cross_attn_outputs = self.crossattention(
732
+ hidden_states,
733
+ attention_mask=attention_mask,
734
+ head_mask=head_mask,
735
+ encoder_hidden_states=encoder_hidden_states,
736
+ encoder_attention_mask=encoder_attention_mask,
737
+ output_attentions=output_attentions,
738
+ )
739
+ attn_output = cross_attn_outputs[0]
740
+ # residual connection
741
+ hidden_states = residual + attn_output
742
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
743
+
744
+ residual = hidden_states
745
+ hidden_states = self.ln_2(hidden_states)
746
+ feed_forward_hidden_states = self.mlp(hidden_states)
747
+ # residual connection
748
+ hidden_states = residual + feed_forward_hidden_states
749
+
750
+ if use_cache:
751
+ outputs = (hidden_states,) + outputs
752
+ else:
753
+ outputs = (hidden_states,) + outputs[1:]
754
+
755
+ return outputs # hidden_states, present, (attentions, cross_attentions)
756
+
757
+
758
+ class GPTBigCodePreTrainedModel(PreTrainedModel):
759
+ """
760
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
761
+ models.
762
+ """
763
+
764
+ config_class = GPTBigCodeConfig
765
+ base_model_prefix = "transformer"
766
+ supports_gradient_checkpointing = True
767
+ _no_split_modules = ["GPTBigCodeBlock"]
768
+ _skip_keys_device_placement = "past_key_values"
769
+ _supports_flash_attn_2 = True
770
+ _supports_sdpa = True
771
+
772
+ def __init__(self, *inputs, **kwargs):
773
+ super().__init__(*inputs, **kwargs)
774
+
775
+ def _init_weights(self, module):
776
+ """Initialize the weights."""
777
+ if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
778
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
779
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
780
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
781
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
782
+ #
783
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
784
+ module.c_proj.weight.data.normal_(
785
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
786
+ )
787
+ module.c_proj._is_hf_initialized = True
788
+ elif isinstance(module, nn.Linear):
789
+ # Slightly different from the TF version which uses truncated_normal for initialization
790
+ # cf https://github.com/pytorch/pytorch/pull/5617
791
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
792
+ if module.bias is not None:
793
+ module.bias.data.zero_()
794
+ elif isinstance(module, nn.Embedding):
795
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
796
+ if module.padding_idx is not None:
797
+ module.weight.data[module.padding_idx].zero_()
798
+ elif isinstance(module, nn.LayerNorm):
799
+ module.bias.data.zero_()
800
+ module.weight.data.fill_(1.0)
801
+
802
+
803
+ GPT_BIGCODE_START_DOCSTRING = r"""
804
+
805
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
806
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
807
+ etc.)
808
+
809
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
810
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
811
+ and behavior.
812
+
813
+ Parameters:
814
+ config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model.
815
+ Initializing with a config file does not load the weights associated with the model, only the
816
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
817
+ """
818
+
819
+ GPT_BIGCODE_INPUTS_DOCSTRING = r"""
820
+ Args:
821
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
822
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
823
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
824
+ sequence tokens in the vocabulary.
825
+
826
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
827
+ `input_ids`.
828
+
829
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
830
+ [`PreTrainedTokenizer.__call__`] for details.
831
+
832
+ [What are input IDs?](../glossary#input-ids)
833
+ past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
834
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
835
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
836
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
837
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
838
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
839
+
840
+ - 1 for tokens that are **not masked**,
841
+ - 0 for tokens that are **masked**.
842
+
843
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
844
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
845
+ `len(past_key_values) + len(input_ids)`
846
+
847
+ [What are attention masks?](../glossary#attention-mask)
848
+ token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
849
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
850
+ 1]`:
851
+
852
+ - 0 corresponds to a *sentence A* token,
853
+ - 1 corresponds to a *sentence B* token.
854
+
855
+ [What are token type IDs?](../glossary#token-type-ids)
856
+ position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
857
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
858
+ config.max_position_embeddings - 1]`.
859
+
860
+ [What are position IDs?](../glossary#position-ids)
861
+ head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
862
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
863
+
864
+ - 1 indicates the head is **not masked**,
865
+ - 0 indicates the head is **masked**.
866
+
867
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
868
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
869
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
870
+ model's internal embedding lookup matrix.
871
+
872
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
873
+ `past_key_values`).
874
+ use_cache (`bool`, *optional*):
875
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
876
+ `past_key_values`).
877
+ output_attentions (`bool`, *optional*):
878
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
879
+ tensors for more detail.
880
+ output_hidden_states (`bool`, *optional*):
881
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
882
+ more detail.
883
+ return_dict (`bool`, *optional*):
884
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
885
+ """
886
+
887
+
888
+ @add_start_docstrings(
889
+ "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
890
+ GPT_BIGCODE_START_DOCSTRING,
891
+ )
892
+ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
893
+ def __init__(self, config):
894
+ super().__init__(config)
895
+ self.multi_query = config.multi_query
896
+ self.embed_dim = config.hidden_size
897
+
898
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
899
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
900
+
901
+ self.drop = nn.Dropout(config.embd_pdrop)
902
+ self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
903
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
904
+
905
+ max_positions = config.max_position_embeddings
906
+ self.register_buffer(
907
+ "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
908
+ )
909
+
910
+ self.gradient_checkpointing = False
911
+
912
+ self._use_sdpa = config._attn_implementation == "sdpa"
913
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
914
+
915
+ # Initialize weights and apply final processing
916
+ self.post_init()
917
+
918
+ def get_input_embeddings(self):
919
+ return self.wte
920
+
921
+ def set_input_embeddings(self, new_embeddings):
922
+ self.wte = new_embeddings
923
+
924
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
925
+ @add_code_sample_docstrings(
926
+ checkpoint=_CHECKPOINT_FOR_DOC,
927
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
928
+ config_class=_CONFIG_FOR_DOC,
929
+ )
930
+ def forward(
931
+ self,
932
+ input_ids: Optional[torch.Tensor] = None,
933
+ past_key_values: Optional[List[torch.Tensor]] = None,
934
+ attention_mask: Optional[torch.Tensor] = None,
935
+ token_type_ids: Optional[torch.Tensor] = None,
936
+ position_ids: Optional[torch.Tensor] = None,
937
+ head_mask: Optional[torch.Tensor] = None,
938
+ inputs_embeds: Optional[torch.Tensor] = None,
939
+ encoder_hidden_states: Optional[torch.Tensor] = None,
940
+ encoder_attention_mask: Optional[torch.Tensor] = None,
941
+ use_cache: Optional[bool] = None,
942
+ output_attentions: Optional[bool] = None,
943
+ output_hidden_states: Optional[bool] = None,
944
+ return_dict: Optional[bool] = None,
945
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
946
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
947
+ output_hidden_states = (
948
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
949
+ )
950
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
951
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
952
+
953
+ if input_ids is not None and inputs_embeds is not None:
954
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
955
+ elif input_ids is not None:
956
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
957
+ input_shape = input_ids.size()
958
+ input_ids = input_ids.view(-1, input_shape[-1])
959
+ batch_size = input_ids.shape[0]
960
+ elif inputs_embeds is not None:
961
+ input_shape = inputs_embeds.size()[:-1]
962
+ batch_size = inputs_embeds.shape[0]
963
+ else:
964
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
965
+
966
+ if batch_size <= 0:
967
+ raise ValueError("batch_size has to be defined and > 0")
968
+
969
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
970
+
971
+ if token_type_ids is not None:
972
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
973
+
974
+ if past_key_values is None:
975
+ past_length = 0
976
+ past_key_values = tuple([None] * len(self.h))
977
+ else:
978
+ past_length = past_key_values[0].size(-2)
979
+
980
+ if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
981
+ # create position_ids on the fly for batch generation
982
+ position_ids = attention_mask.long().cumsum(-1) - 1
983
+ position_ids.masked_fill_(attention_mask == 0, 1)
984
+ if past_length > 0:
985
+ position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
986
+ elif position_ids is None:
987
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
988
+ position_ids = position_ids.unsqueeze(0)
989
+
990
+ # Self-attention mask.
991
+ query_length = input_shape[-1]
992
+ key_length = past_length + query_length
993
+ self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
994
+
995
+ if self._use_flash_attention_2:
996
+ # 2d mask is passed through the layers
997
+ attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
998
+ encoder_attention_mask = (
999
+ encoder_attention_mask.bool()
1000
+ if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
1001
+ else None
1002
+ )
1003
+ else:
1004
+ # 4d mask is passed through the layers
1005
+ if attention_mask is not None:
1006
+ self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
1007
+ dtype=torch.bool, device=self_attention_mask.device
1008
+ )
1009
+
1010
+ # MQA models: (batch_size, query_length, n_heads, key_length)
1011
+ # MHA models: (batch_size, n_heads, query_length, key_length)
1012
+ self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
1013
+
1014
+ if self._use_sdpa and head_mask is None and not output_attentions:
1015
+ # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
1016
+ dtype = self.wte.weight.dtype
1017
+ min_dtype = torch.finfo(dtype).min
1018
+ self_attention_mask = torch.where(
1019
+ self_attention_mask,
1020
+ torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
1021
+ torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device),
1022
+ )
1023
+
1024
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1025
+ # the manual implementation that requires a 4D causal mask in all cases.
1026
+ if self.multi_query:
1027
+ # gpt_bigcode using MQA has the bad taste to use a causal mask with shape
1028
+ # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
1029
+ self_attention_mask = self_attention_mask.transpose(1, 2)
1030
+
1031
+ if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda":
1032
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
1033
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
1034
+ self_attention_mask = AttentionMaskConverter._unmask_unattended(
1035
+ self_attention_mask, min_dtype=min_dtype
1036
+ )
1037
+
1038
+ attention_mask = self_attention_mask
1039
+
1040
+ # If a 2D or 3D attention mask is provided for the cross-attention
1041
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1042
+ if (
1043
+ self.config.add_cross_attention
1044
+ and encoder_hidden_states is not None
1045
+ and encoder_attention_mask is not None
1046
+ ):
1047
+ if encoder_attention_mask.dim() == 2:
1048
+ encoder_attention_mask.unsqueeze(1)
1049
+ assert encoder_attention_mask.dim() == 3
1050
+ encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
1051
+ else:
1052
+ encoder_attention_mask = None
1053
+
1054
+ # Prepare head mask if needed
1055
+ # 1.0 in head_mask indicate we keep the head
1056
+ # attention_probs has shape bsz x n_heads x N x N
1057
+ # head_mask has shape n_layer x batch x n_heads x N x N
1058
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1059
+
1060
+ if inputs_embeds is None:
1061
+ inputs_embeds = self.wte(input_ids)
1062
+ position_embeds = self.wpe(position_ids)
1063
+ hidden_states = inputs_embeds + position_embeds
1064
+
1065
+ if token_type_ids is not None:
1066
+ token_type_embeds = self.wte(token_type_ids)
1067
+ hidden_states = hidden_states + token_type_embeds
1068
+
1069
+ hidden_states = self.drop(hidden_states)
1070
+
1071
+ output_shape = input_shape + (hidden_states.size(-1),)
1072
+
1073
+ presents = [] if use_cache else None
1074
+ all_self_attentions = () if output_attentions else None
1075
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
1076
+ all_hidden_states = () if output_hidden_states else None
1077
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1078
+ if output_hidden_states:
1079
+ all_hidden_states = all_hidden_states + (hidden_states,)
1080
+
1081
+ if self.gradient_checkpointing and self.training:
1082
+ outputs = self._gradient_checkpointing_func(
1083
+ block.__call__,
1084
+ hidden_states,
1085
+ None,
1086
+ attention_mask,
1087
+ head_mask[i],
1088
+ encoder_hidden_states,
1089
+ encoder_attention_mask,
1090
+ use_cache,
1091
+ output_attentions,
1092
+ )
1093
+ else:
1094
+ outputs = block(
1095
+ hidden_states,
1096
+ layer_past=layer_past,
1097
+ attention_mask=attention_mask,
1098
+ head_mask=head_mask[i],
1099
+ encoder_hidden_states=encoder_hidden_states,
1100
+ encoder_attention_mask=encoder_attention_mask,
1101
+ use_cache=use_cache,
1102
+ output_attentions=output_attentions,
1103
+ )
1104
+
1105
+ hidden_states = outputs[0]
1106
+ if use_cache:
1107
+ presents.append(outputs[1])
1108
+
1109
+ if output_attentions:
1110
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1111
+ if self.config.add_cross_attention:
1112
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1113
+
1114
+ hidden_states = self.ln_f(hidden_states)
1115
+
1116
+ hidden_states = hidden_states.view(output_shape)
1117
+ # Add last hidden state
1118
+ if output_hidden_states:
1119
+ all_hidden_states = all_hidden_states + (hidden_states,)
1120
+
1121
+ if not return_dict:
1122
+ return tuple(
1123
+ v
1124
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1125
+ if v is not None
1126
+ )
1127
+
1128
+ return BaseModelOutputWithPastAndCrossAttentions(
1129
+ last_hidden_state=hidden_states,
1130
+ past_key_values=presents,
1131
+ hidden_states=all_hidden_states,
1132
+ attentions=all_self_attentions,
1133
+ cross_attentions=all_cross_attentions,
1134
+ )
1135
+
1136
+
1137
+ @add_start_docstrings(
1138
+ """
1139
+ The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
1140
+ embeddings).
1141
+ """,
1142
+ GPT_BIGCODE_START_DOCSTRING,
1143
+ )
1144
+ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
1145
+ _tied_weights_keys = ["lm_head.weight"]
1146
+
1147
+ def __init__(self, config):
1148
+ super().__init__(config)
1149
+ self.transformer = GPTBigCodeModel(config)
1150
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1151
+
1152
+ # Initialize weights and apply final processing
1153
+ self.post_init()
1154
+
1155
+ def get_output_embeddings(self):
1156
+ return self.lm_head
1157
+
1158
+ def set_output_embeddings(self, new_embeddings):
1159
+ self.lm_head = new_embeddings
1160
+
1161
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
1162
+ token_type_ids = kwargs.get("token_type_ids", None)
1163
+ # Omit tokens covered by past_key_values
1164
+ if past_key_values:
1165
+ if self.config.multi_query:
1166
+ past_length = past_key_values[0].shape[1]
1167
+ else:
1168
+ past_length = past_key_values[0].shape[2]
1169
+
1170
+ # Some generation methods already pass only the last input ID
1171
+ if input_ids.shape[1] > past_length:
1172
+ remove_prefix_length = past_length
1173
+ else:
1174
+ # Default to old behavior: keep only final ID
1175
+ remove_prefix_length = input_ids.shape[1] - 1
1176
+
1177
+ input_ids = input_ids[:, remove_prefix_length:]
1178
+ if token_type_ids is not None:
1179
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1180
+
1181
+ attention_mask = kwargs.get("attention_mask", None)
1182
+ position_ids = kwargs.get("position_ids", None)
1183
+
1184
+ if attention_mask is not None and position_ids is None:
1185
+ # create position_ids on the fly for batch generation
1186
+ position_ids = attention_mask.long().cumsum(-1) - 1
1187
+ position_ids.masked_fill_(attention_mask == 0, 1)
1188
+ if past_key_values:
1189
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1190
+ else:
1191
+ position_ids = None
1192
+
1193
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1194
+ if inputs_embeds is not None and past_key_values is None:
1195
+ model_inputs = {"inputs_embeds": inputs_embeds}
1196
+ else:
1197
+ model_inputs = {"input_ids": input_ids}
1198
+
1199
+ model_inputs.update(
1200
+ {
1201
+ "past_key_values": past_key_values,
1202
+ "use_cache": kwargs.get("use_cache"),
1203
+ "position_ids": position_ids,
1204
+ "attention_mask": attention_mask,
1205
+ "token_type_ids": token_type_ids,
1206
+ }
1207
+ )
1208
+ return model_inputs
1209
+
1210
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
1211
+ @add_code_sample_docstrings(
1212
+ checkpoint=_CHECKPOINT_FOR_DOC,
1213
+ output_type=CausalLMOutputWithCrossAttentions,
1214
+ config_class=_CONFIG_FOR_DOC,
1215
+ )
1216
+ def forward(
1217
+ self,
1218
+ input_ids: Optional[torch.Tensor] = None,
1219
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ token_type_ids: Optional[torch.Tensor] = None,
1222
+ position_ids: Optional[torch.Tensor] = None,
1223
+ head_mask: Optional[torch.Tensor] = None,
1224
+ inputs_embeds: Optional[torch.Tensor] = None,
1225
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1226
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1227
+ labels: Optional[torch.Tensor] = None,
1228
+ use_cache: Optional[bool] = None,
1229
+ output_attentions: Optional[bool] = None,
1230
+ output_hidden_states: Optional[bool] = None,
1231
+ return_dict: Optional[bool] = None,
1232
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1233
+ r"""
1234
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1235
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1236
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1237
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1238
+ """
1239
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1240
+
1241
+ transformer_outputs = self.transformer(
1242
+ input_ids,
1243
+ past_key_values=past_key_values,
1244
+ attention_mask=attention_mask,
1245
+ token_type_ids=token_type_ids,
1246
+ position_ids=position_ids,
1247
+ head_mask=head_mask,
1248
+ inputs_embeds=inputs_embeds,
1249
+ encoder_hidden_states=encoder_hidden_states,
1250
+ encoder_attention_mask=encoder_attention_mask,
1251
+ use_cache=use_cache,
1252
+ output_attentions=output_attentions,
1253
+ output_hidden_states=output_hidden_states,
1254
+ return_dict=return_dict,
1255
+ )
1256
+ hidden_states = transformer_outputs[0]
1257
+
1258
+ lm_logits = self.lm_head(hidden_states)
1259
+
1260
+ loss = None
1261
+ if labels is not None:
1262
+ # Shift so that tokens < n predict n
1263
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1264
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
1265
+ # Flatten the tokens
1266
+ loss_fct = CrossEntropyLoss()
1267
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1268
+
1269
+ if not return_dict:
1270
+ output = (lm_logits,) + transformer_outputs[1:]
1271
+ return ((loss,) + output) if loss is not None else output
1272
+
1273
+ return CausalLMOutputWithCrossAttentions(
1274
+ loss=loss,
1275
+ logits=lm_logits,
1276
+ past_key_values=transformer_outputs.past_key_values,
1277
+ hidden_states=transformer_outputs.hidden_states,
1278
+ attentions=transformer_outputs.attentions,
1279
+ cross_attentions=transformer_outputs.cross_attentions,
1280
+ )
1281
+
1282
+ @staticmethod
1283
+ def _reorder_cache(
1284
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1285
+ ) -> Tuple[Tuple[torch.Tensor]]:
1286
+ """
1287
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1288
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1289
+ beam_idx at every generation step.
1290
+ """
1291
+ return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
1292
+
1293
+
1294
+ @add_start_docstrings(
1295
+ """
1296
+ The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
1297
+
1298
+ [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
1299
+ models (e.g. GPT-1) do.
1300
+
1301
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1302
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1303
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1304
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1305
+ each row of the batch).
1306
+ """,
1307
+ GPT_BIGCODE_START_DOCSTRING,
1308
+ )
1309
+ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
1310
+ def __init__(self, config):
1311
+ super().__init__(config)
1312
+ self.num_labels = config.num_labels
1313
+ self.transformer = GPTBigCodeModel(config)
1314
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1315
+
1316
+ # Initialize weights and apply final processing
1317
+ self.post_init()
1318
+
1319
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
1320
+ def forward(
1321
+ self,
1322
+ input_ids: Optional[torch.Tensor] = None,
1323
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1324
+ attention_mask: Optional[torch.Tensor] = None,
1325
+ token_type_ids: Optional[torch.Tensor] = None,
1326
+ position_ids: Optional[torch.Tensor] = None,
1327
+ head_mask: Optional[torch.Tensor] = None,
1328
+ inputs_embeds: Optional[torch.Tensor] = None,
1329
+ labels: Optional[torch.Tensor] = None,
1330
+ use_cache: Optional[bool] = None,
1331
+ output_attentions: Optional[bool] = None,
1332
+ output_hidden_states: Optional[bool] = None,
1333
+ return_dict: Optional[bool] = None,
1334
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1335
+ r"""
1336
+ labels (`torch.Tensor` of shape `(batch_size,)`, *optional*):
1337
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1338
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1339
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1340
+ """
1341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1342
+
1343
+ transformer_outputs = self.transformer(
1344
+ input_ids,
1345
+ past_key_values=past_key_values,
1346
+ attention_mask=attention_mask,
1347
+ token_type_ids=token_type_ids,
1348
+ position_ids=position_ids,
1349
+ head_mask=head_mask,
1350
+ inputs_embeds=inputs_embeds,
1351
+ use_cache=use_cache,
1352
+ output_attentions=output_attentions,
1353
+ output_hidden_states=output_hidden_states,
1354
+ return_dict=return_dict,
1355
+ )
1356
+ hidden_states = transformer_outputs[0]
1357
+ logits = self.score(hidden_states)
1358
+
1359
+ if input_ids is not None:
1360
+ batch_size, sequence_length = input_ids.shape[:2]
1361
+ else:
1362
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1363
+
1364
+ assert (
1365
+ self.config.pad_token_id is not None or batch_size == 1
1366
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1367
+ if self.config.pad_token_id is None:
1368
+ sequence_lengths = -1
1369
+ else:
1370
+ if input_ids is not None:
1371
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1372
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1373
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1374
+ sequence_lengths = sequence_lengths.to(logits.device)
1375
+ else:
1376
+ sequence_lengths = -1
1377
+ logger.warning(
1378
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1379
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1380
+ )
1381
+
1382
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1383
+
1384
+ loss = None
1385
+ if labels is not None:
1386
+ labels = labels.to(logits.device)
1387
+
1388
+ if self.config.problem_type is None:
1389
+ if self.num_labels == 1:
1390
+ self.config.problem_type = "regression"
1391
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1392
+ self.config.problem_type = "single_label_classification"
1393
+ else:
1394
+ self.config.problem_type = "multi_label_classification"
1395
+
1396
+ if self.config.problem_type == "regression":
1397
+ loss_fct = MSELoss()
1398
+ if self.num_labels == 1:
1399
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1400
+ else:
1401
+ loss = loss_fct(pooled_logits, labels)
1402
+ elif self.config.problem_type == "single_label_classification":
1403
+ loss_fct = CrossEntropyLoss()
1404
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1405
+ elif self.config.problem_type == "multi_label_classification":
1406
+ loss_fct = BCEWithLogitsLoss()
1407
+ loss = loss_fct(pooled_logits, labels)
1408
+ if not return_dict:
1409
+ output = (pooled_logits,) + transformer_outputs[1:]
1410
+ return ((loss,) + output) if loss is not None else output
1411
+
1412
+ return SequenceClassifierOutputWithPast(
1413
+ loss=loss,
1414
+ logits=pooled_logits,
1415
+ past_key_values=transformer_outputs.past_key_values,
1416
+ hidden_states=transformer_outputs.hidden_states,
1417
+ attentions=transformer_outputs.attentions,
1418
+ )
1419
+
1420
+
1421
+ @add_start_docstrings(
1422
+ """
1423
+ GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
1424
+ for Named-Entity-Recognition (NER) tasks.
1425
+ """,
1426
+ GPT_BIGCODE_START_DOCSTRING,
1427
+ )
1428
+ class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
1429
+ def __init__(self, config):
1430
+ super().__init__(config)
1431
+ self.num_labels = config.num_labels
1432
+
1433
+ self.transformer = GPTBigCodeModel(config)
1434
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1435
+ classifier_dropout = config.classifier_dropout
1436
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1437
+ classifier_dropout = config.hidden_dropout
1438
+ else:
1439
+ classifier_dropout = 0.1
1440
+ self.dropout = nn.Dropout(classifier_dropout)
1441
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1442
+
1443
+ # Initialize weights and apply final processing
1444
+ self.post_init()
1445
+
1446
+ @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
1447
+ def forward(
1448
+ self,
1449
+ input_ids: Optional[torch.Tensor] = None,
1450
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1451
+ attention_mask: Optional[torch.Tensor] = None,
1452
+ token_type_ids: Optional[torch.Tensor] = None,
1453
+ position_ids: Optional[torch.Tensor] = None,
1454
+ head_mask: Optional[torch.Tensor] = None,
1455
+ inputs_embeds: Optional[torch.Tensor] = None,
1456
+ labels: Optional[torch.Tensor] = None,
1457
+ use_cache: Optional[bool] = None,
1458
+ output_attentions: Optional[bool] = None,
1459
+ output_hidden_states: Optional[bool] = None,
1460
+ return_dict: Optional[bool] = None,
1461
+ ) -> Union[Tuple, TokenClassifierOutput]:
1462
+ r"""
1463
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1464
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1465
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1466
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1467
+ """
1468
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1469
+
1470
+ transformer_outputs = self.transformer(
1471
+ input_ids,
1472
+ past_key_values=past_key_values,
1473
+ attention_mask=attention_mask,
1474
+ token_type_ids=token_type_ids,
1475
+ position_ids=position_ids,
1476
+ head_mask=head_mask,
1477
+ inputs_embeds=inputs_embeds,
1478
+ use_cache=use_cache,
1479
+ output_attentions=output_attentions,
1480
+ output_hidden_states=output_hidden_states,
1481
+ return_dict=return_dict,
1482
+ )
1483
+
1484
+ hidden_states = transformer_outputs[0]
1485
+ hidden_states = self.dropout(hidden_states)
1486
+ logits = self.classifier(hidden_states)
1487
+
1488
+ loss = None
1489
+ if labels is not None:
1490
+ loss_fct = CrossEntropyLoss()
1491
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device))
1492
+
1493
+ if not return_dict:
1494
+ output = (logits,) + transformer_outputs[2:]
1495
+ return ((loss,) + output) if loss is not None else output
1496
+
1497
+ return TokenClassifierOutput(
1498
+ loss=loss,
1499
+ logits=logits,
1500
+ hidden_states=transformer_outputs.hidden_states,
1501
+ attentions=transformer_outputs.attentions,
1502
+ )
starvector/model/image_encoder/clip_model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py
2
+
3
+ from collections import OrderedDict
4
+ from itertools import repeat
5
+ import collections.abc
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
11
+
12
+ def convert_weights_to_precision(model: nn.Module, precision: torch.dtype):
13
+ """Convert applicable model parameters to the specified precision"""
14
+
15
+ def _convert_weights_to_precision(l):
16
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
17
+ l.weight.data = l.weight.data.to(precision)
18
+ if l.bias is not None:
19
+ l.bias.data = l.bias.data.to(precision)
20
+
21
+ elif isinstance(l, (nn.MultiheadAttention)):
22
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
23
+ tensor = getattr(l, attr)
24
+ if tensor is not None:
25
+ tensor.data = tensor.data.to(precision)
26
+ else:
27
+ for _, p in l.named_parameters():
28
+ p.data = p.data.to(precision)
29
+
30
+ model.apply(_convert_weights_to_precision)
31
+
32
+ class Bottleneck(nn.Module):
33
+ expansion = 4
34
+
35
+ def __init__(self, inplanes, planes, stride=1):
36
+ super().__init__()
37
+
38
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
39
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
40
+ self.bn1 = nn.BatchNorm2d(planes)
41
+ self.relu1 = nn.ReLU(inplace=True)
42
+
43
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
44
+ self.bn2 = nn.BatchNorm2d(planes)
45
+ self.relu2 = nn.ReLU(inplace=True)
46
+
47
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
48
+
49
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
50
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
51
+ self.relu3 = nn.ReLU(inplace=True)
52
+
53
+ self.downsample = None
54
+ self.stride = stride
55
+
56
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
57
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
58
+ self.downsample = nn.Sequential(OrderedDict([
59
+ ("-1", nn.AvgPool2d(stride)),
60
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
61
+ ("1", nn.BatchNorm2d(planes * self.expansion))
62
+ ]))
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ identity = x
66
+
67
+ out = self.relu1(self.bn1(self.conv1(x)))
68
+ out = self.relu2(self.bn2(self.conv2(out)))
69
+ out = self.avgpool(out)
70
+ out = self.bn3(self.conv3(out))
71
+
72
+ if self.downsample is not None:
73
+ identity = self.downsample(x)
74
+
75
+ out += identity
76
+ out = self.relu3(out)
77
+ return out
78
+
79
+
80
+ class AttentionPool2d(nn.Module):
81
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
82
+ super().__init__()
83
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
84
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
85
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
86
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
87
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
88
+ self.num_heads = num_heads
89
+
90
+ def forward(self, x):
91
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
92
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
93
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
94
+ x, _ = F.multi_head_attention_forward(
95
+ query=x, key=x, value=x,
96
+ embed_dim_to_check=x.shape[-1],
97
+ num_heads=self.num_heads,
98
+ q_proj_weight=self.q_proj.weight,
99
+ k_proj_weight=self.k_proj.weight,
100
+ v_proj_weight=self.v_proj.weight,
101
+ in_proj_weight=None,
102
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
103
+ bias_k=None,
104
+ bias_v=None,
105
+ add_zero_attn=False,
106
+ dropout_p=0,
107
+ out_proj_weight=self.c_proj.weight,
108
+ out_proj_bias=self.c_proj.bias,
109
+ use_separate_proj_weight=True,
110
+ training=self.training,
111
+ need_weights=False
112
+ )
113
+
114
+ return x[0]
115
+
116
+
117
+ class LayerNorm(nn.LayerNorm):
118
+ """Subclass torch's LayerNorm to handle fp16."""
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ orig_type = x.dtype
122
+ layernorm_dtype = self.weight.dtype
123
+ ret = super().forward(x.type(layernorm_dtype))
124
+ return ret.type(orig_type)
125
+
126
+ class QuickGELU(nn.Module):
127
+ def forward(self, x: torch.Tensor):
128
+ return x * torch.sigmoid(1.702 * x)
129
+
130
+ class ResidualAttentionBlock(nn.Module):
131
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
132
+ super().__init__()
133
+
134
+ self.attn = nn.MultiheadAttention(d_model, n_head)
135
+ self.ln_1 = LayerNorm(d_model)
136
+ self.mlp = nn.Sequential(OrderedDict([
137
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
138
+ ("gelu", QuickGELU()),
139
+ ("c_proj", nn.Linear(d_model * 4, d_model))
140
+ ]))
141
+ self.ln_2 = LayerNorm(d_model)
142
+ self.attn_mask = attn_mask
143
+
144
+ if use_grad_checkpointing:
145
+ self.attn = checkpoint_wrapper(self.attn)
146
+ self.mlp = checkpoint_wrapper(self.mlp)
147
+
148
+ def attention(self, x: torch.Tensor):
149
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
150
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ x = x + self.attention(self.ln_1(x))
154
+ x = x + self.mlp(self.ln_2(x))
155
+ return x
156
+
157
+ class Transformer(nn.Module):
158
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
159
+ super().__init__()
160
+ self.width = width
161
+ self.layers = layers
162
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
163
+
164
+ def forward(self, x: torch.Tensor):
165
+ return self.resblocks(x)
166
+
167
+ class VisionTransformer(nn.Module):
168
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
169
+ super().__init__()
170
+ self.input_resolution = input_resolution
171
+ self.num_features = width
172
+ self.num_heads = heads
173
+ self.num_patches = (input_resolution // patch_size) ** 2
174
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
175
+ scale = width ** -0.5
176
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
177
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
178
+ self.ln_pre = LayerNorm(width)
179
+ self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
180
+
181
+ def forward(self, x: torch.Tensor):
182
+ x = self.conv1(x) # shape = [*, width, grid, grid]
183
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
184
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
185
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
186
+ x = x + self.positional_embedding.to(x.dtype)
187
+ x = self.ln_pre(x)
188
+ x = x.permute(1, 0, 2) # NLD -> LND
189
+ x = self.transformer(x)
190
+ x = x.permute(1, 0, 2) # LND -> NLD
191
+ return x
starvector/model/image_encoder/image_encoder.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from omegaconf import OmegaConf
6
+ from starvector.model.image_encoder.clip_model import convert_weights_to_precision
7
+ from starvector.data.util import ImageTrainProcessor
8
+
9
+ class ImageEncoder(nn.Module):
10
+ def __init__(self, config, **kwargs):
11
+ super(ImageEncoder, self).__init__()
12
+
13
+ image_size = config.image_size
14
+ torch_dtype = kwargs.get('model_precision', config.torch_dtype)
15
+ # torch_dtype = torch.float32
16
+ self.image_encoder_type = config.image_encoder_type
17
+ if self.image_encoder_type == 'clip':
18
+ self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size)
19
+ convert_weights_to_precision(self, torch_dtype)
20
+ self.processor = ImageTrainProcessor(size=config.image_size)
21
+
22
+ elif self.image_encoder_type == 'vqgan':
23
+ self.visual_encoder = self.build_vqgan_encoder()
24
+ self.ln_vision = None
25
+ self.processor = ImageTrainProcessor(size=config.image_size)
26
+
27
+ elif self.image_encoder_type == 'convnext':
28
+ self.visual_encoder = self.build_vqgan_encoder()
29
+ self.ln_vision = None
30
+ self.processor = ImageTrainProcessor(size=config.image_size)
31
+
32
+ elif 'siglip' in self.image_encoder_type:
33
+ if self.image_encoder_type == 'siglip_512':
34
+ model_name = "google/siglip-base-patch16-512"
35
+ elif self.image_encoder_type == 'siglip_384':
36
+ model_name = "google/siglip-large-patch16-384"
37
+ elif self.image_encoder_type == 'siglip_256':
38
+ model_name = "google/siglip-base-patch16-256"
39
+
40
+ from transformers import AutoProcessor, AutoModel
41
+
42
+ self.visual_encoder = AutoModel.from_pretrained(
43
+ model_name, torch_dtype = torch_dtype
44
+ ).vision_model
45
+
46
+ self.processor = AutoProcessor.from_pretrained(
47
+ model_name, torch_dtype = torch_dtype
48
+ )
49
+
50
+ def build_clip_encoder(self, image_size):
51
+ from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm
52
+ visual_encoder = VisionTransformer(
53
+ input_resolution=image_size,
54
+ patch_size=14,
55
+ width=1024,
56
+ layers=23,
57
+ heads=16,
58
+ use_grad_checkpointing=False)
59
+
60
+ ln_vision = LayerNorm(visual_encoder.num_features)
61
+ return visual_encoder, ln_vision
62
+
63
+ def build_vqgan_encoder(self):
64
+ from taming.modules.diffusionmodules.model import Encoder
65
+ VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md
66
+ vqgan_chkp_path = VQGAN_CHECKPOINT
67
+ files_in_directory = os.listdir(vqgan_chkp_path + '/configs')
68
+ vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0]
69
+ vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file))
70
+ visual_encoder = Encoder(**vqgan_config.model.params.ddconfig)
71
+
72
+ # Load checkpoint weights
73
+ checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict']
74
+
75
+ # Create a new state_dict with modified keys
76
+ new_state_dict = {}
77
+ for key, value in checkpoint.items():
78
+ if key.startswith('encoder.'):
79
+ new_key = key[len('encoder.'):]
80
+ new_state_dict[new_key] = value
81
+
82
+ # Load weights
83
+ visual_encoder.load_state_dict(new_state_dict)
84
+ return visual_encoder
85
+
86
+ def build_convnext_encoder(self):
87
+ import open_clip
88
+ model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
89
+ return model.visual
90
+
91
+ def forward(self, image):
92
+ if self.image_encoder_type == 'clip':
93
+ embeds = self.visual_encoder(image)
94
+ out = self.ln_vision(embeds)
95
+ elif self.image_encoder_type == 'open-clip':
96
+ out = self.visual_encoder(image)[1]
97
+ out = self.ln_vision(out)
98
+ elif self.image_encoder_type == 'vqgan':
99
+ out = self.visual_encoder(image)
100
+ size = out.size()
101
+ out = out.view(size[0], size[1], -1)
102
+ out = out.permute(0, 2, 1)
103
+ elif self.image_encoder_type == 'convnext':
104
+ out = self.visual_encoder.trunk.forward_features(image)
105
+ size = out.size()
106
+ out = out.view(size[0], size[1], -1)
107
+ out = out.permute(0, 2, 1)
108
+ elif 'siglip' in self.image_encoder_type:
109
+ out = self.visual_encoder(image)["last_hidden_state"]
110
+ return out
111
+
112
+ def process_images(self, images):
113
+ if self.image_encoder_type == 'clip':
114
+ res = []
115
+ for image in images:
116
+ res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W
117
+ return res
118
+ else:
119
+ return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0)
120
+
starvector/model/llm/starcoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import (
3
+ AutoConfig,
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ )
7
+
8
+ class StarCoderModel(nn.Module):
9
+ def __init__(self, config, **kwargs):
10
+ super(StarCoderModel, self).__init__()
11
+
12
+ self.init_tokenizer(config.starcoder_model_name)
13
+
14
+ self.max_length = config.max_length
15
+ model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True)
16
+ kwargs = {}
17
+ kwargs['trust_remote_code'] = True
18
+ kwargs['torch_dtype'] = config.torch_dtype
19
+
20
+ # Configure special tokens for generation
21
+ model_config.eos_token_id = self.tokenizer.eos_token_id
22
+ model_config.pad_token_id = self.tokenizer.pad_token_id
23
+ model_config.bos_token_id = self.tokenizer.bos_token_id
24
+ try:
25
+ model_config.flash_attention = config.use_flash_attn
26
+ model_config._attn_implementation = "flash_attention_2"
27
+ except ImportError:
28
+ config.use_flash_attn = False
29
+
30
+ # model = GPTBigCodeForCausalLM(config=model_config)
31
+ model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs)
32
+ model.resize_token_embeddings(len(self.tokenizer))
33
+ self.transformer = model
34
+
35
+ # Prompt the model after image
36
+ self.prompt = '<svg'
37
+
38
+ def init_tokenizer(self, model_name):
39
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
40
+ # Incude padding and eos tokens in the vocabulary
41
+ if self.tokenizer.eos_token_id is None:
42
+ self.tokenizer.add_special_tokens({"eos_token": "[EOS]"})
43
+ if self.tokenizer.pad_token_id is None:
44
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
45
+
46
+ self.svg_start_token = "<svg-start>"
47
+ self.image_start_token = "<image-start>"
48
+ self.text_start_token = "<caption-start>"
49
+
50
+ self.tokenizer.add_tokens([self.svg_start_token, self.image_start_token, self.text_start_token])
51
+ self.svg_start_token_id = self.tokenizer.encode(self.svg_start_token)[0]
starvector/model/llm/starcoder2.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import (
3
+ AutoConfig,
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ )
7
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
8
+ from functools import partial
9
+ from starvector.train.util import get_module_class_from_name
10
+ import torch
11
+
12
+ class StarCoderModel(nn.Module):
13
+ def __init__(self, config, **kwargs):
14
+ super(StarCoderModel, self).__init__()
15
+
16
+ self.init_tokenizer(config.starcoder_model_name)
17
+
18
+ self.max_length = config.max_length
19
+ model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True)
20
+ model_config.use_cache = config.use_cache
21
+ model_config.use_bfloat16 = True
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ config.starcoder_model_name,
24
+ config=model_config,
25
+ attn_implementation="flash_attention_2",
26
+ torch_dtype=torch.bfloat16,
27
+ trust_remote_code=True)
28
+ model.resize_token_embeddings(len(self.tokenizer))
29
+ self.transformer = model
30
+
31
+ # Prompt the model after image
32
+ self.prompt = '<svg'
33
+
34
+ transformer_layer_cls = kwargs.get('transformer_layer_cls', 'Starcoder2DecoderLayer')
35
+ self.transformer_layer_cls = get_module_class_from_name(self, transformer_layer_cls)
36
+
37
+ def init_tokenizer(self, model_name):
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
39
+ # Incude padding and eos tokens in the vocabulary
40
+ if self.tokenizer.eos_token_id is None:
41
+ self.tokenizer.add_special_tokens({"eos_token": "[EOS]"})
42
+ if self.tokenizer.pad_token_id is None:
43
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
44
+
45
+ self.svg_start_token = "<svg-start>"
46
+ self.svg_end_token = "<svg-end>"
47
+ self.image_start_token = "<image-start>"
48
+ self.text_start_token = "<caption-start>"
49
+
50
+ self.tokenizer.add_tokens([self.svg_start_token, self.image_start_token, self.text_start_token, self.svg_end_token])
51
+ self.svg_start_token_id = self.tokenizer.encode(self.svg_start_token)[0]
52
+ self.svg_end_token_id = self.tokenizer.encode(self.svg_end_token)[0]
53
+ self.tokenizer.padding_side = "left"
54
+
55
+ def get_fsdp_wrapping_policy(self):
56
+ """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
57
+ transformer_block_policy = partial(
58
+ transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
59
+ )
60
+
61
+ return transformer_block_policy
starvector/model/models/starvector_base.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from abc import ABC, abstractmethod
4
+ from starvector.model.adapters.adapter import Adapter
5
+ from starvector.model.image_encoder.image_encoder import ImageEncoder
6
+ from starvector.util import print_trainable_parameters
7
+ from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
8
+
9
+ class StoppingCriteriaSub(StoppingCriteria):
10
+
11
+ def __init__(self, stops=[]):
12
+ super().__init__() # Correct super() call
13
+ self.stops = stops
14
+
15
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
16
+ # Check if any of the stop sequences are in the input_ids
17
+ for stop_ids in self.stops:
18
+ if input_ids[0][-len(stop_ids):].tolist() == stop_ids:
19
+ return True
20
+ return False
21
+
22
+ class StarVectorBase(nn.Module, ABC):
23
+ def __init__(self, config, **kwargs):
24
+ super().__init__()
25
+ # Task-specific layers
26
+ self.task = kwargs.get('task', 'im2svg')
27
+ self.model_precision = kwargs.get('model_precision', config.torch_dtype)
28
+ # Build Code LLM (StarCoder)
29
+ self.svg_transformer = self._get_svg_transformer(config, **kwargs)
30
+
31
+ if self.use_image_encoder():
32
+ # Build Image Encoder
33
+ self.image_encoder = ImageEncoder(config, **kwargs)
34
+
35
+ # Build Adapter
36
+ self.image_projection = self.get_adapter(config, **kwargs).to(dtype=self.model_precision)
37
+ else:
38
+ self.query_length = 0
39
+
40
+ self.max_length = config.max_length_train - self.query_length - 4 # for added special tokens
41
+
42
+ self.train_image_encoder = kwargs.get('train_image_encoder', False)
43
+ self.train_LLM = kwargs.get('train_LLM', False)
44
+ self.train_connector = kwargs.get('train_connector', False)
45
+
46
+ # Freeze parameters
47
+ self.freze_parameters(self.train_image_encoder, self.train_LLM, self.train_connector)
48
+ print_trainable_parameters(self)
49
+
50
+ @abstractmethod
51
+ def _get_svg_transformer(self, config, **kwargs):
52
+ """Get SVG transformer model - implementation differs between versions"""
53
+ pass
54
+
55
+ def freze_parameters(self, train_image_encoder, train_LLM, train_connector):
56
+ """V2 implementation of parameter freezing"""
57
+ if self.use_image_encoder():
58
+ for _, param in self.image_encoder.named_parameters():
59
+ param.requires_grad = train_image_encoder
60
+
61
+ # adapter trainable
62
+ for _, param in self.image_projection.named_parameters():
63
+ param.requires_grad = train_connector
64
+
65
+ for _, param in self.svg_transformer.named_parameters():
66
+ param.requires_grad = train_LLM
67
+
68
+ def use_image_encoder(self):
69
+ """Determine if image encoder should be used based on task"""
70
+ return self.task == 'im2svg'
71
+
72
+ def get_adapter(self, config, **kwargs):
73
+ """Get adapter layer for image projection"""
74
+ vision_hidden_size, self.query_length = self.get_hidden_size_and_query_length(config.image_encoder_type)
75
+ llm_hidden_size = self.svg_transformer.transformer.config.hidden_size
76
+ image_projection = Adapter(
77
+ vision_hidden_size,
78
+ llm_hidden_size,
79
+ adapter_norm=config.adapter_norm,
80
+ query_length=self.query_length,
81
+ dropout_prob=kwargs.get('dropout', 0.1)
82
+ )
83
+ return image_projection
84
+
85
+ def get_hidden_size_and_query_length(self, image_encoder_type):
86
+ """Get hidden size and query length based on encoder type"""
87
+ if image_encoder_type == 'clip':
88
+ hidden_size = self.image_encoder.visual_encoder.num_features
89
+ query_length = 257
90
+ elif image_encoder_type == 'open-clip':
91
+ hidden_size = self.image_encoder.visual_encoder.transformer.width
92
+ query_length = 256
93
+ elif image_encoder_type == 'vqgan':
94
+ hidden_size = 256
95
+ query_length = 196
96
+ elif image_encoder_type == 'convnext':
97
+ hidden_size = 1024
98
+ query_length = 49
99
+ elif 'siglip' in image_encoder_type:
100
+ hidden_size = self.image_encoder.visual_encoder.head.mlp.fc2.out_features
101
+ if '512' in image_encoder_type:
102
+ query_length = 1024
103
+ elif '384' in image_encoder_type:
104
+ query_length = 576
105
+
106
+ return hidden_size, query_length
107
+
108
+ def _tokenize(self, text, max_length, device, add_special_tokens=True):
109
+ """Common tokenization logic"""
110
+ tokens = self.svg_transformer.tokenizer(
111
+ text,
112
+ truncation=True,
113
+ add_special_tokens=add_special_tokens,
114
+ padding='longest',
115
+ max_length=max_length,
116
+ return_tensors="pt"
117
+ ).to(device)
118
+ return tokens
119
+
120
+ def _create_targets(self, tokens):
121
+ """Create targets with padding mask"""
122
+ target_mask = (tokens.input_ids == self.svg_transformer.tokenizer.pad_token_id)
123
+ return tokens.input_ids.masked_fill(target_mask, -100)
124
+
125
+ @abstractmethod
126
+ def _get_embeddings(self, input_ids):
127
+ """Get embeddings from input ids - implementation differs between v1 and v2"""
128
+ pass
129
+
130
+ def embed_text_to_svg(self, batch, device):
131
+ """Common text to SVG embedding logic"""
132
+ captions = batch["caption"]
133
+ svgs = batch["svg"]
134
+ samples = [captions[i] + self.svg_transformer.svg_start_token + svgs[i] + self.svg_transformer.tokenizer.eos_token
135
+ for i in range(len(captions))]
136
+
137
+ tokens = self._tokenize(samples, self.max_length, device)
138
+ targets = self._create_targets(tokens)
139
+ inputs_embeds = self._get_embeddings(tokens.input_ids)
140
+
141
+ return inputs_embeds, tokens.attention_mask, targets
142
+
143
+ def get_image_embeddings(self, batch, device):
144
+ """Get image embeddings"""
145
+ image = batch["image"].to(dtype=self.model_precision)
146
+ embedded_image = self.image_encoder(image)
147
+ conditioning_embeds = self.image_projection(embedded_image)
148
+ return conditioning_embeds
149
+
150
+ def embed_im_to_svg(self, batch, device):
151
+ """Common image to SVG embedding logic"""
152
+ # Process image
153
+ image = batch["image"].to(dtype=self.model_precision)
154
+ embedded_image = self.image_encoder(image)
155
+ conditioning_embeds = self.image_projection(embedded_image)
156
+ conditioning_embeds_att = torch.ones(conditioning_embeds.size()[:-1], dtype=torch.long).to(device)
157
+
158
+ # Get SVG text with appropriate end tokens (implemented by subclasses)
159
+ svg_text = self._get_svg_text(batch["svg"])
160
+
161
+ svg_tokens = self._tokenize(svg_text, self.max_length, device)
162
+ svg_tokens_embeds = self._get_embeddings(svg_tokens.input_ids)
163
+
164
+ inputs_embeds = torch.cat([conditioning_embeds, svg_tokens_embeds], dim=1)
165
+
166
+ svg_targets = self._create_targets(svg_tokens)
167
+ empty_targets = torch.ones(conditioning_embeds_att.size(), dtype=torch.long).to(device).fill_(-100)
168
+ targets = torch.cat([empty_targets, svg_targets], dim=1)
169
+
170
+ attention_mask = torch.cat([conditioning_embeds_att, svg_tokens.attention_mask], dim=1)
171
+
172
+ return inputs_embeds, attention_mask, targets
173
+
174
+ def forward(self, batch):
175
+ """Forward pass"""
176
+ device = batch["image"].device
177
+ task = self.task
178
+
179
+ # Depending
180
+ if task == 'text2svg':
181
+ inputs_embeds, attention_mask, targets = self.embed_text_to_svg(batch, device)
182
+ elif task == 'im2svg':
183
+ inputs_embeds, attention_mask, targets = self.embed_im_to_svg(batch, device)
184
+
185
+ outputs = self.svg_transformer.transformer(
186
+ inputs_embeds=inputs_embeds,
187
+ attention_mask=attention_mask,
188
+ labels=targets,
189
+ return_dict=True,
190
+ output_hidden_states=True,
191
+ use_cache=False,
192
+ )
193
+ loss = outputs.loss
194
+ return loss
195
+
196
+
197
+ @abstractmethod
198
+ def _get_svg_text(self, svg_list):
199
+ """Get SVG text with appropriate end tokens - implementation differs between v1 and v2"""
200
+ pass
201
+
202
+
203
+ def _prepare_generation_inputs(self, batch, prompt, device):
204
+ """Common preparation for generation inputs"""
205
+ image = batch["image"]
206
+ image = image.to(device).to(self.model_precision)
207
+
208
+ embedded_image = self.image_encoder(image)
209
+ embedded_image = self.image_projection(embedded_image)
210
+ embedded_att = torch.ones(embedded_image.size()[:-1], dtype=torch.long).to(device)
211
+
212
+ if prompt is None:
213
+ prompt = self.svg_transformer.prompt
214
+ prompt = [prompt] * image.size(0)
215
+
216
+ prompt_tokens = self._tokenize(prompt, None, device, add_special_tokens=False)
217
+ attention_mask = torch.cat([embedded_att, prompt_tokens.attention_mask], dim=1)
218
+ inputs_embeds = self._get_embeddings(prompt_tokens.input_ids)
219
+ inputs_embeds = torch.cat([embedded_image, inputs_embeds], dim=1)
220
+
221
+ return inputs_embeds, attention_mask, prompt_tokens
222
+
223
+ def _get_generation_kwargs(self, base_kwargs):
224
+ """Common generation kwargs preparation"""
225
+ # Get token IDs for "</svg>"
226
+ end_sequence = self.svg_transformer.tokenizer("</svg>", add_special_tokens=False)['input_ids']
227
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[end_sequence])])
228
+ return {
229
+ 'inputs_embeds': base_kwargs['inputs_embeds'],
230
+ 'attention_mask': base_kwargs['attention_mask'],
231
+ 'do_sample': base_kwargs.get('use_nucleus_sampling', True),
232
+ 'top_p': base_kwargs.get('top_p', 0.9),
233
+ 'temperature': base_kwargs.get('temperature', 1),
234
+ 'num_beams': base_kwargs.get('num_beams', 2),
235
+ 'max_length': base_kwargs.get('max_length', 30),
236
+ 'min_length': base_kwargs.get('min_length', 1),
237
+ 'repetition_penalty': base_kwargs.get('repetition_penalty', 1.0),
238
+ 'length_penalty': base_kwargs.get('length_penalty', 1.0),
239
+ 'use_cache': base_kwargs.get('use_cache', True),
240
+ 'stopping_criteria': stopping_criteria
241
+ }
242
+
243
+ def generate_im2svg(self, batch, **kwargs):
244
+ """Base implementation of image to SVG generation"""
245
+ inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
246
+ batch, kwargs.get('prompt'), batch["image"].device
247
+ )
248
+
249
+ generation_kwargs = self._get_generation_kwargs(
250
+ {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
251
+ )
252
+ # Let subclasses override these defaults if needed
253
+ generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
254
+
255
+ outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
256
+ outputs = torch.cat([prompt_tokens.input_ids, outputs], dim=1)
257
+ raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
258
+
259
+ return raw_svg
260
+
261
+ def generate_im2svg_grpo(self, batch, **kwargs):
262
+ """Base implementation of image to SVG generation"""
263
+ inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
264
+ batch, kwargs.get('prompt'), batch["image"].device
265
+ )
266
+
267
+ generation_kwargs = self._get_generation_kwargs(
268
+ {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
269
+ )
270
+ # Let subclasses override these defaults if needed
271
+ generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
272
+
273
+ num_return_sequences = kwargs.get('num_return_sequences', 1)
274
+ if num_return_sequences > 1:
275
+ generation_kwargs['num_return_sequences'] = num_return_sequences
276
+ generation_kwargs['num_beams'] = 1
277
+
278
+ outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
279
+ outputs = torch.cat([prompt_tokens.input_ids.repeat(num_return_sequences, 1), outputs], dim=1)
280
+ raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
281
+
282
+ return {
283
+ "raw_svg": raw_svg,
284
+ "outputs": outputs,
285
+ "inputs_embeds": inputs_embeds,
286
+ }
287
+
288
+
289
+ def _get_im2svg_specific_kwargs(self, kwargs):
290
+ """Default implementation of im2svg specific generation kwargs.
291
+ Subclasses can override this to customize generation behavior."""
292
+ return {
293
+ 'early_stopping': True,
294
+ 'pad_token_id': self.svg_transformer.tokenizer.pad_token_id
295
+ }
296
+
297
+ def generate_text2svg(self, batch, **kwargs):
298
+ """Base implementation of text to SVG generation"""
299
+ device = batch["image"].device
300
+ prompt = batch["caption"]
301
+
302
+ prompt_tokens = self._tokenize(
303
+ prompt,
304
+ max_length=kwargs.get('max_length', 30),
305
+ device=device,
306
+ add_special_tokens=False
307
+ )
308
+
309
+ trigger_token = self._tokenize(
310
+ [self.svg_transformer.svg_start_token for _ in batch["caption"]],
311
+ max_length=None,
312
+ device=device,
313
+ add_special_tokens=False
314
+ )
315
+
316
+ input_tokens = torch.cat([prompt_tokens.input_ids, trigger_token.input_ids], dim=1)
317
+ attention_mask = torch.cat([prompt_tokens.attention_mask, trigger_token.attention_mask], dim=1)
318
+ inputs_embeds = self._get_embeddings(input_tokens)
319
+ max_length = kwargs.get('max_length', 30) - input_tokens.size(1)
320
+
321
+ generation_kwargs = self._get_generation_kwargs(
322
+ {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask},
323
+ input_tokens.size(1)
324
+ )
325
+ # Let subclasses override these defaults if needed
326
+ generation_kwargs.update(self._get_text2svg_specific_kwargs(kwargs))
327
+ generation_kwargs['max_length'] = max_length
328
+
329
+ outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
330
+ return outputs
331
+
332
+ def _get_text2svg_specific_kwargs(self, kwargs):
333
+ """Default implementation of text2svg specific generation kwargs.
334
+ Subclasses can override this to customize generation behavior."""
335
+ return {
336
+ 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
337
+ 'early_stopping': True,
338
+ 'length_penalty': kwargs.get('length_penalty', 1.0)
339
+ }
starvector/model/models/starvector_v1.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from starvector.model.models.starvector_base import StarVectorBase
4
+ from transformers import AutoProcessor
5
+
6
+ class StarVectorStarCoder(StarVectorBase):
7
+ def __init__(self, config, **kwargs):
8
+ super().__init__(config, **kwargs)
9
+
10
+ self.processor = AutoProcessor.from_pretrained(config._name_or_path)
11
+
12
+ def _get_svg_transformer(self, config, **kwargs):
13
+ from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1)
14
+ return StarCoderModel(config, **kwargs)
15
+
16
+ def _get_embeddings(self, input_ids):
17
+ """V1 specific embedding method"""
18
+ return self.svg_transformer.transformer.transformer.wte(input_ids)
19
+
20
+ def _get_svg_text(self, svg_list):
21
+ """V1 specific SVG text preparation"""
22
+ return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list]
starvector/model/models/starvector_v2.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
4
+ from functools import partial
5
+ from starvector.model.models.starvector_base import StarVectorBase
6
+ from transformers import AutoImageProcessor
7
+
8
+ class StarVectorStarCoder2(StarVectorBase):
9
+ def __init__(self, config, **kwargs):
10
+ super().__init__(config, **kwargs)
11
+
12
+ self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
13
+
14
+ def _get_svg_transformer(self, config, **kwargs):
15
+ from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2
16
+ return StarCoderModel(config, **kwargs)
17
+
18
+
19
+ def get_fsdp_wrapping_policy(self):
20
+ """V2 specific FSDP wrapping policy"""
21
+ from starvector.model.image_encoder.image_encoder import ImageEncoder
22
+
23
+ image_encoder_wrapping_policy = partial(
24
+ _module_wrap_policy,
25
+ module_classes={ImageEncoder},
26
+ )
27
+
28
+ llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy()
29
+ from starvector.model.adapters.adapter import Adapter
30
+
31
+ adapter_wrapping_policy = partial(
32
+ _module_wrap_policy,
33
+ module_classes={Adapter},
34
+ )
35
+
36
+ return partial(
37
+ _or_policy,
38
+ policies=[
39
+ image_encoder_wrapping_policy,
40
+ llm_fsdp_wrapping_policy,
41
+ adapter_wrapping_policy,
42
+ ],
43
+ )
44
+
45
+ def _get_embeddings(self, input_ids):
46
+ """V2 specific embedding method"""
47
+ return self.svg_transformer.transformer.model.embed_tokens(input_ids)
48
+
49
+ def _get_svg_text(self, svg_list):
50
+ """V2 specific SVG text preparation"""
51
+ return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list]
52
+
53
+ def _get_im2svg_specific_kwargs(self, kwargs):
54
+ """V2 specific generation kwargs"""
55
+ return {
56
+ # 'eos_token_id': self.svg_transformer.svg_end_token_id,
57
+ }
58
+
59
+ def _get_text2svg_specific_kwargs(self, kwargs):
60
+ """V2 specific text2svg generation kwargs"""
61
+ return {
62
+ 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
63
+ }
starvector/model/starvector_arch.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ PretrainedConfig,
3
+ PreTrainedModel
4
+ )
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions
7
+ from typing import Optional, Tuple, Union
8
+ import torch
9
+
10
+ from transformers.processing_utils import ProcessorMixin
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode, pad
13
+ from transformers.feature_extraction_sequence_utils import BatchFeature
14
+ from transformers import AutoProcessor
15
+
16
+ class SimpleStarVectorProcessor(ProcessorMixin):
17
+ attributes = ["tokenizer"] # Only include tokenizer in attributes
18
+ valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs
19
+ image_processor_class = "AutoImageProcessor"
20
+ tokenizer_class = "AutoTokenizer"
21
+
22
+ def __init__(self,
23
+ tokenizer=None, # Make tokenizer the first argument
24
+ size=224,
25
+ mean=None,
26
+ std=None,
27
+ **kwargs,
28
+ ):
29
+ if mean is None:
30
+ mean = (0.48145466, 0.4578275, 0.40821073)
31
+ if std is None:
32
+ std = (0.26862954, 0.26130258, 0.27577711)
33
+
34
+ # Store these as instance variables
35
+ self.mean = mean
36
+ self.std = std
37
+ self.size = size
38
+ self.normalize = transforms.Normalize(mean=mean, std=std)
39
+
40
+ self.transform = transforms.Compose([
41
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img),
42
+ transforms.Lambda(lambda img: self._pad_to_square(img)),
43
+ transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
44
+ transforms.ToTensor(),
45
+ self.normalize
46
+ ])
47
+
48
+ # Initialize parent class with tokenizer
49
+ super().__init__(tokenizer=tokenizer)
50
+
51
+
52
+ def __call__(self, images=None, text=None, max_length=None, **kwargs) -> BatchFeature:
53
+ """
54
+ Process images and/or text inputs.
55
+
56
+ Args:
57
+ images: Optional image input(s)
58
+ text: Optional text input(s)
59
+ **kwargs: Additional arguments
60
+ """
61
+ if images is None and text is None:
62
+ raise ValueError("You have to specify at least one of `images` or `text`.")
63
+
64
+ image_inputs = {}
65
+ if images is not None:
66
+ if isinstance(images, (list, tuple)):
67
+ images_ = torch.stack([self.transform(img) for img in images])
68
+ else:
69
+ images_ = self.transform(images)
70
+ image_inputs = {"pixel_values": images_}
71
+
72
+ text_inputs = {}
73
+ if text is not None:
74
+ text_inputs = self.tokenizer(
75
+ text, truncation=True,
76
+ add_special_tokens=True,
77
+ padding='longest',
78
+ max_length=max_length,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ return BatchFeature(data={**text_inputs, **image_inputs})
83
+
84
+ def _pad_to_square(self, img):
85
+ # Calculate padding to make the image square
86
+ width, height = img.size
87
+ max_dim = max(width, height)
88
+ padding = [(max_dim - width) // 2, (max_dim - height) // 2]
89
+ padding += [max_dim - width - padding[0], max_dim - height - padding[1]]
90
+ return pad(img, padding, fill=255) # Assuming white padding
91
+
92
+
93
+ AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor)
94
+
95
+
96
+ class StarVectorConfig(PretrainedConfig):
97
+ model_type = "starvector"
98
+
99
+ def __init__(
100
+ self,
101
+ starcoder_model_name: str = "bigcode/starcoderbase-1b",
102
+ image_encoder_type: str = "clip",
103
+ adapter_norm: str = "layer_norm",
104
+ image_size: int = 224,
105
+ max_length: int = 8192,
106
+ max_length_train: int = 8192,
107
+ use_flash_attn: bool = True,
108
+ use_cache: bool = True,
109
+ num_attention_heads: int = 16,
110
+ num_hidden_layers: int = 24,
111
+ vocab_size: int = 49152,
112
+ hidden_size: int = 2048,
113
+ num_kv_heads: int = 4,
114
+ torch_dtype: str = "bfloat16",
115
+ **kwargs,
116
+ ):
117
+ kwargs["torch_dtype"] = torch_dtype
118
+ self.starcoder_model_name = starcoder_model_name
119
+ self.image_encoder_type = image_encoder_type
120
+ self.adapter_norm = adapter_norm
121
+ self.image_size = image_size
122
+ self.max_length = max_length
123
+ self.max_length_train = max_length_train
124
+ self.use_flash_attn = use_flash_attn
125
+ self.use_cache = use_cache
126
+ self.num_attention_heads = num_attention_heads
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.vocab_size = vocab_size
129
+ self.hidden_size = hidden_size
130
+ self.num_kv_heads = num_kv_heads
131
+ super().__init__(**kwargs)
132
+
133
+ class StarVectorForCausalLM(PreTrainedModel):
134
+ config_class = StarVectorConfig
135
+ _no_split_modules = []
136
+
137
+ def __init__(self, config: StarVectorConfig, **kwargs):
138
+ super().__init__(config)
139
+ starcoder_model_name = config.starcoder_model_name
140
+ if 'starcoder2' in starcoder_model_name:
141
+ from starvector.model.models.starvector_v2 import StarVectorStarCoder2
142
+ self.model = StarVectorStarCoder2(config=config, **kwargs)
143
+ else:
144
+ from starvector.model.models.starvector_v1 import StarVectorStarCoder
145
+ self.model = StarVectorStarCoder(config=config, **kwargs)
146
+
147
+
148
+ @property
149
+ def supports_gradient_checkpointing(self):
150
+ # If the underlying transformer (e.g., the one in StarCoderModel)
151
+ # supports gradient checkpointing, delegate to it.
152
+ if hasattr(self.model, 'svg_transformer'):
153
+ return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False)
154
+ return False
155
+
156
+ def gradient_checkpointing_enable(self):
157
+ # Optionally, forward this call to the internal transformer.
158
+ if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
159
+ self.model.svg_transformer.gradient_checkpointing_enable()
160
+
161
+ def forward(self, vision_embeds, input_ids, num_generations, attention_mask, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
162
+ completion_embeds = self.model._get_embeddings(input_ids)
163
+ inputs_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
164
+
165
+ transformer_outputs = self.model.svg_transformer.transformer.transformer(
166
+ inputs_embeds=inputs_embeds,
167
+ attention_mask=attention_mask,
168
+ )
169
+ hidden_states = transformer_outputs[0]
170
+
171
+ if num_logits_to_keep > 0:
172
+ lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
173
+ else:
174
+ lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states)
175
+
176
+ loss = None
177
+ return CausalLMOutputWithCrossAttentions(
178
+ loss=loss,
179
+ logits=lm_logits,
180
+ past_key_values=transformer_outputs.past_key_values,
181
+ hidden_states=transformer_outputs.hidden_states,
182
+ attentions=transformer_outputs.attentions,
183
+ cross_attentions=transformer_outputs.cross_attentions,
184
+ )
185
+
186
+ def generate_im2svg(self, batch, **kwargs):
187
+ return self.model.generate_im2svg(batch, **kwargs)
188
+
189
+ def generate_im2text(self, batch, **kwargs):
190
+ return self.model.generate_im2text(batch, **kwargs)
191
+
192
+ def process_images(self, images):
193
+ return self.model.image_encoder.process_images(images)
194
+
starvector/serve/__init__.py ADDED
File without changes
starvector/serve/constants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
14
+
15
+ CLIP_QUERY_LENGTH = 257
16
+
starvector/serve/controller.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from starvector.serve.util import build_logger, server_error_msg
23
+
24
+ logger = build_logger("controller", "controller.log")
25
+
26
+ class DispatchMethod(Enum):
27
+ LOTTERY = auto()
28
+ SHORTEST_QUEUE = auto()
29
+
30
+ @classmethod
31
+ def from_str(cls, name):
32
+ if name == "lottery":
33
+ return cls.LOTTERY
34
+ elif name == "shortest_queue":
35
+ return cls.SHORTEST_QUEUE
36
+ else:
37
+ raise ValueError(f"Invalid dispatch method")
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class WorkerInfo:
42
+ model_names: List[str]
43
+ speed: int
44
+ queue_length: int
45
+ check_heart_beat: bool
46
+ last_heart_beat: str
47
+
48
+
49
+ def heart_beat_controller(controller):
50
+ while True:
51
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
52
+ controller.remove_stable_workers_by_expiration()
53
+
54
+
55
+ class Controller:
56
+ def __init__(self, dispatch_method: str):
57
+ # Dict[str -> WorkerInfo]
58
+ self.worker_info = {}
59
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
60
+
61
+ self.heart_beat_thread = threading.Thread(
62
+ target=heart_beat_controller, args=(self,))
63
+ self.heart_beat_thread.start()
64
+
65
+ logger.info("Init controller")
66
+
67
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
68
+ worker_status: dict):
69
+ if worker_name not in self.worker_info:
70
+ logger.info(f"Register a new worker: {worker_name}")
71
+ else:
72
+ logger.info(f"Register an existing worker: {worker_name}")
73
+
74
+ if not worker_status:
75
+ worker_status = self.get_worker_status(worker_name)
76
+ if not worker_status:
77
+ return False
78
+
79
+ self.worker_info[worker_name] = WorkerInfo(
80
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
81
+ check_heart_beat, time.time())
82
+
83
+ logger.info(f"Register done: {worker_name}, {worker_status}")
84
+ return True
85
+
86
+ def get_worker_status(self, worker_name: str):
87
+ try:
88
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
89
+ except requests.exceptions.RequestException as e:
90
+ logger.error(f"Get status fails: {worker_name}, {e}")
91
+ return None
92
+
93
+ if r.status_code != 200:
94
+ logger.error(f"Get status fails: {worker_name}, {r}")
95
+ return None
96
+
97
+ return r.json()
98
+
99
+ def remove_worker(self, worker_name: str):
100
+ del self.worker_info[worker_name]
101
+
102
+ def refresh_all_workers(self):
103
+ old_info = dict(self.worker_info)
104
+ self.worker_info = {}
105
+
106
+ for w_name, w_info in old_info.items():
107
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
108
+ logger.info(f"Remove stale worker: {w_name}")
109
+
110
+ def list_models(self):
111
+ model_names = set()
112
+
113
+ for w_name, w_info in self.worker_info.items():
114
+ model_names.update(w_info.model_names)
115
+
116
+ return list(model_names)
117
+
118
+ def get_worker_address(self, model_name: str):
119
+ if self.dispatch_method == DispatchMethod.LOTTERY:
120
+ worker_names = []
121
+ worker_speeds = []
122
+ for w_name, w_info in self.worker_info.items():
123
+ if model_name in w_info.model_names:
124
+ worker_names.append(w_name)
125
+ worker_speeds.append(w_info.speed)
126
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
127
+ norm = np.sum(worker_speeds)
128
+ if norm < 1e-4:
129
+ return ""
130
+ worker_speeds = worker_speeds / norm
131
+ if True: # Directly return address
132
+ pt = np.random.choice(np.arange(len(worker_names)),
133
+ p=worker_speeds)
134
+ worker_name = worker_names[pt]
135
+ return worker_name
136
+
137
+ # Check status before returning
138
+ while True:
139
+ pt = np.random.choice(np.arange(len(worker_names)),
140
+ p=worker_speeds)
141
+ worker_name = worker_names[pt]
142
+
143
+ if self.get_worker_status(worker_name):
144
+ break
145
+ else:
146
+ self.remove_worker(worker_name)
147
+ worker_speeds[pt] = 0
148
+ norm = np.sum(worker_speeds)
149
+ if norm < 1e-4:
150
+ return ""
151
+ worker_speeds = worker_speeds / norm
152
+ continue
153
+ return worker_name
154
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
155
+ worker_names = []
156
+ worker_qlen = []
157
+ for w_name, w_info in self.worker_info.items():
158
+ if model_name in w_info.model_names:
159
+ worker_names.append(w_name)
160
+ worker_qlen.append(w_info.queue_length / w_info.speed)
161
+ if len(worker_names) == 0:
162
+ return ""
163
+ min_index = np.argmin(worker_qlen)
164
+ w_name = worker_names[min_index]
165
+ self.worker_info[w_name].queue_length += 1
166
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
167
+ return w_name
168
+ else:
169
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
170
+
171
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
172
+ if worker_name not in self.worker_info:
173
+ logger.info(f"Receive unknown heart beat. {worker_name}")
174
+ return False
175
+
176
+ self.worker_info[worker_name].queue_length = queue_length
177
+ self.worker_info[worker_name].last_heart_beat = time.time()
178
+ logger.info(f"Receive heart beat. {worker_name}")
179
+ return True
180
+
181
+ def remove_stable_workers_by_expiration(self):
182
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
183
+ to_delete = []
184
+ for worker_name, w_info in self.worker_info.items():
185
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
186
+ to_delete.append(worker_name)
187
+
188
+ for worker_name in to_delete:
189
+ self.remove_worker(worker_name)
190
+
191
+ def worker_api_generate_stream(self, params):
192
+ worker_addr = self.get_worker_address(params["model"])
193
+ if not worker_addr:
194
+ logger.info(f"no worker: {params['model']}")
195
+ ret = {
196
+ "text": server_error_msg,
197
+ "error_code": 2,
198
+ }
199
+ yield json.dumps(ret).encode() + b"\0"
200
+
201
+ try:
202
+ response = requests.post(worker_addr + "/worker_generate_stream",
203
+ json=params, stream=True, timeout=5)
204
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
+ if chunk:
206
+ yield chunk + b"\0"
207
+ except requests.exceptions.RequestException as e:
208
+ logger.info(f"worker timeout: {worker_addr}")
209
+ ret = {
210
+ "text": server_error_msg,
211
+ "error_code": 3,
212
+ }
213
+ yield json.dumps(ret).encode() + b"\0"
214
+
215
+
216
+ # Let the controller act as a worker to achieve hierarchical
217
+ # management. This can be used to connect isolated sub networks.
218
+ def worker_api_get_status(self):
219
+ model_names = set()
220
+ speed = 0
221
+ queue_length = 0
222
+
223
+ for w_name in self.worker_info:
224
+ worker_status = self.get_worker_status(w_name)
225
+ if worker_status is not None:
226
+ model_names.update(worker_status["model_names"])
227
+ speed += worker_status["speed"]
228
+ queue_length += worker_status["queue_length"]
229
+
230
+ return {
231
+ "model_names": list(model_names),
232
+ "speed": speed,
233
+ "queue_length": queue_length,
234
+ }
235
+
236
+
237
+ app = FastAPI()
238
+
239
+ @app.post("/register_worker")
240
+ async def register_worker(request: Request):
241
+ data = await request.json()
242
+ controller.register_worker(
243
+ data["worker_name"], data["check_heart_beat"],
244
+ data.get("worker_status", None))
245
+
246
+ @app.post("/refresh_all_workers")
247
+ async def refresh_all_workers():
248
+ models = controller.refresh_all_workers()
249
+
250
+
251
+ @app.post("/list_models")
252
+ async def list_models():
253
+ models = controller.list_models()
254
+ return {"models": models}
255
+
256
+
257
+ @app.post("/get_worker_address")
258
+ async def get_worker_address(request: Request):
259
+ data = await request.json()
260
+ addr = controller.get_worker_address(data["model"])
261
+ return {"address": addr}
262
+
263
+ @app.post("/receive_heart_beat")
264
+ async def receive_heart_beat(request: Request):
265
+ data = await request.json()
266
+ exist = controller.receive_heart_beat(
267
+ data["worker_name"], data["queue_length"])
268
+ return {"exist": exist}
269
+
270
+
271
+ @app.post("/worker_generate_stream")
272
+ async def worker_api_generate_stream(request: Request):
273
+ params = await request.json()
274
+ generator = controller.worker_api_generate_stream(params)
275
+ return StreamingResponse(generator)
276
+
277
+
278
+ @app.post("/worker_get_status")
279
+ async def worker_api_get_status(request: Request):
280
+ return controller.worker_api_get_status()
281
+
282
+
283
+ if __name__ == "__main__":
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--host", type=str, default="localhost")
286
+ parser.add_argument("--port", type=int, default=21001)
287
+ parser.add_argument("--dispatch-method", type=str, choices=[
288
+ "lottery", "shortest_queue"], default="shortest_queue")
289
+ args = parser.parse_args()
290
+ logger.info(f"args: {args}")
291
+
292
+ controller = Controller(args.dispatch_method)
293
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
starvector/serve/conversation.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import List
3
+ from PIL import Image
4
+ import concurrent.futures
5
+ from bs4 import BeautifulSoup
6
+ import cairosvg
7
+ from io import BytesIO
8
+
9
+ @dataclasses.dataclass
10
+ class Conversation:
11
+ """A class that keeps all conversation history."""
12
+ system: str
13
+ image_prompt: str
14
+ roles: List[str]
15
+ messages: List[List[str]]
16
+ offset: int
17
+ version: str = "Unknown"
18
+ stop_sampling: bool = False
19
+ skip_next: bool = False
20
+ display_images: bool = False
21
+ task: str = "Im2SVG"
22
+
23
+ def set_task(self, task):
24
+ self.task = task
25
+
26
+ def get_image_prompt(self):
27
+ return self.image_prompt
28
+
29
+ def get_images(self, return_pil=False):
30
+ images = []
31
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
32
+ if i % 2 == 0:
33
+ if type(msg) is tuple:
34
+ import base64
35
+ from io import BytesIO
36
+ from PIL import Image
37
+ image, image_process_mode = msg
38
+ if image_process_mode == "Pad":
39
+ def expand2square(pil_img, background_color=(255, 255, 255)):
40
+ width, height = pil_img.size
41
+ if width == height:
42
+ return pil_img
43
+ elif width > height:
44
+ result = Image.new(pil_img.mode, (width, width), background_color)
45
+ result.paste(pil_img, (0, (width - height) // 2))
46
+ return result
47
+ else:
48
+ result = Image.new(pil_img.mode, (height, height), background_color)
49
+ result.paste(pil_img, ((height - width) // 2, 0))
50
+ return result
51
+ image = expand2square(image)
52
+ elif image_process_mode in ["Default", "Crop"]:
53
+ pass
54
+ elif image_process_mode == "Resize":
55
+ image = image.resize((224, 224))
56
+ else:
57
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
58
+ max_hw, min_hw = max(image.size), min(image.size)
59
+ aspect_ratio = max_hw / min_hw
60
+ max_len, min_len = 800, 400
61
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
62
+ longest_edge = int(shortest_edge * aspect_ratio)
63
+ W, H = image.size
64
+ if longest_edge != max(image.size):
65
+ if H > W:
66
+ H, W = longest_edge, shortest_edge
67
+ else:
68
+ H, W = shortest_edge, longest_edge
69
+ image = image.resize((W, H))
70
+ if return_pil:
71
+ images.append(image)
72
+ else:
73
+ buffered = BytesIO()
74
+ image.save(buffered, format="PNG")
75
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
76
+ images.append(img_b64_str)
77
+ return images
78
+
79
+ def append_message(self, role, message):
80
+ self.messages.append([role, message])
81
+
82
+ def download_files(self):
83
+ svg_string = self.messages[-1][-1][:-1]
84
+ image = self.render_svg(svg_string)
85
+ svg_out = clean_svg(svg_string)
86
+
87
+ return image, svg_out
88
+
89
+ def rasterize_svg(self, svg_string, resolution=224, dpi = 128, scale=2):
90
+ try:
91
+ svg_raster_bytes = cairosvg.svg2png(
92
+ bytestring=svg_string,
93
+ background_color='white',
94
+ output_width=resolution,
95
+ output_height=resolution,
96
+ dpi=dpi,
97
+ scale=scale)
98
+ svg_raster = Image.open(BytesIO(svg_raster_bytes))
99
+ except:
100
+ try:
101
+ svg = self.clean_svg(svg_string)
102
+ svg_raster_bytes = cairosvg.svg2png(
103
+ bytestring=svg,
104
+ background_color='white',
105
+ output_width=resolution,
106
+ output_height=resolution,
107
+ dpi=dpi,
108
+ scale=scale)
109
+ svg_raster = Image.open(BytesIO(svg_raster_bytes))
110
+ except:
111
+ svg_raster = Image.new('RGB', (resolution, resolution), color = 'white')
112
+ return svg_raster
113
+
114
+ def clean_svg(self, svg_text, output_width=None, output_height=None):
115
+ soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml
116
+ svg_bs4 = soup.prettify() # Prettify to get a string
117
+ svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode()
118
+ svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("<?xml")]) # Remove xml header
119
+ return svg_clean
120
+
121
+ def render_svg(self, svg_string):
122
+ with concurrent.futures.ThreadPoolExecutor() as executor:
123
+ future = executor.submit(self.rasterize_svg, svg_string, resolution = 512)
124
+ try:
125
+ result = future.result(timeout=0.1) # Specify the timeout duration in seconds
126
+ except concurrent.futures.TimeoutError:
127
+ print("Timeout occurred!")
128
+ result = None
129
+ return result
130
+
131
+ def to_gradio_svg_render(self):
132
+ svg_string = self.messages[-1][-1][:-1]
133
+ result = self.render_svg(svg_string)
134
+ return result
135
+
136
+ def to_gradio_svg_code(self):
137
+ ret = []
138
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
139
+ if i % 2 == 0:
140
+ if type(msg) is tuple:
141
+ import base64
142
+ from io import BytesIO
143
+ image, image_process_mode = msg
144
+ max_hw, min_hw = max(image.size), min(image.size)
145
+ aspect_ratio = max_hw / min_hw
146
+ max_len, min_len = 800, 400
147
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148
+ longest_edge = int(shortest_edge * aspect_ratio)
149
+ W, H = image.size
150
+ if H > W:
151
+ H, W = longest_edge, shortest_edge
152
+ else:
153
+ H, W = shortest_edge, longest_edge
154
+ image = image.resize((W, H))
155
+ buffered = BytesIO()
156
+ image.save(buffered, format="JPEG")
157
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
158
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
159
+ msg = img_str
160
+ ret.append([msg, None])
161
+ else:
162
+ ret.append([msg, None])
163
+ else:
164
+ ret[-1][-1] = msg
165
+ return ret
166
+
167
+ def copy(self):
168
+ return Conversation(
169
+ system=self.system,
170
+ image_prompt=self.image_prompt,
171
+ roles=self.roles,
172
+ messages=[[x, y] for x, y in self.messages],
173
+ offset=self.offset,
174
+ version=self.version
175
+
176
+ )
177
+ def dict(self):
178
+ if len(self.get_images()) > 0:
179
+ return {
180
+ "system": self.system,
181
+ "image_prompt": self.image_prompt,
182
+ "roles": self.roles,
183
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
184
+ "offset": self.offset,
185
+ }
186
+ return {
187
+ "system": self.system,
188
+ "image_prompt": self.image_prompt,
189
+ "roles": self.roles,
190
+ "messages": self.messages,
191
+ "offset": self.offset,
192
+ }
193
+
194
+ starvector_v1 = Conversation(
195
+ system="StarVector",
196
+ # prompt='<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 32 32" version="1.1">',
197
+ image_prompt='<svg',
198
+ roles=("Human", "StarVector"),
199
+ version="v1",
200
+ messages=(
201
+ ),
202
+ offset=0,
203
+ task="Im2SVG",
204
+ )
205
+ default_conversation = starvector_v1
206
+ conv_templates = {
207
+ "default": default_conversation,
208
+ }
209
+
210
+ if __name__ == "__main__":
211
+ print(default_conversation.get_image_prompt())
starvector/serve/gradio_demo_with_updated_gradio.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ from starvector.serve.conversation import default_conversation
9
+ from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
+ from starvector.serve.util import (build_logger, server_error_msg)
11
+
12
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
+ headers = {"User-Agent": "StarVector Client"}
14
+
15
+ no_change_btn = gr.Button()
16
+ enable_btn = gr.Button(interactive=True)
17
+ disable_btn = gr.Button(interactive=False)
18
+
19
+ priority = {
20
+ "starvector-1.4b": "aaaaaaa",
21
+ }
22
+
23
+ def get_conv_log_filename():
24
+ t = datetime.datetime.now()
25
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
+ return name
27
+
28
+ def get_model_list():
29
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
30
+ assert ret.status_code == 200
31
+ ret = requests.post(args.controller_url + "/list_models")
32
+ models = ret.json()["models"]
33
+ models.sort(key=lambda x: priority.get(x, x))
34
+ logger.info(f"Models: {models}")
35
+ return models
36
+
37
+ get_window_url_params = """
38
+ function() {
39
+ const params = new URLSearchParams(window.location.search);
40
+ url_params = Object.fromEntries(params);
41
+ console.log(url_params);
42
+ return url_params;
43
+ }
44
+ """
45
+
46
+ def load_demo(url_params, request: gr.Request):
47
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
48
+
49
+ dropdown_update = gr.Dropdown(visible=True)
50
+ if "model" in url_params:
51
+ model = url_params["model"]
52
+ if model in models:
53
+ dropdown_update = gr.Dropdown(value=model, visible=True)
54
+
55
+ state = default_conversation.copy()
56
+ return state, dropdown_update
57
+
58
+
59
+ def load_demo_refresh_model_list(request: gr.Request):
60
+ logger.info(f"load_demo. ip: {request.client.host}")
61
+ models = get_model_list()
62
+ state = default_conversation.copy()
63
+ dropdown_update = gr.Dropdown(
64
+ choices=models,
65
+ value=models[0] if len(models) > 0 else ""
66
+ )
67
+ return state, dropdown_update
68
+
69
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
70
+ with open(get_conv_log_filename(), "a") as fout:
71
+ data = {
72
+ "tstamp": round(time.time(), 4),
73
+ "type": vote_type,
74
+ "model": model_selector,
75
+ "state": state.dict(),
76
+ "ip": request.client.host,
77
+ }
78
+ fout.write(json.dumps(data) + "\n")
79
+
80
+ def upvote_last_response(state, model_selector, request: gr.Request):
81
+ logger.info(f"upvote. ip: {request.client.host}")
82
+ vote_last_response(state, "upvote", model_selector, request)
83
+ return ("",) + (disable_btn,) * 3
84
+
85
+ def downvote_last_response(state, model_selector, request: gr.Request):
86
+ logger.info(f"downvote. ip: {request.client.host}")
87
+ vote_last_response(state, "downvote", model_selector, request)
88
+ return ("",) + (disable_btn,) * 3
89
+
90
+ def flag_last_response(state, model_selector, request: gr.Request):
91
+ logger.info(f"flag. ip: {request.client.host}")
92
+ vote_last_response(state, "flag", model_selector, request)
93
+ return ("",) + (disable_btn,) * 3
94
+
95
+ def regenerate(state, image_process_mode, request: gr.Request):
96
+ logger.info(f"regenerate. ip: {request.client.host}")
97
+ state.messages[-1][-1] = None
98
+ prev_human_msg = state.messages[-2]
99
+ if type(prev_human_msg[1]) in (tuple, list):
100
+ prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
101
+ state.skip_next = False
102
+ return (state, None, None, None) + (disable_btn,) * 6
103
+
104
+ def clear_history(request: gr.Request):
105
+ logger.info(f"clear_history. ip: {request.client.host}")
106
+ state = default_conversation.copy()
107
+ return (state, None, None) + (disable_btn,) * 6
108
+
109
+ def send_image(state, image, image_process_mode, request: gr.Request):
110
+ logger.info(f"send_image. ip: {request.client.host}.")
111
+ state.stop_sampling = False
112
+ if image is None:
113
+ state.skip_next = True
114
+ return (state, None, None, image) + (no_change_btn,) * 6
115
+
116
+ if image is not None:
117
+ text = (image, image_process_mode)
118
+ state.append_message(state.roles[0], text)
119
+ state.append_message(state.roles[1], "β–Œ")
120
+ state.skip_next = False
121
+ msg = state.to_gradio_svg_code()[0][1]
122
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
123
+
124
+ def stop_sampling(state, image, request: gr.Request):
125
+ logger.info(f"stop_sampling. ip: {request.client.host}")
126
+ state.stop_sampling = True
127
+ return (state, None, None, image) + (disable_btn,) * 6
128
+
129
+ def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
130
+ logger.info(f"http_bot. ip: {request.client.host}")
131
+ start_tstamp = time.time()
132
+ model_name = model_selector
133
+
134
+ if state.skip_next:
135
+ # This generate call is skipped due to invalid inputs
136
+ yield (state, None, None) + (no_change_btn,) * 6
137
+ return
138
+
139
+ # Query worker address
140
+ controller_url = args.controller_url
141
+ ret = requests.post(controller_url + "/get_worker_address",
142
+ json={"model": model_name})
143
+ worker_addr = ret.json()["address"]
144
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
145
+
146
+ # No available worker
147
+ if worker_addr == "":
148
+ state.messages[-1][-1] = server_error_msg
149
+ yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
150
+ return
151
+
152
+ # Construct prompt
153
+ prompt = state.get_prompt()
154
+
155
+ # Make requests
156
+ pload = {
157
+ "model": model_name,
158
+ "prompt": prompt,
159
+ "num_beams": int(num_beams),
160
+ "temperature": float(temperature),
161
+ "len_penalty": float(len_penalty),
162
+ "top_p": float(top_p),
163
+ "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
164
+ }
165
+ logger.info(f"==== request ====\n{pload}")
166
+
167
+ pload['images'] = state.get_images()
168
+
169
+ state.messages[-1][-1] = "β–Œ"
170
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
171
+
172
+ try:
173
+ # Stream output
174
+ if state.stop_sampling:
175
+ state.messages[1][-1] = "β–Œ"
176
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
177
+ return
178
+
179
+ response = requests.post(worker_addr + "/worker_generate_stream",
180
+ headers=headers, json=pload, stream=True, timeout=100)
181
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
182
+ if chunk:
183
+ data = json.loads(chunk.decode())
184
+ if data["error_code"] == 0:
185
+ # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
186
+ output = data["text"].strip()
187
+ state.messages[-1][-1] = output + "β–Œ"
188
+ st = state.to_gradio_svg_code()
189
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
190
+ else:
191
+ output = data["text"] + f" (error_code: {data['error_code']})"
192
+ state.messages[-1][-1] = output
193
+
194
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
195
+ return
196
+ time.sleep(0.03)
197
+ except requests.exceptions.RequestException as e:
198
+ state.messages[-1][-1] = server_error_msg
199
+ yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
200
+ return
201
+
202
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
203
+
204
+ finish_tstamp = time.time()
205
+ logger.info(f"{output}")
206
+
207
+ with open(get_conv_log_filename(), "a") as fout:
208
+ data = {
209
+ "tstamp": round(finish_tstamp, 4),
210
+ "type": "chat",
211
+ "model": model_name,
212
+ "start": round(start_tstamp, 4),
213
+ "finish": round(finish_tstamp, 4),
214
+ "svg": state.messages[-1][-1],
215
+ "ip": request.client.host,
216
+ }
217
+ fout.write(json.dumps(data) + "\n")
218
+
219
+ title_markdown = ("""
220
+ # πŸ’« StarVector: Generating Scalable Vector Graphics Code from Images and Text
221
+ [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | πŸ“š [[StarVector](https://arxiv.org/abs/2312.11556)]
222
+ """)
223
+
224
+ sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
225
+ tos_markdown = ("""
226
+ ### Terms of use
227
+ By using this service, users are required to agree to the following terms:
228
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
229
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
230
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
231
+ """)
232
+
233
+
234
+ learn_more_markdown = ("""
235
+ ### License
236
+ The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
237
+ """)
238
+
239
+ block_css = """
240
+
241
+ #buttons button {
242
+ min-width: min(120px,100%);
243
+ }
244
+
245
+ .gradio-container{
246
+ max-width: 1200px!important
247
+ }
248
+
249
+ #svg_render{
250
+ padding: 20px !important;
251
+ }
252
+
253
+ #svg_code{
254
+ height: 200px !important;
255
+ overflow: scroll !important;
256
+ white-space: unset !important;
257
+ flex-shrink: unset !important;
258
+ }
259
+
260
+
261
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
262
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
263
+ """
264
+
265
+ def build_demo(embed_mode, concurrency_count=10):
266
+ with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
267
+ state = gr.State()
268
+ if not embed_mode:
269
+ gr.Markdown(title_markdown)
270
+ gr.Markdown(sub_title_markdown)
271
+ with gr.Row():
272
+ with gr.Column(scale=3):
273
+ with gr.Row(elem_id="model_selector_row"):
274
+ model_selector = gr.Dropdown(
275
+ choices=models,
276
+ value=models[0] if len(models) > 0 else "",
277
+ interactive=True,
278
+ show_label=False,
279
+ container=False)
280
+ imagebox = gr.Image(type="pil")
281
+ image_process_mode = gr.Radio(
282
+ ["Resize", "Pad", "Default"],
283
+ value="Pad",
284
+ label="Preprocess for non-square image", visible=False)
285
+
286
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
287
+ gr.Examples(examples=[
288
+ [f"{cur_dir}/examples/sample-4.png"],
289
+ [f"{cur_dir}/examples/sample-7.png"],
290
+ [f"{cur_dir}/examples/sample-16.png"],
291
+ [f"{cur_dir}/examples/sample-17.png"],
292
+ [f"{cur_dir}/examples/sample-18.png"],
293
+ [f"{cur_dir}/examples/sample-0.png"],
294
+ [f"{cur_dir}/examples/sample-1.png"],
295
+ [f"{cur_dir}/examples/sample-6.png"],
296
+ ], inputs=[imagebox])
297
+
298
+ with gr.Column(scale=1, min_width=50):
299
+ submit_btn = gr.Button(value="Send", variant="primary")
300
+
301
+ with gr.Accordion("Parameters", open=True) as parameter_row:
302
+ num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
303
+ temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
304
+ len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
305
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
306
+ max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
307
+
308
+ with gr.Column(scale=8):
309
+ with gr.Row():
310
+ svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
311
+ with gr.Row():
312
+ gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
313
+ with gr.Row(elem_id="buttons") as button_row:
314
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
315
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
316
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
317
+ stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
318
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False, visible=False)
319
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
320
+
321
+ if not embed_mode:
322
+ gr.Markdown(tos_markdown)
323
+ gr.Markdown(learn_more_markdown)
324
+ url_params = gr.JSON(visible=False)
325
+
326
+ # Register listeners
327
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
328
+ upvote_btn.click(
329
+ upvote_last_response,
330
+ [state, model_selector],
331
+ [upvote_btn, downvote_btn, flag_btn],
332
+ queue=False
333
+ )
334
+ downvote_btn.click(
335
+ downvote_last_response,
336
+ [state, model_selector],
337
+ [upvote_btn, downvote_btn, flag_btn],
338
+ queue=False
339
+ )
340
+ flag_btn.click(
341
+ flag_last_response,
342
+ [state, model_selector],
343
+ [upvote_btn, downvote_btn, flag_btn],
344
+ queue=False
345
+ )
346
+
347
+ regenerate_btn.click(
348
+ regenerate,
349
+ [state, image_process_mode],
350
+ [state, svg_code, svg_render, imagebox] + btn_list,
351
+ queue=False
352
+ ).then(
353
+ http_bot,
354
+ [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
355
+ [state, svg_code, svg_render] + btn_list,
356
+ concurrency_limit=concurrency_count
357
+ )
358
+
359
+ submit_btn.click(
360
+ send_image,
361
+ [state, imagebox, image_process_mode],
362
+ [state, svg_code, svg_render, imagebox] + btn_list,
363
+ queue=False
364
+ ).then(
365
+ http_bot,
366
+ [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
367
+ [state, svg_code, svg_render] + btn_list,
368
+ concurrency_limit=concurrency_count
369
+ )
370
+
371
+ clear_btn.click(
372
+ clear_history,
373
+ None,
374
+ [state, svg_code, svg_render] + btn_list,
375
+ queue=False
376
+ )
377
+
378
+ stop_btn.click(
379
+ stop_sampling,
380
+ [state, imagebox],
381
+ [state, imagebox] + btn_list,
382
+ queue=False
383
+ ).then(
384
+ clear_history,
385
+ None,
386
+ [state, svg_code, svg_render] + btn_list,
387
+ queue=False
388
+ )
389
+
390
+ if args.model_list_mode == "once":
391
+ demo.load(
392
+ load_demo,
393
+ [url_params],
394
+ [state, model_selector],
395
+ _js=get_window_url_params,
396
+ )
397
+ elif args.model_list_mode == "reload":
398
+ demo.load(
399
+ load_demo_refresh_model_list,
400
+ None,
401
+ [state, model_selector],
402
+ queue=False
403
+ )
404
+ else:
405
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
406
+
407
+ return demo
408
+
409
+ if __name__ == "__main__":
410
+ parser = argparse.ArgumentParser()
411
+ parser.add_argument("--host", type=str, default="0.0.0.0")
412
+ parser.add_argument("--port", type=int)
413
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
414
+ parser.add_argument("--concurrency-count", type=int, default=15)
415
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
416
+ parser.add_argument("--share", action="store_true")
417
+ parser.add_argument("--moderate", action="store_true")
418
+ parser.add_argument("--embed", action="store_true")
419
+ args = parser.parse_args()
420
+ logger.info(f"args: {args}")
421
+
422
+ models = get_model_list()
423
+
424
+ logger.info(args)
425
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
426
+ demo.queue(
427
+ api_open=False
428
+ ).launch(
429
+ server_name=args.host,
430
+ server_port=args.port,
431
+ share=args.share
432
+ )
starvector/serve/gradio_web_server.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ from starvector.serve.conversation import default_conversation
9
+ from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
+ from starvector.serve.util import (build_logger, server_error_msg)
11
+
12
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
+ headers = {"User-Agent": "StarVector Client"}
14
+
15
+ no_change_btn = gr.Button.update()
16
+ enable_btn = gr.Button.update(interactive=True)
17
+ disable_btn = gr.Button.update(interactive=False)
18
+
19
+ priority = {
20
+ "starvector-1b-im2svg": "aaaaaaa",
21
+ }
22
+
23
+ def get_conv_log_filename():
24
+ t = datetime.datetime.now()
25
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
+ return name
27
+
28
+ def get_model_list():
29
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
30
+ assert ret.status_code == 200
31
+ ret = requests.post(args.controller_url + "/list_models")
32
+ models = ret.json()["models"]
33
+ models.sort(key=lambda x: priority.get(x, x))
34
+ logger.info(f"Models: {models}")
35
+ return models
36
+
37
+ def load_demo(url_params, request: gr.Request):
38
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
39
+
40
+ dropdown_update = gr.Dropdown.update(visible=True)
41
+ if "model" in url_params:
42
+ model = url_params["model"]
43
+ if model in models:
44
+ dropdown_update = gr.Dropdown.update(
45
+ value=model, visible=True)
46
+
47
+ state = default_conversation.copy()
48
+ return state, dropdown_update
49
+
50
+ mapping_model_task = {
51
+ 'Image2SVG': 'im2svg',
52
+ 'Text2SVG': 'text2svg'
53
+ }
54
+
55
+ def get_models_dropdown_from_task(task):
56
+ models = get_model_list()
57
+ models = [model for model in models if mapping_model_task[task] in model]
58
+ dropdown_update = gr.Dropdown.update(
59
+ choices=models,
60
+ value=models[0] if len(models) > 0 else ""
61
+ )
62
+ return dropdown_update
63
+
64
+
65
+ def load_demo_refresh_model_list(task, request: gr.Request):
66
+ logger.info(f"load_demo. ip: {request.client.host}")
67
+ dropdown_update = get_models_dropdown_from_task(task)
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
72
+ with open(get_conv_log_filename(), "a") as fout:
73
+ data = {
74
+ "tstamp": round(time.time(), 4),
75
+ "type": vote_type,
76
+ "model": model_selector,
77
+ "state": state.dict(),
78
+ "ip": request.client.host,
79
+ }
80
+ fout.write(json.dumps(data) + "\n")
81
+
82
+ def upvote_last_response(state, model_selector, request: gr.Request):
83
+ logger.info(f"upvote. ip: {request.client.host}")
84
+ vote_last_response(state, "upvote", model_selector, request)
85
+ return ("",) + (disable_btn,) * 7
86
+
87
+ def downvote_last_response(state, model_selector, request: gr.Request):
88
+ logger.info(f"downvote. ip: {request.client.host}")
89
+ vote_last_response(state, "downvote", model_selector, request)
90
+ return ("",) + (disable_btn,) * 7
91
+
92
+ def flag_last_response(state, model_selector, request: gr.Request):
93
+ logger.info(f"flag. ip: {request.client.host}")
94
+ vote_last_response(state, "flag", model_selector, request)
95
+ return ("",) + (disable_btn,) * 7
96
+
97
+ def regenerate(state, image_process_mode, request: gr.Request):
98
+ logger.info(f"regenerate. ip: {request.client.host}")
99
+ state.messages[-1][-1] = None
100
+ prev_human_msg = state.messages[-2]
101
+ if type(prev_human_msg[1]) in (tuple, list):
102
+ prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
103
+ state.skip_next = False
104
+ return (state, None, None, None) + (disable_btn,) * 7
105
+
106
+ def clear_history(request: gr.Request):
107
+ logger.info(f"clear_history. ip: {request.client.host}")
108
+ state = default_conversation.copy()
109
+ return (state, None, None) + (disable_btn,) * 7
110
+
111
+ def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
112
+ logger.info(f"send_data. ip: {request.client.host}.")
113
+ if task == 'Image2SVG':
114
+ if image is None:
115
+ state.skip_next = True
116
+ return (state, None, None, image) + (no_change_btn,) * 7
117
+
118
+ if image is not None:
119
+ image_message = (image, image_process_mode)
120
+ state.append_message(state.roles[0], image_message)
121
+ state.append_message(state.roles[1], "β–Œ")
122
+ state.skip_next = False
123
+ msg = state.to_gradio_svg_code()[0][1]
124
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
125
+ else:
126
+ if text_caption is None:
127
+ state.skip_next = True
128
+ return (state, None, None, image) + (no_change_btn,) * 7
129
+
130
+ state.append_message(state.roles[0], text_caption)
131
+ state.append_message(state.roles[1], "β–Œ")
132
+ state.skip_next = False
133
+ msg = state.to_gradio_svg_code()[0][1]
134
+ return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
135
+
136
+ def download_files(state, request: gr.Request):
137
+ logger.info(f"download_files. ip: {request.client.host}")
138
+ svg_str, image = state.download_files()
139
+
140
+ # TODO: Figure out how to download the SVG in the users browser, idk how to do it now
141
+
142
+ def update_task(task):
143
+ dropdown_update = get_models_dropdown_from_task(task)
144
+
145
+ if task == "Text2SVG":
146
+ return 1.0, 0.9, 0.95, dropdown_update
147
+ else:
148
+ return 0.6, 0.9, 0.95, dropdown_update
149
+
150
+
151
+ def stop_sampling(state, image, request: gr.Request):
152
+ logger.info(f"stop_sampling. ip: {request.client.host}")
153
+ state.stop_sampling = True
154
+ return (state, None, None, image) + (disable_btn,) * 7
155
+
156
+ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, None, None) + (no_change_btn,) * 7
164
+ return
165
+
166
+ # Query worker address
167
+ controller_url = args.controller_url
168
+ ret = requests.post(controller_url + "/get_worker_address",
169
+ json={"model": model_name})
170
+ worker_addr = ret.json()["address"]
171
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
172
+
173
+ # No available worker
174
+ if worker_addr == "":
175
+ state.messages[-1][-1] = server_error_msg
176
+ yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
177
+ return
178
+
179
+ # Construct prompt
180
+ if task_selector == "Image2SVG":
181
+ prompt = state.get_image_prompt()
182
+ else:
183
+ prompt = text_caption
184
+
185
+ # Make requests
186
+ pload = {
187
+ "model": model_name,
188
+ "prompt": prompt,
189
+ "num_beams": int(num_beams),
190
+ "temperature": float(temperature),
191
+ "len_penalty": float(len_penalty),
192
+ "top_p": float(top_p),
193
+ "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
194
+ }
195
+ logger.info(f"==== request ====\n{pload}")
196
+
197
+ pload['images'] = state.get_images()
198
+
199
+ state.messages[-1][-1] = "β–Œ"
200
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
201
+
202
+ try:
203
+ # Stream output
204
+ if state.stop_sampling:
205
+ state.messages[1][-1] = "β–Œ"
206
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
207
+ return
208
+
209
+ response = requests.post(worker_addr + "/worker_generate_stream",
210
+ headers=headers, json=pload, stream=True, timeout=10)
211
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
212
+ if chunk:
213
+ data = json.loads(chunk.decode())
214
+ if data["error_code"] == 0:
215
+ # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
216
+ output = data["text"].strip()
217
+ state.messages[-1][-1] = output + "β–Œ"
218
+ st = state.to_gradio_svg_code()
219
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
220
+ else:
221
+ output = data["text"] + f" (error_code: {data['error_code']})"
222
+ state.messages[-1][-1] = output
223
+
224
+ yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
225
+ return
226
+ time.sleep(0.03)
227
+ except requests.exceptions.RequestException as e:
228
+ state.messages[-1][-1] = server_error_msg
229
+ yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
230
+ return
231
+
232
+ yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
233
+
234
+ finish_tstamp = time.time()
235
+ logger.info(f"{output}")
236
+
237
+ with open(get_conv_log_filename(), "a") as fout:
238
+ data = {
239
+ "tstamp": round(finish_tstamp, 4),
240
+ "type": "chat",
241
+ "model": model_name,
242
+ "start": round(start_tstamp, 4),
243
+ "finish": round(finish_tstamp, 4),
244
+ "svg": state.messages[-1][-1],
245
+ "ip": request.client.host,
246
+ }
247
+ fout.write(json.dumps(data) + "\n")
248
+
249
+ title_markdown = ("""
250
+ # πŸ’« StarVector: Generating Scalable Vector Graphics Code from Images and Text
251
+
252
+ [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | πŸ“š [[StarVector](https://arxiv.org/abs/2312.11556)]""")
253
+
254
+ sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
255
+ **Note**: The current model works on vector-like images like icons and or vector-like designs.""")
256
+ tos_markdown = ("""
257
+ ### Terms of use
258
+ By using this service, users are required to agree to the following terms:
259
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
260
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
261
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
262
+ """)
263
+
264
+ learn_more_markdown = ("""
265
+ ### License
266
+ The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
267
+ """)
268
+
269
+ block_css = """
270
+
271
+ #buttons button {
272
+ min-width: min(120px,100%);
273
+ }
274
+
275
+ .gradio-container{
276
+ max-width: 1200px!important
277
+ }
278
+
279
+ .ΝΌ1 .cm-content {
280
+ white-space: unset !important;
281
+ flex-shrink: unset !important;
282
+ }
283
+
284
+ .ΝΌ2p .cm-scroller {
285
+ max-height: 200px;
286
+ overflow: scroll;
287
+ }
288
+
289
+ #svg_render{
290
+ padding: 20px !important;
291
+ }
292
+
293
+ #submit_btn{
294
+ max-height: 40px;
295
+ }
296
+
297
+ .selector{
298
+ max-height: 100px;
299
+ }
300
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
301
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
302
+ """
303
+ def build_demo(embed_mode):
304
+ svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
305
+ svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
306
+
307
+ with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
308
+ state = gr.State()
309
+ if not embed_mode:
310
+ gr.Markdown(title_markdown)
311
+ gr.Markdown(sub_title_markdown)
312
+ with gr.Row():
313
+ with gr.Column(scale=4):
314
+ task_selector = gr.Dropdown(
315
+ choices=["Image2SVG", "Text2SVG"],
316
+ value="Image2SVG",
317
+ label="Task",
318
+ interactive=True,
319
+ show_label=True,
320
+ container=True,
321
+ elem_id="task_selector",
322
+ elem_classes=["selector"],
323
+ )
324
+ model_selector = gr.Dropdown(
325
+ choices=models,
326
+ value=models[0] if len(models) > 0 else "",
327
+ label="Model",
328
+ interactive=True,
329
+ show_label=True,
330
+ container=True,
331
+ elem_classes=["selector"],
332
+ )
333
+
334
+ imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
335
+ image_process_mode = gr.Radio(
336
+ ["Resize", "Pad", "Default"],
337
+ value="Pad",
338
+ label="Preprocess for non-square image", visible=False)
339
+
340
+ # Text input
341
+ text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
342
+
343
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
344
+ gr.Examples(examples=[
345
+ [f"{cur_dir}/examples/sample-4.png"],
346
+ [f"{cur_dir}/examples/sample-7.png"],
347
+ [f"{cur_dir}/examples/sample-16.png"],
348
+ [f"{cur_dir}/examples/sample-17.png"],
349
+ [f"{cur_dir}/examples/sample-18.png"],
350
+ [f"{cur_dir}/examples/sample-0.png"],
351
+ [f"{cur_dir}/examples/sample-1.png"],
352
+ [f"{cur_dir}/examples/sample-6.png"],
353
+ ], inputs=[imagebox], elem_id="examples")
354
+
355
+ submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
356
+
357
+ with gr.Accordion("Parameters", open=False):
358
+ num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
359
+ temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
360
+ len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
361
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
362
+ max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",)
363
+
364
+ with gr.Column(scale=9):
365
+ with gr.Row():
366
+ svg_code.render()
367
+ with gr.Row():
368
+ svg_render.render()
369
+
370
+ with gr.Row(elem_id="buttons") as button_row:
371
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
372
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
373
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
374
+ stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
375
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False, visible=False)
376
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
377
+ download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
378
+
379
+ if not embed_mode:
380
+ gr.Markdown(tos_markdown)
381
+ gr.Markdown(learn_more_markdown)
382
+ url_params = gr.JSON(visible=False)
383
+
384
+ # Register listeners
385
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
386
+ upvote_btn.click(
387
+ upvote_last_response,
388
+ [state, model_selector],
389
+ [upvote_btn, downvote_btn, flag_btn],
390
+ queue=False
391
+ )
392
+ downvote_btn.click(
393
+ downvote_last_response,
394
+ [state, model_selector],
395
+ [upvote_btn, downvote_btn, flag_btn],
396
+ queue=False
397
+ )
398
+ flag_btn.click(
399
+ flag_last_response,
400
+ [state, model_selector],
401
+ [upvote_btn, downvote_btn, flag_btn],
402
+ queue=False
403
+ )
404
+
405
+ regenerate_btn.click(
406
+ regenerate,
407
+ [state, image_process_mode],
408
+ [state, svg_code, svg_render, imagebox] + btn_list,
409
+ queue=False
410
+ ).then(
411
+ http_bot,
412
+ [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
413
+ [state, svg_code, svg_render] + btn_list)
414
+
415
+ submit_btn.click(
416
+ send_data,
417
+ [state, imagebox, image_process_mode, text_caption, task_selector],
418
+ [state, svg_code, svg_render, imagebox] + btn_list,
419
+ queue=False
420
+ ).then(
421
+ http_bot,
422
+ [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
423
+ [state, svg_code, svg_render] + btn_list
424
+ )
425
+
426
+ clear_btn.click(
427
+ clear_history,
428
+ None,
429
+ [state, svg_code, svg_render] + btn_list,
430
+ queue=False
431
+ )
432
+
433
+ stop_btn.click(
434
+ stop_sampling,
435
+ [state, imagebox],
436
+ [state, imagebox] + btn_list,
437
+ queue=False
438
+ ).then(
439
+ clear_history,
440
+ None,
441
+ [state, svg_code, svg_render] + btn_list,
442
+ queue=False
443
+ )
444
+
445
+ download_btn.click(
446
+ download_files,
447
+ [state],
448
+ None,
449
+ queue=False
450
+ )
451
+ task_selector.change(
452
+ update_task,
453
+ inputs=[task_selector],
454
+ outputs=[len_penalty, temperature, top_p, model_selector],
455
+ queue=False,
456
+ _js="""
457
+ function(task) {
458
+ var imageBoxElement = document.getElementById("imagebox");
459
+ var textCaptionElement = document.getElementById("text_caption");
460
+ var examplesElement = document.getElementById("examples");
461
+ if (task === "Text2SVG") {
462
+ imageBoxElement.style.display = "none";
463
+ textCaptionElement.style.display = "block";
464
+ examplesElement.style.display = "none";
465
+ } else if (task === "Image2SVG") {
466
+ imageBoxElement.style.display = "block";
467
+ textCaptionElement.style.display = "none";
468
+ examplesElement.style.display = "block";
469
+ }
470
+ return task;
471
+ }
472
+ """
473
+ )
474
+
475
+ if args.model_list_mode == "once":
476
+ demo.load(
477
+ load_demo,
478
+ [url_params, task_selector],
479
+ [state, model_selector],
480
+ _js="""
481
+ function() {
482
+ const params = new URLSearchParams(window.location.search);
483
+ url_params = Object.fromEntries(params);
484
+ console.log(url_params);
485
+ return url_params;
486
+
487
+ }
488
+ """,
489
+ queue=False
490
+ )
491
+ elif args.model_list_mode == "reload":
492
+ demo.load(
493
+ load_demo_refresh_model_list,
494
+ [task_selector],
495
+ [state, model_selector],
496
+ _js="""
497
+ function(task) {
498
+ var textCaptionElement = document.getElementById("text_caption");
499
+ var autoScrollBottom = true;
500
+ textCaptionElement.style.display = "none";
501
+ function updateScroll(){
502
+ if (autoScrollBottom) {
503
+ var element = document.getElementsByClassName("cm-scroller")[0];
504
+ element.scrollTop = element.scrollHeight;
505
+ }
506
+ }
507
+ function handleScroll() {
508
+ var element = document.getElementsByClassName("cm-scroller")[0];
509
+ //if (element.scrollHeight - element.scrollTop === element.clientHeight) {
510
+ if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
511
+ // User has scrolled to the bottom, enable auto-scrolling
512
+ autoScrollBottom = true;
513
+ console.log("bottom");
514
+ } else {
515
+ console.log("not bottom");
516
+ // User has scrolled away from the bottom, disable auto-scrolling
517
+ autoScrollBottom = false;
518
+ }
519
+ }
520
+ setInterval(updateScroll,500);
521
+ var element = document.getElementsByClassName("cm-scroller")[0];
522
+ element.addEventListener("scroll", handleScroll);
523
+
524
+ return task;
525
+ }
526
+
527
+ """,
528
+ queue=False,
529
+ )
530
+
531
+ else:
532
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
533
+
534
+ return demo
535
+
536
+ if __name__ == "__main__":
537
+
538
+ parser = argparse.ArgumentParser()
539
+ parser.add_argument("--host", type=str, default="0.0.0.0")
540
+ parser.add_argument("--port", type=int)
541
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
542
+ parser.add_argument("--concurrency-count", type=int, default=10)
543
+ parser.add_argument("--model-list-mode", type=str, default="once",
544
+ choices=["once", "reload"])
545
+ parser.add_argument("--share", action="store_true")
546
+ parser.add_argument("--moderate", action="store_true")
547
+ parser.add_argument("--embed", action="store_true")
548
+ args = parser.parse_args()
549
+ logger.info(f"args: {args}")
550
+
551
+ models = get_model_list()
552
+
553
+ logger.info(args)
554
+ demo = build_demo(args.embed)
555
+ demo.queue(
556
+ concurrency_count=args.concurrency_count,
557
+ api_open=False
558
+ ).launch(
559
+ server_name=args.host,
560
+ server_port=args.port,
561
+ share=args.share
562
+ )