File size: 4,486 Bytes
7901fac
0ae684c
 
 
fff6204
7901fac
0ae684c
 
 
 
 
 
 
 
 
 
 
7901fac
0ae684c
7901fac
0ae684c
fff6204
 
 
 
 
5781f7f
0ae684c
 
 
 
5781f7f
fff6204
0ae684c
 
7901fac
fff6204
 
7901fac
0ae684c
fff6204
 
 
0ae684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import torch
from PIL import Image
import gc
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from torch.utils.data import DataLoader

# Function to load Colpali model
@st.cache_resource
def load_colpali_model():
    model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.float32, device_map="cpu").eval()
    model.load_adapter("vidore/colpali")
    processor = AutoProcessor.from_pretrained("vidore/colpali")
    return model, processor

# Function to load Qwen2-VL model
@st.cache_resource
def load_qwen_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
    return model, processor

# Function to clear GPU memory
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Streamlit Interface
st.title("OCR and Visual Language Model Demo")
st.write("Upload an image for OCR extraction and then ask a question about the image.")

# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if image:
    img = Image.open(image)
    st.image(img, caption="Uploaded Image", use_column_width=True)

    # OCR Extraction with Colpali
    st.write("Extracting text from image...")
    colpali_model, colpali_processor = load_colpali_model()
    
    # Process image for Colpali
    dataloader = DataLoader(
        [img],
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: process_images(colpali_processor, x),
    )
    
    for batch_doc in dataloader:
        with torch.no_grad():
            batch_doc = {k: v.to('cpu') for k, v in batch_doc.items()}
            embeddings_doc = colpali_model(**batch_doc)
    
    # For simplicity, we'll use a dummy query to extract text
    dummy_query = "Extract all text from the image"
    query_dataloader = DataLoader(
        [dummy_query],
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: process_queries(colpali_processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
    )
    
    for batch_query in query_dataloader:
        with torch.no_grad():
            batch_query = {k: v.to('cpu') for k, v in batch_query.items()}
            embeddings_query = colpali_model(**batch_query)
    
    # In a real scenario, you'd use these embeddings to extract text
    # For this demo, we'll just show a placeholder text
    extracted_text = "This is a placeholder for the extracted text. In a real scenario, you would use the embeddings to extract actual text from the image."
    
    st.write("Extracted Text:")
    st.write(extracted_text)
    
    # Clear Colpali model from memory
    del colpali_model, colpali_processor
    clear_memory()

    # Text input field for question
    question = st.text_input("Ask a question about the image and extracted text")

    if question:
        st.write("Processing with Qwen2-VL...")
        qwen_model, qwen_processor = load_qwen_model()

        # Prepare inputs for Qwen2-VL
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img},
                    {"type": "text", "text": f"Extracted text: {extracted_text}\n\nQuestion: {question}"},
                ],
            }
        ]

        # Prepare for inference
        text_input = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = qwen_processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")

        # Move tensors to CPU
        inputs = inputs.to("cpu")

        # Run the model and generate output
        with torch.no_grad():
            generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)

        # Decode the output text
        generated_text = qwen_processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Display the response
        st.write("Model's response:", generated_text)

        # Clear Qwen model from memory
        del qwen_model, qwen_processor
        clear_memory()