import base64 import re import os import pathlib import random import time from io import BytesIO from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import gradio as gr import imgkit from PIL import Image import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast, pipeline gpu = False AUTH_TOKEN = os.environ.get('AUTH_TOKEN') BASE_MODEL = "gpt2" MERGED_MODEL = "gpt2-magic-card" if gpu: image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) image_pipeline.scheduler = scheduler image_pipeline.to("cuda") else: image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) image_pipeline.scheduler = scheduler # Huggingface Spaces have 16GB RAM and 8 CPU cores # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources model = GPT2LMHeadModel.from_pretrained(MERGED_MODEL) tokenizer = GPT2TokenizerFast.from_pretrained(BASE_MODEL) END_TOKEN = '###' eos_id = tokenizer.encode(END_TOKEN) text_pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer) def gen_card_text(name): if name == '': prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" else: prompt = f"Name: {name}\n" print(f'GENERATING CARD TEXT with prompt: {prompt}') output = text_pipeline(prompt, max_length=512, num_return_sequences=1, num_beams=5, temperature=1.5, do_sample=True, repetition_penalty=1.2, eos_token_id=eos_id) result = output[0]['generated_text'].split("###")[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') print(f'GENERATING CARD COMPLETE') print(result) if name == '': pattern = re.compile('Name: (.*)') name = pattern.findall(result)[0] return name, result pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) def run(name): start = time.time() print(f'BEGINNING RUN FOR {name}') name, text = gen_card_text(name) save_name = get_savename('card_data', name, 'txt') pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" print(f"GENERATING IMAGE FOR {prompt_template}") # Regarding sizing see https://huggingface.co/blog/stable_diffusion#:~:text=When%20choosing%20image%20sizes%2C%20we%20advise%20the%20following%3A images = image_pipeline(prompt_template, width=512, height=368, num_inference_steps=20).images card_image = None for image in images: save_name = get_savename('card_images', name, 'png') image.save(f"card_images/{save_name}") card_image = image image_data = pil_to_base64(card_image) html = format_html(text, image_data) save_name = get_savename('card_html', name, 'html') pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') rendered = html_to_png(name, html) end = time.time() print(f'RUN COMPLETED IN {int(end - start)} seconds') return rendered, text, card_image, html def pil_to_base64(image): print('CONVERTING PIL IMAGE TO BASE64 STRING') buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') return img_str def format_html(text, image_data): template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') if "['U']" in text: template = template.replace("{card_color}", 'style="background-color:#5a73ab"') elif "['W']" in text: template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') elif "['G']" in text: template = template.replace("{card_color}", 'style="background-color:#325433"') elif "['B']" in text: template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') elif "['R']" in text: template = template.replace("{card_color}", 'style="background-color:#c2401c"') elif "Type: Land" in text: template = template.replace("{card_color}", 'style="background-color:#aa8c71"') elif "Type: Artifact" in text: template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') else: template = template.replace("{card_color}", 'style="background-color:#edd99d"') pattern = re.compile('Name: (.*)') name = pattern.findall(text)[0] template = template.replace("{name}", name) pattern = re.compile('ManaCost: (.*)') mana_cost = pattern.findall(text)[0] if mana_cost == "None": template = template.replace("{mana_cost}", '') else: symbols = [] for c in mana_cost: if c in {"{", "}"}: continue else: symbols.append(c.lower()) formatted_symbols = [] for s in symbols: formatted_symbols.append(f'') template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) if not isinstance(image_data, (bytes, bytearray)): template = template.replace('{image_data}', f'{image_data}') else: template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] template = template.replace("{card_type}", card_type) if len(card_type) > 30: template = template.replace("{type_size}", "16") else: template = template.replace("{type_size}", "18") pattern = re.compile('Rarity: (.*)') rarity = pattern.findall(text)[0] template = template.replace("{rarity}", f"ss-{rarity}") pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) card_text = pattern.findall(text)[0] text_lines = [] for line in card_text.splitlines(): line = line.replace('{T}', '') line = line.replace('{UT}', '') line = line.replace('{E}', '') line = re.sub(r"{(.*?)}", r''.lower(), line) line = re.sub(r"ms-(.)/(.)", r''.lower(), line) line = line.replace('(', '(').replace(')', ')') text_lines.append(f"
{line}
") template = template.replace("{card_text}", "\n".join(text_lines)) pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) flavor_text = pattern.findall(text) if flavor_text: flavor_text = flavor_text[0] flavor_text_lines = [] for line in flavor_text.splitlines(): flavor_text_lines.append(f"{line}
") template = template.replace("{flavor_text}", "" + "\n".join(flavor_text_lines) + "") else: template = template.replace("{flavor_text}", "") if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: template = template.replace("{text_size}", '16') template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;">', 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;">') else: template = template.replace("{text_size}", '18') pattern = re.compile('Power: (.*)') power = pattern.findall(text) if power: power = power[0] if not power: template = template.replace("{power_toughness}", "") pattern = re.compile('Toughness: (.*)') toughness = pattern.findall(text)[0] template = template.replace("{power_toughness}", f'