fix bugs
Browse files- dnaflash.py +3 -2
dnaflash.py
CHANGED
@@ -351,8 +351,9 @@ class FLASHTransformer(nn.Module):
|
|
351 |
|
352 |
for flash in self.layers:
|
353 |
x = flash(x, mask = mask)
|
354 |
-
|
355 |
-
|
|
|
356 |
|
357 |
class FLASHTransformerConfig(PretrainedConfig):
|
358 |
model_type = "flash_transformer"
|
|
|
351 |
|
352 |
for flash in self.layers:
|
353 |
x = flash(x, mask = mask)
|
354 |
+
x_norm = self.to_logits[0](x)
|
355 |
+
logits = self.to_logits[1](x_norm)
|
356 |
+
return logits, x_norm
|
357 |
|
358 |
class FLASHTransformerConfig(PretrainedConfig):
|
359 |
model_type = "flash_transformer"
|