sigmoidneuron123 commited on
Commit
3edb9ef
·
verified ·
1 Parent(s): cfd9a86

Upload 2 files

Browse files
Files changed (2) hide show
  1. neochessppo.py +418 -0
  2. san_moves.txt +0 -0
neochessppo.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchrl
2
+ import torch
3
+ import chess
4
+ import chess.engine
5
+ import gymnasium
6
+ import numpy as np
7
+ import tensordict
8
+ from collections import defaultdict
9
+ from tensordict.nn import TensorDictModule
10
+ from tensordict.nn.distributions import NormalParamExtractor
11
+ from torch import nn
12
+ from torchrl.collectors import SyncDataCollector
13
+ from torchrl.data.replay_buffers import ReplayBuffer
14
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
15
+ from torchrl.data.replay_buffers.storages import LazyTensorStorage
16
+ import torch.nn.functional as F
17
+ from torch.distributions import Categorical
18
+ from torchrl.envs import (
19
+ Compose,
20
+ DoubleToFloat,
21
+ ObservationNorm,
22
+ StepCounter,
23
+ TransformedEnv,
24
+ )
25
+ from torchrl.envs.libs.gym import GymEnv
26
+ from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
27
+ from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator, MaskedCategorical, ActorCriticWrapper
28
+ from torchrl.objectives import ClipPPOLoss
29
+ from torchrl.objectives.value import GAE
30
+ from tqdm import tqdm
31
+ from torchrl.envs.custom.chess import ChessEnv
32
+ from torchrl.envs.libs.gym import set_gym_backend, GymWrapper
33
+ from torchrl.envs import GymEnv
34
+ from tensordict import TensorDict
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ def board_to_tensor(board):
39
+ piece_encoding = {
40
+ 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6,
41
+ 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12
42
+ }
43
+
44
+ tensor = torch.zeros(64, dtype=torch.long)
45
+ for square in chess.SQUARES:
46
+ piece = board.piece_at(square)
47
+ if piece:
48
+ tensor[square] = piece_encoding[piece.symbol()]
49
+ else:
50
+ tensor[square] = 0
51
+
52
+ return tensor.unsqueeze(0)
53
+
54
+ class Policy(nn.Module):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.embedding = nn.Embedding(13, 32)
58
+ self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16)
59
+ self.neu = 256
60
+ self.neurons = nn.Sequential(
61
+ nn.Linear(64*32, self.neu),
62
+ nn.ReLU(),
63
+ nn.Linear(self.neu, self.neu),
64
+ nn.ReLU(),
65
+ nn.Linear(self.neu, self.neu),
66
+ nn.ReLU(),
67
+ nn.Linear(self.neu, self.neu),
68
+ nn.ReLU(),
69
+ nn.Linear(self.neu, 128),
70
+ nn.ReLU(),
71
+ nn.Linear(128, 29275),
72
+ )
73
+
74
+ def forward(self, x):
75
+ x = chess.Board(x)
76
+ color = x.turn
77
+ x = board_to_tensor(x)
78
+ x = self.embedding(x)
79
+ x = x.permute(1, 0, 2)
80
+ attn_output, _ = self.attention(x, x, x)
81
+ x = attn_output.permute(1, 0, 2).contiguous()
82
+ x = x.view(x.size(0), -1)
83
+ x = self.neurons(x) * color
84
+ return x
85
+
86
+ class Value(nn.Module):
87
+ def __init__(self):
88
+ super().__init__()
89
+ self.embedding = nn.Embedding(13, 64)
90
+ self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16)
91
+ self.neu = 512
92
+ self.neurons = nn.Sequential(
93
+ nn.Linear(64*64, self.neu),
94
+ nn.ReLU(),
95
+ nn.Linear(self.neu, self.neu),
96
+ nn.ReLU(),
97
+ nn.Linear(self.neu, self.neu),
98
+ nn.ReLU(),
99
+ nn.Linear(self.neu, self.neu),
100
+ nn.ReLU(),
101
+ nn.Linear(self.neu, self.neu),
102
+ nn.ReLU(),
103
+ nn.Linear(self.neu, self.neu),
104
+ nn.ReLU(),
105
+ nn.Linear(self.neu, self.neu),
106
+ nn.ReLU(),
107
+ nn.Linear(self.neu, self.neu),
108
+ nn.ReLU(),
109
+ nn.Linear(self.neu, self.neu),
110
+ nn.ReLU(),
111
+ nn.Linear(self.neu, self.neu),
112
+ nn.ReLU(),
113
+ nn.Linear(self.neu, self.neu),
114
+ nn.ReLU(),
115
+ nn.Linear(self.neu, self.neu),
116
+ nn.ReLU(),
117
+ nn.Linear(self.neu, self.neu),
118
+ nn.ReLU(),
119
+ nn.Linear(self.neu, 64),
120
+ nn.ReLU(),
121
+ nn.Linear(64, 4)
122
+ )
123
+
124
+ def forward(self, x):
125
+ x = chess.Board(x)
126
+ color = x.turn
127
+ x = board_to_tensor(x)
128
+ x = self.embedding(x)
129
+ x = x.permute(1, 0, 2)
130
+ attn_output, _ = self.attention(x, x, x)
131
+ x = attn_output.permute(1, 0, 2).contiguous()
132
+ x = x.view(x.size(0), -1)
133
+ x = self.neurons(x)
134
+ x = x[0][0]/10
135
+ if color == chess.WHITE:
136
+ x = -x
137
+ return x
138
+
139
+ with set_gym_backend("gymnasium"):
140
+ env = ChessEnv(
141
+ stateful=True,
142
+ include_fen=True,
143
+ include_san=False,
144
+ )
145
+
146
+ policy = Policy().to(device)
147
+ value = Value().to(device)
148
+ valweight = torch.load("NeoChess/chessy_model.pth",map_location=device)
149
+ value.load_state_dict(valweight)
150
+ polweight = torch.load("NeoChess/chessy_policy.pth")
151
+ policy.load_state_dict(polweight)
152
+
153
+ def sample_masked_action(logits, mask):
154
+ masked_logits = logits.clone()
155
+ masked_logits[~mask] = float('-inf') # Illegal moves
156
+ probs = F.softmax(masked_logits, dim=-1)
157
+ dist = Categorical(probs=probs)
158
+ action = dist.sample()
159
+ log_prob = dist.log_prob(action)
160
+ return action, log_prob
161
+
162
+ class FENPolicyWrapper(nn.Module):
163
+ def __init__(self, policy_net):
164
+ super().__init__()
165
+ self.policy_net = policy_net
166
+
167
+ def forward(self, fens, action_mask=None) -> torch.tensor:
168
+ if isinstance(fens, (TensorDict, dict)):
169
+ fens = fens["fen"]
170
+
171
+ # Normalize to list of strings
172
+ if isinstance(fens, str):
173
+ fens = [fens]
174
+
175
+ # Flatten nested list
176
+ while isinstance(fens[0], list):
177
+ fens = fens[0]
178
+
179
+ # Ensure action_mask is a list of tensors (or None)
180
+ if action_mask is not None:
181
+ if isinstance(action_mask, torch.Tensor):
182
+ action_mask = action_mask.unsqueeze(0) if action_mask.ndim == 1 else action_mask
183
+ if not isinstance(action_mask, list):
184
+ action_mask = [action_mask[i] for i in range(len(fens))]
185
+
186
+ logits_list = []
187
+
188
+ for i, fen in enumerate(fens):
189
+ logits = self.policy_net(fen) # shape: [4672]
190
+
191
+ # Apply masking if provided
192
+ if action_mask is not None:
193
+ mask = action_mask[i].bool() # shape: [4672]
194
+ logits = logits.masked_fill(~mask, float("-inf"))
195
+
196
+ logits_list.append(logits)
197
+
198
+ return torch.stack(logits_list).squeeze(-2).squeeze(-2) # shape: [batch_size, 4672]
199
+
200
+ class FENValueWrapper(nn.Module):
201
+ def __init__(self, value_net):
202
+ super().__init__()
203
+ self.value_net = value_net
204
+
205
+ def forward(self, fens) -> torch.tensor:
206
+ if isinstance(fens, TensorDict) or isinstance(fens,dict):
207
+ fens = fens["fen"]
208
+ if isinstance(fens, str):
209
+ fens = [fens] # Wrap single string in a list
210
+ while isinstance(fens[0], list):
211
+ fens = fens[0]
212
+ state_value = []
213
+ for fen in fens:
214
+ state_value += [self.value_net(fen)]
215
+ state_value = torch.stack(state_value)
216
+ # Ensure output has a batch dimension of 1 if it's a single sample
217
+ if state_value.ndim == 0:
218
+ state_value = state_value.unsqueeze(0)
219
+ return state_value
220
+
221
+ ACTION_DIM = 64 * 73
222
+
223
+ from functools import partial
224
+ # Wrap policy
225
+ policy_module = TensorDictModule(
226
+ FENPolicyWrapper(policy),
227
+ in_keys=["fen"],
228
+ out_keys=["logits"]
229
+ )
230
+ value_module = TensorDictModule(
231
+ FENValueWrapper(value),
232
+ in_keys=["fen"],
233
+ out_keys=["state_value"]
234
+ )
235
+
236
+ def masked_categorical_factory(logits, action_mask):
237
+ return MaskedCategorical(logits=logits, mask=action_mask)
238
+
239
+ actor = ProbabilisticActor(
240
+ module=policy_module,
241
+ in_keys=["logits", "action_mask"],
242
+ out_keys=["action"],
243
+ distribution_class=masked_categorical_factory,
244
+ return_log_prob=True,
245
+ )
246
+ #test
247
+ obs = env.reset()
248
+ print(obs)
249
+ print(policy_module(obs))
250
+ print(value_module(obs))
251
+ print(actor(obs))
252
+
253
+ rollout = env.rollout(3)
254
+
255
+ from torchrl.record.loggers import generate_exp_name, get_logger
256
+ def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=100,
257
+ num_epochs=10, lr=3e-4, gamma=0.99, lmbda=0.95,
258
+ clip_epsilon=0.2, device="cpu"):
259
+ """
260
+ Main PPO training loop for Chess
261
+
262
+ Args:
263
+ chess_env: Your ChessEnv instance
264
+ num_iterations: Number of training iterations
265
+ frames_per_batch: Number of environment steps per batch
266
+ num_epochs: Number of PPO epochs per iteration
267
+ lr: Learning rate
268
+ gamma: Discount factor
269
+ lmbda: GAE lambda parameter
270
+ clip_epsilon: PPO clipping parameter
271
+ device: Training device
272
+ """
273
+
274
+ # Wrap the chess environment
275
+ env = chess_env
276
+ # Create actor and value modules
277
+ actor_module = actor
278
+ global actor_module, value_module, loss_module
279
+
280
+ collector = SyncDataCollector(
281
+ env,
282
+ actor_module,
283
+ frames_per_batch=frames_per_batch,
284
+ total_frames=-1,
285
+ device=device,
286
+ )
287
+
288
+ # Create replay buffer
289
+ replay_buffer = ReplayBuffer(
290
+ storage=LazyTensorStorage(frames_per_batch),
291
+ sampler=SamplerWithoutReplacement(),
292
+ batch_size=256, # Mini-batch size for PPO updates
293
+ )
294
+
295
+ # Create PPO loss module
296
+ loss_module = ClipPPOLoss(
297
+ actor_network=actor_module,
298
+ critic_network=value_module,
299
+ clip_epsilon=clip_epsilon,
300
+ entropy_bonus=True,
301
+ entropy_coef=0.01,
302
+ critic_coef=1.0,
303
+ normalize_advantage=True,
304
+ )
305
+
306
+ optim = torch.optim.Adam(loss_module.parameters(), lr=lr)
307
+
308
+ # Setup logging
309
+ logger = get_logger("tensorboard", logger_name="ppo_chess", experiment_name=generate_exp_name("PPO", "Chess"))
310
+
311
+ # Training loop
312
+ collected_frames = 0
313
+
314
+ for iteration in range(num_iterations):
315
+ print(f"\n=== Iteration {iteration + 1}/{num_iterations} ===")
316
+
317
+ # Collect data
318
+ batch_data = []
319
+ for i, batch in enumerate(collector):
320
+ batch_data.append(batch)
321
+ collected_frames += batch.numel()
322
+
323
+ # Break after collecting enough frames
324
+ if len(batch_data) * collector.frames_per_batch >= frames_per_batch:
325
+ break
326
+
327
+ # Concatenate all batches
328
+ if batch_data:
329
+ full_batch = torch.cat(batch_data, dim=0)
330
+
331
+ # Add GAE (Generalized Advantage Estimation)
332
+ with torch.no_grad():
333
+ full_batch = loss_module.value_estimator(full_batch)
334
+
335
+ replay_buffer.extend(full_batch)
336
+
337
+ # Training phase
338
+ total_loss = 0
339
+ total_actor_loss = 0
340
+ total_critic_loss = 0
341
+ total_entropy_loss = 0
342
+
343
+ for epoch in range(num_epochs):
344
+ epoch_loss = 0
345
+ epoch_actor_loss = 0
346
+ epoch_critic_loss = 0
347
+ epoch_entropy_loss = 0
348
+ num_batches = 0
349
+
350
+ for batch in replay_buffer:
351
+ print(batch)
352
+ # Ensure batch has correct dimensions
353
+ if "state_value" in batch and batch["state_value"].dim() > 1:
354
+ batch["state_value"] = batch["state_value"].squeeze(-1)
355
+
356
+ batch["value_target"] = batch["value_target"].squeeze(1)
357
+ # Compute losses
358
+ loss_dict = loss_module(batch)
359
+ loss = loss_dict["loss_objective"] + loss_dict["loss_critic"] + loss_dict["loss_entropy"]
360
+
361
+ # Backward pass
362
+ optim.zero_grad()
363
+ loss.backward()
364
+ torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=0.5)
365
+ optim.step()
366
+
367
+ # Accumulate losses
368
+ epoch_loss += loss.item()
369
+ epoch_actor_loss += loss_dict["loss_objective"].item()
370
+ epoch_critic_loss += loss_dict["loss_critic"].item()
371
+ epoch_entropy_loss += loss_dict["loss_entropy"].item()
372
+ num_batches += 1
373
+
374
+ # Average losses over epoch
375
+ if num_batches > 0:
376
+ total_loss += epoch_loss / num_batches
377
+ total_actor_loss += epoch_actor_loss / num_batches
378
+ total_critic_loss += epoch_critic_loss / num_batches
379
+ total_entropy_loss += epoch_entropy_loss / num_batches
380
+
381
+ # Average losses over all epochs
382
+ avg_total_loss = total_loss / num_epochs
383
+ avg_actor_loss = total_actor_loss / num_epochs
384
+ avg_critic_loss = total_critic_loss / num_epochs
385
+ avg_entropy_loss = total_entropy_loss / num_epochs
386
+
387
+ # Log metrics
388
+ metrics = {
389
+ "train/total_loss": avg_total_loss,
390
+ "train/actor_loss": avg_actor_loss,
391
+ "train/critic_loss": avg_critic_loss,
392
+ "train/entropy_loss": avg_entropy_loss,
393
+ "train/collected_frames": collected_frames,
394
+ }
395
+
396
+ # Log reward if available in batch
397
+ if "reward" in batch.keys():
398
+ avg_reward = batch["reward"].mean().item()
399
+ metrics["train/avg_reward"] = avg_reward
400
+ print(f"Average Reward: {avg_reward:.3f}")
401
+
402
+ for key, value in metrics.items():
403
+ logger.log_scalar(key, value, step=iteration)
404
+
405
+ print(f"Total Loss: {avg_total_loss:.4f}")
406
+ print(f"Actor Loss: {avg_actor_loss:.4f}")
407
+ print(f"Critic Loss: {avg_critic_loss:.4f}")
408
+ print(f"Entropy Loss: {avg_entropy_loss:.4f}")
409
+ print(f"Collected Frames: {collected_frames}")
410
+
411
+ # Clear replay buffer for next iteration
412
+ replay_buffer.empty()
413
+
414
+ print("\nTraining completed!")
415
+
416
+ train_ppo_chess(env)
417
+ torch.save(value.state_dict(),"chessy_model.pth")
418
+ torch.save(policy.state_dict(),"chessy_policy.pth")
san_moves.txt ADDED
The diff for this file is too large to render. See raw diff