Spaces:
Sleeping
Sleeping
import json | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, T5Config | |
from torch.nn import CrossEntropyLoss | |
from custom_t5_vit import CustomT5ForConditionalGeneration | |
from GP import genetic_programming | |
from Nods import FUNCTIONS_dictionary | |
from task_loader import * | |
import os | |
# Get the base directory where the current script is running | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# Build relative paths | |
TOKENIZER_PATH = os.path.join(BASE_DIR, "tokenizer_vs22_extendarctokens") | |
MODEL_SAVE_PATH_1 = os.path.join(BASE_DIR, "model", "final_cls_modell.pt") | |
print("Loading tokenizer from:", TOKENIZER_PATH) | |
print("Loading model from:", MODEL_SAVE_PATH_1) | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) | |
class CustomT5Config(T5Config): | |
def __init__(self, PE_mix_strategy="default", use_objidx="yes", | |
grid_max_height=33, grid_max_width=34, **kwargs): | |
super().__init__(**kwargs) | |
self.PE_mix_strategy = PE_mix_strategy | |
self.use_objidx = use_objidx | |
self.grid_max_height = grid_max_height | |
self.grid_max_width = grid_max_width | |
config = CustomT5Config( | |
vocab_size=len(tokenizer), | |
d_model=128, | |
num_layers=3, | |
num_decoder_layers=3, | |
num_heads=8, | |
d_ff=256, | |
dropout_rate=0.1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
class ConceptDetector(torch.nn.Module): | |
def __init__(self, config, num_classes): | |
super().__init__() | |
self.model = CustomT5ForConditionalGeneration(config) | |
self.classifier_head = torch.nn.Linear(config.d_model, num_classes) | |
self.loss_fn = CrossEntropyLoss() | |
def forward(self, input_ids, attention_mask): | |
encoder_outputs = self.model.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
pooled_output = encoder_outputs.last_hidden_state[:, 0, :] | |
logits = self.classifier_head(pooled_output) | |
probs = F.softmax(logits, dim=1) | |
return probs | |
def load_model(model_path): | |
print(f"Loading model from {model_path}...") | |
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
num_classes = checkpoint["classifier_head.weight"].shape[0] | |
print(f"Detected `num_classes`: {num_classes}") | |
model = ConceptDetector(config=config, num_classes=num_classes) | |
model.load_state_dict(checkpoint) | |
model.eval() | |
return model | |
model = load_model(MODEL_SAVE_PATH_1) | |
def replace_digits_with_arc(grid): | |
return [[f'<arc_{num}>' for num in row] for row in grid] | |
def pad_2d_list(grid, pad_token='<arc_pad>', target_size=32): | |
padded_grid = [row + [pad_token] * (target_size - len(row)) for row in grid] | |
while len(padded_grid) < target_size: | |
padded_grid.append([pad_token] * target_size) | |
return padded_grid | |
def reformat_arc_tokens(grid): | |
padded_tokens_2d = pad_2d_list(grid) | |
flattened_tokens = [token for row in padded_tokens_2d for token in row] | |
return " ".join(flattened_tokens) | |
def preprocess_for_inference(input_grid, output_grid): | |
input_grid = replace_digits_with_arc(input_grid) | |
output_grid = replace_digits_with_arc(output_grid) | |
input_tokens = "<s> Input Grid: " + reformat_arc_tokens(input_grid) + " </s>" | |
output_tokens = " Output Grid: " + reformat_arc_tokens(output_grid) + " </s>" | |
return input_tokens + output_tokens | |
CONCEPT_LABELS = {'Above_below': 0, 'Below_row_line': 1, 'Center': 2, 'Copy': 3, 'Horizontal_vertical': 4, 'Inside_outside': 5, 'Remove_below_horizontal_line': 6} | |
CONCEPT_LABELS_INV = {v: k for k, v in CONCEPT_LABELS.items()} | |
CONCEPT_TO_FUNCTION_MAP = { | |
'Center': 'find_center_pixel', | |
'Copy': 'identity', | |
'Above_below': 'flip_horizontal', | |
'color_top_part': 'flip_vertical', | |
'Horizontal_vertical':'Horizontal_vertical', | |
} | |
def run_inference(model, input_grid, output_grid): | |
formatted_input = preprocess_for_inference(input_grid, output_grid) | |
encoded = tokenizer(formatted_input, return_tensors="pt") | |
with torch.no_grad(): | |
probs = model(encoded["input_ids"], encoded["attention_mask"]) | |
predicted_class_index = torch.argmax(probs, dim=1).item() | |
concept_label = CONCEPT_LABELS_INV.get(predicted_class_index, "Unknown Concept") | |
print(f"Predicted class index: {predicted_class_index}") | |
print(f"Predicted concept: {concept_label}") | |
gp_function_name = CONCEPT_TO_FUNCTION_MAP.get(concept_label, None) | |
if gp_function_name is None: | |
print(f"Warning: No matching GP function found for concept `{concept_label}`.") | |
return concept_label, None | |
mapped_function = FUNCTIONS_dictionary.get(gp_function_name, None) | |
return concept_label, mapped_function | |
if __name__ == "__main__": | |
JSON_DATA_PATH = r"C:\Users\gebre\OneDrive - GIST\문서\KakaoTalk Downloads\GPARC_concept_with_vit\GPARC\SRC\data\AboveBelow3.json" | |
with open(JSON_DATA_PATH, "r") as f: | |
data = json.load(f) | |
results = [] | |
for split_name in ["train", "test"]: | |
if split_name in data: | |
print(f"\nRunning inference on `{split_name}` set...") | |
split_results = [] | |
for sample in data[split_name]: | |
input_grid = sample["input"] | |
output_grid = sample["output"] | |
predicted_label, mapped_function = run_inference(model, input_grid, output_grid) | |
split_results.append({ | |
"input": input_grid, | |
"output": output_grid, | |
"predicted_label": predicted_label, | |
"mapped_function": str(mapped_function) | |
}) | |
results.append({ | |
"split": split_name, | |
"predictions": split_results | |
}) | |
with open("inference_results.json", "w") as f: | |
json.dump(results, f, indent=2) | |
print("\nInference completed. Results saved to `inference_results.json`.") | |