Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json, urllib | |
| from model import GPT, GPTConfig | |
| from utils import sample | |
| import torch | |
| import pickle | |
| device = torch.device('cpu') | |
| # Create the model | |
| vocab_size=147 | |
| block_size=128 | |
| mconf = GPTConfig(vocab_size, block_size, | |
| n_layer=6, n_head=8, n_embd=256) | |
| model = GPT(mconf) | |
| # Load checkpoint | |
| model.load_state_dict(torch.load('another_epoch_1.75total.ckpt', map_location=device)) | |
| # Vocab | |
| stoi = pickle.load(open('stoi.pkl', 'rb')) | |
| itos = pickle.load(open('itos.pkl', 'rb')) | |
| # Generate function | |
| def generate_song(randomize, title, nu, ks, key): | |
| # Start sequence | |
| context = b"""T:""" | |
| if not randomize: | |
| context += bytes(title+'\n', 'utf-8') | |
| context += bytes('M:'+ks+'\n', 'utf-8') | |
| context += bytes('K:'+key+'\n', 'utf-8') | |
| context += bytes('L:'+nu+'\n', 'utf-8') | |
| # Model inputs | |
| x = torch.tensor([stoi[s] for s in context], dtype=torch.long)[None,...].to(device) | |
| # Completion | |
| y = sample(model, x, 400, temperature=1.0, sample=True, top_k=10)[0] | |
| completion = ''.join([chr(itos[int(i)]) for i in y]) | |
| # Return the first song | |
| song = completion_to_song(completion) | |
| html_song = song.replace('\n', '<br>') | |
| url_song = urllib.parse.quote(song, safe='~@#$&()*!+=:;,?/\'') | |
| html_text = '<p><a href="https://editor.drawthedots.com?t='+url_song+'" target="_blank"><b>EDIT LINK - click to open abcjs editor (allows download and playback)</b></a></p>'+"<p>"+html_song+'</p>' | |
| return html_text | |
| # Gradio demo | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("Quick demo for [WhistleGen v2](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) which lets you generate folk music using a transformer model. I can't get the javascript needed for rendering and playback working with gradio, so this shows the raw ABC notation from the model and a link to view it properly in an external editor.") | |
| with gr.Row(): | |
| title = gr.Text(label='Title', value='The Song of AI') | |
| with gr.Column(): | |
| nu = gr.Text(label='Note unit', value='1/8') | |
| with gr.Row(): | |
| key_signature = gr.Dropdown(['3/4', '4/4', '6/8', 'Random'], value='4/4', label='Time Signature') | |
| with gr.Column(): | |
| key = gr.Text(label='Key', value='D') | |
| with gr.Row(): | |
| randomize = gr.Checkbox(label='Randomize (ignores settings above)', value=True) | |
| with gr.Row(): | |
| out = gr.HTML(label="Output", value='Output...') | |
| btn = gr.Button("Run") | |
| btn.click(fn=generate_song, inputs=[randomize, title, nu, key_signature, key], outputs=out) | |
| with gr.Row(): | |
| gr.Markdown("") | |
| gr.Markdown("This is currently using an early model. See the [report](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) for training info and updates.") | |
| demo.launch(enable_queue=True) |