import gradio as gr import os import tiktoken import torch import torch.nn as nn import torch.nn.functional as F from model_gpt2 import GPT, GPTConfig from huggingface_hub import hf_hub_download import joblib REPO_ID = "sayanbanerjee32/nanogpt2_test" model_file = "saved_model/ckpt.pt" # load model model_spec = torch.load( hf_hub_download(repo_id=REPO_ID, filename=model_file), map_location=torch.device('cpu')) model_args = model_spec['model_args'] model_weights = model_spec['model'] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more unwanted_prefix = '_orig_mod.' for k,v in list(model_weights.items()): if k.startswith(unwanted_prefix): model_weights[k[len(unwanted_prefix):]] = model_weights.pop(k) #### load model and weight modelconf = GPTConfig(**model_args) trained_model = GPT(modelconf) trained_model.load_state_dict(model_weights) # import the encoder and decoder from tiktoken enc = tiktoken.get_encoding("gpt2") def generate_text(seed_text, max_new_tokens, temperature, top_k = None): text = seed_text if seed_text is not None else " " text = text if text.endswith(" ") else seed_text + " " context = torch.tensor(enc.encode(text), dtype=torch.long).unsqueeze(0) temperature = temperature if temperature > 0 else 1e-5 top_k = top_k if top_k is None or top_k > 0 else None return enc.decode(trained_model.generate(context, temperature = temperature, top_k = top_k, max_new_tokens=max_new_tokens)[0].tolist()) with gr.Blocks() as demo: gr.HTML("