Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						6e64f14
	
1
								Parent(s):
							
							ca1fc51
								
from minGPT
Browse files
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def set_seed(seed):
         | 
| 8 | 
            +
                random.seed(seed)
         | 
| 9 | 
            +
                np.random.seed(seed)
         | 
| 10 | 
            +
                torch.manual_seed(seed)
         | 
| 11 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def top_k_logits(logits, k):
         | 
| 14 | 
            +
                v, ix = torch.topk(logits, k)
         | 
| 15 | 
            +
                out = logits.clone()
         | 
| 16 | 
            +
                out[out < v[:, [-1]]] = -float('Inf')
         | 
| 17 | 
            +
                return out
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            @torch.no_grad()
         | 
| 20 | 
            +
            def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
         | 
| 23 | 
            +
                the sequence, feeding the predictions back into the model each time. Clearly the sampling
         | 
| 24 | 
            +
                has quadratic complexity unlike an RNN that is only linear, and has a finite context window
         | 
| 25 | 
            +
                of block_size, unlike an RNN that has an infinite context window.
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                block_size = model.get_block_size()
         | 
| 28 | 
            +
                model.eval()
         | 
| 29 | 
            +
                for k in range(steps):
         | 
| 30 | 
            +
                    x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
         | 
| 31 | 
            +
                    logits, _ = model(x_cond)
         | 
| 32 | 
            +
                    # pluck the logits at the final step and scale by temperature
         | 
| 33 | 
            +
                    logits = logits[:, -1, :] / temperature
         | 
| 34 | 
            +
                    # optionally crop probabilities to only the top k options
         | 
| 35 | 
            +
                    if top_k is not None:
         | 
| 36 | 
            +
                        logits = top_k_logits(logits, top_k)
         | 
| 37 | 
            +
                    # apply softmax to convert to probabilities
         | 
| 38 | 
            +
                    probs = F.softmax(logits, dim=-1)
         | 
| 39 | 
            +
                    # sample from the distribution or take the most likely
         | 
| 40 | 
            +
                    if sample:
         | 
| 41 | 
            +
                        ix = torch.multinomial(probs, num_samples=1)
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 44 | 
            +
                    # append to the sequence and continue
         | 
| 45 | 
            +
                    x = torch.cat((x, ix), dim=1)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                return x
         | 
