wangleiofficial commited on
Commit
f93d8f2
·
verified ·
1 Parent(s): ef5bad5
Files changed (1) hide show
  1. dnaflash.py +3 -2
dnaflash.py CHANGED
@@ -351,8 +351,9 @@ class FLASHTransformer(nn.Module):
351
 
352
  for flash in self.layers:
353
  x = flash(x, mask = mask)
354
-
355
- return self.to_logits(x), x
 
356
 
357
  class FLASHTransformerConfig(PretrainedConfig):
358
  model_type = "flash_transformer"
 
351
 
352
  for flash in self.layers:
353
  x = flash(x, mask = mask)
354
+ x_norm = self.to_logits[0](x)
355
+ logits = self.to_logits[1](x_norm)
356
+ return logits, x_norm
357
 
358
  class FLASHTransformerConfig(PretrainedConfig):
359
  model_type = "flash_transformer"