Spaces:
Runtime error
Runtime error
| """ | |
| Train a language model: | |
| Use PyTorch and spaCy to train a language model on the preprocessed tweet data. | |
| Load a pre-trained language model from spaCy and fine-tune it on your tweet data using PyTorch. Here's an example code snippet: | |
| """ | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| def generate_account_text(prompt, model_dir, num_return_sequences=5): | |
| if not os.path.exists(model_dir): | |
| print("****************** ERROR **************************") | |
| print(f"Error: {model_dir} does not exist.") | |
| print("****************** ERROR **************************") | |
| return f"Error: {model_dir} does not exist." | |
| # Load the tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForCausalLM.from_pretrained(model_dir) | |
| # Prepend the BOS (beginning of sequence) token to the prompt | |
| start_with_bos = "<|endoftext|>" + prompt | |
| # Encode the prompt using the trainer's tokenizer and convert to PyTorch tensor | |
| encoded_prompt = tokenizer( | |
| start_with_bos, add_special_tokens=False, return_tensors="pt" | |
| ).input_ids | |
| encoded_prompt = encoded_prompt.to(model.device) | |
| # Generate sequences using the encoded prompt as input | |
| output_sequences = model.generate( | |
| input_ids=encoded_prompt, | |
| max_length=200, | |
| min_length=10, | |
| temperature=0.85, | |
| top_p=0.95, | |
| do_sample=True, | |
| num_return_sequences=num_return_sequences, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Set a flag for whether new lines are allowed in the generated text | |
| ALLOW_NEW_LINES = False | |
| # Decode the generated sequences and store them in a list of dictionaries | |
| generated_sequences = [] | |
| # decode prediction | |
| for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
| generated_sequence = generated_sequence.tolist() | |
| text = tokenizer.decode( | |
| generated_sequence, | |
| clean_up_tokenization_spaces=True, | |
| skip_special_tokens=True, | |
| ) | |
| if not ALLOW_NEW_LINES: | |
| limit = text.find("\n") | |
| text = text[: limit if limit != -1 else None] | |
| generated_sequences.append({"prompt": prompt, "generated_text": text.strip()}) | |
| return generated_sequences | |