ExllamaV2-Control-Vectors / exl2_wrapper.py
llmixer's picture
Normalize hidden state after adding control vectors to preserve L2 norm
1f3dd72
import glob
import torch
from gguf.gguf_reader import GGUFReader
class ExLlamaV2ModuleWrapper:
@classmethod
def wrap(cls, model, vector_configs):
vectors = {}
for file in glob.glob(str(model.config.model_dir) + '-vectors/*.gguf'):
base = file.rsplit('-', 1)[-1].replace('.gguf', '')
vector, direction = base.split('__')
print(f"Loaded control vector: {vector}, Direction: {direction}")
reader = GGUFReader(file)
if reader.tensors[0].n_elements != model.config.hidden_size:
print(' ## Control vector n_elements ({reader.tensors[0].n_elements}) != model.config.hidden_size ({model.config.hidden_size})')
sys.exit()
layers = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size), dtype=torch.float32)
for tensor in reader.tensors:
idx = int(tensor.name.split('.')[-1])
layers[idx] = torch.from_numpy(tensor.data.copy())
vectors.setdefault(vector, {})[direction] = layers
vector_configs = vector_configs.split(',')
control_vector = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size), dtype=torch.float32)
for vector_config in vector_configs:
(vector, direction, weight) = vector_config.split(':')
vector_dirs = None
for k, v in vectors.items():
if vector in k:
vector = k
vector_dirs = v
break
if vector_dirs is None:
print(' !! Error: No vector for "%s" (%s)' % (vector, vector_config))
continue
debias_layers = vector_dirs.get('debias', None)
if debias_layers is None:
print(' !! Error: No debias for "%s" (%s)' % (vector, vector_config))
continue
direction_layers = vector_dirs.get(direction, None)
if direction_layers is None:
print(' !! Error: No "%s" for "%s" (%s)' % (direction, vector, vector_config))
continue
try:
weight = float(weight)
except Exception as e:
print(' !! Non float weight %s (%s)' % (weight, vector_config))
weight = 1.0
print(' -- Applying %s debias and %s * %s' % (vector, direction, weight))
control_vector += debias_layers
control_vector += direction_layers * weight
for idx, module in enumerate(model.modules):
if idx == 0 or idx >= (len(model.modules) - 2) or module.name != 'MLP':
continue
model.modules[idx] = ExLlamaV2ModuleWrapper(module, control_vector)
def __init__(self, module, control_vector):
self.module = module
self.control_vector = control_vector
def __getattribute__(self, name):
if name == 'forward':
return object.__getattribute__(self, 'wrapped_forward')
try:
return getattr(object.__getattribute__(self, 'module'), name)
except AttributeError:
pass
return object.__getattribute__(self, name)
def wrapped_forward(self, *args, **kwargs):
x = self.module.forward(*args, **kwargs)
try:
prev_norm = torch.norm(x, p=2)
x += self.control_vector[self.module.layer_idx].clone().to(x.device)
x *= prev_norm / torch.norm(x, p=2)
except IndexError:
pass
return x