RoBERTa-large-finnish / flax_model_to_pytorch.py
aapot
move cleaned data tokenizer to main
6d2b0f2
raw
history blame
845 Bytes
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM, AutoTokenizer
import torch
import numpy as np
import jax
import jax.numpy as jnp
MODEL_PATH = "./"
model = FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH)
def to_f32(t):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
model.params = to_f32(model.params)
model.save_pretrained(MODEL_PATH)
pt_model = RobertaForMaskedLM.from_pretrained(MODEL_PATH, from_flax=True)
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = pt_model(input_ids_pt).logits
print(logits_pt)
logits_fx = model(input_ids).logits
print(logits_fx)
pt_model.save_pretrained(MODEL_PATH)
# also save tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)