import torch from torch.nn import functional as F from safetensors.torch import load_file, save_file pad_size = 128 # Specific to Qwen2-72B architecture total_shards = 32 # Total number of shards in the model, edit according to the actual files for shard_idx in range(1, total_shards + 1): # Generate filename with zero-padded shard numbers filename = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors" # Load shard state_dict = load_file(filename) modified = False # Process each tensor in the current shard for key in list(state_dict.keys()): tensor = state_dict[key] if 'multi_modal_projector.linear_1.weight' in key or 'multi_modal_projector.linear_3.weight' in key: prev_tensor = F.pad(tensor.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] new_tensor = torch.cat([prev_tensor, tensor[pad_size:]], dim=0) state_dict[key] = new_tensor modified = True elif 'multi_modal_projector.linear_2.weight' in key: prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2] new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1) state_dict[key] = new_tensor modified = True elif 'mlp.fc1.weight' in key: print(tensor.shape, "KEK1") gate_proj, up_proj = torch.chunk(tensor, 2, dim=0) print(gate_proj.shape, up_proj.shape, "KEK2") # Apply interleaving pattern for up/gate projections prev_tensor_gate = F.pad(gate_proj.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] new_tensor_gate = torch.cat([prev_tensor_gate, gate_proj[pad_size:]], dim=0) prev_tensor_up = F.pad(up_proj.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] new_tensor_up = torch.cat([prev_tensor_up, up_proj[pad_size:]], dim=0) new_tensor = torch.cat([new_tensor_gate, new_tensor_up], dim=0) print(new_tensor.shape, "KEK3") state_dict[key] = new_tensor modified = True #kek = 1/0 elif 'mlp.fc2.weight' in key: # Apply pattern for down projection prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2] new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1) state_dict[key] = new_tensor modified = True # Save modified shard back to original file if changes were made if modified: save_file(state_dict, filename, metadata={"format": "pt"}) print(f"Processed and saved {filename}") else: print(f"No modifications needed for {filename}")