Simplify usage
Browse files
README.md
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
---
|
| 2 |
pipeline_tag: text-generation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
|
| 5 |
## Usage
|
|
@@ -41,8 +45,7 @@ past_key_values = {
|
|
| 41 |
for kv in ('key', 'value')
|
| 42 |
}
|
| 43 |
input_ids = inputs['input_ids']
|
| 44 |
-
|
| 45 |
-
position_ids = np.cumsum(inputs['attention_mask'], axis=-1)
|
| 46 |
|
| 47 |
# 3. Generation loop
|
| 48 |
max_new_tokens = 1024
|
|
@@ -50,14 +53,12 @@ generated_tokens = np.array([[]], dtype=np.int64)
|
|
| 50 |
for i in range(max_new_tokens):
|
| 51 |
logits, *present_key_values = decoder_session.run(None, dict(
|
| 52 |
input_ids=input_ids,
|
| 53 |
-
attention_mask=attention_mask,
|
| 54 |
position_ids=position_ids,
|
| 55 |
**past_key_values,
|
| 56 |
))
|
| 57 |
|
| 58 |
## Update values for next generation loop
|
| 59 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
| 60 |
-
attention_mask = np.ones_like(input_ids)
|
| 61 |
position_ids = position_ids[:, -1:] + 1
|
| 62 |
for j, key in enumerate(past_key_values):
|
| 63 |
past_key_values[key] = present_key_values[j]
|
|
@@ -145,5 +146,4 @@ const messages = [
|
|
| 145 |
// Generate a response
|
| 146 |
const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
|
| 147 |
console.log(output[0].generated_text.at(-1).content);
|
| 148 |
-
```
|
| 149 |
-
|
|
|
|
| 1 |
---
|
| 2 |
pipeline_tag: text-generation
|
| 3 |
+
base_model:
|
| 4 |
+
- google/gemma-3-1b-it
|
| 5 |
+
library_name: transformers.js
|
| 6 |
+
license: gemma
|
| 7 |
---
|
| 8 |
|
| 9 |
## Usage
|
|
|
|
| 45 |
for kv in ('key', 'value')
|
| 46 |
}
|
| 47 |
input_ids = inputs['input_ids']
|
| 48 |
+
position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1))
|
|
|
|
| 49 |
|
| 50 |
# 3. Generation loop
|
| 51 |
max_new_tokens = 1024
|
|
|
|
| 53 |
for i in range(max_new_tokens):
|
| 54 |
logits, *present_key_values = decoder_session.run(None, dict(
|
| 55 |
input_ids=input_ids,
|
|
|
|
| 56 |
position_ids=position_ids,
|
| 57 |
**past_key_values,
|
| 58 |
))
|
| 59 |
|
| 60 |
## Update values for next generation loop
|
| 61 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
|
|
|
| 62 |
position_ids = position_ids[:, -1:] + 1
|
| 63 |
for j, key in enumerate(past_key_values):
|
| 64 |
past_key_values[key] = present_key_values[j]
|
|
|
|
| 146 |
// Generate a response
|
| 147 |
const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
|
| 148 |
console.log(output[0].generated_text.at(-1).content);
|
| 149 |
+
```
|
|
|