moojink commited on
Commit
31f090d
1 Parent(s): e5822cc

Update modeling_prismatic.py to match public GitHub repo

Browse files

Check input_ids before adding the special empty token ('') rather
than adding it unconditionally.

Files changed (1) hide show
  1. modeling_prismatic.py +12 -11
modeling_prismatic.py CHANGED
@@ -504,14 +504,15 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
504
  self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
505
 
506
  def predict_action(
507
- self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs
508
  ) -> np.ndarray:
509
  """Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
510
- # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
511
- # in order for the predictions to match the training configuration and be accurate.
512
- input_ids = torch.cat(
513
- (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
514
- )
 
515
 
516
  # Run VLA inference
517
  generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
@@ -535,7 +536,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
535
  return actions
536
 
537
  @staticmethod
538
- def _check_unnorm_key(norm_stats, unnorm_key):
539
  if unnorm_key is None:
540
  assert len(norm_stats) == 1, (
541
  f"Your model was trained on more than one dataset, "
@@ -550,12 +551,12 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
550
  )
551
  return unnorm_key
552
 
553
- def get_action_dim(self, unnorm_key=None):
554
- """Dimensionality of the policy's action space."""
555
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
556
  return len(self.norm_stats[unnorm_key]["action"]["q01"])
557
 
558
- def get_action_stats(self, unnorm_key=None):
559
- """Dimensionality of the policy's action space."""
560
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
561
  return self.norm_stats[unnorm_key]["action"]
 
504
  self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
505
 
506
  def predict_action(
507
+ self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
508
  ) -> np.ndarray:
509
  """Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
510
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
511
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
512
+ if not torch.all(input_ids[:, -1] == 29871):
513
+ input_ids = torch.cat(
514
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
515
+ )
516
 
517
  # Run VLA inference
518
  generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
 
536
  return actions
537
 
538
  @staticmethod
539
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
540
  if unnorm_key is None:
541
  assert len(norm_stats) == 1, (
542
  f"Your model was trained on more than one dataset, "
 
551
  )
552
  return unnorm_key
553
 
554
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
555
+ """Get the dimensionality of the policy's action space."""
556
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
557
  return len(self.norm_stats[unnorm_key]["action"]["q01"])
558
 
559
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
560
+ """Get all the logged statistics for the given dataset."""
561
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
562
  return self.norm_stats[unnorm_key]["action"]