Update README.md
Browse files
README.md
CHANGED
|
@@ -21,7 +21,7 @@ EAST was applied to a multimodal language model with RadGraph as the reward. Oth
|
|
| 21 |
- Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
|
| 22 |
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
```python
|
| 27 |
import torch
|
|
@@ -42,14 +42,23 @@ transforms = v2.Compose(
|
|
| 42 |
]
|
| 43 |
)
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
output_ids = model.generate(
|
| 48 |
-
pixel_values=images,
|
| 49 |
max_length=512,
|
| 50 |
-
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
|
| 51 |
num_beams=4,
|
| 52 |
use_cache=True,
|
|
|
|
| 53 |
)
|
| 54 |
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
|
| 55 |
```
|
|
|
|
| 21 |
- Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
|
| 22 |
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
|
| 23 |
|
| 24 |
+
## Example:
|
| 25 |
|
| 26 |
```python
|
| 27 |
import torch
|
|
|
|
| 42 |
]
|
| 43 |
)
|
| 44 |
|
| 45 |
+
dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']
|
| 46 |
+
|
| 47 |
+
def transform_batch(batch):
|
| 48 |
+
batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
|
| 49 |
+
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
| 50 |
+
return batch
|
| 51 |
+
|
| 52 |
+
dataset = dataset.with_transform(transform_batch)
|
| 53 |
+
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
|
| 54 |
+
batch = next(iter(dataloader))
|
| 55 |
|
| 56 |
output_ids = model.generate(
|
| 57 |
+
pixel_values=batch['images'],
|
| 58 |
max_length=512,
|
|
|
|
| 59 |
num_beams=4,
|
| 60 |
use_cache=True,
|
| 61 |
+
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
|
| 62 |
)
|
| 63 |
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
|
| 64 |
```
|