IAozora commited on
Commit
4ee8f13
1 Parent(s): 421c1e9

Initial Demo

Browse files
.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ *.json
3
+ .venv/
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ .env
3
+ *.json
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /usr/src/app
4
+
5
+ COPY . .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+ EXPOSE 7860
8
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
9
+
10
+ # Get service account key from HF Spaces' secrets
11
+ # https://huggingface.co/docs/hub/spaces-sdks-docker#buildtime
12
+ RUN --mount=type=secret,id=GCLOUD_SA_JSON,mode=0444,required=true \
13
+ cat /run/secrets/GCLOUD_SA_JSON > /usr/src/app/credentials.json
14
+
15
+ RUN --mount=type=secret,id=GCLOUD_OBJECT_SA_JSON,mode=0444,required=true \
16
+ cat /run/secrets/GCLOUD_OBJECT_SA_JSON > /usr/src/app/credentials_object.json
17
+
18
+ RUN --mount=type=secret,id=BUCKET_NAME,mode=0444,required=true \
19
+ --mount=type=secret,id=CLI_OBJECT_NAME,mode=0444,required=true \
20
+ python download_gcs_object.py --bucket-name $(cat /run/secrets/BUCKET_NAME) --object-name $(cat /run/secrets/CLI_OBJECT_NAME)
21
+
22
+ RUN --mount=type=secret,id=CLI_OBJECT_NAME,mode=0444,required=true \
23
+ tar -xf $(cat /run/secrets/CLI_OBJECT_NAME)
24
+
25
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Llava Calm2 Preview
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: llava-calm2-preview
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # HF Gradio Spaces for llava-calm2
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import subprocess
5
+ from functools import partial
6
+
7
+ import gradio as gr
8
+ import httpx
9
+ from const import BASE_URL, CLI_COMMAND, CSS, FOOTER, HEADER, MODELS, PLACEHOLDER
10
+ from openai import OpenAI
11
+ from PIL import Image
12
+
13
+
14
+ def get_token() -> str:
15
+ return (
16
+ subprocess.run(
17
+ CLI_COMMAND,
18
+ stdout=subprocess.PIPE,
19
+ stderr=subprocess.DEVNULL,
20
+ env=os.environ.copy(),
21
+ )
22
+ .stdout.decode("utf-8")
23
+ .strip()
24
+ )
25
+
26
+
27
+ def get_headers(host: str) -> dict:
28
+ return {
29
+ "Authorization": f"Bearer {get_token()}",
30
+ "Host": host,
31
+ "Accept": "application/json",
32
+ "Content-Type": "application/json",
33
+ }
34
+
35
+
36
+ def proxy(request: httpx.Request, model_info: dict) -> httpx.Request:
37
+ request.url = request.url.copy_with(path=model_info["endpoint"])
38
+ request.headers.update(get_headers(host=model_info["host"]))
39
+ return request
40
+
41
+
42
+ def encode_image_with_pillow(image_path: str) -> str:
43
+ with Image.open(image_path) as img:
44
+ img.thumbnail((384, 384))
45
+ buffered = io.BytesIO()
46
+ img.convert("RGB").save(buffered, format="JPEG")
47
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
48
+
49
+
50
+ def call_chat_api(message, history, model_name):
51
+ if message["files"]:
52
+ if isinstance(message["files"], dict):
53
+ image = message["files"]["path"]
54
+ else:
55
+ image = message["files"][-1]
56
+ else:
57
+ for hist in history:
58
+ if isinstance(hist[0], tuple):
59
+ image = hist[0][0]
60
+
61
+ img_base64 = encode_image_with_pillow(image)
62
+
63
+ history_openai_format = [
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {
68
+ "type": "image_url",
69
+ "image_url": {
70
+ "url": f"data:image/jpeg;base64,{img_base64}",
71
+ },
72
+ },
73
+ ],
74
+ }
75
+ ]
76
+
77
+ if len(history) == 0:
78
+ history_openai_format[0]["content"].append(
79
+ {"type": "text", "text": message["text"]}
80
+ )
81
+ else:
82
+ for human, assistant in history[1:]:
83
+ if len(history_openai_format) == 1:
84
+ history_openai_format[0]["content"].append(
85
+ {"type": "text", "text": human}
86
+ )
87
+ else:
88
+ history_openai_format.append({"role": "user", "content": human})
89
+ history_openai_format.append({"role": "assistant", "content": assistant})
90
+ history_openai_format.append({"role": "user", "content": message["text"]})
91
+
92
+ client = OpenAI(
93
+ api_key="",
94
+ base_url=BASE_URL,
95
+ http_client=httpx.Client(
96
+ event_hooks={
97
+ "request": [partial(proxy, model_info=MODELS[model_name])],
98
+ },
99
+ verify=False,
100
+ ),
101
+ )
102
+
103
+ stream = client.chat.completions.create(
104
+ model=f"/data/cyberagent/{model_name}",
105
+ messages=history_openai_format,
106
+ temperature=0.2,
107
+ top_p=1.0,
108
+ max_tokens=1024,
109
+ stream=True,
110
+ extra_body={"repetition_penalty": 1.1},
111
+ )
112
+
113
+ message = ""
114
+ for chunk in stream:
115
+ content = chunk.choices[0].delta.content or ""
116
+ message = message + content
117
+ yield message
118
+
119
+
120
+ def run():
121
+ chatbot = gr.Chatbot(
122
+ elem_id="chatbot", placeholder=PLACEHOLDER, scale=1, height=700
123
+ )
124
+ chat_input = gr.MultimodalTextbox(
125
+ interactive=True,
126
+ file_types=["image"],
127
+ placeholder="Enter message or upload file...",
128
+ show_label=False,
129
+ )
130
+ with gr.Blocks(css=CSS) as demo:
131
+ gr.Markdown(HEADER)
132
+ with gr.Row():
133
+ model_selector = gr.Dropdown(
134
+ choices=MODELS.keys(),
135
+ value=list(MODELS.keys())[0],
136
+ label="Model",
137
+ )
138
+ gr.ChatInterface(
139
+ fn=call_chat_api,
140
+ stop_btn="Stop Generation",
141
+ examples=[
142
+ [
143
+ {
144
+ "text": "この画像を詳しく説明してください。",
145
+ "files": ["./examples/cat.jpg"],
146
+ },
147
+ ],
148
+ [
149
+ {
150
+ "text": "この料理はどんな味がするか詳しく教えてください。",
151
+ "files": ["./examples/takoyaki.jpg"],
152
+ },
153
+ ],
154
+ ],
155
+ multimodal=True,
156
+ textbox=chat_input,
157
+ chatbot=chatbot,
158
+ additional_inputs=[model_selector],
159
+ )
160
+ gr.Markdown(FOOTER)
161
+ demo.queue().launch(share=False)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ run()
const.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ SECRET_PREFIX = os.environ.get("SECRET_PREFIX", "")
5
+ PROJECT_ID = os.environ.get("PROJECT_ID", "")
6
+ ROLE_SUBJECT = os.environ.get("ROLE_SUBJECT", "")
7
+ CREDENTIALS = os.environ.get("CREDENTIALS", "")
8
+ os.environ[SECRET_PREFIX + "PROJECT_ID"] = PROJECT_ID
9
+ os.environ[SECRET_PREFIX + "ROLE_SUBJECT"] = ROLE_SUBJECT
10
+ os.environ[SECRET_PREFIX + "CREDENTIALS"] = CREDENTIALS
11
+
12
+ BASE_URL = os.environ.get("GCLOUD_BASE_URL", "")
13
+ BASE_ENDPOINT = os.environ.get("GCLOUD_ENDPOINT", "")
14
+ CHATTY_ENDPOINT = os.environ.get("GCLOUD_CHATTY_ENDPOINT", "")
15
+ BASE_HOST = os.environ.get("GCLOUD_HOST", "")
16
+ CHATTY_HOST = os.environ.get("GCLOUD_CHATTY_HOST", "")
17
+ CLI_COMMAND_NAME = os.environ.get("CLI_COMMAND_NAME", "")
18
+ CLI_ARG1 = os.environ.get("CLI_ARG1", "")
19
+ CLI_ARG2 = os.environ.get("CLI_ARG2", "")
20
+ ROOT_DIR = Path(__file__).parent.absolute()
21
+ GCLOUD_BIN = str(ROOT_DIR / CLI_COMMAND_NAME)
22
+ CLI_COMMAND = [GCLOUD_BIN, CLI_ARG1, CLI_ARG2]
23
+
24
+ MODELS = {
25
+ "llava-calm2-siglip-chatty": {"endpoint": CHATTY_ENDPOINT, "host": CHATTY_HOST},
26
+ "llava-calm2-siglip": {"endpoint": BASE_ENDPOINT, "host": BASE_HOST},
27
+ }
28
+
29
+ HEADER = """
30
+ # LLaVA-CALM2-SigLIP
31
+ LLaVA-CALM2-SigLIPは、calm2-7b-chatとsiglip-so400m-patch14-384からファインチューニングされたLLaVAモデルです。
32
+ ## Models
33
+ - **llava-calm2-siglip**: 公開データを用いて学習されたVLM
34
+ - **llava-calm2-siglip-chatty**: よりチャットに最適化するように学習したVLM
35
+ """
36
+
37
+ FOOTER = """
38
+ ## Term of Use
39
+ Please note that by using this service, you agree to the following terms: This model is provided for research purposes only. CyberAgent expressly disclaim any liability for direct, indirect, special, incidental, or consequential damages, as well as for any losses that may result from using this model, regardless of the outcomes. It is essential for users to fully understand these limitations before employing the model.
40
+
41
+ ## License
42
+ The service is a research preview intended for non-commercial use only.
43
+ """
44
+
45
+ PLACEHOLDER = """
46
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
47
+ <img src="https://d23iyfk1a359di.cloudfront.net/files/topics/26317_ext_03_0.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
48
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-CALM2-SigLIP</h1>
49
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">LLaVA-CALM2-SigLIP is a LLaVA model fine-tuned from calm2-7b-chat and siglip-so400m-patch14-384</p>
50
+ </div>
51
+ """
52
+
53
+ CSS = """
54
+ #chatbot {
55
+ height: auto !important;
56
+ max_height: none !important;
57
+ overflow: auto !important;
58
+ flex-grow: 1 !important;
59
+ }
60
+ """
download_gcs_object.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ from google.cloud import storage
5
+ from google.oauth2 import service_account
6
+
7
+
8
+ def download_gcs_object(bucket_name: str, object_name: str):
9
+ with open("/usr/src/app/credentials_object.json", "r") as f:
10
+ credentials_dict = json.load(f)
11
+ credentials = service_account.Credentials.from_service_account_info(
12
+ credentials_dict
13
+ )
14
+ client = storage.Client(
15
+ credentials=credentials, project=credentials_dict["project_id"]
16
+ )
17
+ blob = client.bucket(bucket_name).blob(object_name)
18
+ blob.download_to_filename(object_name)
19
+
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--bucket-name", type=str, required=True)
24
+ parser.add_argument("--object-name", type=str, required=True)
25
+ args = parser.parse_args()
26
+ download_gcs_object(args.bucket_name, args.object_name)
examples/LICENSE.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ## Licenses
2
+
3
+
4
+ | title | filename | source_url | license |
5
+ |------------------------------------------------|---------------|-------------------------------------------------------|---------------|
6
+ | \[フリー写真\] 窓の外を見ている猫の横顔 | cat.jpg | https://publicdomainq.net/cat-watch-animal-photo-0079232/ | public domain |
7
+ | \[フリー写真\] 食べ物のプレートが乗った木製のテーブル | takoyaki.jpg | https://unsplash.com/photos/LipkIP4fXbM | unsplash |
8
+
examples/cat.jpg ADDED
examples/takoyaki.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ httpx
3
+ openai
4
+ pillow
5
+ google-cloud-storage
6
+ google-auth