|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, HybridCache, StaticCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.integrations import use_kernel_forward_from_hub |
|
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
|
from transformers.utils.generic import check_model_inputs |
|
from .configuration_exaone4 import Exaone4Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@use_kernel_forward_from_hub("RMSNorm") |
|
class Exaone4RMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
Exaone4RMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
def extra_repr(self): |
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
class Exaone4RotaryEmbedding(nn.Module): |
|
def __init__(self, config: Exaone4Config, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs: Unpack[TransformersKwargs], |
|
): |
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
def check_is_sliding(config, layer_idx): |
|
""" |
|
Check if the current layer is a sliding window attention (local attention) layer. |
|
""" |
|
if config.sliding_window is None: |
|
return False |
|
if config.layer_types is not None: |
|
return config.layer_types[layer_idx] == "sliding_attention" |
|
if isinstance(config.sliding_window_pattern, int): |
|
return ((layer_idx + 1) % config.sliding_window_pattern) != 0 |
|
elif isinstance(config.sliding_window_pattern, str): |
|
assert isinstance(config.sliding_window, int), ( |
|
f"Sliding window must be positive integer, but got {config.sliding_window}" |
|
) |
|
return ( |
|
layer_idx != config.num_hidden_layers - 1 |
|
and config.sliding_window_pattern[layer_idx % len(config.sliding_window_pattern)] == "L" |
|
) |
|
else: |
|
logger.warning_once( |
|
"Sliding window is set, but none of `sliding_window_pattern` or `layer_types` is set. " |
|
"Defaulting to use 'full_attention' for all layers." |
|
) |
|
return False |
|
|
|
|
|
class Exaone4Attention(nn.Module): |
|
def __init__(self, config: Exaone4Config, layer_idx: int): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.num_attention_heads = config.num_attention_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.hidden_size = config.hidden_size |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
self.attention_dropout = config.attention_dropout |
|
self.is_causal = True |
|
self.scaling = self.head_dim**-0.5 |
|
self.sliding_window = config.sliding_window |
|
self.sliding_window_pattern = config.sliding_window_pattern |
|
self.is_sliding = check_is_sliding(config, layer_idx) |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
query_states = self.q_norm(query_states) |
|
key_states = self.k_norm(key_states) |
|
|
|
cos, sin = position_embeddings |
|
|
|
if self.sliding_window is None or self.is_sliding: |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = { |
|
"sin": sin, |
|
"cos": cos, |
|
"cache_position": cache_position, |
|
"sliding_window": self.sliding_window, |
|
} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
seq_len = cache_position[-1] + 1 if attention_mask is None else attention_mask.shape[1] |
|
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] |
|
|
|
attention_interface: Callable = eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
sliding_window=self.sliding_window if self.is_sliding else None, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class Exaone4MLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
|
|
class Exaone4DecoderLayer(nn.Module): |
|
def __init__(self, config: Exaone4Config, layer_idx: int): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.attention_type = config.layer_types[layer_idx] |
|
self.hidden_size = config.hidden_size |
|
|
|
self.self_attn = Exaone4Attention(config, layer_idx) |
|
self.mlp = Exaone4MLP(config) |
|
|
|
self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
self.is_sliding = check_is_sliding(config, layer_idx) |
|
self.sliding_window = config.sliding_window |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
|
residual = hidden_states |
|
|
|
|
|
hidden_states, _ = self.self_attn( |
|
hidden_states=hidden_states, |
|
position_embeddings=position_embeddings, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
hidden_states = self.post_feedforward_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
@auto_docstring |
|
class Exaone4PreTrainedModel(PreTrainedModel): |
|
config_class = Exaone4Config |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["Exaone4DecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_flash_attn_3 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
_supports_attention_backend = True |
|
_can_record_outputs = { |
|
"hidden_states": Exaone4DecoderLayer, |
|
"attentions": Exaone4Attention, |
|
} |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, Exaone4RMSNorm): |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
@auto_docstring |
|
class Exaone4Model(Exaone4PreTrainedModel): |
|
def __init__(self, config: Exaone4Config): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Exaone4RotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@check_model_inputs |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> Union[tuple, BaseModelOutputWithPast]: |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if use_cache and past_key_values is None and not self.training: |
|
batch_size, seq_len, _ = inputs_embeds.shape |
|
|
|
if self.config.sliding_window is None: |
|
past_key_values = StaticCache( |
|
self.config, |
|
max_batch_size=batch_size, |
|
max_cache_len=seq_len, |
|
dtype=inputs_embeds.dtype, |
|
device=self.device, |
|
) |
|
else: |
|
past_key_values = HybridCache( |
|
self.config, |
|
max_batch_size=batch_size, |
|
max_cache_len=seq_len, |
|
dtype=inputs_embeds.dtype, |
|
device=self.device, |
|
) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
if not isinstance(causal_mask_mapping := attention_mask, dict): |
|
|
|
mask_kwargs = { |
|
"config": self.config, |
|
"input_embeds": inputs_embeds, |
|
"attention_mask": attention_mask, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"position_ids": position_ids, |
|
} |
|
|
|
causal_mask_mapping = { |
|
"full_attention": create_causal_mask(**mask_kwargs), |
|
} |
|
if self.config.sliding_window is not None: |
|
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
|
hidden_states = decoder_layer( |
|
hidden_states, |
|
position_embeddings=position_embeddings, |
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values if use_cache else None, |
|
) |
|
|
|
|
|
@auto_docstring |
|
class Exaone4ForCausalLM(Exaone4PreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = Exaone4Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> CausalLMOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer |
|
>>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct") |
|
|
|
>>> prompt = "Explain how wonderful you are" |
|
>>> messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
>>> input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
enable_thinking=False, |
|
) |
|
|
|
>>> output = model.generate(input_ids, max_new_tokens=128) |
|
>>> tokenizer.decode(output[0], skip_special_tokens=False) |
|
"[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out" |
|
``` |
|
|
|
NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future.""" |
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@auto_docstring( |
|
custom_intro=""" |
|
The Exaone4 Model transformer with a sequence classification head on top (linear layer). |
|
|
|
[`Exaone4ForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
|
(e.g. GPT-2) do. |
|
|
|
Since it does classification on the last token, it requires to know the position of the last token. If a |
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
|
each row of the batch). |
|
""" |
|
) |
|
class Exaone4ForSequenceClassification(Exaone4PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = Exaone4Model(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> SequenceClassifierOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
transformer_outputs: BaseModelOutputWithPast = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
hidden_states = transformer_outputs.last_hidden_state |
|
logits = self.score(hidden_states) |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
if self.config.pad_token_id is None: |
|
last_non_pad_token = -1 |
|
elif input_ids is not None: |
|
|
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) |
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) |
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) |
|
else: |
|
last_non_pad_token = -1 |
|
logger.warning_once( |
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
|
) |
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) |
|
|
|
return SequenceClassifierOutputWithPast( |
|
loss=loss, |
|
logits=pooled_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|
|
|
|
@auto_docstring |
|
class Exaone4ForTokenClassification(Exaone4PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = Exaone4Model(config) |
|
if getattr(config, "classifier_dropout", None) is not None: |
|
classifier_dropout = config.classifier_dropout |
|
elif getattr(config, "hidden_dropout", None) is not None: |
|
classifier_dropout = config.hidden_dropout |
|
else: |
|
classifier_dropout = 0.1 |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
**kwargs, |
|
) -> TokenClassifierOutput: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
sequence_output = outputs.last_hidden_state |
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.score(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.config) |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@auto_docstring |
|
class Exaone4ForQuestionAnswering(Exaone4PreTrainedModel): |
|
base_model_prefix = "transformer" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer = Exaone4Model(config) |
|
self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.transformer.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.transformer.embed_tokens = value |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
start_positions: Optional[torch.LongTensor] = None, |
|
end_positions: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> QuestionAnsweringModelOutput: |
|
outputs: BaseModelOutputWithPast = self.transformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs, |
|
) |
|
|
|
sequence_output = outputs.last_hidden_state |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
loss = None |
|
if start_positions is not None and end_positions is not None: |
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) |
|
|
|
return QuestionAnsweringModelOutput( |
|
loss=loss, |
|
start_logits=start_logits, |
|
end_logits=end_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
__all__ = [ |
|
"Exaone4PreTrainedModel", |
|
"Exaone4Model", |
|
"Exaone4ForCausalLM", |
|
"Exaone4ForSequenceClassification", |
|
"Exaone4ForTokenClassification", |
|
"Exaone4ForQuestionAnswering", |
|
] |
|
|