
Can't generate dectent text out of it

by useless-ai - opened

Am I missing something or this upload has some issues? it's not generating any text trhat makes sense.

You have to use the implementation from

You have to use the implementation from

Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are getting TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

You have to use the implementation from

Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are getting TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

Same here, using the model from the repo gives cache_position error

any chance of seeing cache_position error getting fixed?

TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'. run th code ,the same error? Do anyone have the same problem?

I add some codes with the help of Cursor. And it can run now but with a bad performance. Actually it can only generate meaningless texts. I do not know it is caused by fault code from AI or the infinitransformer code. I leave comments # where I changed.

class GemmaInfiniAttention(GemmaAttention):
    def __init__(
        config: GemmaConfig,
        layer_idx: Optional[int] = None,
        super().__init__(config, layer_idx)
        self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))

    def forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: bool = False,
        past_key_value: Optional[Cache] = None,  # Add this line
        output_attentions: Optional[bool] = False,  # Add this line
        use_cache: Optional[bool] = False,  # Add this line
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
class GemmaModel(GemmaPreTrainedModel):
    def __init__(self, config: GemmaConfig):
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        self.layers = nn.ModuleList(
                GemmaDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False


    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Union[Tuple, InfiniBaseModelOutputWithPast]:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        if use_cache and isinstance(past_key_values, StaticCache):
            past_seen_tokens = past_key_values.get_seq_length()

        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
        position_ids = cache_position.unsqueeze(0) if position_ids is None else position_ids
        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1])

        hidden_states = inputs_embeds * torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype)

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        next_decoder_cache = None  # Initialize next_decoder_cache
class GemmaInfiniAttention(GemmaAttention):
    def __init__(
        config: GemmaConfig,
        layer_idx: Optional[int] = None,
        super().__init__(config, layer_idx)
        self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))

    def forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: bool = False,
        past_key_value: Optional[Cache] = None,  # Add this line
        output_attentions: Optional[bool] = False,  # Add this line
        use_cache: Optional[bool] = False,  # Add this line
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        bsz, seq_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Add this line to repeat key and value states. Those lines can be removed, still keeping working
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # Adjust attention_mask shape if necessary
        if attention_mask is not None and attention_mask.shape[-1] != key_states.shape[-2]:
            attention_mask = attention_mask[:, :, :, :key_states.shape[-2]]

        # Debugging: Print shapes
        print(f"query_states shape: {query_states.shape}")
        print(f"key_states shape: {key_states.shape}")
        print(f"value_states shape: {value_states.shape}")
        if attention_mask is not None:
            print(f"attention_mask shape: {attention_mask.shape}")

        if no_memory_update:
            memory_output = None
            memory_output = self._retrieve_from_memory(query_states, memory, norm_term)

        if not no_memory_update:
            updated_memory, updated_norm_term = self._update_memory(key_states, value_states, memory, norm_term)
            memory = updated_memory.detach()
            norm_term = updated_norm_term.detach()

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            dropout_p=self.attention_dropout if else 0.0,

        if memory_output is None:
            combined_output = attn_output
            combined_output = F.sigmoid(self.gate) * memory_output + (1 - F.sigmoid(self.gate)) * attn_output

        combined_output = combined_output.transpose(1, 2).contiguous()
        combined_output = combined_output.view(bsz, seq_len, self.hidden_size)

        final_output = self.o_proj(combined_output)

        if no_memory_update:
            memory = None
            norm_term = None

        # Ensure the return statement provides five values
        return final_output, None, None, memory, norm_term     

Followings are my outcomes:

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment