from datasets import load_dataset | |
from plaid.bridges.huggingface_bridge import huggingface_dataset_to_plaid | |
import mmgp_tensile2d | |
model = mmgp_tensile2d.load() | |
hf_dataset = load_dataset("PLAID-datasets/Tensile2d", split="all_samples") | |
ids_test = hf_dataset.description["split"]['test'][:5] | |
dataset_test, _ = huggingface_dataset_to_plaid(hf_dataset, ids = ids_test, processes_number = 5, verbose = True) | |
print("Check the 'U1' field is not present: dataset_test[0].get_field('U1') =", dataset_test[0].get_field('U1')) | |
print("Run prediction...") | |
dataset_pred = model.predict(dataset_test) | |
print("Check the 'U1' field is not present: dataset_pred[0].get_field('U1') =", dataset_pred[0].get_field('U1')) | |