conceptt / Vit_concept.py
Liyew's picture
Moved all contents from concept--main to root directory
25baf69
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, T5Config
from torch.nn import CrossEntropyLoss
from custom_t5_vit import CustomT5ForConditionalGeneration
from GP import genetic_programming
from Nods import FUNCTIONS_dictionary
from task_loader import *
import os
# Get the base directory where the current script is running
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Build relative paths
TOKENIZER_PATH = os.path.join(BASE_DIR, "tokenizer_vs22_extendarctokens")
MODEL_SAVE_PATH_1 = os.path.join(BASE_DIR, "model", "final_cls_modell.pt")
print("Loading tokenizer from:", TOKENIZER_PATH)
print("Loading model from:", MODEL_SAVE_PATH_1)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
class CustomT5Config(T5Config):
def __init__(self, PE_mix_strategy="default", use_objidx="yes",
grid_max_height=33, grid_max_width=34, **kwargs):
super().__init__(**kwargs)
self.PE_mix_strategy = PE_mix_strategy
self.use_objidx = use_objidx
self.grid_max_height = grid_max_height
self.grid_max_width = grid_max_width
config = CustomT5Config(
vocab_size=len(tokenizer),
d_model=128,
num_layers=3,
num_decoder_layers=3,
num_heads=8,
d_ff=256,
dropout_rate=0.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
class ConceptDetector(torch.nn.Module):
def __init__(self, config, num_classes):
super().__init__()
self.model = CustomT5ForConditionalGeneration(config)
self.classifier_head = torch.nn.Linear(config.d_model, num_classes)
self.loss_fn = CrossEntropyLoss()
def forward(self, input_ids, attention_mask):
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
pooled_output = encoder_outputs.last_hidden_state[:, 0, :]
logits = self.classifier_head(pooled_output)
probs = F.softmax(logits, dim=1)
return probs
def load_model(model_path):
print(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
num_classes = checkpoint["classifier_head.weight"].shape[0]
print(f"Detected `num_classes`: {num_classes}")
model = ConceptDetector(config=config, num_classes=num_classes)
model.load_state_dict(checkpoint)
model.eval()
return model
model = load_model(MODEL_SAVE_PATH_1)
def replace_digits_with_arc(grid):
return [[f'<arc_{num}>' for num in row] for row in grid]
def pad_2d_list(grid, pad_token='<arc_pad>', target_size=32):
padded_grid = [row + [pad_token] * (target_size - len(row)) for row in grid]
while len(padded_grid) < target_size:
padded_grid.append([pad_token] * target_size)
return padded_grid
def reformat_arc_tokens(grid):
padded_tokens_2d = pad_2d_list(grid)
flattened_tokens = [token for row in padded_tokens_2d for token in row]
return " ".join(flattened_tokens)
def preprocess_for_inference(input_grid, output_grid):
input_grid = replace_digits_with_arc(input_grid)
output_grid = replace_digits_with_arc(output_grid)
input_tokens = "<s> Input Grid: " + reformat_arc_tokens(input_grid) + " </s>"
output_tokens = " Output Grid: " + reformat_arc_tokens(output_grid) + " </s>"
return input_tokens + output_tokens
# Concept Label Mapping
CONCEPT_LABELS = {'Above_below': 0, 'Below_row_line': 1, 'Center': 2, 'Copy': 3, 'Horizontal_vertical': 4, 'Inside_outside': 5, 'Remove_below_horizontal_line': 6}
CONCEPT_LABELS_INV = {v: k for k, v in CONCEPT_LABELS.items()}
# Map ViT Concept to GP Function
CONCEPT_TO_FUNCTION_MAP = {
'Center': 'find_center_pixel',
'Copy': 'identity',
'Above_below': 'flip_horizontal',
'color_top_part': 'flip_vertical',
'Horizontal_vertical':'Horizontal_vertical',
}
def run_inference(model, input_grid, output_grid):
formatted_input = preprocess_for_inference(input_grid, output_grid)
encoded = tokenizer(formatted_input, return_tensors="pt")
with torch.no_grad():
probs = model(encoded["input_ids"], encoded["attention_mask"])
predicted_class_index = torch.argmax(probs, dim=1).item()
concept_label = CONCEPT_LABELS_INV.get(predicted_class_index, "Unknown Concept")
print(f"Predicted class index: {predicted_class_index}")
print(f"Predicted concept: {concept_label}")
gp_function_name = CONCEPT_TO_FUNCTION_MAP.get(concept_label, None)
if gp_function_name is None:
print(f"Warning: No matching GP function found for concept `{concept_label}`.")
return concept_label, None
mapped_function = FUNCTIONS_dictionary.get(gp_function_name, None)
return concept_label, mapped_function
if __name__ == "__main__":
# Path to your JSON file
JSON_DATA_PATH = r"C:\Users\gebre\OneDrive - GIST\문서\KakaoTalk Downloads\GPARC_concept_with_vit\GPARC\SRC\data\AboveBelow3.json"
# Load JSON data
with open(JSON_DATA_PATH, "r") as f:
data = json.load(f)
# Loop through both train and test sets
results = []
for split_name in ["train", "test"]:
if split_name in data:
print(f"\nRunning inference on `{split_name}` set...")
split_results = []
for sample in data[split_name]:
input_grid = sample["input"]
output_grid = sample["output"]
predicted_label, mapped_function = run_inference(model, input_grid, output_grid)
split_results.append({
"input": input_grid,
"output": output_grid,
"predicted_label": predicted_label,
"mapped_function": str(mapped_function) # in case it's a callable
})
results.append({
"split": split_name,
"predictions": split_results
})
# Optionally: save the result to a JSON file
with open("inference_results.json", "w") as f:
json.dump(results, f, indent=2)
print("\nInference completed. Results saved to `inference_results.json`.")