import math from copy import deepcopy from functools import partial from typing import Callable, Optional, Sequence, Tuple, Union import array_typing as at import numpy as np import torch import torch.nn as nn from beartype import beartype as typechecker from jaxtyping import Float, jaxtyped from torch.distributions import Independent, Normal, TransformedDistribution from torch.distributions.transforms import ( AffineTransform, ComposeTransform, TanhTransform, ) from torch.optim import Adam, AdamW, Optimizer from torch.optim.lr_scheduler import ( LambdaLR, ) from transformers import ( AutoConfig, GemmaForCausalLM, PretrainedConfig, PreTrainedModel, ) from transformers.models.auto import CONFIG_MAPPING def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor: return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim) def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False): if isinstance(module, nn.Linear): if orthogonal_init: nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) else: nn.init.xavier_uniform_(module.weight, gain=1e-2) class VQHBackboneConfig(PretrainedConfig): model_type = "VQHBackbone" sub_configs = {"gemma_expert_config": AutoConfig} def __init__( self, gemma_expert_config: dict | None = None, attention_implementation: str = "eager", **kwargs, ): self.attention_implementation = attention_implementation if gemma_expert_config is None: self.gemma_expert_config = CONFIG_MAPPING["gemma"]( attention_bias=False, attention_dropout=0.0, bos_token_id=2, eos_token_id=1, head_dim=256, hidden_act="gelu_pytorch_tanh", hidden_activation="gelu_pytorch_tanh", hidden_size=2048, initializer_range=0.02, intermediate_size=4096, max_position_embeddings=8192, model_type="gemma", num_attention_heads=8, num_hidden_layers=4, num_key_value_heads=1, pad_token_id=0, rms_norm_eps=1e-06, rope_theta=10000.0, torch_dtype="float32", transformers_version="4.48.1", use_cache=True, vocab_size=257152, ) elif isinstance(gemma_expert_config, dict): if "model_type" not in gemma_expert_config: gemma_expert_config["model_type"] = "gemma" cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]] self.gemma_expert_config = cfg_cls(**gemma_expert_config) super().__init__(**kwargs) def __post_init__(self): super().__post_init__() if self.attention_implementation not in ["eager", "fa2", "flex"]: raise ValueError( f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." ) def apply_rope(x, positions, max_wavelength=10_000): """ Applies RoPE positions [B, L] to x [B, L, H, D]. """ d_half = x.shape[-1] // 2 device = x.device dtype = x.dtype x = x.to(torch.float32) freq_exponents = (2.0 / x.shape[-1]) * torch.arange( d_half, dtype=torch.float32, device=device ) timescale = max_wavelength**freq_exponents radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to( torch.float32 ) radians = radians[..., None, :] sin = torch.sin(radians) # .to(dtype=dtype) cos = torch.cos(radians) # .to(dtype=dtype) x1, x2 = x.split(d_half, dim=-1) res = torch.empty_like(x) res[..., :d_half] = x1 * cos - x2 * sin res[..., d_half:] = x2 * cos + x1 * sin return res.to(dtype) class VQHBackbone(PreTrainedModel): config_class = VQHBackboneConfig def __init__(self, config: VQHBackboneConfig): super().__init__(config=config) self.config = config self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) self.to_bfloat16_like_physical_intelligence() def train(self, mode: bool = True): super().train(mode) def to_bfloat16_like_physical_intelligence(self): params_to_change_dtype = [ "language_model.model.layers", "gemma_expert.model.layers", ] for name, param in self.named_parameters(): if any(selector in name for selector in params_to_change_dtype): param.data = param.data.to(dtype=torch.bfloat16) def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): # RMSNorm head_dim = self.gemma_expert.config.head_dim hidden_states = inputs_embeds batch_size = hidden_states.shape[0] for layer in self.gemma_expert.model.layers[ : self.gemma_expert.config.num_hidden_layers ]: # normalizer = torch.tensor(model.config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer hidden_states = layer.input_layernorm(hidden_states) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) # self attention hidden_states = hidden_states.to(dtype=torch.bfloat16) query_states = layer.self_attn.q_proj(hidden_states).view(hidden_shape) key_states = layer.self_attn.k_proj(hidden_states).view(hidden_shape) value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) query_states = apply_rope(query_states, position_ids) key_states = apply_rope(key_states, position_ids) attention_interface = self.get_attention_interface() att_output = attention_interface( attention_mask, batch_size, head_dim, query_states, key_states, value_states, ) if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output) # first residual out_emb += hidden_states after_first_residual = out_emb.clone() out_emb = layer.post_attention_layernorm(out_emb) out_emb = layer.mlp(out_emb) # second residual out_emb += after_first_residual hidden_states = out_emb # final norm hidden_states = self.gemma_expert.model.norm(hidden_states) return hidden_states def get_attention_interface(self): if self.config.attention_implementation == "fa2": attention_interface = self.flash_attention_forward else: attention_interface = self.eager_attention_forward return attention_interface def eager_attention_forward( self, attention_mask, batch_size, head_dim, query_states, key_states, value_states, ): num_att_heads = self.config.gemma_expert_config.num_attention_heads num_key_value_heads = self.config.gemma_expert_config.num_key_value_heads num_key_value_groups = num_att_heads // num_key_value_heads # query_states: batch_size, sequence_length, num_att_head, head_dim # key_states: batch_size, sequence_length, num_key_value_head, head_dim # value_states: batch_size, sequence_length, num_key_value_head, head_dim sequence_length = key_states.shape[1] key_states = key_states[:, :, :, None, :].expand( batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim, ) key_states = key_states.reshape( batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim, ) value_states = value_states[:, :, :, None, :].expand( batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim, ) value_states = value_states.reshape( batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim, ) # Attention here is upcasted to float32 to match the original eager implementation. query_states = query_states.to(dtype=torch.float32) key_states = key_states.to(dtype=torch.float32) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) att_weights *= head_dim**-0.5 big_neg = -2.3819763e38 # See gemma/modules.py masked_att_weights = torch.where( attention_mask[:, None, :, :], att_weights, big_neg ) probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = probs.to(dtype=value_states.dtype) # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length # value_states: batch_size, sequence_length, num_att_heads, head_dim att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) att_output = att_output.permute(0, 2, 1, 3) # we use -1 because sequence length can change att_output = att_output.reshape( batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim ) return att_output class LagrangeMultiplier(nn.Module): def __init__( self, init_value: float = 1.0, constraint_shape: Tuple[int, ...] = (), constraint_type: str = "eq", # One of ("eq", "leq", "geq") parameterization: Optional[ str ] = None, # One of ("softplus", "exp"), or None for equality constraints ): super().__init__() self.constraint_type = constraint_type self.parameterization = parameterization if constraint_type != "eq": assert ( init_value > 0 ), "Inequality constraints must have non-negative initial multiplier values" if parameterization == "softplus": init_value = torch.log(torch.exp(torch.tensor(init_value)) - 1).item() elif parameterization == "exp": init_value = torch.log(torch.tensor(init_value)).item() else: raise ValueError( f"Invalid multiplier parameterization {parameterization}" ) else: assert ( parameterization is None ), "Equality constraints must have no parameterization" self.multiplier = nn.Parameter(torch.full(constraint_shape, init_value)) def forward( self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None ) -> torch.Tensor: multiplier = self.multiplier if self.constraint_type != "eq": if self.parameterization == "softplus": multiplier = torch.nn.functional.softplus(multiplier) elif self.parameterization == "exp": multiplier = torch.exp(multiplier) else: raise ValueError( f"Invalid multiplier parameterization {self.parameterization}" ) if lhs is None: return multiplier if rhs is None: rhs = torch.zeros_like(lhs) diff = lhs - rhs assert ( diff.shape == multiplier.shape ), f"Shape mismatch: {diff.shape} vs {multiplier.shape}" if self.constraint_type == "eq": return multiplier * diff elif self.constraint_type == "geq": return multiplier * diff elif self.constraint_type == "leq": return -multiplier * diff GeqLagrangeMultiplier = partial( LagrangeMultiplier, constraint_type="geq", parameterization="softplus" ) LeqLagrangeMultiplier = partial( LagrangeMultiplier, constraint_type="leq", parameterization="softplus" ) class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dims: Sequence[int], activations: Union[Callable[[torch.Tensor], torch.Tensor], str] = "silu", activate_final: bool = False, use_layer_norm: bool = False, use_group_norm: bool = False, dropout_rate: Optional[float] = None, ): super().__init__() assert not (use_layer_norm and use_group_norm) self.activate_final = activate_final self.dropout_rate = dropout_rate self.input_dim = input_dim self.hidden_dims = hidden_dims if isinstance(activations, str): if activations == "silu" or activations == "swish": self.activations = nn.SiLU() else: self.activations = getattr(nn, activations)() else: self.activations = activations layers = [] for i, hidden_dim in enumerate(hidden_dims): layers.append(nn.Linear(input_dim, hidden_dim)) nn.init.xavier_uniform_(layers[-1].weight) nn.init.zeros_(layers[-1].bias) input_dim = hidden_dim if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) if use_layer_norm: layers.append(nn.LayerNorm(hidden_dim)) elif use_group_norm: num_groups = min(hidden_dim, 32) layers.append(nn.GroupNorm(num_groups, hidden_dim)) layers.append(self.activations) self.layers = nn.ModuleList(layers) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = layer(x) return x class TanhMultivariateNormalDiag(TransformedDistribution): def __init__( self, loc: torch.Tensor, scale_diag: torch.Tensor, low: Optional[torch.Tensor] = None, high: Optional[torch.Tensor] = None, ): self.loc = loc self.scale_diag = scale_diag base_distribution = Independent(Normal(loc, scale_diag), 1) transforms = [] transforms.append(TanhTransform()) if not (low is None or high is None): transforms.append( AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2) ) transform = ComposeTransform(transforms) super().__init__(base_distribution, transform) def mode(self) -> torch.Tensor: """返回分布的众数""" # 对于正态分布,众数就是均值 mode = self.loc # 应用变换 for transform in self.transforms: mode = transform(mode) return mode def stddev(self) -> torch.Tensor: """返回变换后的标准差(近似值)""" # 注意:这只是一个近似,因为非线性变换后的标准差计算复杂 return self.transform(self.loc + self.scale_diag) - self.transform(self.loc) def log_prob(self, value: torch.Tensor) -> torch.Tensor: eps = 1e-6 value = torch.clamp(value, -1 + eps, 1 - eps) return super().log_prob(value) class Policy(nn.Module): def __init__( self, obs_encoded_dim: int, network: nn.Module, action_dim: int, std_parameterization: str = "exp", # "exp", "softplus", "fixed", or "uniform" std_min: Optional[float] = 1e-5, std_max: Optional[float] = 10.0, tanh_squash_distribution: bool = False, fixed_std: Optional[torch.Tensor] = None, ): super().__init__() self.obs_encoded_dim = obs_encoded_dim self.network = network self.action_dim = action_dim self.std_parameterization = std_parameterization self.std_min = std_min self.std_max = std_max self.tanh_squash_distribution = tanh_squash_distribution self.fixed_std = fixed_std self.mean_layer = nn.Linear(network.hidden_dims[-1], action_dim) if fixed_std is None: if std_parameterization in ["exp", "softplus"]: self.std_layer = nn.Linear(network.hidden_dims[-1], action_dim) elif std_parameterization == "uniform": self.log_stds = nn.Parameter(torch.zeros(action_dim)) else: raise ValueError( f"Invalid std_parameterization: {self.std_parameterization}" ) else: assert std_parameterization == "fixed" nn.init.xavier_uniform_(self.mean_layer.weight) nn.init.zeros_(self.mean_layer.bias) if fixed_std is None and std_parameterization in ["exp", "softplus"]: nn.init.xavier_uniform_(self.std_layer.weight) nn.init.zeros_(self.std_layer.bias) def forward( self, encoded_observations: torch.Tensor, temperature: float = 1.0 ) -> Union[TransformedDistribution, Normal]: outputs = self.network(encoded_observations) means = self.mean_layer(outputs) if self.fixed_std is None: if self.std_parameterization == "exp": log_stds = self.std_layer(outputs) stds = torch.exp(log_stds) elif self.std_parameterization == "softplus": stds = self.std_layer(outputs) stds = nn.functional.softplus(stds) elif self.std_parameterization == "uniform": stds = torch.exp(self.log_stds).expand_as(means) else: raise ValueError( f"Invalid std_parameterization: {self.std_parameterization}" ) else: stds = self.fixed_std.to(means.device).expand_as(means) stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt( torch.tensor(temperature) ) if self.tanh_squash_distribution: distribution = TanhMultivariateNormalDiag( loc=means, scale_diag=stds, ) else: distribution = Normal(loc=means, scale=stds) return distribution class Critics(nn.Module): def __init__( self, obs_encoded_dim: int, networks: list[nn.Module], num_backbones: int = 2, init_final: Optional[float] = None, ): super().__init__() assert len(networks) == num_backbones self.obs_encoded_dim = obs_encoded_dim self.networks = nn.ModuleList(networks) self.num_backbones = num_backbones self.init_final = init_final self.backbone_output_dims = networks[0].hidden_dims[-1] if init_final is not None: self.output_layer = nn.Linear(self.backbone_output_dims, 1) nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: self.output_layer = nn.Linear(self.backbone_output_dims, 1) nn.init.xavier_uniform_(self.output_layer.weight) nn.init.zeros_(self.output_layer.bias) @jaxtyped(typechecker=typechecker) def forward( self, encoded_observations: Float[torch.Tensor, "batch {self.obs_encoded_dim}"], actions: Float[torch.Tensor, "batch *num_actions action_dim"], ) -> Float[torch.Tensor, "{self.num_backbones} batch *num_actions"]: if actions.ndim == 3: # forward the q function with multiple actions on each state encoded_observations = encoded_observations.unsqueeze(1).expand( -1, actions.shape[1], -1 ) # HACK: check dimensions here inputs = torch.cat([encoded_observations, actions], dim=-1) backbone_outputs = [] for network in self.networks: backbone_outputs.append(network(inputs)) backbone_outputs: Float[ torch.Tensor, "{self.num_backbones} batch *num_actions {self.backbone_output_dims}", ] = torch.stack(backbone_outputs, dim=0) value = self.output_layer(backbone_outputs) # HACK: check output shape here # if actions.ndim == 3: # value = value.squeeze(-1).permute(0, 2, 1) # else: value = value.squeeze(-1) return value # (num_backbones, batch, *num_actions) class CalQlConfig(PretrainedConfig): moedel_type = "calql" def __init__( self, obs_encoded_dim=2048, action_dim=70, actor_lr=1e-4, critic_lr=3e-4, temp_lr=3e-4, actor_wps=2000, critic_wps=2000, **kwargs, ): self.cql_clip_diff_min = -np.inf self.cql_clip_diff_max = np.inf self.cql_alpha = 5.0 self.cql_autotune_alpha = False self.action_dim = action_dim self.target_entropy = -self.action_dim self.obs_encoded_dim = obs_encoded_dim self.cql_temperature_init_value = 1.0 self.critic_ensemble_size = 2 self.cql_n_actions = 4 self.cql_max_target_backup = True self.policy_network_kwargs = dict( input_dim=self.obs_encoded_dim, hidden_dims=[256, 256], activate_final=True, use_layer_norm=False, ) self.critic_network_kwargs = dict( input_dim=self.obs_encoded_dim + self.action_dim, hidden_dims=[256, 256], activate_final=True, use_layer_norm=False, ) self.policy_kwargs = dict( tanh_squash_distribution=True, std_parameterization="exp", ) self.critic_subsample_size = None self.cql_max_target_backup = True self.backup_entropy = False self.discount = 0.98 self.goal_conditioned = True self.gc_kwargs = dict( negative_proportion=0.0, ) self.use_td_loss = True self.cql_action_sample_method = "uniform" self.cql_importance_sample = True self.cql_temp = 1.0 self.use_calql = True self.actor_optimizer_kwargs = dict( learning_rate=actor_lr, warmup_steps=actor_wps, ) self.critic_optimizer_kwargs = dict( learning_rate=critic_lr, warmup_steps=critic_wps, ) self.temperature_optimizer_kwargs = dict(learning_rate=temp_lr) super().__init__(**kwargs) class CalQL(PreTrainedModel): config_calss = CalQlConfig def __init__(self, config: CalQlConfig): super(CalQL, self).__init__(config=config) self.config = config self.temperature = GeqLagrangeMultiplier( init_value=self.config.cql_temperature_init_value, constraint_shape=(), ) self.policy = Policy( obs_encoded_dim=self.config.obs_encoded_dim, network=MLP(**self.config.policy_network_kwargs), action_dim=self.config.action_dim, **self.config.policy_kwargs, ) self.critics = Critics( obs_encoded_dim=self.config.obs_encoded_dim, networks=[ MLP(**self.config.critic_network_kwargs) for _ in range(self.config.critic_ensemble_size) ], num_backbones=self.config.critic_ensemble_size, ) self.target_critics = deepcopy(self.critics) def forward_policy_and_sample( self, encoded_obs: Float[torch.Tensor, "batch {self.config.obs_encoded_dim}"], repeat: int = None, ): action_dist = self.policy.forward(encoded_obs) if repeat: new_actions = action_dist.rsample( torch.tensor([repeat]) ) # repeat, tensor, act_dim log_pi = action_dist.log_prob(new_actions) new_actions = new_actions.permute(1, 0, 2) # (batch, repeat, action_dim) log_pi = log_pi.permute(1, 0) # (batch, repeat) else: new_actions = action_dist.rsample() # (batch, action_dim) log_pi = action_dist.log_prob(new_actions) # (batch) # NOTE: detach gradient here new_actions = new_actions.detach() log_pi = log_pi.detach() return new_actions, log_pi def _compute_next_actions(self, batch: at.CalQlBatch): """ compute the next actions but with repeat cql_n_actions times this should only be used when calculating critic loss using cql_max_target_backup """ sample_n_actions = ( self.config.cql_n_actions if self.config.cql_max_target_backup else None ) next_actions, next_actions_log_probs = self.forward_policy_and_sample( batch["encoded_next_observations"], repeat=sample_n_actions, ) return next_actions, next_actions_log_probs def temperature_loss_fn(self, batch: at.CalQlBatch): next_actions, next_actions_log_probs = self._compute_next_actions(batch) entropy = -next_actions_log_probs.mean() temperature_loss = self.temperature.forward( lhs=entropy, rhs=self.config.target_entropy, ) return temperature_loss, {"temperature_loss": temperature_loss} def policy_loss_fn(self, batch: at.CalQlBatch): batch_size = batch["rewards"].shape[0] temperature = self.temperature.forward().detach() # detach gradient action_distributions = self.policy.forward(batch["encoded_observations"]) actions = action_distributions.rsample() log_probs = action_distributions.log_prob(actions) predicted_qs = self.critics.forward( batch["encoded_observations"], actions, ).detach() # NOTE: detach grads predicted_q = predicted_qs.min(dim=0)[0] assert predicted_q.shape == (batch_size,) assert log_probs.shape == (batch_size,) nll_objective = -torch.mean( action_distributions.log_prob(torch.clip(batch["actions"], -0.99, 0.99)) ) actor_objective = predicted_q actor_loss = -torch.mean(actor_objective) + torch.mean(temperature * log_probs) info = { "actor_loss": actor_loss, "actor_nll": nll_objective, "temperature": temperature, "entropy": -log_probs.mean(), "log_probs": log_probs, "actions_mse": ((actions - batch["actions"]) ** 2).sum(dim=-1).mean(), "dataset_rewards": batch["rewards"], "mc_returns": batch.get("mc_returns", None), } return actor_loss, info def sac_critic_loss_fn(self, batch: at.CalQlBatch): """classes that inherit this class can change this function""" batch_size = batch["rewards"].shape[0] next_actions, next_actions_log_probs = self._compute_next_actions(batch) # (batch_size, ) for sac, (batch_size, cql_n_actions) for cql # Evaluate next Qs for all ensemble members (cheap because we're only doing the forward pass) with torch.no_grad(): self.target_critics.eval() target_next_qs = self.target_critics.forward( batch["encoded_next_observations"], next_actions, ) # (critic_ensemble_size, batch_size, cql_n_actions) self.target_critics.train() # Subsample if requested if self.config.critic_subsample_size is not None: subsample_idcs = torch.randint( 0, self.config.critic_ensemble_size, (self.config.critic_ensemble_size,), device=target_next_qs.device, ) target_next_qs = target_next_qs[subsample_idcs] # Minimum Q across (subsampled) ensemble members target_next_min_q = target_next_qs.min(dim=0)[0] assert target_next_min_q.shape == next_actions_log_probs.shape # (batch_size,) for sac, (batch_size, cql_n_actions) for cql target_next_min_q = self._process_target_next_qs( target_next_min_q, next_actions_log_probs, ) target_q = ( batch["rewards"] + self.config.discount * batch["masks"] * target_next_min_q ) assert target_q.shape == (batch_size,) predicted_qs = self.critics.forward( batch["encoded_observations"], batch["actions"] ) assert predicted_qs.shape == (self.config.critic_ensemble_size, batch_size) target_qs = target_q.unsqueeze(0).expand(self.config.critic_ensemble_size, -1) assert predicted_qs.shape == target_qs.shape critic_loss = torch.mean((predicted_qs - target_qs) ** 2) info = { "td_err": critic_loss, "online_q": torch.mean(predicted_qs), "target_q": torch.mean(target_qs), } if self.config.goal_conditioned: num_negatives = int( self.config.gc_kwargs["negative_proportion"] * batch_size ) info["negative_qs"] = torch.mean(predicted_qs, dim=-1)[ :num_negatives ].mean() info["positive_qs"] = torch.mean(predicted_qs, dim=-1)[ num_negatives: ].mean() return critic_loss, info def _process_target_next_qs(self, target_next_qs, next_actions_log_probs): """add cql_max_target_backup option""" if self.config.cql_max_target_backup: max_target_indices = torch.argmax(target_next_qs, dim=-1, keepdim=True) target_next_qs = torch.gather( target_next_qs, -1, max_target_indices ).squeeze(-1) next_actions_log_probs = torch.gather( next_actions_log_probs, -1, max_target_indices ).squeeze(-1) target_next_qs = self.sac_process_target_next_qs( target_next_qs, next_actions_log_probs, ) return target_next_qs def sac_process_target_next_qs(self, target_next_qs, next_actions_log_probs): """classes that inherit this class can add to this function e.g. CQL will add the cql_max_target_backup option """ if self.config.backup_entropy: temperature = self.forward_temperature() target_next_qs = target_next_qs - temperature * next_actions_log_probs return target_next_qs def critic_loss_fn(self, batch: at.CalQlBatch): """add CQL loss on top of SAC loss""" if self.config.use_td_loss: td_loss, td_loss_info = self.sac_critic_loss_fn(batch) else: td_loss, td_loss_info = 0.0, {} cql_q_diff, cql_intermediate_results = self._get_cql_q_diff(batch) """auto tune cql alpha""" if self.config.cql_autotune_alpha: raise NotImplementedError # alpha = self.forward_cql_alpha_lagrange() # cql_loss = (cql_q_diff - self.config["cql_target_action_gap"]).mean() else: alpha = self.config.cql_alpha cql_loss = torch.clip( cql_q_diff, self.config.cql_clip_diff_min, self.config.cql_clip_diff_max ).mean() critic_loss = td_loss + alpha * cql_loss info = { **td_loss_info, "critic_loss": critic_loss, "td_err": td_loss, "cql_loss": cql_loss, "cql_alpha": alpha, "cql_diff": cql_q_diff.mean(), **cql_intermediate_results, } return critic_loss, info def _get_cql_q_diff(self, batch: at.CalQlBatch): """ most of the CQL loss logic is here It is needed for both critic_loss_fn and cql_alpha_loss_fn """ batch_size = batch["rewards"].shape[0] q_pred = self.critics.forward(batch["encoded_observations"], batch["actions"]) # HACK: shape changed from jax implementation assert q_pred.shape == (self.config.critic_ensemble_size, batch_size) """sample random actions""" action_dim = batch["actions"].shape[-1] if self.config.cql_action_sample_method == "uniform": cql_random_actions = ( torch.rand( (batch_size, self.config.cql_n_actions, action_dim), device=batch["actions"].device, ) * 2.0 - 1.0 ) elif self.config.cql_action_sample_method == "normal": cql_random_actions = torch.randn( (batch_size, self.config.cql_n_actions, action_dim), device=batch["actions"].device, ) else: raise NotImplementedError cql_current_actions, cql_current_log_pis = self.forward_policy_and_sample( batch["encoded_observations"], repeat=self.config.cql_n_actions, ) assert cql_current_log_pis.shape == (batch_size, self.config.cql_n_actions) cql_next_actions, cql_next_log_pis = self.forward_policy_and_sample( batch["encoded_next_observations"], repeat=self.config.cql_n_actions, ) all_sampled_actions = torch.cat( [ cql_random_actions, cql_current_actions, cql_next_actions, ], dim=1, ) """q values of randomly sampled actions""" cql_q_samples = self.critics.forward( batch["encoded_observations"], all_sampled_actions ) # HACK: shape changed from jax implementation assert cql_q_samples.shape == ( self.config.critic_ensemble_size, batch_size, self.config.cql_n_actions * 3, ) if self.config.critic_subsample_size is not None: subsample_idcs = torch.randint( 0, self.config.critic_ensemble_size, (self.config.critic_ensemble_size,), device=cql_q_samples.device, ) cql_q_samples = cql_q_samples[subsample_idcs] """Cal-QL""" if self.config.use_calql: # HACK: check shape of mc_returns mc_lower_bound = ( batch["mc_returns"] .reshape(-1, 1) .repeat(1, self.config.cql_n_actions * 2) ) assert mc_lower_bound.shape == ( batch_size, self.config.cql_n_actions * 2, ) cql_q_pi = cql_q_samples[:, :, self.config.cql_n_actions :] num_vals = cql_q_pi.numel() calql_bound_rate = torch.sum((cql_q_pi < mc_lower_bound).float()) / num_vals cql_q_pi = torch.maximum(cql_q_pi, mc_lower_bound) cql_q_samples = torch.cat( [ cql_q_samples[:, :, : self.config.cql_n_actions], cql_q_pi, ], dim=-1, ) if self.config.cql_importance_sample: random_density = torch.log( torch.tensor(0.5**action_dim, device=cql_q_samples.device) ) importance_prob = torch.cat( [ random_density.expand(batch_size, self.config.cql_n_actions), cql_current_log_pis, cql_next_log_pis, ], dim=1, ) # HACK: check dim cql_q_samples = cql_q_samples - importance_prob.unsqueeze(0) else: cql_q_samples = torch.cat([cql_q_samples, q_pred.unsqueeze(-1)], dim=-1) cql_q_samples -= ( torch.log( torch.tensor( cql_q_samples.shape[-1], dtype=torch.float, device=cql_q_samples.device, ) ) * self.config.cql_temp ) # HACK: shape diff from jax implementation assert cql_q_samples.shape == ( self.config.critic_ensemble_size, batch_size, 3 * self.config.cql_n_actions + 1, ) """log sum exp of the ood actions""" cql_ood_values = ( torch.logsumexp(cql_q_samples / self.config.cql_temp, dim=-1) * self.config.cql_temp ) assert cql_ood_values.shape == (self.config.critic_ensemble_size, batch_size) cql_q_diff = cql_ood_values - q_pred info = { "cql_ood_values": cql_ood_values.mean(), } if self.config.use_calql: info["calql_bound_rate"] = calql_bound_rate return cql_q_diff, info @staticmethod def make_optimizer( params: torch.nn.Module, learning_rate: float = 3e-4, warmup_steps: int = 0, cosine_decay_steps: Optional[int] = None, weight_decay: Optional[float] = None, return_lr_schedule: bool = True, ) -> Union[Optimizer, Tuple[Optimizer, LambdaLR]]: optimizer: Optimizer if weight_decay is not None: optimizer = AdamW( params=params, lr=learning_rate, weight_decay=weight_decay, ) else: optimizer = Adam(params=params, lr=learning_rate) def _lr_lambda(step: int) -> float: if warmup_steps > 0 and step < warmup_steps: return step / warmup_steps if cosine_decay_steps is not None: decay_step = step - warmup_steps if decay_step < 0: return 0.0 if decay_step >= cosine_decay_steps: return 0.0 progress = decay_step / cosine_decay_steps return 0.5 * (1.0 + math.cos(math.pi * progress)) return 1.0 scheduler = LambdaLR(optimizer, lr_lambda=_lr_lambda) if return_lr_schedule: return optimizer, scheduler else: return optimizer def prepare_optimizers(self): actor_optimizer, actor_scheduler = self.make_optimizer( self.policy.parameters(), **self.config.actor_optimizer_kwargs ) critic_optimizer, critic_scheduler = self.make_optimizer( self.critics.parameters(), **self.config.critic_optimizer_kwargs ) temperature_optimizer, temperature_scheduler = self.make_optimizer( self.temperature.parameters(), **self.config.temperature_optimizer_kwargs ) return ( actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, temperature_optimizer, temperature_scheduler, ) def forward(self, batch: at.CalQlBatch): temperature_loss, temperature_loss_info = self.temperature_loss_fn(batch) policy_loss, policy_loss_info = self.policy_loss_fn(batch) critic_loss, critic_loss_info = self.critic_loss_fn(batch) return ( temperature_loss, policy_loss, critic_loss, { **temperature_loss_info, **policy_loss_info, **critic_loss_info, }, ) @jaxtyped(typechecker=typechecker) def get_q_values( self, encoded_observations: Float[ torch.Tensor, "batch {self.config.obs_encoded_dim}" ], noise_actions: Float[torch.Tensor, "batch num_actions action_dim"], ) -> Float[torch.Tensor, "batch num_actions"]: # (num_backbones, batch, *num_actions) q_values = self.target_critics.forward(encoded_observations, noise_actions) q_values = q_values.min(dim=0)[0] return q_values