Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
from .promptops import * | |
from .aux import CmdlineArgs, log | |
from .data import get_data_loader | |
from .trainllm import env_stuff, load_model, load_tokenizer | |
import sys | |
import torch | |
import json | |
import torch.distributed as dist | |
from accelerate import Accelerator | |
from datetime import datetime | |
""" | |
This currently assumes the batch size to be 1. With larger batches the padding tokens went | |
into the decoder. Right-padding as a solution? | |
""" | |
def llm_generate(model, tokenizer, tok_batch, debug=False, max_len=2000): | |
tok_batch['input_ids'] = tok_batch['input_ids'].to(model.device) | |
tok_batch['attention_mask'] = tok_batch['attention_mask'].to(model.device) | |
start_time = datetime.now() | |
if debug: | |
log(f"Tokenized input: {tok_batch['input_ids']}") | |
raw_output_toks = model.generate(**tok_batch, tokenizer=tokenizer, | |
do_sample=False, num_beams=4, max_length=max_len, top_p=None, temperature=None, | |
eos_token_id=[tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("<|reserved_special_token_14|>")]) | |
#clean_output_toks = remove_prompt_from_output(tok_batch['attention_mask'], raw_output_toks, filler_id) | |
assert len(raw_output_toks) == 1, "Only batch size=1 supported %-(" | |
gen_idx = len(tok_batch['attention_mask'][0]) | |
if debug: | |
log(f"Full tokenized output: {raw_output_toks[0]}") | |
log(f"Full tokens: {tokenizer.convert_ids_to_tokens(raw_output_toks[0])}") | |
full_out = tokenizer.batch_decode([raw_output_toks[0]], skip_special_tokens=True) | |
log(f"Full text: {full_out[0]}") | |
clean_output_toks = raw_output_toks[0][gen_idx:] | |
clean_outputs = tokenizer.batch_decode([clean_output_toks], skip_special_tokens=True) | |
if debug: | |
log(f"Pruned tokenized output: {clean_output_toks}") | |
log(f"Pruned tokens: {tokenizer.convert_ids_to_tokens(clean_output_toks)}") | |
log(f"Cleaned output: {clean_outputs[0]}") | |
end_time = datetime.now() | |
log(f"This took: {end_time - start_time}") | |
return clean_outputs | |
def reassemble_multi(list_of_lists): | |
result = [] | |
for gen_idx in range(len(list_of_lists[0])): | |
for i in range(len(list_of_lists)): | |
if gen_idx < len(list_of_lists[i]): | |
result.append(list_of_lists[i][gen_idx]) | |
return result | |
def predict(model, tokenizer, data_loader, accel, multi=False, debug=False, max_len=2000): | |
outs_final = [] | |
with torch.no_grad(): | |
for idx, batch in enumerate(data_loader): | |
if idx % accel.num_processes == accel.process_index: | |
start_time = datetime.now() | |
outputs = llm_generate(model, tokenizer, batch, debug=debug, max_len=max_len) | |
end_time = datetime.now() | |
log(f"Generated for {idx} in proc {accel.process_index} in {end_time - start_time}") | |
outs_final += outputs | |
if multi: | |
accel.wait_for_everyone() | |
rank0_buffer = [None] * accel.num_processes if accel.is_main_process else None | |
dist.gather_object(outs_final, rank0_buffer, dst=0) | |
if accel.is_main_process: | |
outs_final = reassemble_multi(rank0_buffer) | |
else: | |
outs_final = None | |
return outs_final | |
def _cmdline_args(): | |
inputs = sys.argv[1:] | |
description = """Predict output for an input via prompting""" | |
pos_args = ["mdl_id"] | |
#post-process the arguments | |
args = CmdlineArgs(description, pos_args, input_args=inputs, | |
kw_arg_dict={"debug": False, | |
"input_file": "none", | |
"output_file": "none", | |
"multiproc": False, | |
"max_len": 2000, | |
"prompt_format": PF_ALPACA}) | |
if args.input_file == "none": | |
args.input_file = None | |
if args.output_file == "none": | |
args.output_file = None | |
log(f"Launched as {args}") | |
return args | |
def save_all(outputs, args, acc): | |
if acc.is_main_process: | |
if args.output_file is None: | |
log("Writing to STDOUT") | |
out_fh = sys.stdout | |
else: | |
out_fh = open(args.output_file, "w") | |
if args.prompt_format in {PF_RAW, PF_RAWLINES}: | |
for line in outputs: | |
out_fh.write(line + "\n") | |
else: | |
json.dump(outputs, out_fh) | |
def and_i_called_this_function_do_main_too(): | |
args = _cmdline_args() | |
if args.multiproc: | |
env_stuff() | |
acc = Accelerator() | |
device = acc.device | |
log(f"Device: {device}.", accelerator=acc) | |
if not args.multiproc and not acc.is_main_process: | |
log("Not launched in multi-processing mode, exiting non-main process.") | |
sys.exit(0) | |
tokenizer = load_tokenizer(args.mdl_id, acc) | |
data_loader = get_data_loader(args.input_file, args.prompt_format, tokenizer, debug=args.debug) | |
model = load_model(args.mdl_id, device, acc, attention="eager") | |
model.eval() | |
log(f"Device: {model.device}.", accelerator=acc) | |
log("Model loaded, starting to generate") | |
outputs = predict(model, tokenizer, data_loader, acc, multi=args.multiproc, debug=args.debug, max_len=args.max_len) | |
save_all(outputs, args, acc) | |
log("Done") | |
if __name__ == "__main__": | |
and_i_called_this_function_do_main_too() | |