Matt
commited on
Commit
·
b5c7055
1
Parent(s):
f285b2c
Tie weights correctly
Browse files- modeling_florence2.py +8 -2
modeling_florence2.py
CHANGED
|
@@ -2066,6 +2066,12 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2066 |
# Initialize weights and apply final processing
|
| 2067 |
self.post_init()
|
| 2068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2069 |
def get_encoder(self):
|
| 2070 |
return self.model.get_encoder()
|
| 2071 |
|
|
@@ -2523,6 +2529,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2523 |
FLORENCE2_START_DOCSTRING,
|
| 2524 |
)
|
| 2525 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
|
|
|
|
|
| 2526 |
def __init__(self, config: Florence2Config):
|
| 2527 |
super().__init__(config)
|
| 2528 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
@@ -2537,8 +2545,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2537 |
|
| 2538 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
| 2539 |
|
| 2540 |
-
if language_model._tied_weights_keys is not None:
|
| 2541 |
-
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
| 2542 |
self.language_model = language_model
|
| 2543 |
|
| 2544 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
|
|
| 2066 |
# Initialize weights and apply final processing
|
| 2067 |
self.post_init()
|
| 2068 |
|
| 2069 |
+
def _tie_weights(self):
|
| 2070 |
+
if self.config.tie_word_embeddings:
|
| 2071 |
+
self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
|
| 2072 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
|
| 2073 |
+
self._tie_or_clone_weights(self.lm_head, self.model.shared)
|
| 2074 |
+
|
| 2075 |
def get_encoder(self):
|
| 2076 |
return self.model.get_encoder()
|
| 2077 |
|
|
|
|
| 2529 |
FLORENCE2_START_DOCSTRING,
|
| 2530 |
)
|
| 2531 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
| 2532 |
+
_tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
|
| 2533 |
+
|
| 2534 |
def __init__(self, config: Florence2Config):
|
| 2535 |
super().__init__(config)
|
| 2536 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
|
|
| 2545 |
|
| 2546 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
| 2547 |
|
|
|
|
|
|
|
| 2548 |
self.language_model = language_model
|
| 2549 |
|
| 2550 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|