Update README.md
Browse files
README.md
CHANGED
@@ -48,87 +48,10 @@ The primary purpose of this model is to provide varied and persona-specific utte
|
|
48 |
|
49 |
## How to Use (as part of the `interactive_chat.py` script)
|
50 |
|
51 |
-
This model is primarily intended to be used within the provided `
|
52 |
-
|
53 |
-
**1. Model Definition (Python):**
|
54 |
-
```python
|
55 |
-
import torch
|
56 |
-
import torch.nn as nn
|
57 |
-
|
58 |
-
# (Ensure VOCAB_SIZE, NUM_PERSONS, etc. are defined based on your training)
|
59 |
-
|
60 |
-
class ConditionalCharLSTM(nn.Module):
|
61 |
-
def __init__(self, vocab_size, embedding_dim, hidden_dim, condition_dim, num_layers=1, dropout=0.1):
|
62 |
-
super().__init__()
|
63 |
-
self.vocab_size = vocab_size
|
64 |
-
self.embedding_dim = embedding_dim
|
65 |
-
self.hidden_dim = hidden_dim
|
66 |
-
self.condition_dim = condition_dim
|
67 |
-
self.num_layers = num_layers
|
68 |
-
|
69 |
-
# Assuming CHAR_TO_IDX[PAD_TOKEN] is defined
|
70 |
-
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=CHAR_TO_IDX[PAD_TOKEN])
|
71 |
-
self.lstm = nn.LSTM(embedding_dim + condition_dim, hidden_dim, num_layers,
|
72 |
-
batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
73 |
-
self.fc_out = nn.Linear(hidden_dim, vocab_size)
|
74 |
-
self.dropout = nn.Dropout(dropout)
|
75 |
-
|
76 |
-
def forward(self, input_chars, condition_vector, hidden_state=None):
|
77 |
-
embedded_chars = self.dropout(self.embedding(input_chars))
|
78 |
-
# Expand condition_vector to match sequence length for concatenation
|
79 |
-
condition_expanded = condition_vector.unsqueeze(1).repeat(1, embedded_chars.size(1), 1)
|
80 |
-
lstm_input = torch.cat((embedded_chars, condition_expanded), dim=2)
|
81 |
-
|
82 |
-
if hidden_state is None: # Initialize hidden state if not provided
|
83 |
-
batch_size = input_chars.size(0)
|
84 |
-
dev = input_chars.device
|
85 |
-
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(dev)
|
86 |
-
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(dev)
|
87 |
-
hidden_state = (h0, c0)
|
88 |
-
|
89 |
-
lstm_out, hidden_state = self.lstm(lstm_input, hidden_state)
|
90 |
-
lstm_out_dropped = self.dropout(lstm_out)
|
91 |
-
output_logits = self.fc_out(lstm_out_dropped)
|
92 |
-
return output_logits, hidden_state
|
93 |
-
```
|
94 |
-
|
95 |
-
**2. Loading the Pretrained Weights:**
|
96 |
-
```python
|
97 |
-
# In interactive_chat.py
|
98 |
-
# (Define VOCAB_SIZE, embedding_dim, hidden_dim, NUM_PERSONS, etc. to match training)
|
99 |
-
model_lstm = ConditionalCharLSTM(
|
100 |
-
vocab_size=VOCAB_SIZE,
|
101 |
-
embedding_dim=args.lstm_embedding_dim, # from command line
|
102 |
-
hidden_dim=args.lstm_hidden_dim, # from command line
|
103 |
-
condition_dim=NUM_PERSONS,
|
104 |
-
# ... other params
|
105 |
-
).to(device)
|
106 |
-
|
107 |
-
model_lstm.load_state_dict(torch.load(args.lstm_model_path, map_location=device))
|
108 |
-
model_lstm.eval()
|
109 |
-
```
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
# In interactive_chat.py (simplified)
|
114 |
-
def generate_text_lstm_stream(model_lstm, person_id_int, device, temperature=0.7, max_len=100):
|
115 |
-
model_lstm.eval()
|
116 |
-
condition_vector = torch.zeros(1, NUM_PERSONS, dtype=torch.float).to(device)
|
117 |
-
condition_vector[0, person_id_int] = 1.0 # One-hot encode the person_id
|
118 |
-
|
119 |
-
current_char_idx = torch.tensor([[CHAR_TO_IDX[SOS_TOKEN]]], dtype=torch.long).to(device)
|
120 |
-
hidden_state = None
|
121 |
-
|
122 |
-
for _ in range(max_len - 1):
|
123 |
-
output_logits, hidden_state = model_lstm(current_char_idx, condition_vector, hidden_state)
|
124 |
-
# Apply temperature and sample next character
|
125 |
-
probabilities = torch.softmax(output_logits.squeeze(0).squeeze(0) / temperature, dim=0)
|
126 |
-
next_char_idx = torch.multinomial(probabilities, 1).item()
|
127 |
-
|
128 |
-
if next_char_idx == CHAR_TO_IDX[EOS_TOKEN]: break
|
129 |
-
char = IDX_TO_CHAR.get(next_char_idx, "")
|
130 |
-
yield char
|
131 |
-
current_char_idx = torch.tensor([[next_char_idx]], dtype=torch.long).to(device)
|
132 |
```
|
133 |
|
134 |
## Training Data 📚
|
|
|
48 |
|
49 |
## How to Use (as part of the `interactive_chat.py` script)
|
50 |
|
51 |
+
This model is primarily intended to be used within the provided `rcj_inference.py` script. Here's how it's used:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
```bash
|
54 |
+
python rcj_inference.py --classifier_model_path /path/to/personify-67m/ --lstm_model_path rcj_lstm.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
```
|
56 |
|
57 |
## Training Data 📚
|