Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update story_generator.py
Browse files- story_generator.py +8 -8
    	
        story_generator.py
    CHANGED
    
    | @@ -29,12 +29,12 @@ def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7): | |
| 29 | 
             
                with torch.no_grad():
         | 
| 30 | 
             
                    output = story_model.generate(
         | 
| 31 | 
             
                        input_ids,
         | 
| 32 | 
            -
                        max_new_tokens=max_new_tokens, | 
| 33 | 
             
                        temperature=temperature,
         | 
| 34 | 
            -
                        top_k=20, | 
| 35 | 
            -
                        top_p=0.7, | 
| 36 | 
             
                        do_sample=True,
         | 
| 37 | 
            -
                        early_stopping=True, | 
| 38 | 
             
                        pad_token_id=story_tokenizer.pad_token_id,
         | 
| 39 | 
             
                        eos_token_id=story_tokenizer.eos_token_id,
         | 
| 40 | 
             
                        attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
         | 
| @@ -50,12 +50,12 @@ def generate_questions(story, max_new_tokens=150, temperature=0.7): | |
| 50 | 
             
                with torch.no_grad():
         | 
| 51 | 
             
                    output = question_model.generate(
         | 
| 52 | 
             
                        input_ids,
         | 
| 53 | 
            -
                        max_new_tokens=max_new_tokens, | 
| 54 | 
             
                        temperature=temperature,
         | 
| 55 | 
            -
                        top_k=20, | 
| 56 | 
            -
                        top_p=0.7, | 
| 57 | 
             
                        do_sample=True,
         | 
| 58 | 
            -
                        early_stopping=True, | 
| 59 | 
             
                        pad_token_id=question_tokenizer.pad_token_id,
         | 
| 60 | 
             
                        eos_token_id=question_tokenizer.eos_token_id,
         | 
| 61 | 
             
                        attention_mask=input_ids.ne(question_tokenizer.pad_token_id)
         | 
|  | |
| 29 | 
             
                with torch.no_grad():
         | 
| 30 | 
             
                    output = story_model.generate(
         | 
| 31 | 
             
                        input_ids,
         | 
| 32 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 33 | 
             
                        temperature=temperature,
         | 
| 34 | 
            +
                        top_k=20,
         | 
| 35 | 
            +
                        top_p=0.7,
         | 
| 36 | 
             
                        do_sample=True,
         | 
| 37 | 
            +
                        early_stopping=True,
         | 
| 38 | 
             
                        pad_token_id=story_tokenizer.pad_token_id,
         | 
| 39 | 
             
                        eos_token_id=story_tokenizer.eos_token_id,
         | 
| 40 | 
             
                        attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
         | 
|  | |
| 50 | 
             
                with torch.no_grad():
         | 
| 51 | 
             
                    output = question_model.generate(
         | 
| 52 | 
             
                        input_ids,
         | 
| 53 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 54 | 
             
                        temperature=temperature,
         | 
| 55 | 
            +
                        top_k=20,
         | 
| 56 | 
            +
                        top_p=0.7,
         | 
| 57 | 
             
                        do_sample=True,
         | 
| 58 | 
            +
                        early_stopping=True,
         | 
| 59 | 
             
                        pad_token_id=question_tokenizer.pad_token_id,
         | 
| 60 | 
             
                        eos_token_id=question_tokenizer.eos_token_id,
         | 
| 61 | 
             
                        attention_mask=input_ids.ne(question_tokenizer.pad_token_id)
         |