import gradio as gr from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import spaces import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_id="mychen76/paligemma-receipt-json-3b-mix-448-v2b" dtype = torch.bfloat16 model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=dtype).to(device).eval() processor = PaliGemmaProcessor.from_pretrained(model_id) MAX_TOKENS = 512 import re # let's turn that into JSON source from Donut def token2json(tokens, is_inner_value=False, added_vocab=None): """ Convert a (generated) token sequence into an ordered JSON format. """ if added_vocab is None: added_vocab = processor.tokenizer.get_added_vocab() output = {} while tokens: start_token = re.search(r"", tokens, re.IGNORECASE) if start_token is None: break key = start_token.group(1) key_escaped = re.escape(key) end_token = re.search(rf"", tokens, re.IGNORECASE) start_token = start_token.group() if end_token is None: tokens = tokens.replace(start_token, "") else: end_token = end_token.group() start_token_escaped = re.escape(start_token) end_token_escaped = re.escape(end_token) content = re.search( f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL ) if content is not None: content = content.group(1).strip() if r""): leaf = leaf.strip() if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>": leaf = leaf[1:-2] # for categorical special tokens output[key].append(leaf) if len(output[key]) == 1: output[key] = output[key][0] tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() if tokens[:6] == r"": # non-leaf nodes return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab) if len(output): return [output] if is_inner_value else output else: return [] if is_inner_value else {"text_sequence": tokens} def modify_caption(caption: str) -> str: """ Removes specific prefixes from captions. Args: caption (str): A string containing a caption. Returns: str: The caption with the prefix removed if it was present. """ # Define the prefixes to remove prefix_substrings = [ ('EXTRACT_JSON_RECEIPT', '') ] # Create a regex pattern to match any of the prefixes pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) replacers = {opening: replacer for opening, replacer in prefix_substrings} # Function to replace matched prefix with its corresponding replacement def replace_fn(match): return replacers[match.group(0)] # Apply the regex to the caption return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) def json_inference(image, input_text="EXTRACT_JSON_RECEIPT", device="cuda:0", max_new_tokens=512): inputs = processor(text=input_text, images=image, return_tensors="pt").to(device) # Autoregressively generate use greedy decoding here,for more fancy methods see https://huggingface.co/blog/how-to-generate generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) # Next turn each predicted token ID back into a string using the decode method # We chop of the prompt, which consists of image tokens and our text prompt image_token_index = model.config.image_token_index num_image_tokens = len(generated_ids[generated_ids==image_token_index]) num_text_tokens = len(processor.tokenizer.encode(input_text)) num_prompt_tokens = num_image_tokens + num_text_tokens + 2 generated_text = processor.batch_decode(generated_ids[:, num_prompt_tokens:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] # convert it into JSON using the method below (taken from Donut): generated_json = token2json(generated_text) return generated_text, generated_json # enable space # @spaces.GPU def create_captions_rich(image): torch.cuda.empty_cache() prompt = "EXTRACT_JSON_RECEIPT" generated_text, generated_json = json_inference(image=image,input_text="EXTRACT_JSON_RECEIPT", device=device, max_new_tokens=MAX_TOKENS) return generated_json css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.HTML("

PaliGemma Receipt and Invoice Model

") with gr.Tab(label="Receipt or Invoices Image"): with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Picture") submit_btn = gr.Button(value="Submit") output = gr.Text(label="Receipt Json") gr.Examples([["receipt_image1.jpg"], ["receipt_image2.jpg"], ["receipt_image3.png"],["receipt_image4.png"]], inputs = [input_img], outputs = [output], fn=create_captions_rich, label='Try captioning on examples' ) submit_btn.click(create_captions_rich, [input_img], [output]) demo.queue().launch(share=True,server_name="0.0.0.0",debug=True)