Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ee2e0b7
1
Parent(s):
6845a5c
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,25 +7,24 @@ from mario_gpt.prompter import Prompter
|
|
| 7 |
from mario_gpt.lm import MarioLM
|
| 8 |
from mario_gpt.utils import view_level, convert_level_to_png
|
| 9 |
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
mario_lm = MarioLM()
|
| 16 |
device = torch.device('cuda')
|
| 17 |
mario_lm = mario_lm.to(device)
|
| 18 |
TILE_DIR = "data/tiles"
|
| 19 |
|
| 20 |
-
|
| 21 |
-
ngrok.set_auth_token(os.environ.get('NGROK_TOKEN'))
|
| 22 |
-
http_tunnel = ngrok.connect(7861,bind_tls=True)
|
| 23 |
|
| 24 |
def make_html_file(generated_level):
|
| 25 |
level_text = f"""{'''
|
| 26 |
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
|
| 27 |
unique_id = uuid.uuid1()
|
| 28 |
-
with open(f"demo-{unique_id}.html", 'w', encoding='utf-8') as f:
|
| 29 |
f.write(f'''<!DOCTYPE html>
|
| 30 |
<html lang="en">
|
| 31 |
|
|
@@ -42,7 +41,7 @@ def make_html_file(generated_level):
|
|
| 42 |
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
|
| 43 |
}});
|
| 44 |
cheerpjCreateDisplay(512, 500);
|
| 45 |
-
cheerpjRunJar("/app/mario.jar");
|
| 46 |
</script>
|
| 47 |
</html>''')
|
| 48 |
return f"demo-{unique_id}.html"
|
|
@@ -61,9 +60,9 @@ def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size =
|
|
| 61 |
filename = make_html_file(generated_level)
|
| 62 |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
|
| 63 |
|
| 64 |
-
gradio_html = f'''<div
|
| 65 |
-
<iframe width=512 height=512 style="margin: 0 auto" src="
|
| 66 |
-
<p style="text-align:center">Press the arrow keys to move. Press <code>s</code> to jump and <code>
|
| 67 |
</div>'''
|
| 68 |
return [img, gradio_html]
|
| 69 |
|
|
@@ -72,16 +71,16 @@ with gr.Blocks() as demo:
|
|
| 72 |
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
|
| 73 |
''')
|
| 74 |
with gr.Tabs():
|
| 75 |
-
with gr.TabItem("Type prompt"):
|
| 76 |
-
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
|
| 77 |
with gr.TabItem("Compose prompt"):
|
| 78 |
with gr.Row():
|
| 79 |
-
pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
|
| 80 |
-
enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
|
| 81 |
with gr.Row():
|
| 82 |
-
blocks = gr.Radio(["little", "some", "many"], label="blocks")
|
| 83 |
-
elevation = gr.Radio(["low", "high"], label="
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
with gr.Accordion(label="Advanced settings", open=False):
|
| 86 |
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
|
| 87 |
level_size = gr.Number(value=1399, precision=0, label="level_size")
|
|
@@ -104,4 +103,7 @@ with gr.Blocks() as demo:
|
|
| 104 |
fn=generate,
|
| 105 |
cache_examples=True,
|
| 106 |
)
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from mario_gpt.lm import MarioLM
|
| 8 |
from mario_gpt.utils import view_level, convert_level_to_png
|
| 9 |
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from fastapi.staticfiles import StaticFiles
|
| 12 |
|
| 13 |
+
import os
|
| 14 |
+
import uvicorn
|
| 15 |
|
| 16 |
mario_lm = MarioLM()
|
| 17 |
device = torch.device('cuda')
|
| 18 |
mario_lm = mario_lm.to(device)
|
| 19 |
TILE_DIR = "data/tiles"
|
| 20 |
|
| 21 |
+
app = FastAPI()
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def make_html_file(generated_level):
|
| 24 |
level_text = f"""{'''
|
| 25 |
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
|
| 26 |
unique_id = uuid.uuid1()
|
| 27 |
+
with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
|
| 28 |
f.write(f'''<!DOCTYPE html>
|
| 29 |
<html lang="en">
|
| 30 |
|
|
|
|
| 41 |
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
|
| 42 |
}});
|
| 43 |
cheerpjCreateDisplay(512, 500);
|
| 44 |
+
cheerpjRunJar("/app/static/mario.jar");
|
| 45 |
</script>
|
| 46 |
</html>''')
|
| 47 |
return f"demo-{unique_id}.html"
|
|
|
|
| 60 |
filename = make_html_file(generated_level)
|
| 61 |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
|
| 62 |
|
| 63 |
+
gradio_html = f'''<div>
|
| 64 |
+
<iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
|
| 65 |
+
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
|
| 66 |
</div>'''
|
| 67 |
return [img, gradio_html]
|
| 68 |
|
|
|
|
| 71 |
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
|
| 72 |
''')
|
| 73 |
with gr.Tabs():
|
|
|
|
|
|
|
| 74 |
with gr.TabItem("Compose prompt"):
|
| 75 |
with gr.Row():
|
| 76 |
+
pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
|
| 77 |
+
enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
|
| 78 |
with gr.Row():
|
| 79 |
+
blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
|
| 80 |
+
elevation = gr.Radio(["low", "high"], label="Elevation?")
|
| 81 |
+
with gr.TabItem("Type prompt"):
|
| 82 |
+
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
|
| 83 |
+
|
| 84 |
with gr.Accordion(label="Advanced settings", open=False):
|
| 85 |
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
|
| 86 |
level_size = gr.Number(value=1399, precision=0, label="level_size")
|
|
|
|
| 103 |
fn=generate,
|
| 104 |
cache_examples=True,
|
| 105 |
)
|
| 106 |
+
|
| 107 |
+
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
|
| 108 |
+
app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
|
| 109 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|