File size: 5,780 Bytes
3ce03a9 b1ebecb b1a76c5 b1ebecb 2f7f445 b1ebecb 2786a65 b1ebecb 2786a65 9d5aa9c 2786a65 9d5aa9c 4dd2328 aa3ead3 992672e aa3ead3 3ce03a9 896e32d 3ce03a9 3b0299d 3ce03a9 004aa8e 3ce03a9 f172ea6 2fc8469 f172ea6 2666d26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
---
datasets:
- Mir-2002/python-google-style-docstrings
language:
- en
metrics:
- bleu
- rouge
base_model:
- Salesforce/codet5p-220m-bimodal
pipeline_tag: summarization
tags:
- code
---
# Overview
This is a fine tuned CodeT5+ (220m) bimodal model tuned on a dataset consisting of 59,000 Python code-docstring pairs. The docstrings are in Google style format.
A google style docstring is formatted as follows:
```
<Description of the code>
Args:
<var1> (<data-type>) : <description of var1>
<var2> (<data_type>) : <description of var2>
Returns:
<var3> (<data-type>) : <description of var3>
Raises:
<var4> (<data-type>) : <description of var4>
```
For more information on my dataset, please see the included referenced dataset.
You can test the model using this:
```python
from transformers import T5ForConditionalGeneration, AutoTokenizer
checkpoint = "Mir-2002/codet5p-google-style-docstrings"
device = "cuda" # or CPU
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint).to(device)
input = """
def calculate_sum(a, b):
return a + b
"""
inputs = tokenizer.encode(input, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_length=128,
num_beams=8,
early_stopping=True,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Calculate the sum of two numbers.
# Args:
# a (int): The first number.
# b (int): The second number.
```
# Fine tuning
In fine tuning the model, i used the special token `<tdec>`. According to CodeT5+'s paper:
" Specifically, when the input is a text
sample, we prepend a [CDec] token to the input
sequence to the decoder. In this case, the decoder
operates under code generation functionality. Alternatively, when the input is a code sample, we
prepend a [TDec] token to the input sequence to
the decoder. The decoder operates under text generation functionality in this case. This type of Causal
LM has been shown to be an effective learning
objective to close the pretrain-finetune gap for generative downstream tasks"
Generally speaking, the `<tdec>` token was prepended to the target (the docstring) to signal to the decoder that it is in a text generation functionality. A sample row looks like this:
```
<s><tdec> Creates a task that to retry a previously abandoned task.
Returns:
Task: a task that was abandoned but should be retried or None if there are
no abandoned tasks that should be retried.</s>
```
This helps the decoder know under what downstream task it is currently being fine tuned in, improving the process. However, the paper doesn't clearly define whether or not the token
is already included in the tokenizer's vocabulary. For safe measure, i manually included the token in the tokenizer's vocabulary using this script:
```python
from transformers import AutoTokenizer, T5ForConditionalGeneration
model_name = "Salesforce/codet5p-220m-bimodal"
model_path = "/path/to/your/model"
import os
os.makedirs(model_path, exist_ok=True)
# Load base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Add special token(s)
tokenizer.add_special_tokens({"additional_special_tokens": ["<tdec>"]})
# Resize embeddings to match new vocab size
model.resize_token_embeddings(len(tokenizer))
# Save both to a custom directory or just as a runtime
tokenizer.save_pretrained(model_path)
model.save_pretrained(model_path)
```
I then verified the token was added using this script:
```python
print("Token ID for <tdec>:", tokenizer.convert_tokens_to_ids("<tdec>"))
print("Tokenized form of '<tdec>':", tokenizer.tokenize("<tdec>"))
# Token ID for <tdec>: 32103
# Tokenized form of '<tdec>': ['<tdec>']
```
The scripts were run beforehand and the modified model and tokenizer was used during fine tuning.
# Hyperparameters
MAX_SOURCE_LENGTH = 256 <br>
MAX_TARGET_LENGTH = 128 <br>
BATCH_SIZE = 16 <br>
NUM_EPOCHS = 35 <br>
LEARNING_RATE = 3e-5 <br>
GRADIENT_ACCUMULATION_STEPS = 4 <br>
EARLY_STOPPING_PATIENCE = 2 <br>
WEIGHT_DECAY = 0.01 <br>
OPTIMIZER = ADAFACTOR <br>
LR_SCHEDULER = LINEAR <br>
The model was trained on via Colab Pro, on an L4 GPU. A gradient accumulation step of 4 was used to simulate an effective batch size of 64 (16 * 4).
# Loss
On the 35th epoch, the model achieved the following loss:
| Epoch | Training Loss | Validation Loss |
| ----------- | ----------- | ----------- |
| 35 | 0.894800 | 1.268536
# BLEU and ROUGE Scores
| SacreBLEU | ROUGE-1 | ROUGE-2 | ROUGE-L
| ----------- | ----------- | ----------- | ----------- |
| 35.40 | 58.55 | 39.46 | 52.43 |
While a SacreBLEU score of 35 is a moderate score, it is important to consider that docstrings in Google style format vary extremely. Some are outliers that have extra sections
that is usually not included in the general population which leads the model to generate "hallucinations". An example of this is this particular sample:
```
Reference: Validate timestamp specified by request.
See `validate.request` for additional info.
Args:
stamp: str. Time request was made as ISO 8601 timestamp.
tolerance: int. Number of seconds request remains valid from timestamp.
Returns
bool: True if valid, False otherwise.
-----------------------------------------------------------------------
Prediction: Validate timestamp.
Args:
stamp (str): A date string in the format YYYY-MM-DDThh:mm:ss.######[+-]##:##
Returns:
bool: True if valid, False otherwise.
```
As you can see, the model generated gibberish in the prediction's Args section specifically the string format for the date. |