Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	update gradio demo
Browse files- .streamlit/config.toml +0 -7
- README.md +2 -2
- app.py +115 -59
- {static → assets}/SimHei.ttf +0 -0
- assets/assistant.png +0 -0
- assets/human.png +0 -0
- controller.py +3 -1
- conversation.py +259 -0
- gallery/child_1.jpg +0 -0
- gallery/child_2.jpg +0 -0
- gallery/child_3.jpg +0 -0
- gradio_web_server.py +824 -0
- library.py +0 -95
- mm_utils.py +0 -102
- model_worker.py +283 -140
- requirements.txt +14 -4
- utils.py +63 -24
    	
        .streamlit/config.toml
    DELETED
    
    | @@ -1,7 +0,0 @@ | |
| 1 | 
            -
            [server]
         | 
| 2 | 
            -
            enableStaticServing = false
         | 
| 3 | 
            -
            enableXsrfProtection = false
         | 
| 4 | 
            -
            enableCORS = false
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            [browser] # This ip and port will show in command prompt
         | 
| 7 | 
            -
            enableCORS = false
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        README.md
    CHANGED
    
    | @@ -3,8 +3,8 @@ title: InternVL | |
| 3 | 
             
            emoji: ⚡
         | 
| 4 | 
             
            colorFrom: yellow
         | 
| 5 | 
             
            colorTo: gray
         | 
| 6 | 
            -
            sdk:  | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
|  | |
| 3 | 
             
            emoji: ⚡
         | 
| 4 | 
             
            colorFrom: yellow
         | 
| 5 | 
             
            colorTo: gray
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.36.1
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,60 +1,116 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
                header {visibility: hidden;}
         | 
| 9 | 
            -
            </style>
         | 
| 10 | 
            -
            """
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            st.markdown(hide_streamlit_style, unsafe_allow_html=True)
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            st.markdown(
         | 
| 15 | 
            -
                """
         | 
| 16 | 
            -
                <style>
         | 
| 17 | 
            -
                html, body, .fullScreenFrame, .fullScreenFrame iframe {
         | 
| 18 | 
            -
                    margin: 0;
         | 
| 19 | 
            -
                    padding: 0;
         | 
| 20 | 
            -
                    height: 100%;
         | 
| 21 | 
            -
                    width: 100%;
         | 
| 22 | 
            -
                    border: none;
         | 
| 23 | 
            -
                    display: block;
         | 
| 24 | 
            -
                    overflow: hidden;
         | 
| 25 | 
            -
                }
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                .fullScreenFrame {
         | 
| 28 | 
            -
                    position: fixed;
         | 
| 29 | 
            -
                    top: 0;
         | 
| 30 | 
            -
                    left: 0;
         | 
| 31 | 
            -
                    right: 0;
         | 
| 32 | 
            -
                    bottom: 0;
         | 
| 33 | 
            -
                    z-index: 9999;
         | 
| 34 | 
            -
                }
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                .main .block-container {
         | 
| 37 | 
            -
                    padding: 0;
         | 
| 38 | 
            -
                    margin: 0;
         | 
| 39 | 
            -
                    height: 100vh;
         | 
| 40 | 
            -
                }
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                /* Hide Streamlit header and footer */
         | 
| 43 | 
            -
                header, footer {
         | 
| 44 | 
            -
                    display: none;
         | 
| 45 | 
            -
                }
         | 
| 46 | 
            -
                </style>
         | 
| 47 | 
            -
                """,
         | 
| 48 | 
            -
                unsafe_allow_html=True,
         | 
| 49 | 
            -
            )
         | 
| 50 | 
            -
             | 
| 51 | 
            -
            # Embed the external Streamlit webpage
         | 
| 52 | 
            -
            st.markdown(
         | 
| 53 | 
            -
                """
         | 
| 54 | 
            -
                <div class="fullScreenFrame">
         | 
| 55 | 
            -
                    <iframe src="https://internvl.opengvlab.com/"></iframe>
         | 
| 56 | 
            -
                </div>
         | 
| 57 | 
            -
                """,
         | 
| 58 | 
            -
                unsafe_allow_html=True,
         | 
| 59 | 
            -
            )
         | 
| 60 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import fire
         | 
| 2 | 
            +
            import subprocess
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import signal
         | 
| 6 | 
            +
            import subprocess
         | 
| 7 | 
            +
            import atexit
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
            +
             | 
| 10 | 
            +
            def kill_processes_by_cmd_substring(cmd_substring):
         | 
| 11 | 
            +
                # execute `ps -ef` and obtain its output
         | 
| 12 | 
            +
                result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
         | 
| 13 | 
            +
                lines = result.stdout.splitlines()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                # visit each line
         | 
| 16 | 
            +
                for line in lines:
         | 
| 17 | 
            +
                    if cmd_substring in line:
         | 
| 18 | 
            +
                        # extract PID
         | 
| 19 | 
            +
                        parts = line.split()
         | 
| 20 | 
            +
                        pid = int(parts[1])
         | 
| 21 | 
            +
                        print(f"Killing process with PID: {pid}, CMD: {line}")
         | 
| 22 | 
            +
                        os.kill(pid, signal.SIGTERM)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def main(
         | 
| 26 | 
            +
                python_path="python",
         | 
| 27 | 
            +
                run_controller=True,
         | 
| 28 | 
            +
                run_worker=True,
         | 
| 29 | 
            +
                run_gradio=True,
         | 
| 30 | 
            +
                controller_port=10086,
         | 
| 31 | 
            +
                gradio_port=10087,
         | 
| 32 | 
            +
                worker_names=[
         | 
| 33 | 
            +
                    "OpenGVLab/InternVL2-8B",
         | 
| 34 | 
            +
                ],
         | 
| 35 | 
            +
                run_sd_worker=False,
         | 
| 36 | 
            +
                **kwargs,
         | 
| 37 | 
            +
            ):
         | 
| 38 | 
            +
                host = "http://0.0.0.0"
         | 
| 39 | 
            +
                controller_process = None
         | 
| 40 | 
            +
                if run_controller:
         | 
| 41 | 
            +
                    # python controller.py --host 0.0.0.0 --port 10086
         | 
| 42 | 
            +
                    cmd_args = [
         | 
| 43 | 
            +
                        f"{python_path}",
         | 
| 44 | 
            +
                        "controller.py",
         | 
| 45 | 
            +
                        "--host",
         | 
| 46 | 
            +
                        "0.0.0.0",
         | 
| 47 | 
            +
                        "--port",
         | 
| 48 | 
            +
                        f"{controller_port}",
         | 
| 49 | 
            +
                    ]
         | 
| 50 | 
            +
                    kill_processes_by_cmd_substring(" ".join(cmd_args))
         | 
| 51 | 
            +
                    print("Launching controller: ", " ".join(cmd_args))
         | 
| 52 | 
            +
                    controller_process = subprocess.Popen(cmd_args)
         | 
| 53 | 
            +
                    atexit.register(controller_process.terminate)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                worker_processes = []
         | 
| 56 | 
            +
                if run_worker:
         | 
| 57 | 
            +
                    worker_port = 10088
         | 
| 58 | 
            +
                    for worker_name in worker_names:
         | 
| 59 | 
            +
                        cmd_args = [
         | 
| 60 | 
            +
                            f"{python_path}",
         | 
| 61 | 
            +
                            "model_worker.py",
         | 
| 62 | 
            +
                            "--port",
         | 
| 63 | 
            +
                            f"{worker_port}",
         | 
| 64 | 
            +
                            "--controller-url",
         | 
| 65 | 
            +
                            f"{host}:{controller_port}",
         | 
| 66 | 
            +
                            "--model-path",
         | 
| 67 | 
            +
                            f"{worker_name}",
         | 
| 68 | 
            +
                            "--load-8bit",
         | 
| 69 | 
            +
                        ]
         | 
| 70 | 
            +
                        kill_processes_by_cmd_substring(" ".join(cmd_args))
         | 
| 71 | 
            +
                        print("Launching worker: ", " ".join(cmd_args))
         | 
| 72 | 
            +
                        worker_process = subprocess.Popen(cmd_args)
         | 
| 73 | 
            +
                        worker_processes.append(worker_process)
         | 
| 74 | 
            +
                        atexit.register(worker_process.terminate)
         | 
| 75 | 
            +
                        worker_port += 1
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                time.sleep(10)
         | 
| 78 | 
            +
                gradio_process = None
         | 
| 79 | 
            +
                if run_gradio:
         | 
| 80 | 
            +
                    #  python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
         | 
| 81 | 
            +
                    cmd_args = [
         | 
| 82 | 
            +
                        f"{python_path}",
         | 
| 83 | 
            +
                        "gradio_web_server.py",
         | 
| 84 | 
            +
                        "--port",
         | 
| 85 | 
            +
                        f"{gradio_port}",
         | 
| 86 | 
            +
                        "--controller-url",
         | 
| 87 | 
            +
                        f"{host}:{controller_port}",
         | 
| 88 | 
            +
                        "--model-list-mode",
         | 
| 89 | 
            +
                        "reload",
         | 
| 90 | 
            +
                    ]
         | 
| 91 | 
            +
                    kill_processes_by_cmd_substring(" ".join(cmd_args))
         | 
| 92 | 
            +
                    print("Launching gradio: ", " ".join(cmd_args))
         | 
| 93 | 
            +
                    gradio_process = subprocess.Popen(cmd_args)
         | 
| 94 | 
            +
                    atexit.register(gradio_process.terminate)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                sd_worker_process = None
         | 
| 97 | 
            +
                if run_sd_worker:
         | 
| 98 | 
            +
                    # python model_worker.py --port 10088 --controller-address http://
         | 
| 99 | 
            +
                    cmd_args = [f"{python_path}", "sd_worker.py"]
         | 
| 100 | 
            +
                    kill_processes_by_cmd_substring(" ".join(cmd_args))
         | 
| 101 | 
            +
                    print("Launching sd_worker: ", " ".join(cmd_args))
         | 
| 102 | 
            +
                    sd_worker_process = subprocess.Popen(cmd_args)
         | 
| 103 | 
            +
                    atexit.register(sd_worker_process.terminate)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                for worker_process in worker_processes:
         | 
| 106 | 
            +
                    worker_process.wait()
         | 
| 107 | 
            +
                if controller_process:
         | 
| 108 | 
            +
                    controller_process.wait()
         | 
| 109 | 
            +
                if gradio_process:
         | 
| 110 | 
            +
                    gradio_process.wait()
         | 
| 111 | 
            +
                if sd_worker_process:
         | 
| 112 | 
            +
                    sd_worker_process.wait()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            if __name__ == "__main__":
         | 
| 116 | 
            +
                fire.Fire(main)
         | 
    	
        {static → assets}/SimHei.ttf
    RENAMED
    
    | 
            File without changes
         | 
    	
        assets/assistant.png
    ADDED
    
    |   | 
    	
        assets/human.png
    ADDED
    
    |   | 
    	
        controller.py
    CHANGED
    
    | @@ -5,9 +5,9 @@ It sends worker addresses to clients. | |
| 5 | 
             
            import argparse
         | 
| 6 | 
             
            import dataclasses
         | 
| 7 | 
             
            import json
         | 
|  | |
| 8 | 
             
            import threading
         | 
| 9 | 
             
            import time
         | 
| 10 | 
            -
            import re
         | 
| 11 | 
             
            from enum import Enum, auto
         | 
| 12 | 
             
            from typing import List
         | 
| 13 |  | 
| @@ -113,6 +113,8 @@ class Controller: | |
| 113 | 
             
                        model_names.update(w_info.model_names)
         | 
| 114 |  | 
| 115 | 
             
                    def extract_key(s):
         | 
|  | |
|  | |
| 116 | 
             
                        match = re.match(r'InternVL2-(\d+)B', s)
         | 
| 117 | 
             
                        if match:
         | 
| 118 | 
             
                            return int(match.group(1))
         | 
|  | |
| 5 | 
             
            import argparse
         | 
| 6 | 
             
            import dataclasses
         | 
| 7 | 
             
            import json
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
             
            import threading
         | 
| 10 | 
             
            import time
         | 
|  | |
| 11 | 
             
            from enum import Enum, auto
         | 
| 12 | 
             
            from typing import List
         | 
| 13 |  | 
|  | |
| 113 | 
             
                        model_names.update(w_info.model_names)
         | 
| 114 |  | 
| 115 | 
             
                    def extract_key(s):
         | 
| 116 | 
            +
                        if 'Pro' in s:
         | 
| 117 | 
            +
                            return 999
         | 
| 118 | 
             
                        match = re.match(r'InternVL2-(\d+)B', s)
         | 
| 119 | 
             
                        if match:
         | 
| 120 | 
             
                            return int(match.group(1))
         | 
    	
        conversation.py
    ADDED
    
    | @@ -0,0 +1,259 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import dataclasses
         | 
| 3 | 
            +
            import base64
         | 
| 4 | 
            +
            import copy
         | 
| 5 | 
            +
            import hashlib
         | 
| 6 | 
            +
            import datetime
         | 
| 7 | 
            +
            from io import BytesIO
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
            from typing import Any, List, Dict, Union
         | 
| 10 | 
            +
            from dataclasses import field
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from utils import LOGDIR
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def pil2base64(img: Image.Image) -> str:
         | 
| 16 | 
            +
                buffered = BytesIO()
         | 
| 17 | 
            +
                img.save(buffered, format="PNG")
         | 
| 18 | 
            +
                return base64.b64encode(buffered.getvalue()).decode()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def resize_img(img: Image.Image, max_len: int, min_len: int) -> Image.Image:
         | 
| 22 | 
            +
                max_hw, min_hw = max(img.size), min(img.size)
         | 
| 23 | 
            +
                aspect_ratio = max_hw / min_hw
         | 
| 24 | 
            +
                # max_len, min_len = 800, 400
         | 
| 25 | 
            +
                shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
         | 
| 26 | 
            +
                longest_edge = int(shortest_edge * aspect_ratio)
         | 
| 27 | 
            +
                W, H = img.size
         | 
| 28 | 
            +
                if H > W:
         | 
| 29 | 
            +
                    H, W = longest_edge, shortest_edge
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    H, W = shortest_edge, longest_edge
         | 
| 32 | 
            +
                return img.resize((W, H))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            @dataclasses.dataclass
         | 
| 36 | 
            +
            class Conversation:
         | 
| 37 | 
            +
                """A class that keeps all conversation history."""
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                SYSTEM = "system"
         | 
| 40 | 
            +
                USER = "user"
         | 
| 41 | 
            +
                ASSISTANT = "assistant"
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                roles: List[str] = field(
         | 
| 44 | 
            +
                    default_factory=lambda: [
         | 
| 45 | 
            +
                        Conversation.SYSTEM,
         | 
| 46 | 
            +
                        Conversation.USER,
         | 
| 47 | 
            +
                        Conversation.ASSISTANT,
         | 
| 48 | 
            +
                    ]
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
                mandatory_system_message = "我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
         | 
| 51 | 
            +
                system_message: str = "请尽可能详细地回答用户的问题。"
         | 
| 52 | 
            +
                messages: List[Dict[str, Any]] = field(default_factory=lambda: [])
         | 
| 53 | 
            +
                max_image_limit: int = 4
         | 
| 54 | 
            +
                skip_next: bool = False
         | 
| 55 | 
            +
                streaming_placeholder: str = "▌"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def get_system_message(self):
         | 
| 58 | 
            +
                    return self.mandatory_system_message + "\n\n" + self.system_message
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def set_system_message(self, system_message: str):
         | 
| 61 | 
            +
                    self.system_message = system_message
         | 
| 62 | 
            +
                    return self
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def get_prompt(self, inlude_image=False):
         | 
| 65 | 
            +
                    send_messages = [{"role": "system", "content": self.get_system_message()}]
         | 
| 66 | 
            +
                    # send_messages = []
         | 
| 67 | 
            +
                    for message in self.messages:
         | 
| 68 | 
            +
                        if message["role"] == self.USER:
         | 
| 69 | 
            +
                            user_message = {
         | 
| 70 | 
            +
                                "role": self.USER,
         | 
| 71 | 
            +
                                "content": message["content"],
         | 
| 72 | 
            +
                            }
         | 
| 73 | 
            +
                            if inlude_image and "image" in message:
         | 
| 74 | 
            +
                                user_message["image"] = []
         | 
| 75 | 
            +
                                for image in message["image"]:
         | 
| 76 | 
            +
                                    user_message["image"].append(pil2base64(image))
         | 
| 77 | 
            +
                            send_messages.append(user_message)
         | 
| 78 | 
            +
                        elif message["role"] == self.ASSISTANT:
         | 
| 79 | 
            +
                            send_messages.append(
         | 
| 80 | 
            +
                                {"role": self.ASSISTANT, "content": message["content"]}
         | 
| 81 | 
            +
                            )
         | 
| 82 | 
            +
                        elif message["role"] == self.SYSTEM:
         | 
| 83 | 
            +
                            send_messages.append(
         | 
| 84 | 
            +
                                {
         | 
| 85 | 
            +
                                    "role": self.SYSTEM,
         | 
| 86 | 
            +
                                    "content": message["content"],
         | 
| 87 | 
            +
                                }
         | 
| 88 | 
            +
                            )
         | 
| 89 | 
            +
                        else:
         | 
| 90 | 
            +
                            raise ValueError(f"Invalid role: {message['role']}")
         | 
| 91 | 
            +
                    return send_messages
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def append_message(
         | 
| 94 | 
            +
                    self,
         | 
| 95 | 
            +
                    role,
         | 
| 96 | 
            +
                    content,
         | 
| 97 | 
            +
                    image_list=None,
         | 
| 98 | 
            +
                ):
         | 
| 99 | 
            +
                    self.messages.append(
         | 
| 100 | 
            +
                        {
         | 
| 101 | 
            +
                            "role": role,
         | 
| 102 | 
            +
                            "content": content,
         | 
| 103 | 
            +
                            "image": [] if image_list is None else image_list,
         | 
| 104 | 
            +
                            # "filenames": save_filenames,
         | 
| 105 | 
            +
                        }
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def get_images(
         | 
| 109 | 
            +
                    self,
         | 
| 110 | 
            +
                    return_copy=False,
         | 
| 111 | 
            +
                    return_base64=False,
         | 
| 112 | 
            +
                    source: Union[str, None] = None,
         | 
| 113 | 
            +
                ):
         | 
| 114 | 
            +
                    assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {soure}"
         | 
| 115 | 
            +
                    images = []
         | 
| 116 | 
            +
                    for i, msg in enumerate(self.messages):
         | 
| 117 | 
            +
                        if source and msg["role"] != source:
         | 
| 118 | 
            +
                            continue
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                        for image in msg.get("image", []):
         | 
| 121 | 
            +
                            # org_image = [i.copy() for i in image]
         | 
| 122 | 
            +
                            if return_copy:
         | 
| 123 | 
            +
                                image = image.copy()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                            if return_base64:
         | 
| 126 | 
            +
                                image = pil2base64(image)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                            images.append(image)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    return images
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def to_gradio_chatbot(self):
         | 
| 133 | 
            +
                    ret = []
         | 
| 134 | 
            +
                    for i, msg in enumerate(self.messages):
         | 
| 135 | 
            +
                        if msg["role"] == self.SYSTEM:
         | 
| 136 | 
            +
                            continue
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        alt_str = (
         | 
| 139 | 
            +
                            "user upload image" if msg["role"] == self.USER else "output image"
         | 
| 140 | 
            +
                        )
         | 
| 141 | 
            +
                        image = msg.get("image", [])
         | 
| 142 | 
            +
                        if not isinstance(image, list):
         | 
| 143 | 
            +
                            images = [image]
         | 
| 144 | 
            +
                        else:
         | 
| 145 | 
            +
                            images = image
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        img_str_list = []
         | 
| 148 | 
            +
                        for i in range(len(images)):
         | 
| 149 | 
            +
                            image = resize_img(
         | 
| 150 | 
            +
                                images[i],
         | 
| 151 | 
            +
                                400,
         | 
| 152 | 
            +
                                800,
         | 
| 153 | 
            +
                            )
         | 
| 154 | 
            +
                            img_b64_str = pil2base64(image)
         | 
| 155 | 
            +
                            W, H = image.size
         | 
| 156 | 
            +
                            img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" style="width: {W}px; max-width:none; max-height:none"></img>'
         | 
| 157 | 
            +
                            img_str = (
         | 
| 158 | 
            +
                                f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" />'
         | 
| 159 | 
            +
                            )
         | 
| 160 | 
            +
                            img_str_list.append(img_str)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        if msg["role"] == self.USER:
         | 
| 163 | 
            +
                            msg_str = " ".join(img_str_list) + msg["content"]
         | 
| 164 | 
            +
                            ret.append([msg_str, None])
         | 
| 165 | 
            +
                        else:
         | 
| 166 | 
            +
                            msg_str = msg["content"] + " ".join(img_str_list)
         | 
| 167 | 
            +
                            ret[-1][-1] = msg_str
         | 
| 168 | 
            +
                    return ret
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def update_message(self, role, content, image=None, idx=-1):
         | 
| 171 | 
            +
                    assert len(self.messages) > 0, "No message in the conversation."
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    idx = (idx + len(self.messages)) % len(self.messages)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    assert (
         | 
| 176 | 
            +
                        self.messages[idx]["role"] == role
         | 
| 177 | 
            +
                    ), f"Role mismatch: {role} vs {self.messages[idx]['role']}"
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    self.messages[idx]["content"] = content
         | 
| 180 | 
            +
                    if image is not None:
         | 
| 181 | 
            +
                        if image not in self.messages[idx]["image"]:
         | 
| 182 | 
            +
                            self.messages[idx]["image"] = []
         | 
| 183 | 
            +
                        if not isinstance(image, list):
         | 
| 184 | 
            +
                            image = [image]
         | 
| 185 | 
            +
                        self.messages[idx]["image"].extend(image)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def return_last_message(self):
         | 
| 188 | 
            +
                    return self.messages[-1]["content"]
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def end_of_current_turn(self):
         | 
| 191 | 
            +
                    assert len(self.messages) > 0, "No message in the conversation."
         | 
| 192 | 
            +
                    assert (
         | 
| 193 | 
            +
                        self.messages[-1]["role"] == self.ASSISTANT
         | 
| 194 | 
            +
                    ), f"It should end with the message from assistant instead of {self.messages[-1]['role']}."
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    if self.messages[-1]["content"][-1] != self.streaming_placeholder:
         | 
| 197 | 
            +
                        return
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    self.update_message(self.ASSISTANT, self.messages[-1]["content"][:-1], None)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def copy(self):
         | 
| 202 | 
            +
                    return Conversation(
         | 
| 203 | 
            +
                        mandatory_system_message=self.mandatory_system_message,
         | 
| 204 | 
            +
                        system_message=self.system_message,
         | 
| 205 | 
            +
                        roles=copy.deepcopy(self.roles),
         | 
| 206 | 
            +
                        messages=copy.deepcopy(self.messages),
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def dict(self):
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    all_images = state.get_images()
         | 
| 212 | 
            +
                    all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
         | 
| 213 | 
            +
                    t = datetime.datetime.now()
         | 
| 214 | 
            +
                    for image, hash in zip(all_images, all_image_hash):
         | 
| 215 | 
            +
                        filename = os.path.join(
         | 
| 216 | 
            +
                            LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
         | 
| 217 | 
            +
                        )
         | 
| 218 | 
            +
                        if not os.path.isfile(filename):
         | 
| 219 | 
            +
                            os.makedirs(os.path.dirname(filename), exist_ok=True)
         | 
| 220 | 
            +
                            image.save(filename)
         | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
                    messages = []
         | 
| 223 | 
            +
                    for message in self.messages:
         | 
| 224 | 
            +
                        images = []
         | 
| 225 | 
            +
                        for image in message.get("image", []):
         | 
| 226 | 
            +
                            filename = self.save_image(image)
         | 
| 227 | 
            +
                            images.append(filename)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        messages.append(
         | 
| 230 | 
            +
                            {
         | 
| 231 | 
            +
                                "role": message["role"],
         | 
| 232 | 
            +
                                "content": message["content"],
         | 
| 233 | 
            +
                                "image": images,
         | 
| 234 | 
            +
                            }
         | 
| 235 | 
            +
                        )
         | 
| 236 | 
            +
                        if len(images) == 0:
         | 
| 237 | 
            +
                            messages[-1].pop("image")
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    return {
         | 
| 240 | 
            +
                        "mandatory_system_message": self.mandatory_system_message,
         | 
| 241 | 
            +
                        "system_message": self.system_message,
         | 
| 242 | 
            +
                        "roles": self.roles,
         | 
| 243 | 
            +
                        "messages": messages,
         | 
| 244 | 
            +
                    }
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                def save_image(self, image: Image.Image) -> str:
         | 
| 247 | 
            +
                    t = datetime.datetime.now()
         | 
| 248 | 
            +
                    image_hash = hashlib.md5(image.tobytes()).hexdigest()
         | 
| 249 | 
            +
                    filename = os.path.join(
         | 
| 250 | 
            +
                        LOGDIR,
         | 
| 251 | 
            +
                        "serve_images",
         | 
| 252 | 
            +
                        f"{t.year}-{t.month:02d}-{t.day:02d}",
         | 
| 253 | 
            +
                        f"{image_hash}.jpg",
         | 
| 254 | 
            +
                    )
         | 
| 255 | 
            +
                    if not os.path.isfile(filename):
         | 
| 256 | 
            +
                        os.makedirs(os.path.dirname(filename), exist_ok=True)
         | 
| 257 | 
            +
                        image.save(filename)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    return filename
         | 
    	
        gallery/child_1.jpg
    ADDED
    
    |   | 
    	
        gallery/child_2.jpg
    ADDED
    
    |   | 
    	
        gallery/child_3.jpg
    ADDED
    
    |   | 
    	
        gradio_web_server.py
    ADDED
    
    | @@ -0,0 +1,824 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            from ast import parse
         | 
| 3 | 
            +
            import datetime
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
            import hashlib
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import gradio as gr
         | 
| 11 | 
            +
            import requests
         | 
| 12 | 
            +
            import random
         | 
| 13 | 
            +
            from filelock import FileLock
         | 
| 14 | 
            +
            from io import BytesIO
         | 
| 15 | 
            +
            from PIL import Image, ImageDraw, ImageFont
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from constants import LOGDIR
         | 
| 18 | 
            +
            from utils import (
         | 
| 19 | 
            +
                build_logger,
         | 
| 20 | 
            +
                server_error_msg,
         | 
| 21 | 
            +
                violates_moderation,
         | 
| 22 | 
            +
                moderation_msg,
         | 
| 23 | 
            +
                load_image_from_base64,
         | 
| 24 | 
            +
                get_log_filename,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
            from conversation import Conversation
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            logger = build_logger("gradio_web_server", "gradio_web_server.log")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            headers = {"User-Agent": "InternVL-Chat Client"}
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            no_change_btn = gr.Button()
         | 
| 33 | 
            +
            enable_btn = gr.Button(interactive=True)
         | 
| 34 | 
            +
            disable_btn = gr.Button(interactive=False)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def write2file(path, content):
         | 
| 38 | 
            +
                lock = FileLock(f"{path}.lock")
         | 
| 39 | 
            +
                with lock:
         | 
| 40 | 
            +
                    with open(path, "a") as fout:
         | 
| 41 | 
            +
                        fout.write(content)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def sort_models(models):
         | 
| 45 | 
            +
                def custom_sort_key(model_name):
         | 
| 46 | 
            +
                    # InternVL-Chat-V1-5 should be the first item
         | 
| 47 | 
            +
                    if model_name == "InternVL-Chat-V1-5":
         | 
| 48 | 
            +
                        return (1, model_name)  # 1 indicates highest precedence
         | 
| 49 | 
            +
                    elif model_name.startswith("InternVL-Chat-V1-5-"):
         | 
| 50 | 
            +
                        return (1, model_name)  # 1 indicates highest precedence
         | 
| 51 | 
            +
                    else:
         | 
| 52 | 
            +
                        return (0, model_name)  # 0 indicates normal order
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                models.sort(key=custom_sort_key, reverse=True)
         | 
| 55 | 
            +
                try:  # We have five InternVL-Chat-V1-5 models, randomly choose one to be the first
         | 
| 56 | 
            +
                    first_three = models[:4]
         | 
| 57 | 
            +
                    random.shuffle(first_three)
         | 
| 58 | 
            +
                    models[:4] = first_three
         | 
| 59 | 
            +
                except:
         | 
| 60 | 
            +
                    pass
         | 
| 61 | 
            +
                return models
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_model_list():
         | 
| 65 | 
            +
                ret = requests.post(args.controller_url + "/refresh_all_workers")
         | 
| 66 | 
            +
                assert ret.status_code == 200
         | 
| 67 | 
            +
                ret = requests.post(args.controller_url + "/list_models")
         | 
| 68 | 
            +
                models = ret.json()["models"]
         | 
| 69 | 
            +
                models = sort_models(models)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                logger.info(f"Models: {models}")
         | 
| 72 | 
            +
                return models
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            get_window_url_params = """
         | 
