qingy2024 commited on
Commit
07537b9
·
verified ·
1 Parent(s): 34b8f81

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -80
README.md CHANGED
@@ -48,87 +48,10 @@ The primary purpose of this model is to provide varied and persona-specific utte
48
 
49
  ## How to Use (as part of the `interactive_chat.py` script)
50
 
51
- This model is primarily intended to be used within the provided `interactive_chat.py` script. Here's how it's loaded and used:
52
-
53
- **1. Model Definition (Python):**
54
- ```python
55
- import torch
56
- import torch.nn as nn
57
-
58
- # (Ensure VOCAB_SIZE, NUM_PERSONS, etc. are defined based on your training)
59
-
60
- class ConditionalCharLSTM(nn.Module):
61
- def __init__(self, vocab_size, embedding_dim, hidden_dim, condition_dim, num_layers=1, dropout=0.1):
62
- super().__init__()
63
- self.vocab_size = vocab_size
64
- self.embedding_dim = embedding_dim
65
- self.hidden_dim = hidden_dim
66
- self.condition_dim = condition_dim
67
- self.num_layers = num_layers
68
-
69
- # Assuming CHAR_TO_IDX[PAD_TOKEN] is defined
70
- self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=CHAR_TO_IDX[PAD_TOKEN])
71
- self.lstm = nn.LSTM(embedding_dim + condition_dim, hidden_dim, num_layers,
72
- batch_first=True, dropout=dropout if num_layers > 1 else 0)
73
- self.fc_out = nn.Linear(hidden_dim, vocab_size)
74
- self.dropout = nn.Dropout(dropout)
75
-
76
- def forward(self, input_chars, condition_vector, hidden_state=None):
77
- embedded_chars = self.dropout(self.embedding(input_chars))
78
- # Expand condition_vector to match sequence length for concatenation
79
- condition_expanded = condition_vector.unsqueeze(1).repeat(1, embedded_chars.size(1), 1)
80
- lstm_input = torch.cat((embedded_chars, condition_expanded), dim=2)
81
-
82
- if hidden_state is None: # Initialize hidden state if not provided
83
- batch_size = input_chars.size(0)
84
- dev = input_chars.device
85
- h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(dev)
86
- c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(dev)
87
- hidden_state = (h0, c0)
88
-
89
- lstm_out, hidden_state = self.lstm(lstm_input, hidden_state)
90
- lstm_out_dropped = self.dropout(lstm_out)
91
- output_logits = self.fc_out(lstm_out_dropped)
92
- return output_logits, hidden_state
93
- ```
94
-
95
- **2. Loading the Pretrained Weights:**
96
- ```python
97
- # In interactive_chat.py
98
- # (Define VOCAB_SIZE, embedding_dim, hidden_dim, NUM_PERSONS, etc. to match training)
99
- model_lstm = ConditionalCharLSTM(
100
- vocab_size=VOCAB_SIZE,
101
- embedding_dim=args.lstm_embedding_dim, # from command line
102
- hidden_dim=args.lstm_hidden_dim, # from command line
103
- condition_dim=NUM_PERSONS,
104
- # ... other params
105
- ).to(device)
106
-
107
- model_lstm.load_state_dict(torch.load(args.lstm_model_path, map_location=device))
108
- model_lstm.eval()
109
- ```
110
 
111
- **3. Generating Text (Streaming):**
112
- ```python
113
- # In interactive_chat.py (simplified)
114
- def generate_text_lstm_stream(model_lstm, person_id_int, device, temperature=0.7, max_len=100):
115
- model_lstm.eval()
116
- condition_vector = torch.zeros(1, NUM_PERSONS, dtype=torch.float).to(device)
117
- condition_vector[0, person_id_int] = 1.0 # One-hot encode the person_id
118
-
119
- current_char_idx = torch.tensor([[CHAR_TO_IDX[SOS_TOKEN]]], dtype=torch.long).to(device)
120
- hidden_state = None
121
-
122
- for _ in range(max_len - 1):
123
- output_logits, hidden_state = model_lstm(current_char_idx, condition_vector, hidden_state)
124
- # Apply temperature and sample next character
125
- probabilities = torch.softmax(output_logits.squeeze(0).squeeze(0) / temperature, dim=0)
126
- next_char_idx = torch.multinomial(probabilities, 1).item()
127
-
128
- if next_char_idx == CHAR_TO_IDX[EOS_TOKEN]: break
129
- char = IDX_TO_CHAR.get(next_char_idx, "")
130
- yield char
131
- current_char_idx = torch.tensor([[next_char_idx]], dtype=torch.long).to(device)
132
  ```
133
 
134
  ## Training Data 📚
 
48
 
49
  ## How to Use (as part of the `interactive_chat.py` script)
50
 
51
+ This model is primarily intended to be used within the provided `rcj_inference.py` script. Here's how it's used:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ ```bash
54
+ python rcj_inference.py --classifier_model_path /path/to/personify-67m/ --lstm_model_path rcj_lstm.pth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ```
56
 
57
  ## Training Data 📚