Text Generation
Transformers
PyTorch
bloom
text-generation-inference
mrm8488 commited on
Commit
2875eb9
1 Parent(s): 92b2ebf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +229 -1
README.md CHANGED
@@ -1,3 +1,231 @@
1
  ---
2
- license: wtfpl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ inference: false
3
+ license: bigscience-bloom-rail-1.0
4
+ language:
5
+ - ak
6
+ - ar
7
+ - as
8
+ - bm
9
+ - bn
10
+ - ca
11
+ - en
12
+ - es
13
+ - eu
14
+ - fon
15
+ - fr
16
+ - gu
17
+ - hi
18
+ - id
19
+ - ig
20
+ - ki
21
+ - kn
22
+ - lg
23
+ - ln
24
+ - ml
25
+ - mr
26
+ - ne
27
+ - nso
28
+ - ny
29
+ - or
30
+ - pa
31
+ - pt
32
+ - rn
33
+ - rw
34
+ - sn
35
+ - st
36
+ - sw
37
+ - ta
38
+ - te
39
+ - tn
40
+ - ts
41
+ - tum
42
+ - tw
43
+ - ur
44
+ - vi
45
+ - wo
46
+ - xh
47
+ - yo
48
+ - zh
49
+ - zu
50
+ pipeline_tag: text-generation
51
  ---