| 76 | 
            +
            function() {
         | 
| 77 | 
            +
                const params = new URLSearchParams(window.location.search);
         | 
| 78 | 
            +
                url_params = Object.fromEntries(params);
         | 
| 79 | 
            +
                console.log(url_params);
         | 
| 80 | 
            +
                return url_params;
         | 
| 81 | 
            +
                }
         | 
| 82 | 
            +
            """
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def init_state(state=None):
         | 
| 86 | 
            +
                if state is not None:
         | 
| 87 | 
            +
                    del state
         | 
| 88 | 
            +
                return Conversation()
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def find_bounding_boxes(state, response):
         | 
| 92 | 
            +
                pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
         | 
| 93 | 
            +
                matches = pattern.findall(response)
         | 
| 94 | 
            +
                results = []
         | 
| 95 | 
            +
                for match in matches:
         | 
| 96 | 
            +
                    results.append((match[0], eval(match[1])))
         | 
| 97 | 
            +
                returned_image = None
         | 
| 98 | 
            +
                latest_image = state.get_images(source=state.USER)[-1]
         | 
| 99 | 
            +
                returned_image = latest_image.copy()
         | 
| 100 | 
            +
                width, height = returned_image.size
         | 
| 101 | 
            +
                draw = ImageDraw.Draw(returned_image)
         | 
| 102 | 
            +
                for result in results:
         | 
| 103 | 
            +
                    line_width = max(1, int(min(width, height) / 200))
         | 
| 104 | 
            +
                    random_color = (
         | 
| 105 | 
            +
                        random.randint(0, 128),
         | 
| 106 | 
            +
                        random.randint(0, 128),
         | 
| 107 | 
            +
                        random.randint(0, 128),
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
                    category_name, coordinates = result
         | 
| 110 | 
            +
                    coordinates = [
         | 
| 111 | 
            +
                        (
         | 
| 112 | 
            +
                            float(x[0]) / 1000,
         | 
| 113 | 
            +
                            float(x[1]) / 1000,
         | 
| 114 | 
            +
                            float(x[2]) / 1000,
         | 
| 115 | 
            +
                            float(x[3]) / 1000,
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                        for x in coordinates
         | 
| 118 | 
            +
                    ]
         | 
| 119 | 
            +
                    coordinates = [
         | 
| 120 | 
            +
                        (
         | 
| 121 | 
            +
                            int(x[0] * width),
         | 
| 122 | 
            +
                            int(x[1] * height),
         | 
| 123 | 
            +
                            int(x[2] * width),
         | 
| 124 | 
            +
                            int(x[3] * height),
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
                        for x in coordinates
         | 
| 127 | 
            +
                    ]
         | 
| 128 | 
            +
                    for box in coordinates:
         | 
| 129 | 
            +
                        draw.rectangle(box, outline=random_color, width=line_width)
         | 
| 130 | 
            +
                        font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
         | 
| 131 | 
            +
                        text_size = font.getbbox(category_name)
         | 
| 132 | 
            +
                        text_width, text_height = (
         | 
| 133 | 
            +
                            text_size[2] - text_size[0],
         | 
| 134 | 
            +
                            text_size[3] - text_size[1],
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        text_position = (box[0], max(0, box[1] - text_height))
         | 
| 137 | 
            +
                        draw.rectangle(
         | 
| 138 | 
            +
                            [
         | 
| 139 | 
            +
                                text_position,
         | 
| 140 | 
            +
                                (text_position[0] + text_width, text_position[1] + text_height),
         | 
| 141 | 
            +
                            ],
         | 
| 142 | 
            +
                            fill=random_color,
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
            +
                        draw.text(text_position, category_name, fill="white", font=font)
         | 
| 145 | 
            +
                return returned_image if len(matches) > 0 else None
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def query_image_generation(response, sd_worker_url, timeout=15):
         | 
| 149 | 
            +
                if not sd_worker_url:
         | 
| 150 | 
            +
                    return None
         | 
| 151 | 
            +
                sd_worker_url = f"{sd_worker_url}/generate_image/"
         | 
| 152 | 
            +
                pattern = r"```drawing-instruction\n(.*?)\n```"
         | 
| 153 | 
            +
                match = re.search(pattern, response, re.DOTALL)
         | 
| 154 | 
            +
                if match:
         | 
| 155 | 
            +
                    payload = {"caption": match.group(1)}
         | 
| 156 | 
            +
                    print("drawing-instruction:", payload)
         | 
| 157 | 
            +
                    response = requests.post(sd_worker_url, json=payload, timeout=timeout)
         | 
| 158 | 
            +
                    response.raise_for_status()  # 检查HTTP请求是否成功
         | 
| 159 | 
            +
                    image = Image.open(BytesIO(response.content))
         | 
| 160 | 
            +
                    return image
         | 
| 161 | 
            +
                else:
         | 
| 162 | 
            +
                    return None
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def load_demo(url_params, request: gr.Request):
         | 
| 166 | 
            +
                logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                dropdown_update = gr.Dropdown(visible=True)
         | 
| 169 | 
            +
                if "model" in url_params:
         | 
| 170 | 
            +
                    model = url_params["model"]
         | 
| 171 | 
            +
                    if model in models:
         | 
| 172 | 
            +
                        dropdown_update = gr.Dropdown(value=model, visible=True)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                state = init_state()
         | 
| 175 | 
            +
                return state, dropdown_update
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            def load_demo_refresh_model_list(request: gr.Request):
         | 
| 179 | 
            +
                logger.info(f"load_demo. ip: {request.client.host}")
         | 
| 180 | 
            +
                models = get_model_list()
         | 
| 181 | 
            +
                state = init_state()
         | 
| 182 | 
            +
                dropdown_update = gr.Dropdown(
         | 
| 183 | 
            +
                    choices=models, value=models[0] if len(models) > 0 else ""
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                return state, dropdown_update
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def vote_last_response(state, liked, model_selector, request: gr.Request):
         | 
| 189 | 
            +
                conv_data = {
         | 
| 190 | 
            +
                    "tstamp": round(time.time(), 4),
         | 
| 191 | 
            +
                    "like": liked,
         | 
| 192 | 
            +
                    "model": model_selector,
         | 
| 193 | 
            +
                    "state": state.dict(),
         | 
| 194 | 
            +
                    "ip": request.client.host,
         | 
| 195 | 
            +
                }
         | 
| 196 | 
            +
                write2file(get_log_filename(), json.dumps(conv_data) + "\n")
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def upvote_last_response(state, model_selector, request: gr.Request):
         | 
| 200 | 
            +
                logger.info(f"upvote. ip: {request.client.host}")
         | 
| 201 | 
            +
                vote_last_response(state, True, model_selector, request)
         | 
| 202 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=True)
         | 
| 203 | 
            +
                return (textbox,) + (disable_btn,) * 3
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def downvote_last_response(state, model_selector, request: gr.Request):
         | 
| 207 | 
            +
                logger.info(f"downvote. ip: {request.client.host}")
         | 
| 208 | 
            +
                vote_last_response(state, False, model_selector, request)
         | 
| 209 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=True)
         | 
| 210 | 
            +
                return (textbox,) + (disable_btn,) * 3
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def vote_selected_response(
         | 
| 214 | 
            +
                state, model_selector, request: gr.Request, data: gr.LikeData
         | 
| 215 | 
            +
            ):
         | 
| 216 | 
            +
                logger.info(
         | 
| 217 | 
            +
                    f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
         | 
| 218 | 
            +
                )
         | 
| 219 | 
            +
                conv_data = {
         | 
| 220 | 
            +
                    "tstamp": round(time.time(), 4),
         | 
| 221 | 
            +
                    "like": data.liked,
         | 
| 222 | 
            +
                    "index": data.index,
         | 
| 223 | 
            +
                    "model": model_selector,
         | 
| 224 | 
            +
                    "state": state.dict(),
         | 
| 225 | 
            +
                    "ip": request.client.host,
         | 
| 226 | 
            +
                }
         | 
| 227 | 
            +
                write2file(get_log_filename(), json.dumps(conv_data) + "\n")
         | 
| 228 | 
            +
                return
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def flag_last_response(state, model_selector, request: gr.Request):
         | 
| 232 | 
            +
                logger.info(f"flag. ip: {request.client.host}")
         | 
| 233 | 
            +
                vote_last_response(state, "flag", model_selector, request)
         | 
| 234 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=True)
         | 
| 235 | 
            +
                return (textbox,) + (disable_btn,) * 3
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            def regenerate(state, image_process_mode, request: gr.Request):
         | 
| 239 | 
            +
                logger.info(f"regenerate. ip: {request.client.host}")
         | 
| 240 | 
            +
                # state.messages[-1][-1] = None
         | 
| 241 | 
            +
                state.update_message(Conversation.ASSISTANT, None, -1)
         | 
| 242 | 
            +
                prev_human_msg = state.messages[-2]
         | 
| 243 | 
            +
                if type(prev_human_msg[1]) in (tuple, list):
         | 
| 244 | 
            +
                    prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
         | 
| 245 | 
            +
                state.skip_next = False
         | 
| 246 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=True)
         | 
| 247 | 
            +
                return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
            def clear_history(request: gr.Request):
         | 
| 251 | 
            +
                logger.info(f"clear_history. ip: {request.client.host}")
         | 
| 252 | 
            +
                state = init_state()
         | 
| 253 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=True)
         | 
| 254 | 
            +
                return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def change_system_prompt(state, system_prompt, request: gr.Request):
         | 
| 258 | 
            +
                logger.info(f"Change system prompt. ip: {request.client.host}")
         | 
| 259 | 
            +
                state.set_system_message(system_prompt)
         | 
| 260 | 
            +
                return state
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def add_text(state, message, system_prompt, request: gr.Request):
         | 
| 264 | 
            +
                images = message.get("files", [])
         | 
| 265 | 
            +
                text = message.get("text", "").strip()
         | 
| 266 | 
            +
                logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
         | 
| 267 | 
            +
                # import pdb; pdb.set_trace()
         | 
| 268 | 
            +
                textbox = gr.MultimodalTextbox(value=None, interactive=False)
         | 
| 269 | 
            +
                if len(text) <= 0 and len(images) == 0:
         | 
| 270 | 
            +
                    state.skip_next = True
         | 
| 271 | 
            +
                    return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
         | 
| 272 | 
            +
                if args.moderate:
         | 
| 273 | 
            +
                    flagged = violates_moderation(text)
         | 
| 274 | 
            +
                    if flagged:
         | 
| 275 | 
            +
                        state.skip_next = True
         | 
| 276 | 
            +
                        textbox = gr.MultimodalTextbox(
         | 
| 277 | 
            +
                            value={"text": moderation_msg}, interactive=True
         | 
| 278 | 
            +
                        )
         | 
| 279 | 
            +
                        return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
         | 
| 280 | 
            +
                images = [Image.open(path).convert("RGB") for path in images]
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
         | 
| 283 | 
            +
                    state = init_state(state)
         | 
| 284 | 
            +
                state.set_system_message(system_prompt)
         | 
| 285 | 
            +
                state.append_message(Conversation.USER, text, images)
         | 
| 286 | 
            +
                state.skip_next = False
         | 
| 287 | 
            +
                return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
         | 
| 288 | 
            +
             | 
| 289 | 
            +
             | 
| 290 | 
            +
            def http_bot(
         | 
| 291 | 
            +
                state,
         | 
| 292 | 
            +
                model_selector,
         | 
| 293 | 
            +
                temperature,
         | 
| 294 | 
            +
                top_p,
         | 
| 295 | 
            +
                repetition_penalty,
         | 
| 296 | 
            +
                max_new_tokens,
         | 
| 297 | 
            +
                max_input_tiles,
         | 
| 298 | 
            +
                # bbox_threshold,
         | 
| 299 | 
            +
                # mask_threshold,
         | 
| 300 | 
            +
                request: gr.Request,
         | 
| 301 | 
            +
            ):
         | 
| 302 | 
            +
                logger.info(f"http_bot. ip: {request.client.host}")
         | 
| 303 | 
            +
                start_tstamp = time.time()
         | 
| 304 | 
            +
                model_name = model_selector
         | 
| 305 | 
            +
                if hasattr(state, "skip_next") and state.skip_next:
         | 
| 306 | 
            +
                    # This generate call is skipped due to invalid inputs
         | 
| 307 | 
            +
                    yield (
         | 
| 308 | 
            +
                        state,
         | 
| 309 | 
            +
                        state.to_gradio_chatbot(),
         | 
| 310 | 
            +
                        gr.MultimodalTextbox(interactive=False),
         | 
| 311 | 
            +
                    ) + (no_change_btn,) * 5
         | 
| 312 | 
            +
                    return
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                # Query worker address
         | 
| 315 | 
            +
                controller_url = args.controller_url
         | 
| 316 | 
            +
                ret = requests.post(
         | 
| 317 | 
            +
                    controller_url + "/get_worker_address", json={"model": model_name}
         | 
| 318 | 
            +
                )
         | 
| 319 | 
            +
                worker_addr = ret.json()["address"]
         | 
| 320 | 
            +
                logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                # No available worker
         | 
| 323 | 
            +
                if worker_addr == "":
         | 
| 324 | 
            +
                    # state.messages[-1][-1] = server_error_msg
         | 
| 325 | 
            +
                    state.update_message(Conversation.ASSISTANT, server_error_msg)
         | 
| 326 | 
            +
                    yield (
         | 
| 327 | 
            +
                        state,
         | 
| 328 | 
            +
                        state.to_gradio_chatbot(),
         | 
| 329 | 
            +
                        gr.MultimodalTextbox(interactive=False),
         | 
| 330 | 
            +
                        disable_btn,
         | 
| 331 | 
            +
                        disable_btn,
         | 
| 332 | 
            +
                        disable_btn,
         | 
| 333 | 
            +
                        enable_btn,
         | 
| 334 | 
            +
                        enable_btn,
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
                    return
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                all_images = state.get_images(source=state.USER)
         | 
| 339 | 
            +
                all_image_paths = [state.save_image(image) for image in all_images]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                # Make requests
         | 
| 342 | 
            +
                pload = {
         | 
| 343 | 
            +
                    "model": model_name,
         | 
| 344 | 
            +
                    "prompt": state.get_prompt(),
         | 
| 345 | 
            +
                    "temperature": float(temperature),
         | 
| 346 | 
            +
                    "top_p": float(top_p),
         | 
| 347 | 
            +
                    "max_new_tokens": max_new_tokens,
         | 
| 348 | 
            +
                    "max_input_tiles": max_input_tiles,
         | 
| 349 | 
            +
                    # "bbox_threshold": bbox_threshold,
         | 
| 350 | 
            +
                    # "mask_threshold": mask_threshold,
         | 
| 351 | 
            +
                    "repetition_penalty": repetition_penalty,
         | 
| 352 | 
            +
                    "images": f"List of {len(all_images)} images: {all_image_paths}",
         | 
| 353 | 
            +
                }
         | 
| 354 | 
            +
                logger.info(f"==== request ====\n{pload}")
         | 
| 355 | 
            +
                pload.pop("images")
         | 
| 356 | 
            +
                pload["prompt"] = state.get_prompt(inlude_image=True)
         | 
| 357 | 
            +
                state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
         | 
| 358 | 
            +
                yield (
         | 
| 359 | 
            +
                    state,
         | 
| 360 | 
            +
                    state.to_gradio_chatbot(),
         | 
| 361 | 
            +
                    gr.MultimodalTextbox(interactive=False),
         | 
| 362 | 
            +
                ) + (disable_btn,) * 5
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                try:
         | 
| 365 | 
            +
                    # Stream output
         | 
| 366 | 
            +
                    response = requests.post(
         | 
| 367 | 
            +
                        worker_addr + "/worker_generate_stream",
         | 
| 368 | 
            +
                        headers=headers,
         | 
| 369 | 
            +
                        json=pload,
         | 
| 370 | 
            +
                        stream=True,
         | 
| 371 | 
            +
                        timeout=20,
         | 
| 372 | 
            +
                    )
         | 
| 373 | 
            +
                    for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
         | 
| 374 | 
            +
                        if chunk:
         | 
| 375 | 
            +
                            data = json.loads(chunk.decode())
         | 
| 376 | 
            +
                            if data["error_code"] == 0:
         | 
| 377 | 
            +
                                if "text" in data:
         | 
| 378 | 
            +
                                    output = data["text"].strip()
         | 
| 379 | 
            +
                                    output += state.streaming_placeholder
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                                image = None
         | 
| 382 | 
            +
                                if "image" in data:
         | 
| 383 | 
            +
                                    image = load_image_from_base64(data["image"])
         | 
| 384 | 
            +
                                    _ = state.save_image(image)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                                state.update_message(Conversation.ASSISTANT, output, image)
         | 
| 387 | 
            +
                                yield (
         | 
| 388 | 
            +
                                    state,
         | 
| 389 | 
            +
                                    state.to_gradio_chatbot(),
         | 
| 390 | 
            +
                                    gr.MultimodalTextbox(interactive=False),
         | 
| 391 | 
            +
                                ) + (disable_btn,) * 5
         | 
| 392 | 
            +
                            else:
         | 
| 393 | 
            +
                                output = (
         | 
| 394 | 
            +
                                    f"**{data['text']}**" + f" (error_code: {data['error_code']})"
         | 
| 395 | 
            +
                                )
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                                state.update_message(Conversation.ASSISTANT, output, None)
         | 
| 398 | 
            +
                                yield (
         | 
| 399 | 
            +
                                    state,
         | 
| 400 | 
            +
                                    state.to_gradio_chatbot(),
         | 
| 401 | 
            +
                                    gr.MultimodalTextbox(interactive=True),
         | 
| 402 | 
            +
                                ) + (
         | 
| 403 | 
            +
                                    disable_btn,
         | 
| 404 | 
            +
                                    disable_btn,
         | 
| 405 | 
            +
                                    disable_btn,
         | 
| 406 | 
            +
                                    enable_btn,
         | 
| 407 | 
            +
                                    enable_btn,
         | 
| 408 | 
            +
                                )
         | 
| 409 | 
            +
                                return
         | 
| 410 | 
            +
                except requests.exceptions.RequestException as e:
         | 
| 411 | 
            +
                    state.update_message(Conversation.ASSISTANT, server_error_msg, None)
         | 
| 412 | 
            +
                    yield (
         | 
| 413 | 
            +
                        state,
         | 
| 414 | 
            +
                        state.to_gradio_chatbot(),
         | 
| 415 | 
            +
                        gr.MultimodalTextbox(interactive=True),
         | 
| 416 | 
            +
                    ) + (
         | 
| 417 | 
            +
                        disable_btn,
         | 
| 418 | 
            +
                        disable_btn,
         | 
| 419 | 
            +
                        disable_btn,
         | 
| 420 | 
            +
                        enable_btn,
         | 
| 421 | 
            +
                        enable_btn,
         | 
| 422 | 
            +
                    )
         | 
| 423 | 
            +
                    return
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                ai_response = state.return_last_message()
         | 
| 426 | 
            +
                if "<ref>" in ai_response:
         | 
| 427 | 
            +
                    returned_image = find_bounding_boxes(state, ai_response)
         | 
| 428 | 
            +
                    returned_image = [returned_image] if returned_image else []
         | 
| 429 | 
            +
                    state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
         | 
| 430 | 
            +
                if "```drawing-instruction" in ai_response:
         | 
| 431 | 
            +
                    returned_image = query_image_generation(
         | 
| 432 | 
            +
                        ai_response, sd_worker_url=sd_worker_url
         | 
| 433 | 
            +
                    )
         | 
| 434 | 
            +
                    returned_image = [returned_image] if returned_image else []
         | 
| 435 | 
            +
                    state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                state.end_of_current_turn()
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                yield (
         | 
| 440 | 
            +
                    state,
         | 
| 441 | 
            +
                    state.to_gradio_chatbot(),
         | 
| 442 | 
            +
                    gr.MultimodalTextbox(interactive=True),
         | 
| 443 | 
            +
                ) + (enable_btn,) * 5
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                finish_tstamp = time.time()
         | 
| 446 | 
            +
                logger.info(f"{output}")
         | 
| 447 | 
            +
                data = {
         | 
| 448 | 
            +
                    "tstamp": round(finish_tstamp, 4),
         | 
| 449 | 
            +
                    "like": None,
         | 
| 450 | 
            +
                    "model": model_name,
         | 
| 451 | 
            +
                    "start": round(start_tstamp, 4),
         | 
| 452 | 
            +
                    "finish": round(start_tstamp, 4),
         | 
| 453 | 
            +
                    "state": state.dict(),
         | 
| 454 | 
            +
                    "images": all_image_paths,
         | 
| 455 | 
            +
                    "ip": request.client.host,
         | 
| 456 | 
            +
                }
         | 
| 457 | 
            +
                write2file(get_log_filename(), json.dumps(data) + "\n")
         | 
| 458 | 
            +
             | 
| 459 | 
            +
             | 
| 460 | 
            +
            title_html = """
         | 
