Upload ChronoGPT_inference.py with huggingface_hub
Browse files- ChronoGPT_inference.py +2 -86
    	
        ChronoGPT_inference.py
    CHANGED
    
    | @@ -125,90 +125,6 @@ class Block(nn.Module): | |
| 125 | 
             
                    return x
         | 
| 126 |  | 
| 127 | 
             
            class ValueEmbedding(nn.Module):
         | 
| 128 | 
            -
                def __init__(self, vocab_size, model_dim):
         | 
| 129 | 
            -
                    super().__init__()
         | 
| 130 | 
            -
                    self.embed = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                @torch.inference_mode()
         | 
| 133 | 
            -
                def forward(self, inputs):
         | 
| 134 | 
            -
                    ve = [emb(inputs).bfloat16() for emb in self.embed]
         | 
| 135 | 
            -
                    ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]]
         | 
| 136 | 
            -
                    return ve
         | 
| 137 | 
            -
             | 
| 138 | 
            -
            class ChronoGPT(nn.Module, PyTorchModelHubMixin):
         | 
| 139 | 
            -
                def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
         | 
| 140 | 
            -
                    super().__init__()
         | 
| 141 | 
            -
                    self.num_heads = num_heads
         | 
| 142 | 
            -
                    self.vocab_size = vocab_size  # Store vocab_size as instance variable
         | 
| 143 | 
            -
                    self.embed = nn.Embedding(vocab_size, model_dim)
         | 
| 144 | 
            -
                    self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=(i != 7))
         | 
| 145 | 
            -
                                               for i in range(num_layers)])
         | 
| 146 | 
            -
                    self.value_embeds = ValueEmbedding(vocab_size, model_dim)
         | 
| 147 | 
            -
                    self.lm_head = CastedLinear(model_dim, vocab_size)
         | 
| 148 | 
            -
                    self.lm_head.weight.data.zero_()
         | 
| 149 | 
            -
                    self.num_encoder_layers = num_layers // 2
         | 
| 150 | 
            -
                    self.num_decoder_layers = num_layers - self.num_encoder_layers
         | 
| 151 | 
            -
                    self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
         | 
| 152 | 
            -
                @torch.inference_mode()
         | 
| 153 | 
            -
                def forward(self, inputs, past_key_values=None):
         | 
| 154 | 
            -
                    B = inputs.size(0) 
         | 
| 155 | 
            -
                    if inputs.dim() == 1:
         | 
| 156 | 
            -
                        inputs = inputs.unsqueeze(0)  # Add batch dimension if not present
         | 
| 157 | 
            -
                    
         | 
| 158 | 
            -
                    x0 = norm(self.embed(inputs).bfloat16())
         | 
| 159 | 
            -
                    x = x0
         | 
| 160 | 
            -
                    
         | 
| 161 | 
            -
                    # Modify value embedding handling for batched input
         | 
| 162 | 
            -
                    ve = [self.value_embeds(inputs[i].view(-1)) for i in range(B)]
         | 
| 163 | 
            -
                    ve = [torch.stack([ve[b][i] for b in range(B)]) if ve[0][i] is not None else None 
         | 
| 164 | 
            -
                          for i in range(len(ve[0]))]
         | 
| 165 | 
            -
                    ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                    # Handle cached states for batched input
         | 
| 168 | 
            -
                    if past_key_values is not None:
         | 
| 169 | 
            -
                        for i, block in enumerate(self.blocks):
         | 
| 170 | 
            -
                            if block.attn is not None:
         | 
| 171 | 
            -
                                block.attn.kv_cache = past_key_values[i]
         | 
| 172 | 
            -
             | 
| 173 | 
            -
                    present = []
         | 
| 174 | 
            -
                    layer_outputs = []
         | 
| 175 | 
            -
                    skip_connections = []
         | 
| 176 | 
            -
             | 
| 177 | 
            -
                    # Process through encoder layers
         | 
| 178 | 
            -
                    for i in range(self.num_encoder_layers):
         | 
| 179 | 
            -
                        block = self.blocks[i]
         | 
| 180 | 
            -
                        x = block(x, ve_enc[i], x0)
         | 
| 181 | 
            -
                        if block.attn is not None:
         | 
| 182 | 
            -
                            present.append(block.attn.kv_cache)
         | 
| 183 | 
            -
                            block.attn.kv_cache = None
         | 
| 184 | 
            -
                        skip_connections.append(x)
         | 
