Spaces:
Paused
Paused
add reference code from vllm
Browse files- .gitignore +1 -0
- app.py +34 -74
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
notes.py
|
app.py
CHANGED
|
@@ -14,7 +14,7 @@ import spaces
|
|
| 14 |
import math
|
| 15 |
from typing import List, Optional, Tuple
|
| 16 |
|
| 17 |
-
title = "# 🙋🏻♂️Welcome to Tonic's Pixtral Model Demo"
|
| 18 |
description = """
|
| 19 |
This demo showcases two capabilities of the Pixtral model:
|
| 20 |
1. Image-to-Text Generation
|
|
@@ -25,6 +25,7 @@ This demo showcases two capabilities of the Pixtral model:
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
|
|
|
|
| 28 |
with open(f'{model_path}/params.json', 'r') as f:
|
| 29 |
params = json.load(f)
|
| 30 |
|
|
@@ -40,32 +41,16 @@ class RMSNorm(nn.Module):
|
|
| 40 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
| 42 |
|
| 43 |
-
def precompute_freqs_cis_2d(
|
| 44 |
-
dim: int,
|
| 45 |
-
height: int,
|
| 46 |
-
width: int,
|
| 47 |
-
theta: float,
|
| 48 |
-
) -> torch.Tensor:
|
| 49 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
| 50 |
-
h = torch.arange(height
|
| 51 |
-
w = torch.arange(width
|
| 52 |
-
|
| 53 |
freqs_h = torch.outer(h, freqs[::2]).float()
|
| 54 |
freqs_w = torch.outer(w, freqs[1::2]).float()
|
| 55 |
-
freqs_2d = torch.cat(
|
| 56 |
-
[
|
| 57 |
-
freqs_h[:, None, :].repeat(1, width, 1),
|
| 58 |
-
freqs_w[None, :, :].repeat(height, 1, 1),
|
| 59 |
-
],
|
| 60 |
-
dim=-1,
|
| 61 |
-
)
|
| 62 |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
| 63 |
|
| 64 |
-
def apply_rotary_emb_vit(
|
| 65 |
-
xq: torch.Tensor,
|
| 66 |
-
xk: torch.Tensor,
|
| 67 |
-
freqs_cis: torch.Tensor,
|
| 68 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 70 |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 71 |
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
|
|
@@ -78,7 +63,6 @@ class Attention(nn.Module):
|
|
| 78 |
super().__init__()
|
| 79 |
self.n_heads = args['num_attention_heads']
|
| 80 |
self.head_dim = args['hidden_size'] // args['num_attention_heads']
|
| 81 |
-
|
| 82 |
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
| 83 |
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
| 84 |
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
|
@@ -86,14 +70,11 @@ class Attention(nn.Module):
|
|
| 86 |
|
| 87 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 88 |
batch, patches, _ = x.shape
|
| 89 |
-
|
| 90 |
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
| 91 |
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
| 92 |
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
| 93 |
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
| 94 |
-
|
| 95 |
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
| 96 |
-
|
| 97 |
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
| 98 |
attn = F.softmax(scores, dim=-1)
|
| 99 |
out = torch.matmul(attn, v)
|
|
@@ -119,9 +100,9 @@ class TransformerBlock(nn.Module):
|
|
| 119 |
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
|
| 120 |
|
| 121 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 122 |
-
r = self.attention
|
| 123 |
h = x + r
|
| 124 |
-
r = self.feed_forward
|
| 125 |
out = h + r
|
| 126 |
return out
|
| 127 |
|
|
@@ -129,16 +110,9 @@ class VisionTransformer(nn.Module):
|
|
| 129 |
def __init__(self, args):
|
| 130 |
super().__init__()
|
| 131 |
self.args = args
|
| 132 |
-
self.patch_conv = nn.Conv2d(
|
| 133 |
-
in_channels=args['num_channels'],
|
| 134 |
-
out_channels=args['hidden_size'],
|
| 135 |
-
kernel_size=args['patch_size'],
|
| 136 |
-
stride=args['patch_size'],
|
| 137 |
-
bias=False,
|
| 138 |
-
)
|
| 139 |
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
|
| 140 |
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
|
| 141 |
-
|
| 142 |
self.max_patches_per_side = args['image_size'] // args['patch_size']
|
| 143 |
self._freqs_cis = None
|
| 144 |
|
|
@@ -157,11 +131,9 @@ class VisionTransformer(nn.Module):
|
|
| 157 |
x = self.patch_conv(x)
|
| 158 |
x = x.flatten(2).transpose(1, 2)
|
| 159 |
x = self.ln_pre(x)
|
| 160 |
-
|
| 161 |
freqs_cis = self.freqs_cis
|
| 162 |
for layer in self.transformer:
|
| 163 |
x = layer(x, freqs_cis=freqs_cis)
|
| 164 |
-
|
| 165 |
return x
|
| 166 |
|
| 167 |
class VisionLanguageAdapter(nn.Module):
|
|
@@ -180,9 +152,7 @@ class PixtralModel(nn.Module):
|
|
| 180 |
self.vision_encoder = VisionTransformer(params['vision_encoder'])
|
| 181 |
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
|
| 182 |
self.language_model = nn.TransformerDecoder(
|
| 183 |
-
nn.TransformerDecoderLayer(d_model=params['dim'],
|
| 184 |
-
nhead=params['n_heads'],
|
| 185 |
-
dim_feedforward=params['hidden_dim']),
|
| 186 |
num_layers=params['n_layers']
|
| 187 |
)
|
| 188 |
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
|
|
@@ -201,12 +171,10 @@ class PixtralModel(nn.Module):
|
|
| 201 |
|
| 202 |
def load_model(params, model_path):
|
| 203 |
model = PixtralModel(params)
|
| 204 |
-
|
| 205 |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
|
| 206 |
for name, param in model.named_parameters():
|
| 207 |
if name in f.keys():
|
| 208 |
param.data = f.get_tensor(name)
|
| 209 |
-
|
| 210 |
model.eval()
|
| 211 |
return model
|
| 212 |
|
|
@@ -224,53 +192,45 @@ def preprocess_image(image):
|
|
| 224 |
@spaces.GPU(duration=120)
|
| 225 |
def generate_text(image, prompt, max_tokens):
|
| 226 |
try:
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
|
| 229 |
tokenized = tokenizer.encode_chat_completion(
|
| 230 |
ChatCompletionRequest(
|
| 231 |
-
messages=[
|
| 232 |
-
UserMessage(
|
| 233 |
-
content=[
|
| 234 |
-
TextChunk(text=prompt),
|
| 235 |
-
ImageChunk(image=image),
|
| 236 |
-
]
|
| 237 |
-
)
|
| 238 |
-
],
|
| 239 |
model="pixtral",
|
| 240 |
)
|
| 241 |
)
|
| 242 |
-
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
model
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
if next_token.item() == tokenizer.eos_token_id:
|
| 252 |
-
break
|
| 253 |
-
model.cpu()
|
| 254 |
|
| 255 |
generated_text = tokenizer.decode(input_ids[0].tolist())
|
| 256 |
-
|
|
|
|
| 257 |
except Exception as e:
|
| 258 |
return f"Error: {str(e)}", 0, 0
|
| 259 |
|
| 260 |
@spaces.GPU(duration=60)
|
| 261 |
def calculate_similarity(image1, image2):
|
| 262 |
try:
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
embedding1 = model(tensor1).mean(dim=1) # Average over spatial dimensions
|
| 269 |
-
embedding2 = model(tensor2).mean(dim=1)
|
| 270 |
-
model.cpu()
|
| 271 |
|
| 272 |
similarity = F.cosine_similarity(embedding1, embedding2).item()
|
| 273 |
-
|
| 274 |
return similarity
|
| 275 |
except Exception as e:
|
| 276 |
return f"Error: {str(e)}"
|
|
@@ -299,7 +259,7 @@ with gr.Blocks() as demo:
|
|
| 299 |
with gr.Column():
|
| 300 |
input_image = gr.Image(type="pil", label="Input Image")
|
| 301 |
input_prompt = gr.Textbox(label="Prompt")
|
| 302 |
-
max_tokens_slider = gr.Slider(minimum=
|
| 303 |
submit_btn = gr.Button("Generate Text")
|
| 304 |
|
| 305 |
with gr.Column():
|
|
|
|
| 14 |
import math
|
| 15 |
from typing import List, Optional, Tuple
|
| 16 |
|
| 17 |
+
title = "# **WIP / DEMO** 🙋🏻♂️Welcome to Tonic's Pixtral Model Demo"
|
| 18 |
description = """
|
| 19 |
This demo showcases two capabilities of the Pixtral model:
|
| 20 |
1. Image-to-Text Generation
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
|
| 28 |
+
|
| 29 |
with open(f'{model_path}/params.json', 'r') as f:
|
| 30 |
params = json.load(f)
|
| 31 |
|
|
|
|
| 41 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
| 43 |
|
| 44 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
| 46 |
+
h = torch.arange(height)
|
| 47 |
+
w = torch.arange(width)
|
|
|
|
| 48 |
freqs_h = torch.outer(h, freqs[::2]).float()
|
| 49 |
freqs_w = torch.outer(w, freqs[1::2]).float()
|
| 50 |
+
freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
| 52 |
|
| 53 |
+
def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 55 |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 56 |
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
|
|
|
|
| 63 |
super().__init__()
|
| 64 |
self.n_heads = args['num_attention_heads']
|
| 65 |
self.head_dim = args['hidden_size'] // args['num_attention_heads']
|
|
|
|
| 66 |
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
| 67 |
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
| 68 |
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
|
|
|
|
| 70 |
|
| 71 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 72 |
batch, patches, _ = x.shape
|
|
|
|
| 73 |
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
| 74 |
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
| 75 |
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
| 76 |
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
|
|
| 77 |
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
|
|
|
| 78 |
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
| 79 |
attn = F.softmax(scores, dim=-1)
|
| 80 |
out = torch.matmul(attn, v)
|
|
|
|
| 100 |
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
|
| 101 |
|
| 102 |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis)
|
| 104 |
h = x + r
|
| 105 |
+
r = self.feed_forward(self.ffn_norm(h))
|
| 106 |
out = h + r
|
| 107 |
return out
|
| 108 |
|
|
|
|
| 110 |
def __init__(self, args):
|
| 111 |
super().__init__()
|
| 112 |
self.args = args
|
| 113 |
+
self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
|
| 115 |
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
|
|
|
|
| 116 |
self.max_patches_per_side = args['image_size'] // args['patch_size']
|
| 117 |
self._freqs_cis = None
|
| 118 |
|
|
|
|
| 131 |
x = self.patch_conv(x)
|
| 132 |
x = x.flatten(2).transpose(1, 2)
|
| 133 |
x = self.ln_pre(x)
|
|
|
|
| 134 |
freqs_cis = self.freqs_cis
|
| 135 |
for layer in self.transformer:
|
| 136 |
x = layer(x, freqs_cis=freqs_cis)
|
|
|
|
| 137 |
return x
|
| 138 |
|
| 139 |
class VisionLanguageAdapter(nn.Module):
|
|
|
|
| 152 |
self.vision_encoder = VisionTransformer(params['vision_encoder'])
|
| 153 |
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
|
| 154 |
self.language_model = nn.TransformerDecoder(
|
| 155 |
+
nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']),
|
|
|
|
|
|
|
| 156 |
num_layers=params['n_layers']
|
| 157 |
)
|
| 158 |
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
|
|
|
|
| 171 |
|
| 172 |
def load_model(params, model_path):
|
| 173 |
model = PixtralModel(params)
|
|
|
|
| 174 |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
|
| 175 |
for name, param in model.named_parameters():
|
| 176 |
if name in f.keys():
|
| 177 |
param.data = f.get_tensor(name)
|
|
|
|
| 178 |
model.eval()
|
| 179 |
return model
|
| 180 |
|
|
|
|
| 192 |
@spaces.GPU(duration=120)
|
| 193 |
def generate_text(image, prompt, max_tokens):
|
| 194 |
try:
|
| 195 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 196 |
+
image_tensor = preprocess_image(image).to(device)
|
| 197 |
+
model.to(device)
|
| 198 |
|
| 199 |
tokenized = tokenizer.encode_chat_completion(
|
| 200 |
ChatCompletionRequest(
|
| 201 |
+
messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
model="pixtral",
|
| 203 |
)
|
| 204 |
)
|
| 205 |
+
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device)
|
| 206 |
+
|
| 207 |
+
for _ in range(max_tokens):
|
| 208 |
+
logits = model(image_tensor, input_ids)
|
| 209 |
+
next_token_logits = logits[0, -1, :]
|
| 210 |
+
next_token = torch.argmax(next_token_logits, dim=-1)
|
| 211 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
|
| 212 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 213 |
+
break
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
generated_text = tokenizer.decode(input_ids[0].tolist())
|
| 216 |
+
# model.to("cpu")
|
| 217 |
+
return generated_text, len(input_ids[0]), 1
|
| 218 |
except Exception as e:
|
| 219 |
return f"Error: {str(e)}", 0, 0
|
| 220 |
|
| 221 |
@spaces.GPU(duration=60)
|
| 222 |
def calculate_similarity(image1, image2):
|
| 223 |
try:
|
| 224 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
+
tensor1 = preprocess_image(image1).to(device)
|
| 226 |
+
tensor2 = preprocess_image(image2).to(device)
|
| 227 |
+
model.to(device)
|
| 228 |
|
| 229 |
+
embedding1 = model(tensor1).mean(dim=1)
|
| 230 |
+
embedding2 = model(tensor2).mean(dim=1)
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
similarity = F.cosine_similarity(embedding1, embedding2).item()
|
| 233 |
+
# model.to("cpu")
|
| 234 |
return similarity
|
| 235 |
except Exception as e:
|
| 236 |
return f"Error: {str(e)}"
|
|
|
|
| 259 |
with gr.Column():
|
| 260 |
input_image = gr.Image(type="pil", label="Input Image")
|
| 261 |
input_prompt = gr.Textbox(label="Prompt")
|
| 262 |
+
max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens")
|
| 263 |
submit_btn = gr.Button("Generate Text")
|
| 264 |
|
| 265 |
with gr.Column():
|