52
+ ### Quantized bigscience/bloom 6B3 with 8-bit weights
53
+
54
+ Heavily inspired by [Hivemind's GPT-J-6B with 8-bit weights](https://huggingface.co/hivemind/gpt-j-6B-8bit), this is a version of [bigscience/bloom](https://huggingface.co/bigscience/bloom-6b3) a ~6 billion parameters language model that you run and fine-tune with less memory.
55
+
56
+ Here, we also apply [LoRA (Low Rank Adaptation)](https://arxiv.org/abs/2106.09685) to reduce model size. The original version takes \~353GB memory, this version takes **\~180GB**.
57
+
58
+ Our main goal is to generate a model compressed enough to be deployed in a traditional Kubernetes cluster.
59
+
60
+ ### How to fine-tune
61
+
62
+ In this [notebook](https://nbviewer.org/urls/huggingface.co/joaoalvarenga/bloom-8bit/raw/main/fine-tuning-example.ipynb) you can find an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es) to fine-tune Bloom 8-bit with a 3x NVIDIA A100 instance.
63
+
64
+ ### How to use
65
+
66
+ This model can be used by adapting Bloom original implementation. This is an adaptation from [Hivemind's GPT-J 8-bit](https://nbviewer.org/urls/huggingface.co/hivemind/gpt-j-6B-8bit/raw/main/convert-gpt-j.ipynb):
67
+
68
+ ```python
69
+ import transformers
70
+ import torch
71
+ import torch.nn as nn
72
+ import torch.nn.functional as F
73
+
74
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
75
+ from typing import Tuple
76
+ from torch.cuda.amp import custom_fwd, custom_bwd
77
+
78
+ class FrozenBNBLinear(nn.Module):
79
+ def __init__(self, weight, absmax, code, bias=None):
80
+ assert isinstance(bias, nn.Parameter) or bias is None
81
+ super().__init__()
82
+ self.out_features, self.in_features = weight.shape
83
+ self.register_buffer("weight", weight.requires_grad_(False))
84
+ self.register_buffer("absmax", absmax.requires_grad_(False))
85
+ self.register_buffer("code", code.requires_grad_(False))
86
+ self.adapter = None
87
+ self.bias = bias
88
+
89
+ def forward(self, input):
90
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
91
+ if self.adapter:
92
+ output += self.adapter(input)
93
+ return output
94
+
95
+ @classmethod
96
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
97
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
98
+ return cls(weights_int8, *state, linear.bias)
99
+
100
+ def __repr__(self):
101
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
102
+
103
+
104
+ class DequantizeAndLinear(torch.autograd.Function):
105
+ @staticmethod
106
+ @custom_fwd
107
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
108
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
109
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
110
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
111
+ ctx._has_bias = bias is not None
112
+ return F.linear(input, weights_deq, bias)
113
+
114
+ @staticmethod
115
+ @custom_bwd
116
+ def backward(ctx, grad_output: torch.Tensor):
117
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
118
+ input, weights_quantized, absmax, code = ctx.saved_tensors
119
+ # grad_output: [*batch, out_features]
120
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
121
+ grad_input = grad_output @ weights_deq
122
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
123
+ return grad_input, None, None, None, grad_bias
124
+
125
+
126
+ class FrozenBNBEmbedding(nn.Module):
127
+ def __init__(self, weight, absmax, code):
128
+ super().__init__()
129
+ self.num_embeddings, self.embedding_dim = weight.shape
130
+ self.register_buffer("weight", weight.requires_grad_(False))
131
+ self.register_buffer("absmax", absmax.requires_grad_(False))
132
+ self.register_buffer("code", code.requires_grad_(False))
133
+ self.adapter = None
134
+
135
+ def forward(self, input, **kwargs):
136
+ with torch.no_grad():
137
+ # note: both quantuized weights and input indices are *not* differentiable
138
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
139
+ output = F.embedding(input, weight_deq, **kwargs)
140
+ if self.adapter:
141
+ output += self.adapter(input)
142
+ return output
143
+
144
+ @classmethod
145
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
146
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
147
+ return cls(weights_int8, *state)
148
+
149
+ def __repr__(self):
150
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
151
+
152
+
153
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
154
+ assert chunk_size % 4096 == 0
155
+ code = None
156
+ chunks = []
157
+ absmaxes = []
158
+ flat_tensor = matrix.view(-1)
159
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
160
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
161
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
162
+ chunks.append(quantized_chunk)
163
+ absmaxes.append(absmax_chunk)
164
+
165
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
166
+ absmax = torch.cat(absmaxes)
167
+ return matrix_i8, (absmax, code)
168
+
169
+
170
+ def convert_to_int8(model):
171
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
172
+ for module in list(model.modules()):
173
+ for name, child in module.named_children():
174
+ if isinstance(child, nn.Linear):
175
+ print(name, child)
176
+ setattr(
177
+ module,
178
+ name,
179
+ FrozenBNBLinear(
180
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
181
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
182
+ code=torch.zeros(256),
183
+ bias=child.bias,
184
+ ),
185
+ )
186
+ elif isinstance(child, nn.Embedding):
187
+ setattr(
188
+ module,
189
+ name,
190
+ FrozenBNBEmbedding(
191
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
192
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
193
+ code=torch.zeros(256),
194
+ )
195
+ )
196
+
197
+ class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):
198
+ def __init__(self, config, layer_number=None):
199
+ super().__init__(config, layer_number)
200
+
201
+ convert_to_int8(self.self_attention)
202
+ convert_to_int8(self.mlp)
203
+
204
+
205
+ class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):
206
+ def __init__(self, config):
207
+ super().__init__(config)
208
+ convert_to_int8(self)
209
+
210
+
211
+ class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):
212
+ def __init__(self, config):
213
+ super().__init__(config)
214
+ convert_to_int8(self)
215
+
216
+ transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock
217
+
218
+ model_name = 'mrm8488/bloom-6b3-8bit'
219
+ model = BloomForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
220
+ tokenizer = BloomTokenizerFast.from_pretrained(model_name)
221
+
222
+ prompt = tokenizer("Given a table named salaries and columns id, created_at, salary, age. Creates a SQL to answer What is the average salary for 22 years old:", return_tensors='pt')
223
+ out = model.generate(**prompt, min_length=10, do_sample=True)
224
+ tokenizer.decode(out[0])
225
+ ```
226
+
227
+
228
+
229
+
230
+
231
+