| 185 | 
            -
                        layer_outputs.append(norm(x))
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    # Process through decoder layers
         | 
| 188 | 
            -
                    for i in range(self.num_decoder_layers):
         | 
| 189 | 
            -
                        x = x + self.skip_weights[i] * skip_connections.pop()
         | 
| 190 | 
            -
                        block = self.blocks[self.num_encoder_layers + i]
         | 
| 191 | 
            -
                        x = block(x, ve_dec[i], x0)
         | 
| 192 | 
            -
                        layer_outputs.append(norm(x))
         | 
| 193 | 
            -
                        if block.attn is not None:
         | 
| 194 | 
            -
                            present.append(block.attn.kv_cache)
         | 
| 195 | 
            -
                            block.attn.kv_cache = None
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                    x = norm(x)
         | 
| 198 | 
            -
                    logits = self.lm_head(x)
         | 
| 199 | 
            -
                    logits = 15 * torch.tanh(logits / 15)
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                    return logits.float(), layer_outputs
         | 
| 202 | 
            -
                @classmethod
         | 
| 203 | 
            -
                def from_pretrained(cls, repo_id, cache_dir=None, **kwargs):
         | 
| 204 | 
            -
                    config_path = hf_hub_download(repo_id=repo_id, filename="config.pt", cache_dir=cache_dir)
         | 
| 205 | 
            -
                    bin_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir)
         | 
| 206 | 
            -
                    config = torch.load(config_path)
         | 
| 207 | 
            -
                    model = cls(**config)
         | 
| 208 | 
            -
                    model.load_state_dict(torch.load(bin_path))
         | 
| 209 | 
            -
                    return model
         | 
| 210 | 
            -
             | 
| 211 | 
            -
            class ValueEmbedding_xl(nn.Module):
         | 
| 212 | 
             
                def __init__(self, vocab_size, model_dim, num_layers=52):
         | 
| 213 | 
             
                    super().__init__()
         | 
| 214 | 
             
                    self.num_layers = num_layers
         | 
| @@ -228,14 +144,14 @@ class ValueEmbedding_xl(nn.Module): | |
| 228 | 
             
                    return encoder + decoder
         | 
| 229 |  | 
| 230 |  | 
| 231 | 
            -
            class  | 
| 232 | 
             
                def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
         | 
| 233 | 
             
                    super().__init__()
         | 
| 234 | 
             
                    self.num_heads = num_heads
         | 
| 235 | 
             
                    self.vocab_size = vocab_size  # Store vocab_size as instance variable
         | 
| 236 | 
             
                    self.embed = nn.Embedding(vocab_size, model_dim)
         | 
| 237 | 
             
                    self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=True) for i in range(num_layers)])
         | 
| 238 | 
            -
                    self.value_embeds =  | 
| 239 | 
             
                    self.lm_head = CastedLinear(model_dim, vocab_size)
         | 
| 240 | 
             
                    self.lm_head.weight.data.zero_()
         | 
| 241 | 
             
                    self.num_encoder_layers = num_layers // 2
         | 
|  | |
| 125 | 
             
                    return x
         | 
| 126 |  | 
| 127 | 
             
            class ValueEmbedding(nn.Module):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 128 | 
             
                def __init__(self, vocab_size, model_dim, num_layers=52):
         | 
| 129 | 
             
                    super().__init__()
         | 
| 130 | 
             
                    self.num_layers = num_layers
         | 
|  | |
| 144 | 
             
                    return encoder + decoder
         | 
| 145 |  | 
| 146 |  | 
| 147 | 
            +
            class ChronoGPT(nn.Module, PyTorchModelHubMixin):
         | 
| 148 | 
             
                def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
         | 
| 149 | 
             
                    super().__init__()
         | 
| 150 | 
             
                    self.num_heads = num_heads
         | 
| 151 | 
             
                    self.vocab_size = vocab_size  # Store vocab_size as instance variable
         | 
| 152 | 
             
                    self.embed = nn.Embedding(vocab_size, model_dim)
         | 
| 153 | 
             
                    self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=True) for i in range(num_layers)])
         | 
| 154 | 
            +
                    self.value_embeds = ValueEmbedding(vocab_size, model_dim, num_layers=num_layers)
         | 
| 155 | 
             
                    self.lm_head = CastedLinear(model_dim, vocab_size)
         | 
| 156 | 
             
                    self.lm_head.weight.data.zero_()
         | 
| 157 | 
             
                    self.num_encoder_layers = num_layers // 2
         | 