| 461 | 
            +
            <h2> <span class="gradient-text" id="text">InternVL2</span><span class="plain-text">: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy</span></h2>
         | 
| 462 | 
            +
            <a href="https://internvl.github.io/blog/2024-07-02-InternVL-2.0/">[📜 InternVL2 Blog]</a> 
         | 
| 463 | 
            +
            <a href="https://huggingface.co/spaces/OpenGVLab/InternVL">[🤗 HF Demo]</a> 
         | 
| 464 | 
            +
            <a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a> 
         | 
| 465 | 
            +
            <a href="https://github.com/OpenGVLab/InternVL/blob/main/document/How_to_use_InternVL_API.md">[🌐 API]</a> 
         | 
| 466 | 
            +
            """
         | 
| 467 | 
            +
             | 
| 468 | 
            +
            tos_markdown = """
         | 
| 469 | 
            +
            ### Terms of use
         | 
| 470 | 
            +
            By using this service, users are required to agree to the following terms:
         | 
| 471 | 
            +
            The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
         | 
| 472 | 
            +
            Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
         | 
| 473 | 
            +
            For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
         | 
| 474 | 
            +
            """
         | 
| 475 | 
            +
             | 
| 476 | 
            +
             | 
| 477 | 
            +
            learn_more_markdown = """
         | 
| 478 | 
            +
            ### License
         | 
| 479 | 
            +
            The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
         | 
| 480 | 
            +
             | 
| 481 | 
            +
            ### Acknowledgement
         | 
| 482 | 
            +
            This demo is modified from LLaVA's demo. Thanks for their awesome work!
         | 
| 483 | 
            +
            """
         | 
| 484 | 
            +
            # .gradio-container {margin: 5px 10px 0 10px !important};
         | 
| 485 | 
            +
            block_css = """
         | 
| 486 | 
            +
            .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
         | 
| 487 | 
            +
            #buttons button {
         | 
| 488 | 
            +
                min-width: min(120px,100%);
         | 
| 489 | 
            +
            }
         | 
| 490 | 
            +
             | 
| 491 | 
            +
            .gradient-text {
         | 
| 492 | 
            +
                font-size: 28px;
         | 
| 493 | 
            +
                width: auto;
         | 
| 494 | 
            +
                font-weight: bold;
         | 
| 495 | 
            +
                background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
         | 
| 496 | 
            +
                background-clip: text;
         | 
| 497 | 
            +
                -webkit-background-clip: text;
         | 
| 498 | 
            +
                color: transparent;
         | 
| 499 | 
            +
            }
         | 
| 500 | 
            +
             | 
| 501 | 
            +
            .plain-text {
         | 
| 502 | 
            +
                font-size: 22px;
         | 
| 503 | 
            +
                width: auto;
         | 
| 504 | 
            +
                font-weight: bold;
         | 
| 505 | 
            +
            }
         | 
| 506 | 
            +
            """
         | 
| 507 | 
            +
             | 
| 508 | 
            +
            js = """
         | 
