File size: 7,832 Bytes
75a2cec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

import keras_nlp
import keras
import tensorflow.data as tf_data
import pickle
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import datetime
import random
import re
from sacrebleu.metrics import CHRF
import time
# from keras import ops
#hyperparameters
MAX_SEQUENCE_LENGTH = 64

transformer = keras.models.load_model('models_europarl/en_cs_translator_saved_20231209_0046.keras')
def read_files(path, lowercase = False):
    with open(path, "r", encoding="utf-8") as f:
        dataset_split = f.read().split("\n")[:-1]
    #to lowercase, idk why
    if(lowercase):
        dataset_split = [line.lower() for line in dataset_split]
    return dataset_split

en_vocab = read_files("tokenizers/en_europarl_vocab")
cs_vocab = read_files("tokenizers/cs_europarl_vocab")
en_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=en_vocab, 
    lowercase=False
)
cs_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=cs_vocab, 
    lowercase=False
)

def compute_probabilities(logits):
    return keras.activations.softmax(logits)

def next_token_logits(encoder_input_tokens, prompt, predicted_token_index):
    logits =  transformer(
        [tf.expand_dims(encoder_input_tokens, axis=0), tf.expand_dims(prompt, axis=0)]
    )[:, predicted_token_index-1, :] #we need prediction for next token, which is on index of last generated token
    return logits


def greedy_decode(encoder_input_tokens, prompt, end_token_id):
    
    start_index = 1
    current_prompt = prompt
    for predicted_token_index in range(start_index, MAX_SEQUENCE_LENGTH):
        next_logits = next_token_logits(encoder_input_tokens, current_prompt, predicted_token_index)
        next_probabilities = compute_probabilities(next_logits)
        max_probability_token_id = tf.argmax(next_probabilities, axis=-1) #index in logits array is equal to id
        indices = tf.constant([[predicted_token_index]])
        data = tf.constant([max_probability_token_id.numpy()[0]])
        current_prompt = tf.tensor_scatter_nd_update(current_prompt, indices, data)
        #generated end token
        if max_probability_token_id == end_token_id:
            break
    return current_prompt
    


def beam_decode(encoder_input_tokens, prompt, end_token_id, beam_size):
    start_index = 1
    #initial beam
    next_logits = next_token_logits(encoder_input_tokens, prompt, start_index)
    next_probabilities = compute_probabilities(next_logits)
    top_k_probabilities, top_k_token_indices = tf.math.top_k(next_probabilities, k=beam_size)
    current_subsequencies = []
    for index, value in enumerate(top_k_token_indices.numpy()[0]):
        #add to current subsequencies 5 versions of prompt with top k tokens on index 1
        indices = tf.constant([[start_index]])
        data = tf.constant([value])
        current_prompt = tf.tensor_scatter_nd_update(prompt, indices, data)
        #add potential subsequence with its log probability and length-normalized log probability (here length = 1, so its same)
        log_prob = tf.math.log(top_k_probabilities.numpy()[0][index])
        current_subsequencies.append((current_prompt, log_prob, log_prob))

    final_potential_solutions = []
    for predicted_token_index in range(start_index+1, MAX_SEQUENCE_LENGTH):
        #solutions which generated end token
        if len(final_potential_solutions) == beam_size:
            break

        tmp_subsequencies = []
        for index, (subseq_prompt, subseq_log_probability, _) in enumerate(current_subsequencies):
            next_logits = next_token_logits(encoder_input_tokens, subseq_prompt, predicted_token_index)
            next_probabilities = compute_probabilities(next_logits)
            top_k_probabilities, top_k_token_indices = tf.math.top_k(next_probabilities, k=beam_size-len(final_potential_solutions))
            for index, value in enumerate(top_k_token_indices.numpy()[0]):
                #add to current subsequencies 5 versions of prompt with top k tokens on index 1
                indices = tf.constant([[predicted_token_index]])
                data = tf.constant([value])
                updated_subseq_prompt = tf.tensor_scatter_nd_update(subseq_prompt, indices, data)
                #add potential subsequence with its log probability
                nextLogProbability = tf.math.log(top_k_probabilities.numpy()[0][index])
                tmp_subsequencies.append((updated_subseq_prompt, subseq_log_probability + nextLogProbability, (subseq_log_probability + nextLogProbability)/(predicted_token_index+1)))
        
        current_subsequencies = []
        current_sequences_to_find = beam_size - len(final_potential_solutions)
        tmp_subsequencies = sorted(tmp_subsequencies, key=lambda x: x[2], reverse=True)
        for i in range(current_sequences_to_find):
            if tmp_subsequencies[i][0][predicted_token_index] == end_token_id:
                final_potential_solutions.append(tmp_subsequencies[i])
            else:
                current_subsequencies.append(tmp_subsequencies[i])
    
    #get best 
    final_potential_solutions = sorted(final_potential_solutions, key=lambda x: x[2], reverse=True)

    if len(final_potential_solutions) > 0:
        return final_potential_solutions[0][0]
    #didnt generate any probable sequence to end
    else:
        sorted_subs = sorted(current_subsequencies, key=lambda x: x[2], reverse=True)
        return sorted_subs[0][0]


