Spaces:
Runtime error
Runtime error
DimaKoshman
commited on
Commit
•
6a43216
1
Parent(s):
ba1576e
fix
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import gradio
|
4 |
import pandas as pd
|
5 |
from matplotlib import pyplot as plt
|
@@ -44,7 +42,7 @@ def main():
|
|
44 |
interface = gradio.Interface(
|
45 |
title="Making graphs accessible",
|
46 |
description="Generate textual representation of a graph\n"
|
47 |
-
|
48 |
fn=lambda image: predict_string(image, model),
|
49 |
inputs="image",
|
50 |
outputs="text",
|
|
|
|
|
|
|
1 |
import gradio
|
2 |
import pandas as pd
|
3 |
from matplotlib import pyplot as plt
|
|
|
42 |
interface = gradio.Interface(
|
43 |
title="Making graphs accessible",
|
44 |
description="Generate textual representation of a graph\n"
|
45 |
+
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
|
46 |
fn=lambda image: predict_string(image, model),
|
47 |
inputs="image",
|
48 |
outputs="text",
|
data.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import dataclasses
|
4 |
import enum
|
5 |
import functools
|
@@ -231,23 +229,23 @@ class AnnotatedImage:
|
|
231 |
|
232 |
def generate_annotated_images():
|
233 |
for image_id in tqdm.autonotebook.tqdm(
|
234 |
-
|
235 |
):
|
236 |
yield AnnotatedImage.from_image_id(image_id)
|
237 |
|
238 |
|
239 |
-
@functools.
|
240 |
def load_train_image_ids() -> list[str]:
|
241 |
train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")]
|
242 |
return train_image_ids[: 1000 if CONFIG.debug else None]
|
243 |
|
244 |
|
245 |
-
@functools.
|
246 |
def load_test_image_ids() -> list[str]:
|
247 |
return [i.replace(".jpg", "") for i in os.listdir("data/test/images")]
|
248 |
|
249 |
|
250 |
-
@functools.
|
251 |
def load_image_annotation(image_id: str) -> dict:
|
252 |
return json.load(open(f"data/train/annotations/{image_id}.json"))
|
253 |
|
@@ -309,7 +307,7 @@ def to_token_str(value: str or enum.Enum):
|
|
309 |
return f"<{string}>"
|
310 |
|
311 |
|
312 |
-
@functools.
|
313 |
def get_extra_tokens() -> types.SimpleNamespace:
|
314 |
token_ns = types.SimpleNamespace()
|
315 |
|
@@ -333,7 +331,7 @@ def convert_number_to_scientific_string(value: int or float) -> str:
|
|
333 |
|
334 |
|
335 |
def convert_axis_data_to_string(
|
336 |
-
|
337 |
) -> str:
|
338 |
formatted_axis_data = []
|
339 |
for value in axis_data:
|
|
|
|
|
|
|
1 |
import dataclasses
|
2 |
import enum
|
3 |
import functools
|
|
|
229 |
|
230 |
def generate_annotated_images():
|
231 |
for image_id in tqdm.autonotebook.tqdm(
|
232 |
+
load_train_image_ids(), "Iterating over annotated images"
|
233 |
):
|
234 |
yield AnnotatedImage.from_image_id(image_id)
|
235 |
|
236 |
|
237 |
+
@functools.cache
|
238 |
def load_train_image_ids() -> list[str]:
|
239 |
train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")]
|
240 |
return train_image_ids[: 1000 if CONFIG.debug else None]
|
241 |
|
242 |
|
243 |
+
@functools.cache
|
244 |
def load_test_image_ids() -> list[str]:
|
245 |
return [i.replace(".jpg", "") for i in os.listdir("data/test/images")]
|
246 |
|
247 |
|
248 |
+
@functools.cache
|
249 |
def load_image_annotation(image_id: str) -> dict:
|
250 |
return json.load(open(f"data/train/annotations/{image_id}.json"))
|
251 |
|
|
|
307 |
return f"<{string}>"
|
308 |
|
309 |
|
310 |
+
@functools.cache
|
311 |
def get_extra_tokens() -> types.SimpleNamespace:
|
312 |
token_ns = types.SimpleNamespace()
|
313 |
|
|
|
331 |
|
332 |
|
333 |
def convert_axis_data_to_string(
|
334 |
+
axis_data: list[str or float], values_type: ValuesType
|
335 |
) -> str:
|
336 |
formatted_axis_data = []
|
337 |
for value in axis_data:
|
metrics.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import numpy as np
|
4 |
import rapidfuzz
|
5 |
import sklearn
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import rapidfuzz
|
3 |
import sklearn
|
model.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import collections
|
4 |
import dataclasses
|
5 |
import types
|
@@ -30,7 +28,7 @@ class Model:
|
|
30 |
|
31 |
|
32 |
def add_unknown_tokens_to_tokenizer(
|
33 |
-
|
34 |
):
|
35 |
tokenizer.add_tokens(unknown_tokens)
|
36 |
encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
|
@@ -53,7 +51,7 @@ def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
|
|
53 |
|
54 |
|
55 |
def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
|
56 |
-
|
57 |
):
|
58 |
token_ids[token_ids == tokenizer.pad_token_id] = -100
|
59 |
return token_ids
|
@@ -144,7 +142,7 @@ def build_model(config: types.SimpleNamespace or object) -> Model:
|
|
144 |
|
145 |
|
146 |
def generate_token_strings(
|
147 |
-
|
148 |
) -> list[str]:
|
149 |
decoder_output = model.encoder_decoder.generate(
|
150 |
images,
|
|
|
|
|
|
|
1 |
import collections
|
2 |
import dataclasses
|
3 |
import types
|
|
|
28 |
|
29 |
|
30 |
def add_unknown_tokens_to_tokenizer(
|
31 |
+
tokenizer, encoder_decoder, unknown_tokens: list[str]
|
32 |
):
|
33 |
tokenizer.add_tokens(unknown_tokens)
|
34 |
encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
|
|
|
51 |
|
52 |
|
53 |
def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
|
54 |
+
tokenizer, token_ids
|
55 |
):
|
56 |
token_ids[token_ids == tokenizer.pad_token_id] = -100
|
57 |
return token_ids
|
|
|
142 |
|
143 |
|
144 |
def generate_token_strings(
|
145 |
+
model: Model, images: torch.Tensor, skip_special_tokens=True
|
146 |
) -> list[str]:
|
147 |
decoder_output = model.encoder_decoder.generate(
|
148 |
images,
|
train.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import os
|
4 |
|
5 |
import pandas as pd
|
@@ -22,12 +20,12 @@ from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus
|
|
22 |
|
23 |
class MetricsCallback(pl.callbacks.Callback):
|
24 |
def on_validation_batch_start(
|
25 |
-
|
26 |
):
|
27 |
predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
|
28 |
|
29 |
for expected_data_index, predicted_string in zip(
|
30 |
-
|
31 |
):
|
32 |
benetech_score = benetech_score_string_prediction(
|
33 |
expected_data_index=expected_data_index,
|
@@ -52,7 +50,7 @@ class MetricsCallback(pl.callbacks.Callback):
|
|
52 |
|
53 |
class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
|
54 |
def __init__(
|
55 |
-
|
56 |
):
|
57 |
super().__init__()
|
58 |
self.pretrained_models = pretrained_models
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
import pandas as pd
|
|
|
20 |
|
21 |
class MetricsCallback(pl.callbacks.Callback):
|
22 |
def on_validation_batch_start(
|
23 |
+
self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
|
24 |
):
|
25 |
predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
|
26 |
|
27 |
for expected_data_index, predicted_string in zip(
|
28 |
+
batch.data_indices, predicted_strings, strict=True
|
29 |
):
|
30 |
benetech_score = benetech_score_string_prediction(
|
31 |
expected_data_index=expected_data_index,
|
|
|
50 |
|
51 |
class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
|
52 |
def __init__(
|
53 |
+
self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
|
54 |
):
|
55 |
super().__init__()
|
56 |
self.pretrained_models = pretrained_models
|
utils.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
import os
|
4 |
import pickle
|
5 |
from typing import Callable, TypeVar
|
@@ -16,7 +14,7 @@ def set_torch_device_order_pci_bus():
|
|
16 |
|
17 |
|
18 |
def load_pickle_or_build_object_and_save(
|
19 |
-
|
20 |
) -> T:
|
21 |
if overwrite or not os.path.exists(pickle_path):
|
22 |
pickle.dump(build_object(), open(pickle_path, "wb"))
|
|
|
|
|
|
|
1 |
import os
|
2 |
import pickle
|
3 |
from typing import Callable, TypeVar
|
|
|
14 |
|
15 |
|
16 |
def load_pickle_or_build_object_and_save(
|
17 |
+
pickle_path: str, build_object: Callable[[], T], overwrite=False
|
18 |
) -> T:
|
19 |
if overwrite or not os.path.exists(pickle_path):
|
20 |
pickle.dump(build_object(), open(pickle_path, "wb"))
|