| 509 | 
            +
            function createWaveAnimation() {
         | 
| 510 | 
            +
                const text = document.getElementById('text');
         | 
| 511 | 
            +
                var i = 0;
         | 
| 512 | 
            +
                setInterval(function() {
         | 
| 513 | 
            +
                    const colors = [
         | 
| 514 | 
            +
                        'red, orange, yellow, green, blue, indigo, violet, purple',
         | 
| 515 | 
            +
                        'orange, yellow, green, blue, indigo, violet, purple, red',
         | 
| 516 | 
            +
                        'yellow, green, blue, indigo, violet, purple, red, orange',
         | 
| 517 | 
            +
                        'green, blue, indigo, violet, purple, red, orange, yellow',
         | 
| 518 | 
            +
                        'blue, indigo, violet, purple, red, orange, yellow, green',
         | 
| 519 | 
            +
                        'indigo, violet, purple, red, orange, yellow, green, blue',
         | 
| 520 | 
            +
                        'violet, purple, red, orange, yellow, green, blue, indigo',
         | 
| 521 | 
            +
                        'purple, red, orange, yellow, green, blue, indigo, violet',
         | 
| 522 | 
            +
                    ];
         | 
| 523 | 
            +
                    const angle = 45;
         | 
| 524 | 
            +
                    const colorIndex = i % colors.length;
         | 
| 525 | 
            +
                    text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`;
         | 
| 526 | 
            +
                    text.style.webkitBackgroundClip = 'text';
         | 
| 527 | 
            +
                    text.style.backgroundClip = 'text';
         | 
| 528 | 
            +
                    text.style.color = 'transparent';
         | 
| 529 | 
            +
                    text.style.fontSize = '28px';
         | 
| 530 | 
            +
                    text.style.width = 'auto';
         | 
| 531 | 
            +
                    text.textContent = 'InternVL2';
         | 
| 532 | 
            +
                    text.style.fontWeight = 'bold';
         | 
| 533 | 
            +
                    i += 1;
         | 
| 534 | 
            +
                }, 200);
         | 
| 535 | 
            +
                const params = new URLSearchParams(window.location.search);
         | 
| 536 | 
            +
                url_params = Object.fromEntries(params);
         | 
| 537 | 
            +
                console.log(url_params);
         | 
| 538 | 
            +
                return url_params;
         | 
| 539 | 
            +
            }
         | 
| 540 | 
            +
             | 
| 541 | 
            +
            """
         | 
| 542 | 
            +
             | 
| 543 | 
            +
             | 
| 544 | 
            +
            def build_demo(embed_mode):
         | 
| 545 | 
            +
                textbox = gr.MultimodalTextbox(
         | 
| 546 | 
            +
                    interactive=True,
         | 
| 547 | 
            +
                    file_types=["image", "video"],
         | 
| 548 | 
            +
                    placeholder="Enter message or upload file...",
         | 
| 549 | 
            +
                    show_label=False,
         | 
| 550 | 
            +
                )
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                with gr.Blocks(
         | 
| 553 | 
            +
                    title="InternVL-Chat",
         | 
| 554 | 
            +
                    theme=gr.themes.Default(),
         | 
| 555 | 
            +
                    css=block_css,
         | 
| 556 | 
            +
                ) as demo:
         | 
| 557 | 
            +
                    state = gr.State()
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    if not embed_mode:
         | 
| 560 | 
            +
                        # gr.Markdown(title_markdown)
         | 
| 561 | 
            +
                        gr.HTML(title_html)
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    with gr.Row():
         | 
| 564 | 
            +
                        with gr.Column(scale=2):
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                            with gr.Row(elem_id="model_selector_row"):
         | 
| 567 | 
            +
                                model_selector = gr.Dropdown(
         | 
| 568 | 
            +
                                    choices=models,
         | 
| 569 | 
            +
                                    value=models[0] if len(models) > 0 else "",
         | 
| 570 | 
            +
                                    # value="InternVL-Chat-V1-5",
         | 
| 571 | 
            +
                                    interactive=True,
         | 
| 572 | 
            +
                                    show_label=False,
         | 
| 573 | 
            +
                                    container=False,
         | 
| 574 | 
            +
                                )
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                            with gr.Accordion("System Prompt", open=False) as system_prompt_row:
         | 
| 577 | 
            +
                                system_prompt = gr.Textbox(
         | 
| 578 | 
            +
                                    value="请尽可能详细地回答用户的问题。",
         | 
| 579 | 
            +
                                    label="System Prompt",
         | 
| 580 | 
            +
                                    interactive=True,
         | 
| 581 | 
            +
                                )
         | 
| 582 | 
            +
                            with gr.Accordion("Parameters", open=False) as parameter_row:
         | 
| 583 | 
            +
                                temperature = gr.Slider(
         | 
| 584 | 
            +
                                    minimum=0.0,
         | 
| 585 | 
            +
                                    maximum=1.0,
         | 
| 586 | 
            +
                                    value=0.2,
         | 
| 587 | 
            +
                                    step=0.1,
         | 
| 588 | 
            +
                                    interactive=True,
         | 
| 589 | 
            +
                                    label="Temperature",
         | 
| 590 | 
            +
                                )
         | 
| 591 | 
            +
                                top_p = gr.Slider(
         | 
| 592 | 
            +
                                    minimum=0.0,
         | 
| 593 | 
            +
                                    maximum=1.0,
         | 
| 594 | 
            +
                                    value=0.7,
         | 
| 595 | 
            +
                                    step=0.1,
         | 
| 596 | 
            +
                                    interactive=True,
         | 
| 597 | 
            +
                                    label="Top P",
         | 
| 598 | 
            +
                                )
         | 
| 599 | 
            +
                                repetition_penalty = gr.Slider(
         | 
| 600 | 
            +
                                    minimum=1.0,
         | 
| 601 | 
            +
                                    maximum=1.5,
         | 
| 602 | 
            +
                                    value=1.1,
         | 
| 603 | 
            +
                                    step=0.02,
         | 
| 604 | 
            +
                                    interactive=True,
         | 
| 605 | 
            +
                                    label="Repetition penalty",
         | 
| 606 | 
            +
                                )
         | 
| 607 | 
            +
                                max_output_tokens = gr.Slider(
         | 
| 608 | 
            +
                                    minimum=0,
         | 
| 609 | 
            +
                                    maximum=4096,
         | 
| 610 | 
            +
                                    value=1024,
         | 
| 611 | 
            +
                                    step=64,
         | 
| 612 | 
            +
                                    interactive=True,
         | 
| 613 | 
            +
                                    label="Max output tokens",
         | 
| 614 | 
            +
                                )
         | 
| 615 | 
            +
                                max_input_tiles = gr.Slider(
         | 
| 616 | 
            +
                                    minimum=1,
         | 
| 617 | 
            +
                                    maximum=32,
         | 
| 618 | 
            +
                                    value=12,
         | 
| 619 | 
            +
                                    step=1,
         | 
| 620 | 
            +
                                    interactive=True,
         | 
| 621 | 
            +
                                    label="Max input tiles (control the image size)",
         | 
| 622 | 
            +
                                )
         | 
| 623 | 
            +
                            examples = gr.Examples(
         | 
| 624 | 
            +
                                examples=[
         | 
| 625 | 
            +
                                    [
         | 
| 626 | 
            +
                                        {
         | 
| 627 | 
            +
                                            "files": [
         | 
| 628 | 
            +
                                                "gallery/prod_9.jpg",
         | 
| 629 | 
            +
                                            ],
         | 
| 630 | 
            +
                                            "text": "What's at the far end of the image?",
         | 
| 631 | 
            +
                                        }
         | 
| 632 | 
            +
                                    ],
         | 
| 633 | 
            +
                                    [
         | 
| 634 | 
            +
                                        {
         | 
| 635 | 
            +
                                            "files": [
         | 
| 636 | 
            +
                                                "gallery/astro_on_unicorn.png",
         | 
| 637 | 
            +
                                            ],
         | 
| 638 | 
            +
                                            "text": "What does this image mean?",
         | 
| 639 | 
            +
                                        }
         | 
| 640 | 
            +
                                    ],
         | 
| 641 | 
            +
                                    [
         | 
| 642 | 
            +
                                        {
         | 
| 643 | 
            +
                                            "files": [
         | 
| 644 | 
            +
                                                "gallery/prod_12.png",
         | 
| 645 | 
            +
                                            ],
         | 
| 646 | 
            +
                                            "text": "What are the consequences of the easy decisions shown in this image?",
         | 
| 647 | 
            +
                                        }
         | 
| 648 | 
            +
                                    ],
         | 
| 649 | 
            +
                                    [
         | 
| 650 | 
            +
                                        {
         | 
| 651 | 
            +
                                            "files": [
         | 
| 652 | 
            +
                                                "gallery/child_1.jpg",
         | 
| 653 | 
            +
                                                "gallery/child_2.jpg",
         | 
| 654 | 
            +
                                                f"gallery/child_3.jpg",
         | 
| 655 | 
            +
                                            ],
         | 
| 656 | 
            +
                                            "text": "这三帧图片讲述了一件什么事情?",
         | 
| 657 | 
            +
                                        }
         | 
| 658 | 
            +
                                    ],
         | 
| 659 | 
            +
                                ],
         | 
| 660 | 
            +
                                inputs=[textbox],
         | 
| 661 | 
            +
                            )
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                        with gr.Column(scale=8):
         | 
| 664 | 
            +
                            chatbot = gr.Chatbot(
         | 
| 665 | 
            +
                                elem_id="chatbot",
         | 
| 666 | 
            +
                                label="InternVL2",
         | 
| 667 | 
            +
                                height=580,
         | 
| 668 | 
            +
                                show_copy_button=True,
         | 
| 669 | 
            +
                                show_share_button=True,
         | 
| 670 | 
            +
                                avatar_images=[
         | 
| 671 | 
            +
                                    "assets/human.png",
         | 
| 672 | 
            +
                                    "assets/assistant.png",
         | 
| 673 | 
            +
                                ],
         | 
| 674 | 
            +
                                bubble_full_width=False,
         | 
| 675 | 
            +
                            )
         | 
| 676 | 
            +
                            with gr.Row():
         | 
| 677 | 
            +
                                with gr.Column(scale=8):
         | 
| 678 | 
            +
                                    textbox.render()
         | 
| 679 | 
            +
                                with gr.Column(scale=1, min_width=50):
         | 
| 680 | 
            +
                                    submit_btn = gr.Button(value="Send", variant="primary")
         | 
| 681 | 
            +
                            with gr.Row(elem_id="buttons") as button_row:
         | 
| 682 | 
            +
                                upvote_btn = gr.Button(value="👍  Upvote", interactive=False)
         | 
| 683 | 
            +
                                downvote_btn = gr.Button(value="👎  Downvote", interactive=False)
         | 
| 684 | 
            +
                                flag_btn = gr.Button(value="⚠️  Flag", interactive=False)
         | 
| 685 | 
            +
                                # stop_btn = gr.Button(value="⏹️  Stop Generation", interactive=False)
         | 
| 686 | 
            +
                                regenerate_btn = gr.Button(
         | 
| 687 | 
            +
                                    value="🔄  Regenerate", interactive=False
         | 
| 688 | 
            +
                                )
         | 
| 689 | 
            +
                                clear_btn = gr.Button(value="🗑️  Clear", interactive=False)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    if not embed_mode:
         | 
| 692 | 
            +
                        gr.Markdown(tos_markdown)
         | 
| 693 | 
            +
                        gr.Markdown(learn_more_markdown)
         | 
| 694 | 
            +
                    url_params = gr.JSON(visible=False)
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                    # Register listeners
         | 
| 697 | 
            +
                    btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
         | 
| 698 | 
            +
                    upvote_btn.click(
         | 
| 699 | 
            +
                        upvote_last_response,
         | 
| 700 | 
            +
                        [state, model_selector],
         | 
| 701 | 
            +
                        [textbox, upvote_btn, downvote_btn, flag_btn],
         | 
| 702 | 
            +
                    )
         | 
| 703 | 
            +
                    downvote_btn.click(
         | 
| 704 | 
            +
                        downvote_last_response,
         | 
| 705 | 
            +
                        [state, model_selector],
         | 
| 706 | 
            +
                        [textbox, upvote_btn, downvote_btn, flag_btn],
         | 
| 707 | 
            +
                    )
         | 
| 708 | 
            +
                    chatbot.like(
         | 
| 709 | 
            +
                        vote_selected_response,
         | 
| 710 | 
            +
                        [state, model_selector],
         | 
| 711 | 
            +
                        [],
         | 
| 712 | 
            +
                    )
         | 
| 713 | 
            +
                    flag_btn.click(
         | 
| 714 | 
            +
                        flag_last_response,
         | 
| 715 | 
            +
                        [state, model_selector],
         | 
| 716 | 
            +
                        [textbox, upvote_btn, downvote_btn, flag_btn],
         | 
| 717 | 
            +
                    )
         | 
| 718 | 
            +
                    regenerate_btn.click(
         | 
| 719 | 
            +
                        regenerate,
         | 
| 720 | 
            +
                        [state, system_prompt],
         | 
| 721 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 722 | 
            +
                    ).then(
         | 
| 723 | 
            +
                        http_bot,
         | 
| 724 | 
            +
                        [
         | 
| 725 | 
            +
                            state,
         | 
| 726 | 
            +
                            model_selector,
         | 
| 727 | 
            +
                            temperature,
         | 
| 728 | 
            +
                            top_p,
         | 
| 729 | 
            +
                            repetition_penalty,
         | 
| 730 | 
            +
                            max_output_tokens,
         | 
| 731 | 
            +
                            max_input_tiles,
         | 
| 732 | 
            +
                            # bbox_threshold,
         | 
| 733 | 
            +
                            # mask_threshold,
         | 
| 734 | 
            +
                        ],
         | 
| 735 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 736 | 
            +
                    )
         | 
| 737 | 
            +
                    clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    textbox.submit(
         | 
| 740 | 
            +
                        add_text,
         | 
| 741 | 
            +
                        [state, textbox, system_prompt],
         | 
| 742 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 743 | 
            +
                    ).then(
         | 
| 744 | 
            +
                        http_bot,
         | 
| 745 | 
            +
                        [
         | 
| 746 | 
            +
                            state,
         | 
| 747 | 
            +
                            model_selector,
         | 
| 748 | 
            +
                            temperature,
         | 
| 749 | 
            +
                            top_p,
         | 
| 750 | 
            +
                            repetition_penalty,
         | 
| 751 | 
            +
                            max_output_tokens,
         | 
| 752 | 
            +
                            max_input_tiles,
         | 
| 753 | 
            +
                            # bbox_threshold,
         | 
| 754 | 
            +
                            # mask_threshold,
         | 
| 755 | 
            +
                        ],
         | 
| 756 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 757 | 
            +
                    )
         | 
| 758 | 
            +
                    submit_btn.click(
         | 
| 759 | 
            +
                        add_text,
         | 
| 760 | 
            +
                        [state, textbox, system_prompt],
         | 
| 761 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 762 | 
            +
                    ).then(
         | 
| 763 | 
            +
                        http_bot,
         | 
| 764 | 
            +
                        [
         | 
| 765 | 
            +
                            state,
         | 
| 766 | 
            +
                            model_selector,
         | 
| 767 | 
            +
                            temperature,
         | 
| 768 | 
            +
                            top_p,
         | 
| 769 | 
            +
                            repetition_penalty,
         | 
| 770 | 
            +
                            max_output_tokens,
         | 
| 771 | 
            +
                            max_input_tiles,
         | 
| 772 | 
            +
                            # bbox_threshold,
         | 
| 773 | 
            +
                            # mask_threshold,
         | 
| 774 | 
            +
                        ],
         | 
| 775 | 
            +
                        [state, chatbot, textbox] + btn_list,
         | 
| 776 | 
            +
                    )
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                    if args.model_list_mode == "once":
         | 
| 779 | 
            +
                        demo.load(
         | 
| 780 | 
            +
                            load_demo,
         | 
| 781 | 
            +
                            [url_params],
         | 
| 782 | 
            +
                            [state, model_selector],
         | 
| 783 | 
            +
                            js=js,
         | 
| 784 | 
            +
                        )
         | 
| 785 | 
            +
                    elif args.model_list_mode == "reload":
         | 
| 786 | 
            +
                        demo.load(
         | 
| 787 | 
            +
                            load_demo_refresh_model_list,
         | 
| 788 | 
            +
                            None,
         | 
| 789 | 
            +
                            [state, model_selector],
         | 
| 790 | 
            +
                            js=js,
         | 
| 791 | 
            +
                        )
         | 
| 792 | 
            +
                    else:
         | 
| 793 | 
            +
                        raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                return demo
         | 
| 796 | 
            +
             | 
| 797 | 
            +
             | 
| 798 | 
            +
            if __name__ == "__main__":
         | 
| 799 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 800 | 
            +
                parser.add_argument("--host", type=str, default="0.0.0.0")
         | 
| 801 | 
            +
                parser.add_argument("--port", type=int, default=11000)
         | 
| 802 | 
            +
                parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
         | 
| 803 | 
            +
                parser.add_argument("--concurrency-count", type=int, default=10)
         | 
| 804 | 
            +
                parser.add_argument(
         | 
| 805 | 
            +
                    "--model-list-mode", type=str, default="once", choices=["once", "reload"]
         | 
| 806 | 
            +
                )
         | 
| 807 | 
            +
                parser.add_argument("--sd-worker-url", type=str, default=None)
         | 
| 808 | 
            +
                parser.add_argument("--share", action="store_true")
         | 
| 809 | 
            +
                parser.add_argument("--moderate", action="store_true")
         | 
| 810 | 
            +
                parser.add_argument("--embed", action="store_true")
         | 
| 811 | 
            +
                args = parser.parse_args()
         | 
| 812 | 
            +
                logger.info(f"args: {args}")
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                models = get_model_list()
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                sd_worker_url = args.sd_worker_url
         | 
| 817 | 
            +
                logger.info(args)
         | 
| 818 | 
            +
                demo = build_demo(args.embed)
         | 
| 819 | 
            +
                demo.queue(api_open=False).launch(
         | 
| 820 | 
            +
                    server_name=args.host,
         | 
| 821 | 
            +
                    server_port=args.port,
         | 
| 822 | 
            +
                    share=args.share,
         | 
| 823 | 
            +
                    max_threads=args.concurrency_count,
         | 
| 824 | 
            +
                )
         | 
    	
        library.py
    DELETED
    
    | @@ -1,95 +0,0 @@ | |
| 1 | 
            -
            # --------------------------------------------------------
         | 
| 2 | 
            -
            # InternVL
         | 
| 3 | 
            -
            # Copyright (c) 2024 OpenGVLab
         | 
| 4 | 
            -
            # Licensed under The MIT License [see LICENSE for details]
         | 
| 5 | 
            -
            # Modified from https://github.com/hreikin/streamlit-uploads-library/blob/main/streamlit_uploads_library/library.py
         | 
| 6 | 
            -
            # --------------------------------------------------------
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            import logging
         | 
| 9 | 
            -
            from math import ceil
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            import streamlit as st
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            logger = logging.getLogger(__name__)
         | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            class Library():
         | 
| 17 | 
            -
                """Create a simple library out of streamlit widgets.
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                Using the library is simple, import `streamlit_uploads_library` and then instantiate the class with the
         | 
| 20 | 
            -
                required `directory` variable. Other options can be configured by passing in different variables
         | 
| 21 | 
            -
                when instantiating the class.
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                Example Usage:
         | 
| 24 | 
            -
                    python
         | 
| 25 | 
            -
                    import streamlit as st
         | 
| 26 | 
            -
                    from library import Library
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                    st.set_page_config(page_title="Streamlit Uploads Library", layout="wide")
         | 
| 29 | 
            -
                    default_library = Library(images=pil_images)
         | 
| 30 | 
            -
                """
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                def __init__(self, images, image_alignment='end', number_of_columns=5):
         | 
| 33 | 
            -
                    self.images = images
         | 
| 34 | 
            -
                    self.image_alignment = image_alignment
         | 
| 35 | 
            -
                    self.number_of_columns = number_of_columns
         | 
| 36 | 
            -
                    self.root_container = self.create(images=self.images,
         | 
| 37 | 
            -
                                                      image_alignment=self.image_alignment,
         | 
| 38 | 
            -
                                                      number_of_columns=self.number_of_columns)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                def create(_self, images, image_alignment, number_of_columns):
         | 
| 41 | 
            -
                    """Creates a simple library or gallery with columns.
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                    Creates a library or gallery using columns out of streamlit widgets.
         | 
| 44 | 
            -
                    """
         | 
| 45 | 
            -
                    root_container = st.container()
         | 
| 46 | 
            -
                    with root_container:
         | 
| 47 | 
            -
                        # To be able to display the images, details and buttons all in one row and aligned
         | 
| 48 | 
            -
                        # correctly so that images of different sizes don't affect the alignment of the details
         | 
| 49 | 
            -
                        # and buttons we need do some minor maths and keep track of multiple index values.
         | 
| 50 | 
            -
                        # First we instantiate some defaults.
         | 
| 51 | 
            -
                        col_idx = 0
         | 
| 52 | 
            -
                        filename_idx = 0
         | 
| 53 | 
            -
                        max_idx = number_of_columns - 1
         | 
| 54 | 
            -
                        # Get the file list and filename list, work out the total number of files from the
         | 
| 55 | 
            -
                        # length of the file list.
         | 
| 56 | 
            -
                        library_files = images
         | 
| 57 | 
            -
                        num_of_files = len(library_files)
         | 
| 58 | 
            -
                        # Work out the number of rows required by dividing the number of files by the number of
         | 
| 59 | 
            -
                        # columns and rounding up using `math.ceil`.
         | 
| 60 | 
            -
                        num_of_rows_req = ceil(num_of_files / number_of_columns)
         | 
| 61 | 
            -
                        # Create the required number of rows (st.container).
         | 
| 62 | 
            -
                        library_rows = list()
         | 
| 63 | 
            -
                        library_rows_idx = 0
         | 
| 64 | 
            -
                        for i in range(num_of_rows_req):
         | 
| 65 | 
            -
                            library_rows.append(st.container())
         | 
| 66 | 
            -
                        # For each library row we need to create separate rows (st.container) for images,
         | 
| 67 | 
            -
                        # and rows (st.expander) for details and buttons to keep them in the correct columns.
         | 
| 68 | 
            -
                        for idx in range(num_of_rows_req):
         | 
| 69 | 
            -
                            with library_rows[library_rows_idx]:
         | 
| 70 | 
            -
                                imgs_columns = list(st.columns(number_of_columns))
         | 
| 71 | 
            -
                            # Since we are keeping track of the column and filename indexes we can use
         | 
| 72 | 
            -
                            # those to slice the `library_files` list at the correct points for each row
         | 
| 73 | 
            -
                            # and then increase or reset the indexes as required.
         | 
| 74 | 
            -
                            for img in library_files[filename_idx:(filename_idx + number_of_columns)]:
         | 
| 75 | 
            -
                                with imgs_columns[col_idx]:
         | 
| 76 | 
            -
                                    st.image(img, use_column_width='auto')
         | 
| 77 | 
            -
                                    st.write(
         | 
| 78 | 
            -
                                        f"""<style>
         | 
| 79 | 
            -
                                            [data-testid="stHorizontalBlock"] {{
         | 
| 80 | 
            -
                                                align-items: {image_alignment};
         | 
| 81 | 
            -
                                            }}
         | 
| 82 | 
            -
                                            </style>
         | 
| 83 | 
            -
                                            """,
         | 
| 84 | 
            -
                                        unsafe_allow_html=True
         | 
| 85 | 
            -
                                    )
         | 
| 86 | 
            -
                                # Keeps track of the current column, if we reach the `max_idx` we reset it
         | 
| 87 | 
            -
                                # to 0 and increase the row index. This combined with the slicing should
         | 
| 88 | 
            -
                                # ensure all images, details and buttons are in the correct columns.
         | 
| 89 | 
            -
                                if col_idx < max_idx:
         | 
| 90 | 
            -
                                    col_idx += 1
         | 
| 91 | 
            -
                                else:
         | 
| 92 | 
            -
                                    col_idx = 0
         | 
| 93 | 
            -
                                    library_rows_idx += 1
         | 
| 94 | 
            -
                                filename_idx += 1
         | 
| 95 | 
            -
                    return root_container
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        mm_utils.py
    DELETED
    
    | @@ -1,102 +0,0 @@ | |
| 1 | 
            -
            import base64
         | 
| 2 | 
            -
            from io import BytesIO
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            import torch
         | 
| 5 | 
            -
            from PIL import Image
         | 
| 6 | 
            -
            from transformers import StoppingCriteria
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            from .constants import IMAGE_TOKEN_INDEX
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            def load_image_from_base64(image):
         | 
| 12 | 
            -
                return Image.open(BytesIO(base64.b64decode(image)))
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            def expand2square(pil_img, background_color):
         | 
| 16 | 
            -
                width, height = pil_img.size
         | 
| 17 | 
            -
                if width == height:
         | 
| 18 | 
            -
                    return pil_img
         | 
| 19 | 
            -
                elif width > height:
         | 
| 20 | 
            -
                    result = Image.new(pil_img.mode, (width, width), background_color)
         | 
| 21 | 
            -
                    result.paste(pil_img, (0, (width - height) // 2))
         | 
| 22 | 
            -
                    return result
         | 
| 23 | 
            -
                else:
         | 
| 24 | 
            -
                    result = Image.new(pil_img.mode, (height, height), background_color)
         | 
| 25 | 
            -
                    result.paste(pil_img, ((height - width) // 2, 0))
         | 
| 26 | 
            -
                    return result
         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
            def process_images(images, image_processor, model_cfg):
         | 
| 30 | 
            -
                image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
         | 
| 31 | 
            -
                new_images = []
         | 
| 32 | 
            -
                if image_aspect_ratio == 'pad':
         | 
| 33 | 
            -
                    for image in images:
         | 
| 34 | 
            -
                        image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
         | 
| 35 | 
            -
                        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 36 | 
            -
                        new_images.append(image)
         | 
| 37 | 
            -
                else:
         | 
| 38 | 
            -
                    return image_processor(images, return_tensors='pt')['pixel_values']
         | 
| 39 | 
            -
                if all(x.shape == new_images[0].shape for x in new_images):
         | 
| 40 | 
            -
                    new_images = torch.stack(new_images, dim=0)
         | 
| 41 | 
            -
                return new_images
         | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
            def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
         | 
| 45 | 
            -
                                      num_image_tokens=None, return_tensors=None):
         | 
| 46 | 
            -
                prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                def insert_separator(X, sep):
         | 
| 49 | 
            -
                    return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                input_ids = []
         | 
| 52 | 
            -
                offset = 0
         | 
| 53 | 
            -
                if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
         | 
| 54 | 
            -
                    offset = 1
         | 
| 55 | 
            -
                    input_ids.append(prompt_chunks[0][0])
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)):
         | 
| 58 | 
            -
                    input_ids.extend(x[offset:])
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                if return_tensors is not None:
         | 
| 61 | 
            -
                    if return_tensors == 'pt':
         | 
| 62 | 
            -
                        return torch.tensor(input_ids, dtype=torch.long)
         | 
| 63 | 
            -
                    raise ValueError(f'Unsupported tensor type: {return_tensors}')
         | 
| 64 | 
            -
                return input_ids
         | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
            def get_model_name_from_path(model_path):
         | 
| 68 | 
            -
                model_path = model_path.strip('/')
         | 
| 69 | 
            -
                model_paths = model_path.split('/')
         | 
| 70 | 
            -
                if model_paths[-1].startswith('checkpoint-'):
         | 
| 71 | 
            -
                    return model_paths[-2] + '_' + model_paths[-1]
         | 
| 72 | 
            -
                else:
         | 
| 73 | 
            -
                    return model_paths[-1]
         | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
            class KeywordsStoppingCriteria(StoppingCriteria):
         | 
| 77 | 
            -
                def __init__(self, keywords, tokenizer, input_ids):
         | 
| 78 | 
            -
                    self.keywords = keywords
         | 
| 79 | 
            -
                    self.keyword_ids = []
         | 
| 80 | 
            -
                    self.max_keyword_len = 0
         | 
| 81 | 
            -
                    for keyword in keywords:
         | 
| 82 | 
            -
                        cur_keyword_ids = tokenizer(keyword).input_ids
         | 
| 83 | 
            -
                        if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
         | 
| 84 | 
            -
                            cur_keyword_ids = cur_keyword_ids[1:]
         | 
| 85 | 
            -
                        if len(cur_keyword_ids) > self.max_keyword_len:
         | 
| 86 | 
            -
                            self.max_keyword_len = len(cur_keyword_ids)
         | 
| 87 | 
            -
                        self.keyword_ids.append(torch.tensor(cur_keyword_ids))
         | 
| 88 | 
            -
                    self.tokenizer = tokenizer
         | 
| 89 | 
            -
                    self.start_len = input_ids.shape[1]
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 92 | 
            -
                    assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)'  # TODO
         | 
| 93 | 
            -
                    offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
         | 
| 94 | 
            -
                    self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
         | 
| 95 | 
            -
                    for keyword_id in self.keyword_ids:
         | 
| 96 | 
            -
                        if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
         | 
| 97 | 
            -
                            return True
         | 
| 98 | 
            -
                    outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
         | 
| 99 | 
            -
                    for keyword in self.keywords:
         | 
| 100 | 
            -
                        if keyword in outputs:
         | 
| 101 | 
            -
                            return True
         | 
| 102 | 
            -
                    return False
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        model_worker.py
    CHANGED
    
    | @@ -9,14 +9,15 @@ A model worker executes the model. | |
| 9 | 
             
            """
         | 
| 10 | 
             
            import argparse
         | 
| 11 | 
             
            import asyncio
         | 
| 12 | 
            -
             | 
| 13 | 
             
            import json
         | 
| 14 | 
            -
            import  | 
| 15 | 
             
            import threading
         | 
| 16 | 
             
            import time
         | 
| 17 | 
             
            import uuid
         | 
|  | |
| 18 | 
             
            from functools import partial
         | 
| 19 | 
            -
             | 
| 20 | 
             
            from threading import Thread
         | 
| 21 |  | 
| 22 | 
             
            import requests
         | 
| @@ -28,33 +29,36 @@ from fastapi import BackgroundTasks, FastAPI, Request | |
| 28 | 
             
            from fastapi.responses import StreamingResponse
         | 
| 29 | 
             
            from PIL import Image
         | 
| 30 | 
             
            from torchvision.transforms.functional import InterpolationMode
         | 
| 31 | 
            -
            from transformers import  | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 34 |  | 
| 35 | 
             
            worker_id = str(uuid.uuid4())[:6]
         | 
| 36 | 
            -
            logger = build_logger( | 
| 37 | 
             
            global_counter = 0
         | 
| 38 | 
             
            model_semaphore = None
         | 
| 39 |  | 
| 40 |  | 
| 41 | 
            -
            def load_image_from_base64(image):
         | 
| 42 | 
            -
                return Image.open(BytesIO(base64.b64decode(image)))
         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
             
            def build_transform(input_size):
         | 
| 46 | 
             
                MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
         | 
| 47 | 
            -
                transform = T.Compose( | 
| 48 | 
            -
                     | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
|  | |
|  | |
| 53 | 
             
                return transform
         | 
| 54 |  | 
| 55 |  | 
| 56 | 
             
            def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
         | 
| 57 | 
            -
                best_ratio_diff = float( | 
| 58 | 
             
                best_ratio = (1, 1)
         | 
| 59 | 
             
                area = width * height
         | 
| 60 | 
             
                for ratio in target_ratios:
         | 
| @@ -69,19 +73,26 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_ | |
| 69 | 
             
                return best_ratio
         | 
| 70 |  | 
| 71 |  | 
| 72 | 
            -
            def dynamic_preprocess( | 
|  | |
|  | |
| 73 | 
             
                orig_width, orig_height = image.size
         | 
| 74 | 
             
                aspect_ratio = orig_width / orig_height
         | 
| 75 |  | 
| 76 | 
             
                # calculate the existing image aspect ratio
         | 
| 77 | 
             
                target_ratios = set(
         | 
| 78 | 
            -
                    (i, j) | 
| 79 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 80 | 
             
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         | 
| 81 |  | 
| 82 | 
             
                # find the closest aspect ratio to the target
         | 
| 83 | 
             
                target_aspect_ratio = find_closest_aspect_ratio(
         | 
| 84 | 
            -
                    aspect_ratio, target_ratios, orig_width, orig_height, image_size | 
|  | |
| 85 |  | 
| 86 | 
             
                # calculate the target width and height
         | 
| 87 | 
             
                target_width = image_size * target_aspect_ratio[0]
         | 
| @@ -96,7 +107,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai | |
| 96 | 
             
                        (i % (target_width // image_size)) * image_size,
         | 
| 97 | 
             
                        (i // (target_width // image_size)) * image_size,
         | 
| 98 | 
             
                        ((i % (target_width // image_size)) + 1) * image_size,
         | 
| 99 | 
            -
                        ((i // (target_width // image_size)) + 1) * image_size
         | 
| 100 | 
             
                    )
         | 
| 101 | 
             
                    # split the image
         | 
| 102 | 
             
                    split_img = resized_img.crop(box)
         | 
| @@ -114,78 +125,163 @@ def heart_beat_worker(controller): | |
| 114 | 
             
                    controller.send_heart_beat()
         | 
| 115 |  | 
| 116 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 117 | 
             
            class ModelWorker:
         | 
| 118 | 
            -
                def __init__( | 
| 119 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 | 
             
                    self.controller_addr = controller_addr
         | 
| 121 | 
             
                    self.worker_addr = worker_addr
         | 
| 122 | 
             
                    self.worker_id = worker_id
         | 
| 123 | 
            -
                    if model_path.endswith( | 
| 124 | 
             
                        model_path = model_path[:-1]
         | 
| 125 | 
             
                    if model_name is None:
         | 
| 126 | 
            -
                        model_paths = model_path.split( | 
| 127 | 
            -
                        if model_paths[-1].startswith( | 
| 128 | 
            -
                            self.model_name = model_paths[-2] +  | 
| 129 | 
             
                        else:
         | 
| 130 | 
             
                            self.model_name = model_paths[-1]
         | 
| 131 | 
             
                    else:
         | 
| 132 | 
             
                        self.model_name = model_name
         | 
| 133 |  | 
| 134 | 
            -
                    logger.info(f | 
| 135 |  | 
| 136 | 
            -
                     | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 141 | 
             
                            model_path,
         | 
| 142 | 
             
                            load_in_8bit=load_8bit,
         | 
| 143 | 
            -
                            torch_dtype=torch. | 
| 144 | 
            -
                            device_map= | 
| 145 | 
            -
                            trust_remote_code=True | 
|  | |
| 146 | 
             
                    else:
         | 
| 147 | 
            -
                        self.model =  | 
| 148 | 
             
                            model_path,
         | 
| 149 | 
             
                            load_in_8bit=load_8bit,
         | 
| 150 | 
            -
                            torch_dtype=torch. | 
| 151 | 
            -
                            trust_remote_code=True | 
| 152 | 
            -
             | 
|  | |
| 153 | 
             
                        self.model = self.model.cuda()
         | 
|  | |
|  | |
|  | |
| 154 | 
             
                    self.image_size = self.model.config.force_image_size
         | 
| 155 | 
             
                    self.context_len = context_len
         | 
| 156 | 
             
                    self.register_to_controller()
         | 
| 157 | 
             
                    self.heart_beat_thread = threading.Thread(
         | 
| 158 | 
            -
                        target=heart_beat_worker, args=(self,) | 
|  | |
| 159 | 
             
                    self.heart_beat_thread.start()
         | 
| 160 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 161 | 
             
                def register_to_controller(self):
         | 
| 162 | 
            -
                    logger.info( | 
| 163 |  | 
| 164 | 
            -
                    url = self.controller_addr +  | 
| 165 | 
             
                    data = {
         | 
| 166 | 
            -
                         | 
| 167 | 
            -
                         | 
| 168 | 
            -
                         | 
| 169 | 
             
                    }
         | 
| 170 | 
             
                    r = requests.post(url, json=data)
         | 
| 171 | 
             
                    assert r.status_code == 200
         | 
| 172 |  | 
| 173 | 
             
                def send_heart_beat(self):
         | 
| 174 | 
            -
                    logger.info( | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
|  | |
|  | |
| 177 |  | 
| 178 | 
            -
                    url = self.controller_addr +  | 
| 179 |  | 
| 180 | 
             
                    while True:
         | 
| 181 | 
             
                        try:
         | 
| 182 | 
            -
                            ret = requests.post( | 
| 183 | 
            -
                                 | 
| 184 | 
            -
                                 | 
| 185 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 186 | 
             
                            break
         | 
| 187 | 
             
                        except requests.exceptions.RequestException as e:
         | 
| 188 | 
            -
                            logger.error(f | 
| 189 | 
             
                        time.sleep(5)
         | 
| 190 |  | 
| 191 | 
             
                    if not exist:
         | 
| @@ -195,80 +291,115 @@ class ModelWorker: | |
| 195 | 
             
                    if model_semaphore is None:
         | 
| 196 | 
             
                        return 0
         | 
| 197 | 
             
                    else:
         | 
| 198 | 
            -
                        return  | 
| 199 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 200 |  | 
| 201 | 
             
                def get_status(self):
         | 
| 202 | 
             
                    return {
         | 
| 203 | 
            -
                         | 
| 204 | 
            -
                         | 
| 205 | 
            -
                         | 
| 206 | 
             
                    }
         | 
| 207 |  | 
|  | |
| 208 | 
             
                @torch.inference_mode()
         | 
| 209 | 
             
                def generate_stream(self, params):
         | 
| 210 | 
            -
                    system_message = params[ | 
| 211 | 
            -
                    send_messages = params[ | 
| 212 | 
            -
                    max_input_tiles = params[ | 
| 213 | 
            -
                    temperature = params[ | 
| 214 | 
            -
                    top_p = params[ | 
| 215 | 
            -
                    max_new_tokens = params[ | 
| 216 | 
            -
                    repetition_penalty = params[ | 
| 217 | 
             
                    do_sample = True if temperature > 0.0 else False
         | 
| 218 |  | 
| 219 | 
            -
                    global_image_cnt =  | 
| 220 | 
             
                    history, pil_images, max_input_tile_list = [], [], []
         | 
| 221 | 
             
                    for message in send_messages:
         | 
| 222 | 
            -
                        if message[ | 
| 223 | 
            -
                            prefix =  | 
| 224 | 
            -
                            if  | 
| 225 | 
             
                                max_input_tile_temp = []
         | 
| 226 | 
            -
                                for image_str in message[ | 
| 227 | 
             
                                    pil_images.append(load_image_from_base64(image_str))
         | 
| 228 | 
            -
                                    prefix += f | 
| 229 | 
             
                                    global_image_cnt += 1
         | 
| 230 | 
            -
                                    max_input_tile_temp.append( | 
|  | |
|  | |
| 231 | 
             
                                if len(max_input_tile_temp) > 0:
         | 
| 232 | 
             
                                    max_input_tile_list.append(max_input_tile_temp)
         | 
| 233 | 
            -
                            content = prefix + message[ | 
| 234 | 
            -
                            history.append( | 
|  | |
|  | |
|  | |
|  | |
| 235 | 
             
                        else:
         | 
| 236 | 
            -
                            history[-1].append(message[ | 
| 237 | 
             
                    question, history = history[-1][0], history[:-1]
         | 
| 238 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 239 | 
             
                    # Create a new list to store processed sublists
         | 
| 240 | 
             
                    flattened_list = []
         | 
| 241 | 
             
                    # Iterate through all but the last sublist in max_input_tile_list and process them
         | 
| 242 | 
             
                    for sublist in max_input_tile_list[:-1]:
         | 
| 243 | 
            -
                        processed_sublist = [1] * len( | 
| 244 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 245 | 
             
                    # If max_input_tile_list is not empty, add the last sublist to the new list
         | 
| 246 | 
             
                    if max_input_tile_list:
         | 
| 247 | 
             
                        flattened_list.extend(max_input_tile_list[-1])
         | 
| 248 | 
             
                    max_input_tile_list = flattened_list
         | 
| 249 | 
            -
                    assert len(max_input_tile_list) == len( | 
| 250 | 
            -
             | 
|  | |
| 251 |  | 
| 252 | 
             
                    old_system_message = self.model.system_message
         | 
| 253 | 
             
                    self.model.system_message = system_message
         | 
| 254 | 
             
                    image_tiles = []
         | 
| 255 | 
             
                    transform = build_transform(input_size=self.image_size)
         | 
| 256 | 
             
                    if len(pil_images) > 0:
         | 
| 257 | 
            -
                        for current_max_input_tiles, pil_image in zip( | 
|  | |
|  | |
| 258 | 
             
                            if self.model.config.dynamic_image_size:
         | 
| 259 | 
             
                                tiles = dynamic_preprocess(
         | 
| 260 | 
            -
                                    pil_image, | 
| 261 | 
            -
                                     | 
|  | |
|  | |
|  | |
| 262 | 
             
                            else:
         | 
| 263 | 
             
                                tiles = [pil_image]
         | 
| 264 | 
             
                            image_tiles += tiles
         | 
| 265 | 
             
                        pixel_values = [transform(item) for item in image_tiles]
         | 
| 266 | 
            -
                        pixel_values = torch.stack(pixel_values).to( | 
| 267 | 
            -
             | 
|  | |
|  | |
| 268 | 
             
                    else:
         | 
| 269 | 
             
                        pixel_values = None
         | 
| 270 |  | 
| 271 | 
            -
                    streamer = TextIteratorStreamer( | 
|  | |
|  | |
| 272 | 
             
                    generation_config = dict(
         | 
| 273 | 
             
                        num_beams=1,
         | 
| 274 | 
             
                        max_new_tokens=max_new_tokens,
         | 
| @@ -279,53 +410,61 @@ class ModelWorker: | |
| 279 | 
             
                        top_p=top_p,
         | 
| 280 | 
             
                        streamer=streamer,
         | 
| 281 | 
             
                    )
         | 
| 282 | 
            -
                    logger.info( | 
| 283 | 
            -
             | 
| 284 | 
            -
                     | 
| 285 | 
            -
                         | 
|  | |
| 286 | 
             
                            tokenizer=self.tokenizer,
         | 
| 287 | 
             
                            pixel_values=pixel_values,
         | 
| 288 | 
             
                            question=question,
         | 
| 289 | 
             
                            history=history,
         | 
| 290 | 
             
                            return_history=False,
         | 
| 291 | 
             
                            generation_config=generation_config,
         | 
| 292 | 
            -
                        ) | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
             | 
| 302 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 303 |  | 
| 304 | 
             
                def generate_stream_gate(self, params):
         | 
| 305 | 
             
                    try:
         | 
| 306 | 
             
                        for x in self.generate_stream(params):
         | 
| 307 | 
             
                            yield x
         | 
| 308 | 
             
                    except ValueError as e:
         | 
| 309 | 
            -
                        print( | 
|  | |
| 310 | 
             
                        ret = {
         | 
| 311 | 
            -
                             | 
| 312 | 
            -
                             | 
| 313 | 
             
                        }
         | 
| 314 | 
            -
                        yield json.dumps(ret).encode() + b | 
| 315 | 
             
                    except torch.cuda.CudaError as e:
         | 
| 316 | 
            -
                         | 
|  | |
| 317 | 
             
                        ret = {
         | 
| 318 | 
            -
                             | 
| 319 | 
            -
                             | 
| 320 | 
             
                        }
         | 
| 321 | 
            -
                        yield json.dumps(ret).encode() + b | 
| 322 | 
             
                    except Exception as e:
         | 
| 323 | 
            -
                         | 
|  | |
| 324 | 
             
                        ret = {
         | 
| 325 | 
            -
                             | 
| 326 | 
            -
                             | 
| 327 | 
             
                        }
         | 
| 328 | 
            -
                        yield json.dumps(ret).encode() + b | 
| 329 |  | 
| 330 |  | 
| 331 | 
             
            app = FastAPI()
         | 
| @@ -337,7 +476,7 @@ def release_model_semaphore(fn=None): | |
| 337 | 
             
                    fn()
         | 
| 338 |  | 
| 339 |  | 
| 340 | 
            -
            @app.post( | 
| 341 | 
             
            async def generate_stream(request: Request):
         | 
| 342 | 
             
                global model_semaphore, global_counter
         | 
| 343 | 
             
                global_counter += 1
         | 
| @@ -349,35 +488,39 @@ async def generate_stream(request: Request): | |
| 349 | 
             
                worker.send_heart_beat()
         | 
| 350 | 
             
                generator = worker.generate_stream_gate(params)
         | 
| 351 | 
             
                background_tasks = BackgroundTasks()
         | 
| 352 | 
            -
                background_tasks.add_task( | 
|  | |
|  | |
| 353 | 
             
                return StreamingResponse(generator, background=background_tasks)
         | 
| 354 |  | 
| 355 |  | 
| 356 | 
            -
            @app.post( | 
| 357 | 
             
            async def get_status(request: Request):
         | 
| 358 | 
             
                return worker.get_status()
         | 
| 359 |  | 
| 360 |  | 
| 361 | 
            -
            if __name__ ==  | 
| 362 | 
             
                parser = argparse.ArgumentParser()
         | 
| 363 | 
            -
                parser.add_argument( | 
| 364 | 
            -
                parser.add_argument( | 
| 365 | 
            -
                parser.add_argument( | 
| 366 | 
            -
                parser.add_argument( | 
| 367 | 
            -
                parser.add_argument( | 
| 368 | 
            -
                parser.add_argument( | 
| 369 | 
            -
                parser.add_argument( | 
| 370 | 
            -
                parser.add_argument( | 
| 371 | 
            -
                parser.add_argument( | 
| 372 | 
            -
                parser.add_argument( | 
| 373 | 
             
                args = parser.parse_args()
         | 
| 374 | 
            -
                logger.info(f | 
| 375 | 
            -
             | 
| 376 | 
            -
                worker = ModelWorker( | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
|  | |
|  | 
|  | |
| 9 | 
             
            """
         | 
| 10 | 
             
            import argparse
         | 
| 11 | 
             
            import asyncio
         | 
| 12 | 
            +
             | 
| 13 | 
             
            import json
         | 
| 14 | 
            +
            import math
         | 
| 15 | 
             
            import threading
         | 
| 16 | 
             
            import time
         | 
| 17 | 
             
            import uuid
         | 
| 18 | 
            +
            import traceback
         | 
| 19 | 
             
            from functools import partial
         | 
| 20 | 
            +
             | 
| 21 | 
             
            from threading import Thread
         | 
| 22 |  | 
| 23 | 
             
            import requests
         | 
|  | |
| 29 | 
             
            from fastapi.responses import StreamingResponse
         | 
| 30 | 
             
            from PIL import Image
         | 
| 31 | 
             
            from torchvision.transforms.functional import InterpolationMode
         | 
| 32 | 
            +
            from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
         | 
| 33 | 
            +
            from utils import (
         | 
| 34 | 
            +
                build_logger,
         | 
| 35 | 
            +
                pretty_print_semaphore,
         | 
| 36 | 
            +
                server_error_msg,
         | 
| 37 | 
            +
                load_image_from_base64,
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            import spaces
         | 
| 40 |  | 
| 41 | 
             
            worker_id = str(uuid.uuid4())[:6]
         | 
| 42 | 
            +
            logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
         | 
| 43 | 
             
            global_counter = 0
         | 
| 44 | 
             
            model_semaphore = None
         | 
| 45 |  | 
| 46 |  | 
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
            def build_transform(input_size):
         | 
| 48 | 
             
                MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
         | 
| 49 | 
            +
                transform = T.Compose(
         | 
| 50 | 
            +
                    [
         | 
| 51 | 
            +
                        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
         | 
| 52 | 
            +
                        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 53 | 
            +
                        T.ToTensor(),
         | 
| 54 | 
            +
                        T.Normalize(mean=MEAN, std=STD),
         | 
| 55 | 
            +
                    ]
         | 
| 56 | 
            +
                )
         | 
| 57 | 
             
                return transform
         | 
| 58 |  | 
| 59 |  | 
| 60 | 
             
            def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
         | 
| 61 | 
            +
                best_ratio_diff = float("inf")
         | 
| 62 | 
             
                best_ratio = (1, 1)
         | 
| 63 | 
             
                area = width * height
         | 
| 64 | 
             
                for ratio in target_ratios:
         | 
|  | |
| 73 | 
             
                return best_ratio
         | 
| 74 |  | 
| 75 |  | 
| 76 | 
            +
            def dynamic_preprocess(
         | 
| 77 | 
            +
                image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
         | 
| 78 | 
            +
            ):
         | 
| 79 | 
             
                orig_width, orig_height = image.size
         | 
| 80 | 
             
                aspect_ratio = orig_width / orig_height
         | 
| 81 |  | 
| 82 | 
             
                # calculate the existing image aspect ratio
         | 
| 83 | 
             
                target_ratios = set(
         | 
| 84 | 
            +
                    (i, j)
         | 
| 85 | 
            +
                    for n in range(min_num, max_num + 1)
         | 
| 86 | 
            +
                    for i in range(1, n + 1)
         | 
| 87 | 
            +
                    for j in range(1, n + 1)
         | 
| 88 | 
            +
                    if i * j <= max_num and i * j >= min_num
         | 
| 89 | 
            +
                )
         | 
| 90 | 
             
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         | 
| 91 |  | 
| 92 | 
             
                # find the closest aspect ratio to the target
         | 
| 93 | 
             
                target_aspect_ratio = find_closest_aspect_ratio(
         | 
| 94 | 
            +
                    aspect_ratio, target_ratios, orig_width, orig_height, image_size
         | 
| 95 | 
            +
                )
         | 
| 96 |  | 
| 97 | 
             
                # calculate the target width and height
         | 
| 98 | 
             
                target_width = image_size * target_aspect_ratio[0]
         | 
|  | |
| 107 | 
             
                        (i % (target_width // image_size)) * image_size,
         | 
| 108 | 
             
                        (i // (target_width // image_size)) * image_size,
         | 
| 109 | 
             
                        ((i % (target_width // image_size)) + 1) * image_size,
         | 
| 110 | 
            +
                        ((i // (target_width // image_size)) + 1) * image_size,
         | 
| 111 | 
             
                    )
         | 
| 112 | 
             
                    # split the image
         | 
| 113 | 
             
                    split_img = resized_img.crop(box)
         | 
|  | |
| 125 | 
             
                    controller.send_heart_beat()
         | 
| 126 |  | 
| 127 |  | 
| 128 | 
            +
            def split_model(model_name):
         | 
| 129 | 
            +
                device_map = {}
         | 
| 130 | 
            +
                world_size = torch.cuda.device_count()
         | 
| 131 | 
            +
                num_layers = {
         | 
| 132 | 
            +
                    "InternVL2-8B": 32,
         | 
| 133 | 
            +
                    "InternVL2-26B": 48,
         | 
| 134 | 
            +
                    "InternVL2-40B": 60,
         | 
| 135 | 
            +
                    "InternVL2-Llama3-76B": 80,
         | 
| 136 | 
            +
                    "InternVL2-78B": 80,
         | 
| 137 | 
            +
                    "InternVL2-Pro": 80,
         | 
| 138 | 
            +
                }[model_name]
         | 
| 139 | 
            +
                # Since the first GPU will be used for ViT, treat it as half a GPU.
         | 
| 140 | 
            +
                num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
         | 
| 141 | 
            +
                num_layers_per_gpu = [num_layers_per_gpu] * world_size
         | 
| 142 | 
            +
                num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
         | 
| 143 | 
            +
                layer_cnt = 0
         | 
| 144 | 
            +
                for i, num_layer in enumerate(num_layers_per_gpu):
         | 
| 145 | 
            +
                    for j in range(num_layer):
         | 
| 146 | 
            +
                        device_map[f"language_model.model.layers.{layer_cnt}"] = i
         | 
| 147 | 
            +
                        layer_cnt += 1
         | 
| 148 | 
            +
                device_map["vision_model"] = 0
         | 
| 149 | 
            +
                device_map["mlp1"] = 0
         | 
| 150 | 
            +
                device_map["language_model.model.tok_embeddings"] = 0
         | 
| 151 | 
            +
                device_map["language_model.model.embed_tokens"] = 0
         | 
| 152 | 
            +
                device_map["language_model.output"] = 0
         | 
| 153 | 
            +
                device_map["language_model.model.norm"] = 0
         | 
| 154 | 
            +
                device_map["language_model.lm_head"] = 0
         | 
| 155 | 
            +
                device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                return device_map
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
             
            class ModelWorker:
         | 
| 161 | 
            +
                def __init__(
         | 
| 162 | 
            +
                    self,
         | 
| 163 | 
            +
                    controller_addr,
         | 
| 164 | 
            +
                    worker_addr,
         | 
| 165 | 
            +
                    worker_id,
         | 
| 166 | 
            +
                    model_path,
         | 
| 167 | 
            +
                    model_name,
         | 
| 168 | 
            +
                    load_8bit,
         | 
| 169 | 
            +
                    device,
         | 
| 170 | 
            +
                    context_len=8192,
         | 
| 171 | 
            +
                ):
         | 
| 172 | 
             
                    self.controller_addr = controller_addr
         | 
| 173 | 
             
                    self.worker_addr = worker_addr
         | 
| 174 | 
             
                    self.worker_id = worker_id
         | 
| 175 | 
            +
                    if model_path.endswith("/"):
         | 
| 176 | 
             
                        model_path = model_path[:-1]
         | 
| 177 | 
             
                    if model_name is None:
         | 
| 178 | 
            +
                        model_paths = model_path.split("/")
         | 
| 179 | 
            +
                        if model_paths[-1].startswith("checkpoint-"):
         | 
| 180 | 
            +
                            self.model_name = model_paths[-2] + "_" + model_paths[-1]
         | 
| 181 | 
             
                        else:
         | 
| 182 | 
             
                            self.model_name = model_paths[-1]
         | 
| 183 | 
             
                    else:
         | 
| 184 | 
             
                        self.model_name = model_name
         | 
| 185 |  | 
| 186 | 
            +
                    logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
         | 
| 187 |  | 
| 188 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained(
         | 
| 189 | 
            +
                        model_path, trust_remote_code=True, use_fast=False
         | 
| 190 | 
            +
                    )
         | 
| 191 | 
            +
                    tokens_to_keep = ["<box>", "</box>", "<ref>", "</ref>"]
         | 
| 192 | 
            +
                    tokenizer.additional_special_tokens = [
         | 
| 193 | 
            +
                        item
         | 
| 194 | 
            +
                        for item in tokenizer.additional_special_tokens
         | 
| 195 | 
            +
                        if item not in tokens_to_keep
         | 
| 196 | 
            +
                    ]
         | 
| 197 | 
            +
                    self.tokenizer = tokenizer
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if device == "auto":
         | 
| 200 | 
            +
                        device_map = split_model(self.model_name)
         | 
| 201 | 
            +
                        self.model = AutoModel.from_pretrained(
         | 
| 202 | 
             
                            model_path,
         | 
| 203 | 
             
                            load_in_8bit=load_8bit,
         | 
| 204 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 205 | 
            +
                            device_map=device_map,
         | 
| 206 | 
            +
                            trust_remote_code=True,
         | 
| 207 | 
            +
                        ).eval()
         | 
| 208 | 
             
                    else:
         | 
| 209 | 
            +
                        self.model = AutoModel.from_pretrained(
         | 
| 210 | 
             
                            model_path,
         | 
| 211 | 
             
                            load_in_8bit=load_8bit,
         | 
| 212 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 213 | 
            +
                            trust_remote_code=True,
         | 
| 214 | 
            +
                        ).eval()
         | 
| 215 | 
            +
                    if not load_8bit and not device == "auto":
         | 
| 216 | 
             
                        self.model = self.model.cuda()
         | 
| 217 | 
            +
                    self.load_8bit = load_8bit
         | 
| 218 | 
            +
                    self.device = device
         | 
| 219 | 
            +
                    self.model_path = model_path
         | 
| 220 | 
             
                    self.image_size = self.model.config.force_image_size
         | 
| 221 | 
             
                    self.context_len = context_len
         | 
| 222 | 
             
                    self.register_to_controller()
         | 
| 223 | 
             
                    self.heart_beat_thread = threading.Thread(
         | 
| 224 | 
            +
                        target=heart_beat_worker, args=(self,)
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
             
                    self.heart_beat_thread.start()
         | 
| 227 |  | 
| 228 | 
            +
                def reload_model(self):
         | 
| 229 | 
            +
                    del self.model
         | 
| 230 | 
            +
                    torch.cuda.empty_cache()
         | 
| 231 | 
            +
                    if self.device == "auto":
         | 
| 232 | 
            +
                        device_map = split_model(self.model_name)
         | 
| 233 | 
            +
                        self.model = AutoModel.from_pretrained(
         | 
| 234 | 
            +
                            self.model_path,
         | 
| 235 | 
            +
                            load_in_8bit=self.load_8bit,
         | 
| 236 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 237 | 
            +
                            device_map=device_map,
         | 
| 238 | 
            +
                            trust_remote_code=True,
         | 
| 239 | 
            +
                        ).eval()
         | 
| 240 | 
            +
                    else:
         | 
| 241 | 
            +
                        self.model = AutoModel.from_pretrained(
         | 
| 242 | 
            +
                            self.model_path,
         | 
| 243 | 
            +
                            load_in_8bit=self.load_8bit,
         | 
| 244 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 245 | 
            +
                            trust_remote_code=True,
         | 
| 246 | 
            +
                        ).eval()
         | 
| 247 | 
            +
                    if not self.load_8bit and not self.device == "auto":
         | 
| 248 | 
            +
                        self.model = self.model.cuda()
         | 
| 249 | 
            +
             | 
| 250 | 
             
                def register_to_controller(self):
         | 
| 251 | 
            +
                    logger.info("Register to controller")
         | 
| 252 |  | 
| 253 | 
            +
                    url = self.controller_addr + "/register_worker"
         | 
| 254 | 
             
                    data = {
         | 
| 255 | 
            +
                        "worker_name": self.worker_addr,
         | 
| 256 | 
            +
                        "check_heart_beat": True,
         | 
| 257 | 
            +
                        "worker_status": self.get_status(),
         | 
| 258 | 
             
                    }
         | 
| 259 | 
             
                    r = requests.post(url, json=data)
         | 
| 260 | 
             
                    assert r.status_code == 200
         | 
| 261 |  | 
| 262 | 
             
                def send_heart_beat(self):
         | 
| 263 | 
            +
                    logger.info(
         | 
| 264 | 
            +
                        f"Send heart beat. Models: {[self.model_name]}. "
         | 
| 265 | 
            +
                        f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
         | 
| 266 | 
            +
                        f"global_counter: {global_counter}"
         | 
| 267 | 
            +
                    )
         | 
| 268 |  | 
| 269 | 
            +
                    url = self.controller_addr + "/receive_heart_beat"
         | 
| 270 |  | 
| 271 | 
             
                    while True:
         | 
| 272 | 
             
                        try:
         | 
| 273 | 
            +
                            ret = requests.post(
         | 
| 274 | 
            +
                                url,
         | 
| 275 | 
            +
                                json={
         | 
| 276 | 
            +
                                    "worker_name": self.worker_addr,
         | 
| 277 | 
            +
                                    "queue_length": self.get_queue_length(),
         | 
| 278 | 
            +
                                },
         | 
| 279 | 
            +
                                timeout=5,
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
                            exist = ret.json()["exist"]
         | 
| 282 | 
             
                            break
         | 
| 283 | 
             
                        except requests.exceptions.RequestException as e:
         | 
| 284 | 
            +
                            logger.error(f"heart beat error: {e}")
         | 
| 285 | 
             
                        time.sleep(5)
         | 
| 286 |  | 
| 287 | 
             
                    if not exist:
         | 
|  | |
| 291 | 
             
                    if model_semaphore is None:
         | 
| 292 | 
             
                        return 0
         | 
| 293 | 
             
                    else:
         | 
| 294 | 
            +
                        return (
         | 
| 295 | 
            +
                            args.limit_model_concurrency
         | 
| 296 | 
            +
                            - model_semaphore._value
         | 
| 297 | 
            +
                            + (
         | 
| 298 | 
            +
                                len(model_semaphore._waiters)
         | 
| 299 | 
            +
                                if model_semaphore._waiters is not None
         | 
| 300 | 
            +
                                else 0
         | 
| 301 | 
            +
                            )
         | 
| 302 | 
            +
                        )
         | 
| 303 |  | 
| 304 | 
             
                def get_status(self):
         | 
| 305 | 
             
                    return {
         | 
| 306 | 
            +
                        "model_names": [self.model_name],
         | 
| 307 | 
            +
                        "speed": 1,
         | 
| 308 | 
            +
                        "queue_length": self.get_queue_length(),
         | 
| 309 | 
             
                    }
         | 
| 310 |  | 
| 311 | 
            +
                @spaces.GPU
         | 
| 312 | 
             
                @torch.inference_mode()
         | 
| 313 | 
             
                def generate_stream(self, params):
         | 
| 314 | 
            +
                    system_message = params["prompt"][0]["content"]
         | 
| 315 | 
            +
                    send_messages = params["prompt"][1:]
         | 
| 316 | 
            +
                    max_input_tiles = params["max_input_tiles"]
         | 
| 317 | 
            +
                    temperature = params["temperature"]
         | 
| 318 | 
            +
                    top_p = params["top_p"]
         | 
| 319 | 
            +
                    max_new_tokens = params["max_new_tokens"]
         | 
| 320 | 
            +
                    repetition_penalty = params["repetition_penalty"]
         | 
| 321 | 
             
                    do_sample = True if temperature > 0.0 else False
         | 
| 322 |  | 
| 323 | 
            +
                    global_image_cnt = 0
         | 
| 324 | 
             
                    history, pil_images, max_input_tile_list = [], [], []
         | 
| 325 | 
             
                    for message in send_messages:
         | 
| 326 | 
            +
                        if message["role"] == "user":
         | 
| 327 | 
            +
                            prefix = ""
         | 
| 328 | 
            +
                            if "image" in message:
         | 
| 329 | 
             
                                max_input_tile_temp = []
         | 
| 330 | 
            +
                                for image_str in message["image"]:
         | 
| 331 | 
             
                                    pil_images.append(load_image_from_base64(image_str))
         | 
| 332 | 
            +
                                    prefix += f"Image-{global_image_cnt + 1}: <image>\n\n"
         | 
| 333 | 
             
                                    global_image_cnt += 1
         | 
| 334 | 
            +
                                    max_input_tile_temp.append(
         | 
| 335 | 
            +
                                        max(1, max_input_tiles // len(message["image"]))
         | 
| 336 | 
            +
                                    )
         | 
| 337 | 
             
                                if len(max_input_tile_temp) > 0:
         | 
| 338 | 
             
                                    max_input_tile_list.append(max_input_tile_temp)
         | 
| 339 | 
            +
                            content = prefix + message["content"]
         | 
| 340 | 
            +
                            history.append(
         | 
| 341 | 
            +
                                [
         | 
| 342 | 
            +
                                    content,
         | 
| 343 | 
            +
                                ]
         | 
| 344 | 
            +
                            )
         | 
| 345 | 
             
                        else:
         | 
| 346 | 
            +
                            history[-1].append(message["content"])
         | 
| 347 | 
             
                    question, history = history[-1][0], history[:-1]
         | 
| 348 |  | 
| 349 | 
            +
                    if global_image_cnt == 1:
         | 
| 350 | 
            +
                        question = question.replace("Image-1: <image>\n\n", "<image>\n")
         | 
| 351 | 
            +
                        history = [
         | 
| 352 | 
            +
                            [item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
         | 
| 353 | 
            +
                            for item in history
         | 
| 354 | 
            +
                        ]
         | 
| 355 | 
            +
             | 
| 356 | 
             
                    # Create a new list to store processed sublists
         | 
| 357 | 
             
                    flattened_list = []
         | 
| 358 | 
             
                    # Iterate through all but the last sublist in max_input_tile_list and process them
         | 
| 359 | 
             
                    for sublist in max_input_tile_list[:-1]:
         | 
| 360 | 
            +
                        processed_sublist = [1] * len(
         | 
| 361 | 
            +
                            sublist
         | 
| 362 | 
            +
                        )  # Change each element in the sublist to 1
         | 
| 363 | 
            +
                        flattened_list.extend(
         | 
| 364 | 
            +
                            processed_sublist
         | 
| 365 | 
            +
                        )  # Flatten the processed sublist and add to the new list
         | 
| 366 | 
             
                    # If max_input_tile_list is not empty, add the last sublist to the new list
         | 
| 367 | 
             
                    if max_input_tile_list:
         | 
| 368 | 
             
                        flattened_list.extend(max_input_tile_list[-1])
         | 
| 369 | 
             
                    max_input_tile_list = flattened_list
         | 
| 370 | 
            +
                    assert len(max_input_tile_list) == len(
         | 
| 371 | 
            +
                        pil_images
         | 
| 372 | 
            +
                    ), "The number of max_input_tile_list and pil_images should be the same."
         | 
| 373 |  | 
| 374 | 
             
                    old_system_message = self.model.system_message
         | 
| 375 | 
             
                    self.model.system_message = system_message
         | 
| 376 | 
             
                    image_tiles = []
         | 
| 377 | 
             
                    transform = build_transform(input_size=self.image_size)
         | 
| 378 | 
             
                    if len(pil_images) > 0:
         | 
| 379 | 
            +
                        for current_max_input_tiles, pil_image in zip(
         | 
| 380 | 
            +
                            max_input_tile_list, pil_images
         | 
| 381 | 
            +
                        ):
         | 
| 382 | 
             
                            if self.model.config.dynamic_image_size:
         | 
| 383 | 
             
                                tiles = dynamic_preprocess(
         | 
| 384 | 
            +
                                    pil_image,
         | 
| 385 | 
            +
                                    image_size=self.image_size,
         | 
| 386 | 
            +
                                    max_num=current_max_input_tiles,
         | 
| 387 | 
            +
                                    use_thumbnail=self.model.config.use_thumbnail,
         | 
| 388 | 
            +
                                )
         | 
| 389 | 
             
                            else:
         | 
| 390 | 
             
                                tiles = [pil_image]
         | 
| 391 | 
             
                            image_tiles += tiles
         | 
| 392 | 
             
                        pixel_values = [transform(item) for item in image_tiles]
         | 
| 393 | 
            +
                        pixel_values = torch.stack(pixel_values).to(
         | 
| 394 | 
            +
                            self.model.device, dtype=torch.bfloat16
         | 
| 395 | 
            +
                        )
         | 
| 396 | 
            +
                        logger.info(f"Split images to {pixel_values.shape}")
         | 
| 397 | 
             
                    else:
         | 
| 398 | 
             
                        pixel_values = None
         | 
| 399 |  | 
| 400 | 
            +
                    streamer = TextIteratorStreamer(
         | 
| 401 | 
            +
                        self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
             
                    generation_config = dict(
         | 
| 404 | 
             
                        num_beams=1,
         | 
| 405 | 
             
                        max_new_tokens=max_new_tokens,
         | 
|  | |
| 410 | 
             
                        top_p=top_p,
         | 
| 411 | 
             
                        streamer=streamer,
         | 
| 412 | 
             
                    )
         | 
| 413 | 
            +
                    logger.info(f"Generation config: {generation_config}")
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    thread = Thread(
         | 
| 416 | 
            +
                        target=self.model.chat,
         | 
| 417 | 
            +
                        kwargs=dict(
         | 
| 418 | 
             
                            tokenizer=self.tokenizer,
         | 
| 419 | 
             
                            pixel_values=pixel_values,
         | 
| 420 | 
             
                            question=question,
         | 
| 421 | 
             
                            history=history,
         | 
| 422 | 
             
                            return_history=False,
         | 
| 423 | 
             
                            generation_config=generation_config,
         | 
| 424 | 
            +
                        ),
         | 
| 425 | 
            +
                    )
         | 
| 426 | 
            +
                    thread.start()
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    generated_text = ""
         | 
| 429 | 
            +
                    for new_text in streamer:
         | 
| 430 | 
            +
                        generated_text += new_text
         | 
| 431 | 
            +
                        if generated_text.endswith(self.model.conv_template.sep):
         | 
| 432 | 
            +
                            generated_text = generated_text[: -len(self.model.conv_template.sep)]
         | 
| 433 | 
            +
                        yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
         | 
| 434 | 
            +
                    logger.info(
         | 
| 435 | 
            +
                        f"max_input_tile_list: {max_input_tile_list}, history: {history}, "
         | 
| 436 | 
            +
                        f"question: {question}, answer: {generated_text}"
         | 
| 437 | 
            +
                    )
         | 
| 438 | 
            +
                    self.model.system_message = old_system_message
         | 
| 439 |  | 
| 440 | 
             
                def generate_stream_gate(self, params):
         | 
| 441 | 
             
                    try:
         | 
| 442 | 
             
                        for x in self.generate_stream(params):
         | 
| 443 | 
             
                            yield x
         | 
| 444 | 
             
                    except ValueError as e:
         | 
| 445 | 
            +
                        print("Caught ValueError:", e)
         | 
| 446 | 
            +
                        traceback.print_exc()
         | 
| 447 | 
             
                        ret = {
         | 
| 448 | 
            +
                            "text": server_error_msg,
         | 
| 449 | 
            +
                            "error_code": 1,
         | 
| 450 | 
             
                        }
         | 
| 451 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 452 | 
             
                    except torch.cuda.CudaError as e:
         | 
| 453 | 
            +
                        traceback.print_exc()
         | 
| 454 | 
            +
                        print("Caught torch.cuda.CudaError:", e)
         | 
| 455 | 
             
                        ret = {
         | 
| 456 | 
            +
                            "text": server_error_msg,
         | 
| 457 | 
            +
                            "error_code": 1,
         | 
| 458 | 
             
                        }
         | 
| 459 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 460 | 
             
                    except Exception as e:
         | 
| 461 | 
            +
                        traceback.print_exc()
         | 
| 462 | 
            +
                        print("Caught Unknown Error", e)
         | 
| 463 | 
             
                        ret = {
         | 
| 464 | 
            +
                            "text": server_error_msg,
         | 
| 465 | 
            +
                            "error_code": 1,
         | 
| 466 | 
             
                        }
         | 
| 467 | 
            +
                        yield json.dumps(ret).encode() + b"\0"
         | 
| 468 |  | 
| 469 |  | 
| 470 | 
             
            app = FastAPI()
         | 
|  | |
| 476 | 
             
                    fn()
         | 
| 477 |  | 
| 478 |  | 
| 479 | 
            +
            @app.post("/worker_generate_stream")
         | 
| 480 | 
             
            async def generate_stream(request: Request):
         | 
| 481 | 
             
                global model_semaphore, global_counter
         | 
| 482 | 
             
                global_counter += 1
         | 
|  | |
| 488 | 
             
                worker.send_heart_beat()
         | 
| 489 | 
             
                generator = worker.generate_stream_gate(params)
         | 
| 490 | 
             
                background_tasks = BackgroundTasks()
         | 
| 491 | 
            +
                background_tasks.add_task(
         | 
| 492 | 
            +
                    partial(release_model_semaphore, fn=worker.send_heart_beat)
         | 
| 493 | 
            +
                )
         | 
| 494 | 
             
                return StreamingResponse(generator, background=background_tasks)
         | 
| 495 |  | 
| 496 |  | 
| 497 | 
            +
            @app.post("/worker_get_status")
         | 
| 498 | 
             
            async def get_status(request: Request):
         | 
| 499 | 
             
                return worker.get_status()
         | 
| 500 |  | 
| 501 |  | 
| 502 | 
            +
            if __name__ == "__main__":
         | 
| 503 | 
             
                parser = argparse.ArgumentParser()
         | 
| 504 | 
            +
                parser.add_argument("--host", type=str, default="0.0.0.0")
         | 
| 505 | 
            +
                parser.add_argument("--port", type=int, default=21002)
         | 
| 506 | 
            +
                parser.add_argument("--worker-url", type=str, default="http://localhost")
         | 
| 507 | 
            +
                parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
         | 
| 508 | 
            +
                parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
         | 
| 509 | 
            +
                parser.add_argument("--model-name", type=str)
         | 
| 510 | 
            +
                parser.add_argument("--device", type=str, default="cuda")
         | 
| 511 | 
            +
                parser.add_argument("--limit-model-concurrency", type=int, default=5)
         | 
| 512 | 
            +
                parser.add_argument("--stream-interval", type=int, default=1)
         | 
| 513 | 
            +
                parser.add_argument("--load-8bit", action="store_true")
         | 
| 514 | 
             
                args = parser.parse_args()
         | 
| 515 | 
            +
                logger.info(f"args: {args}")
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                worker = ModelWorker(
         | 
| 518 | 
            +
                    args.controller_url,
         | 
| 519 | 
            +
                    args.worker_url + f":{args.port}",
         | 
| 520 | 
            +
                    worker_id,
         | 
| 521 | 
            +
                    args.model_path,
         | 
| 522 | 
            +
                    args.model_name,
         | 
| 523 | 
            +
                    args.load_8bit,
         | 
| 524 | 
            +
                    args.device,
         | 
| 525 | 
            +
                )
         | 
| 526 | 
            +
                uvicorn.run(app, host=args.host, port=args.port, log_level="info")
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,14 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            diffusers==0.29.2
         | 
| 2 | 
            +
            fastapi==0.111.1
         | 
| 3 | 
            +
            filelock==3.15.4
         | 
| 4 | 
            +
            fire==0.6.0
         | 
| 5 | 
            +
            gradio==4.38.1
         | 
| 6 | 
            +
            numpy==2.0.1
         | 
| 7 | 
            +
            Pillow==10.4.0
         | 
| 8 | 
            +
            pydantic==2.8.2
         | 
| 9 | 
            +
            Requests==2.32.3
         | 
| 10 | 
            +
            spaces==0.28.3
         | 
| 11 | 
            +
            torch==2.0.1
         | 
| 12 | 
            +
            torchvision==0.15.2
         | 
| 13 | 
            +
            transformers==4.37.2
         | 
| 14 | 
            +
            uvicorn==0.30.3
         | 
    	
        utils.py
    CHANGED
    
    | @@ -1,13 +1,22 @@ | |
|  | |
| 1 | 
             
            import logging
         | 
| 2 | 
             
            import logging.handlers
         | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            import sys
         | 
| 5 | 
            -
             | 
|  | |
|  | |
|  | |
| 6 | 
             
            import requests
         | 
| 7 | 
             
            from constants import LOGDIR
         | 
|  | |
| 8 |  | 
| 9 | 
            -
            server_error_msg =  | 
| 10 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            handler = None
         | 
| 13 |  | 
| @@ -16,8 +25,8 @@ def build_logger(logger_name, logger_filename): | |
| 16 | 
             
                global handler
         | 
| 17 |  | 
| 18 | 
             
                formatter = logging.Formatter(
         | 
| 19 | 
            -
                    fmt= | 
| 20 | 
            -
                    datefmt= | 
| 21 | 
             
                )
         | 
| 22 |  | 
| 23 | 
             
                # Set the format of root handlers
         | 
| @@ -26,12 +35,12 @@ def build_logger(logger_name, logger_filename): | |
| 26 | 
             
                logging.getLogger().handlers[0].setFormatter(formatter)
         | 
| 27 |  | 
| 28 | 
             
                # Redirect stdout and stderr to loggers
         | 
| 29 | 
            -
                stdout_logger = logging.getLogger( | 
| 30 | 
             
                stdout_logger.setLevel(logging.INFO)
         | 
| 31 | 
             
                sl = StreamToLogger(stdout_logger, logging.INFO)
         | 
| 32 | 
             
                sys.stdout = sl
         | 
| 33 |  | 
| 34 | 
            -
                stderr_logger = logging.getLogger( | 
| 35 | 
             
                stderr_logger.setLevel(logging.ERROR)
         | 
| 36 | 
             
                sl = StreamToLogger(stderr_logger, logging.ERROR)
         | 
| 37 | 
             
                sys.stderr = sl
         | 
| @@ -45,7 +54,8 @@ def build_logger(logger_name, logger_filename): | |
| 45 | 
             
                    os.makedirs(LOGDIR, exist_ok=True)
         | 
| 46 | 
             
                    filename = os.path.join(LOGDIR, logger_filename)
         | 
| 47 | 
             
                    handler = logging.handlers.TimedRotatingFileHandler(
         | 
| 48 | 
            -
                        filename, when= | 
|  | |
| 49 | 
             
                    handler.setFormatter(formatter)
         | 
| 50 |  | 
| 51 | 
             
                    for name, item in logging.root.manager.loggerDict.items():
         | 
| @@ -59,33 +69,34 @@ class StreamToLogger(object): | |
| 59 | 
             
                """
         | 
| 60 | 
             
                Fake file-like stream object that redirects writes to a logger instance.
         | 
| 61 | 
             
                """
         | 
|  | |
| 62 | 
             
                def __init__(self, logger, log_level=logging.INFO):
         | 
| 63 | 
             
                    self.terminal = sys.stdout
         | 
| 64 | 
             
                    self.logger = logger
         | 
| 65 | 
             
                    self.log_level = log_level
         | 
| 66 | 
            -
                    self.linebuf =  | 
| 67 |  | 
| 68 | 
             
                def __getattr__(self, attr):
         | 
| 69 | 
             
                    return getattr(self.terminal, attr)
         | 
| 70 |  | 
| 71 | 
             
                def write(self, buf):
         | 
| 72 | 
             
                    temp_linebuf = self.linebuf + buf
         | 
| 73 | 
            -
                    self.linebuf =  | 
| 74 | 
             
                    for line in temp_linebuf.splitlines(True):
         | 
| 75 | 
             
                        # From the io.TextIOWrapper docs:
         | 
| 76 | 
             
                        #   On output, if newline is None, any '\n' characters written
         | 
| 77 | 
             
                        #   are translated to the system default line separator.
         | 
| 78 | 
             
                        # By default sys.stdout.write() expects '\n' newlines and then
         | 
| 79 | 
             
                        # translates them so this is still cross platform.
         | 
| 80 | 
            -
                        if line[-1] ==  | 
| 81 | 
             
                            self.logger.log(self.log_level, line.rstrip())
         | 
| 82 | 
             
                        else:
         | 
| 83 | 
             
                            self.linebuf += line
         | 
| 84 |  | 
| 85 | 
             
                def flush(self):
         | 
| 86 | 
            -
                    if self.linebuf !=  | 
| 87 | 
             
                        self.logger.log(self.log_level, self.linebuf.rstrip())
         | 
| 88 | 
            -
                    self.linebuf =  | 
| 89 |  | 
| 90 |  | 
| 91 | 
             
            def disable_torch_init():
         | 
| @@ -93,23 +104,26 @@ def disable_torch_init(): | |
| 93 | 
             
                Disable the redundant torch default initialization to accelerate model creation.
         | 
| 94 | 
             
                """
         | 
| 95 | 
             
                import torch
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                setattr(torch.nn. | 
|  | |
| 98 |  | 
| 99 |  | 
| 100 | 
             
            def violates_moderation(text):
         | 
| 101 | 
             
                """
         | 
| 102 | 
             
                Check whether the text violates OpenAI moderation API.
         | 
| 103 | 
             
                """
         | 
| 104 | 
            -
                url =  | 
| 105 | 
            -
                headers = { | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
                 | 
| 109 | 
            -
                 | 
|  | |
|  | |
| 110 | 
             
                try:
         | 
| 111 | 
             
                    ret = requests.post(url, headers=headers, data=data, timeout=5)
         | 
| 112 | 
            -
                    flagged = ret.json()[ | 
| 113 | 
             
                except requests.exceptions.RequestException as e:
         | 
| 114 | 
             
                    flagged = False
         | 
| 115 | 
             
                except KeyError as e:
         | 
| @@ -120,5 +134,30 @@ def violates_moderation(text): | |
| 120 |  | 
| 121 | 
             
            def pretty_print_semaphore(semaphore):
         | 
| 122 | 
             
                if semaphore is None:
         | 
| 123 | 
            -
                    return  | 
| 124 | 
            -
                return f | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from ast import Dict
         | 
| 2 | 
             
            import logging
         | 
| 3 | 
             
            import logging.handlers
         | 
| 4 | 
             
            import os
         | 
| 5 | 
             
            import sys
         | 
| 6 | 
            +
            import base64
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from io import BytesIO
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
             
            import requests
         | 
| 11 | 
             
            from constants import LOGDIR
         | 
| 12 | 
            +
            import datetime
         | 
| 13 |  | 
| 14 | 
            +
            server_error_msg = (
         | 
| 15 | 
            +
                "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
            moderation_msg = (
         | 
| 18 | 
            +
                "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
         | 
| 19 | 
            +
            )
         | 
| 20 |  | 
| 21 | 
             
            handler = None
         | 
| 22 |  | 
|  | |
| 25 | 
             
                global handler
         | 
| 26 |  | 
| 27 | 
             
                formatter = logging.Formatter(
         | 
| 28 | 
            +
                    fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
         | 
| 29 | 
            +
                    datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 30 | 
             
                )
         | 
| 31 |  | 
| 32 | 
             
                # Set the format of root handlers
         | 
|  | |
| 35 | 
             
                logging.getLogger().handlers[0].setFormatter(formatter)
         | 
| 36 |  | 
| 37 | 
             
                # Redirect stdout and stderr to loggers
         | 
| 38 | 
            +
                stdout_logger = logging.getLogger("stdout")
         | 
| 39 | 
             
                stdout_logger.setLevel(logging.INFO)
         | 
| 40 | 
             
                sl = StreamToLogger(stdout_logger, logging.INFO)
         | 
| 41 | 
             
                sys.stdout = sl
         | 
| 42 |  | 
| 43 | 
            +
                stderr_logger = logging.getLogger("stderr")
         | 
| 44 | 
             
                stderr_logger.setLevel(logging.ERROR)
         | 
| 45 | 
             
                sl = StreamToLogger(stderr_logger, logging.ERROR)
         | 
| 46 | 
             
                sys.stderr = sl
         | 
|  | |
| 54 | 
             
                    os.makedirs(LOGDIR, exist_ok=True)
         | 
| 55 | 
             
                    filename = os.path.join(LOGDIR, logger_filename)
         | 
| 56 | 
             
                    handler = logging.handlers.TimedRotatingFileHandler(
         | 
| 57 | 
            +
                        filename, when="D", utc=True
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
             
                    handler.setFormatter(formatter)
         | 
| 60 |  | 
| 61 | 
             
                    for name, item in logging.root.manager.loggerDict.items():
         | 
|  | |
| 69 | 
             
                """
         | 
| 70 | 
             
                Fake file-like stream object that redirects writes to a logger instance.
         | 
| 71 | 
             
                """
         | 
| 72 | 
            +
             | 
| 73 | 
             
                def __init__(self, logger, log_level=logging.INFO):
         | 
| 74 | 
             
                    self.terminal = sys.stdout
         | 
| 75 | 
             
                    self.logger = logger
         | 
| 76 | 
             
                    self.log_level = log_level
         | 
| 77 | 
            +
                    self.linebuf = ""
         | 
| 78 |  | 
| 79 | 
             
                def __getattr__(self, attr):
         | 
| 80 | 
             
                    return getattr(self.terminal, attr)
         | 
| 81 |  | 
| 82 | 
             
                def write(self, buf):
         | 
| 83 | 
             
                    temp_linebuf = self.linebuf + buf
         | 
| 84 | 
            +
                    self.linebuf = ""
         | 
| 85 | 
             
                    for line in temp_linebuf.splitlines(True):
         | 
| 86 | 
             
                        # From the io.TextIOWrapper docs:
         | 
| 87 | 
             
                        #   On output, if newline is None, any '\n' characters written
         | 
| 88 | 
             
                        #   are translated to the system default line separator.
         | 
| 89 | 
             
                        # By default sys.stdout.write() expects '\n' newlines and then
         | 
| 90 | 
             
                        # translates them so this is still cross platform.
         | 
| 91 | 
            +
                        if line[-1] == "\n":
         | 
| 92 | 
             
                            self.logger.log(self.log_level, line.rstrip())
         | 
| 93 | 
             
                        else:
         | 
| 94 | 
             
                            self.linebuf += line
         | 
| 95 |  | 
| 96 | 
             
                def flush(self):
         | 
| 97 | 
            +
                    if self.linebuf != "":
         | 
| 98 | 
             
                        self.logger.log(self.log_level, self.linebuf.rstrip())
         | 
| 99 | 
            +
                    self.linebuf = ""
         | 
| 100 |  | 
| 101 |  | 
| 102 | 
             
            def disable_torch_init():
         | 
|  | |
| 104 | 
             
                Disable the redundant torch default initialization to accelerate model creation.
         | 
| 105 | 
             
                """
         | 
| 106 | 
             
                import torch
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
         | 
| 109 | 
            +
                setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
         | 
| 110 |  | 
| 111 |  | 
| 112 | 
             
            def violates_moderation(text):
         | 
| 113 | 
             
                """
         | 
| 114 | 
             
                Check whether the text violates OpenAI moderation API.
         | 
| 115 | 
             
                """
         | 
| 116 | 
            +
                url = "https://api.openai.com/v1/moderations"
         | 
| 117 | 
            +
                headers = {
         | 
| 118 | 
            +
                    "Content-Type": "application/json",
         | 
| 119 | 
            +
                    "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
         | 
| 120 | 
            +
                }
         | 
| 121 | 
            +
                text = text.replace("\n", "")
         | 
| 122 | 
            +
                data = "{" + '"input": ' + f'"{text}"' + "}"
         | 
| 123 | 
            +
                data = data.encode("utf-8")
         | 
| 124 | 
             
                try:
         | 
| 125 | 
             
                    ret = requests.post(url, headers=headers, data=data, timeout=5)
         | 
| 126 | 
            +
                    flagged = ret.json()["results"][0]["flagged"]
         | 
| 127 | 
             
                except requests.exceptions.RequestException as e:
         | 
| 128 | 
             
                    flagged = False
         | 
| 129 | 
             
                except KeyError as e:
         | 
|  | |
| 134 |  | 
| 135 | 
             
            def pretty_print_semaphore(semaphore):
         | 
| 136 | 
             
                if semaphore is None:
         | 
| 137 | 
            +
                    return "None"
         | 
| 138 | 
            +
                return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def load_image_from_base64(image):
         | 
| 142 | 
            +
                return Image.open(BytesIO(base64.b64decode(image)))
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def get_log_filename():
         | 
| 146 | 
            +
                t = datetime.datetime.now()
         | 
| 147 | 
            +
                name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
         | 
| 148 | 
            +
                return name
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def data_wrapper(data):
         | 
| 152 | 
            +
                if isinstance(data, bytes):
         | 
| 153 | 
            +
                    return data
         | 
| 154 | 
            +
                elif isinstance(data, Image.Image):
         | 
| 155 | 
            +
                    buffered = BytesIO()
         | 
| 156 | 
            +
                    data.save(buffered, format="PNG")
         | 
| 157 | 
            +
                    return buffered.getvalue()
         | 
| 158 | 
            +
                elif isinstance(data, str):
         | 
| 159 | 
            +
                    return data.encode()
         | 
| 160 | 
            +
                elif isinstance(data, Dict):
         | 
| 161 | 
            +
                    return json.dumps(data).encode()
         | 
| 162 | 
            +
                else:
         | 
| 163 | 
            +
                    raise ValueError(f"Unsupported data type: {type(data)}")
         | 
 
			