def decode_sequences(input_sentence):

    # Tokenize the encoder input.
    encoder_input_tokens = en_tokenizer(input_sentence)
    # encoder_input_tokens = tf.expand_dims(encoder_input_tokens, axis=0)
    if len(encoder_input_tokens) < MAX_SEQUENCE_LENGTH:
        pads = tf.fill((MAX_SEQUENCE_LENGTH - len(encoder_input_tokens)), 0)
        encoder_input_tokens = tf.concat([encoder_input_tokens, pads], 0)
    if len(encoder_input_tokens) > MAX_SEQUENCE_LENGTH:
        tensor_content = "[START] Exceeded. [END] [PAD] [PAD] [PAD] [PAD]"
        tensor = tf.constant([tensor_content], dtype=tf.string)
        return tensor

    start = tf.fill((1), cs_tokenizer.token_to_id("[START]"))
    pads = tf.fill((MAX_SEQUENCE_LENGTH - 1), cs_tokenizer.token_to_id("[PAD]"))
    prompt = tf.concat((start, pads), axis=-1)

    end_token_id = cs_tokenizer.token_to_id("[END]")

    generated_tokens = greedy_decode(encoder_input_tokens, prompt, end_token_id)
    # generated_tokens = beam_decode(encoder_input_tokens, prompt, end_token_id, 5)
    
    generated_sentences = cs_tokenizer.detokenize(tf.expand_dims(generated_tokens, axis=0))
    return generated_sentences


test_en = read_files('datasets/europarl/test-cs-en.en')
test_cs = read_files('datasets/europarl/test-cs-en.cs')
bleu_metrics = keras_nlp.metrics.Bleu(
    name="bleu", 
    tokenizer = cs_tokenizer
)

eval_samples = 100
chrf = CHRF() 
refs = test_cs[:eval_samples]
translations = []
start_time = time.time()

for i in range(len(refs)):

    cs_translated = decode_sequences(test_en[i])
    cs_translated = cs_translated.numpy()[0].decode("utf-8")
    cs_translated = (
        cs_translated.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )
    #remove spaces before interpunction
    cs_translated = re.sub(r'\s+([.,;!?:])', r'\1', cs_translated)
    print(cs_translated, flush=True)
    translations.append(cs_translated)

end_time = time.time()




refs_twodim = [[ref] for ref in refs]
bleu_metrics(refs_twodim, translations)

print("evaluating chrf", flush=True)
chrf2_result = chrf.corpus_score(translations, refs_twodim)

print("chrf2")
print(chrf2_result)
print("bleu")
print(bleu_metrics.result().numpy())
print("elapsed time")
elapsed_time = end_time - start_time
print(elapsed_time)