File size: 6,175 Bytes
3052d0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, T5Config
from torch.nn import CrossEntropyLoss
from custom_t5_vit import CustomT5ForConditionalGeneration
from GP import genetic_programming
from Nods import FUNCTIONS_dictionary
from task_loader import *
import os
# Get the base directory where the current script is running
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Build relative paths
TOKENIZER_PATH = os.path.join(BASE_DIR, "tokenizer_vs22_extendarctokens")
MODEL_SAVE_PATH_1 = os.path.join(BASE_DIR, "model", "final_cls_modell.pt")
print("Loading tokenizer from:", TOKENIZER_PATH)
print("Loading model from:", MODEL_SAVE_PATH_1)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
class CustomT5Config(T5Config):
def __init__(self, PE_mix_strategy="default", use_objidx="yes",
grid_max_height=33, grid_max_width=34, **kwargs):
super().__init__(**kwargs)
self.PE_mix_strategy = PE_mix_strategy
self.use_objidx = use_objidx
self.grid_max_height = grid_max_height
self.grid_max_width = grid_max_width
config = CustomT5Config(
vocab_size=len(tokenizer),
d_model=128,
num_layers=3,
num_decoder_layers=3,
num_heads=8,
d_ff=256,
dropout_rate=0.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
class ConceptDetector(torch.nn.Module):
def __init__(self, config, num_classes):
super().__init__()
self.model = CustomT5ForConditionalGeneration(config)
self.classifier_head = torch.nn.Linear(config.d_model, num_classes)
self.loss_fn = CrossEntropyLoss()
def forward(self, input_ids, attention_mask):
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
pooled_output = encoder_outputs.last_hidden_state[:, 0, :]
logits = self.classifier_head(pooled_output)
probs = F.softmax(logits, dim=1)
return probs
def load_model(model_path):
print(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
num_classes = checkpoint["classifier_head.weight"].shape[0]
print(f"Detected `num_classes`: {num_classes}")
model = ConceptDetector(config=config, num_classes=num_classes)
model.load_state_dict(checkpoint)
model.eval()
return model
model = load_model(MODEL_SAVE_PATH_1)
def replace_digits_with_arc(grid):
return [[f'<arc_{num}>' for num in row] for row in grid]
def pad_2d_list(grid, pad_token='<arc_pad>', target_size=32):
padded_grid = [row + [pad_token] * (target_size - len(row)) for row in grid]
while len(padded_grid) < target_size:
padded_grid.append([pad_token] * target_size)
return padded_grid
def reformat_arc_tokens(grid):
padded_tokens_2d = pad_2d_list(grid)
flattened_tokens = [token for row in padded_tokens_2d for token in row]
return " ".join(flattened_tokens)
def preprocess_for_inference(input_grid, output_grid):
input_grid = replace_digits_with_arc(input_grid)
output_grid = replace_digits_with_arc(output_grid)
input_tokens = "<s> Input Grid: " + reformat_arc_tokens(input_grid) + " </s>"
output_tokens = " Output Grid: " + reformat_arc_tokens(output_grid) + " </s>"
return input_tokens + output_tokens
# Concept Label Mapping
CONCEPT_LABELS = {'Above_below': 0, 'Below_row_line': 1, 'Center': 2, 'Copy': 3, 'Horizontal_vertical': 4, 'Inside_outside': 5, 'Remove_below_horizontal_line': 6}
CONCEPT_LABELS_INV = {v: k for k, v in CONCEPT_LABELS.items()}
# Map ViT Concept to GP Function
CONCEPT_TO_FUNCTION_MAP = {
'Center': 'find_center_pixel',
'Copy': 'identity',
'Above_below': 'flip_horizontal',
'color_top_part': 'flip_vertical',
'Horizontal_vertical':'Horizontal_vertical',
}
def run_inference(model, input_grid, output_grid):
formatted_input = preprocess_for_inference(input_grid, output_grid)
encoded = tokenizer(formatted_input, return_tensors="pt")
with torch.no_grad():
probs = model(encoded["input_ids"], encoded["attention_mask"])
predicted_class_index = torch.argmax(probs, dim=1).item()
concept_label = CONCEPT_LABELS_INV.get(predicted_class_index, "Unknown Concept")
print(f"Predicted class index: {predicted_class_index}")
print(f"Predicted concept: {concept_label}")
gp_function_name = CONCEPT_TO_FUNCTION_MAP.get(concept_label, None)
if gp_function_name is None:
print(f"Warning: No matching GP function found for concept `{concept_label}`.")
return concept_label, None
mapped_function = FUNCTIONS_dictionary.get(gp_function_name, None)
return concept_label, mapped_function
if __name__ == "__main__":
# Path to your JSON file
JSON_DATA_PATH = r"C:\Users\gebre\OneDrive - GIST\문서\KakaoTalk Downloads\GPARC_concept_with_vit\GPARC\SRC\data\AboveBelow3.json"
# Load JSON data
with open(JSON_DATA_PATH, "r") as f:
data = json.load(f)
# Loop through both train and test sets
results = []
for split_name in ["train", "test"]:
if split_name in data:
print(f"\nRunning inference on `{split_name}` set...")
split_results = []
for sample in data[split_name]:
input_grid = sample["input"]
output_grid = sample["output"]
predicted_label, mapped_function = run_inference(model, input_grid, output_grid)
split_results.append({
"input": input_grid,
"output": output_grid,
"predicted_label": predicted_label,
"mapped_function": str(mapped_function) # in case it's a callable
})
results.append({
"split": split_name,
"predictions": split_results
})
# Optionally: save the result to a JSON file
with open("inference_results.json", "w") as f:
json.dump(results, f, indent=2)
print("\nInference completed. Results saved to `inference_results.json`.")
|