Update modeling_prismatic.py to match public GitHub repo
Browse filesCheck input_ids before adding the special empty token ('') rather
than adding it unconditionally.
- modeling_prismatic.py +12 -11
modeling_prismatic.py
CHANGED
@@ -504,14 +504,15 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
504 |
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
505 |
|
506 |
def predict_action(
|
507 |
-
self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs
|
508 |
) -> np.ndarray:
|
509 |
"""Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
|
510 |
-
#
|
511 |
-
#
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
515 |
|
516 |
# Run VLA inference
|
517 |
generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
@@ -535,7 +536,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
535 |
return actions
|
536 |
|
537 |
@staticmethod
|
538 |
-
def _check_unnorm_key(norm_stats, unnorm_key):
|
539 |
if unnorm_key is None:
|
540 |
assert len(norm_stats) == 1, (
|
541 |
f"Your model was trained on more than one dataset, "
|
@@ -550,12 +551,12 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
550 |
)
|
551 |
return unnorm_key
|
552 |
|
553 |
-
def get_action_dim(self, unnorm_key=None):
|
554 |
-
"""
|
555 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
556 |
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
557 |
|
558 |
-
def get_action_stats(self, unnorm_key=None):
|
559 |
-
"""
|
560 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
561 |
return self.norm_stats[unnorm_key]["action"]
|
|
|
504 |
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
505 |
|
506 |
def predict_action(
|
507 |
+
self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
|
508 |
) -> np.ndarray:
|
509 |
"""Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
|
510 |
+
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
511 |
+
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
512 |
+
if not torch.all(input_ids[:, -1] == 29871):
|
513 |
+
input_ids = torch.cat(
|
514 |
+
(input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
515 |
+
)
|
516 |
|
517 |
# Run VLA inference
|
518 |
generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
|
|
536 |
return actions
|
537 |
|
538 |
@staticmethod
|
539 |
+
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
540 |
if unnorm_key is None:
|
541 |
assert len(norm_stats) == 1, (
|
542 |
f"Your model was trained on more than one dataset, "
|
|
|
551 |
)
|
552 |
return unnorm_key
|
553 |
|
554 |
+
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
555 |
+
"""Get the dimensionality of the policy's action space."""
|
556 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
557 |
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
558 |
|
559 |
+
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
560 |
+
"""Get all the logged statistics for the given dataset."""
|
561 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
562 |
return self.norm_stats[unnorm_key]["action"]
|