Rasmus Lellep
add loader
76b1ec5
# first, keyword identifiers for selecting prompt templates in scripts:
PF_RAW = "raw"
PF_RAWLINES = "rawlines"
PF_SMUGRI_MT = "smugri_mt"
PF_SMUGRI_LID = "smugri_lid"
PF_ALPACA = "alpaca"
# now the prompt templates themselves, SMUGRI LID / MT template:
SMUGRI_INF_PROMPT_LID = "<|reserved_special_token_12|>{src_segm}<|reserved_special_token_13|>"
_SMUGRI_INF_PROMPT_TMPMID = "<|reserved_special_token_14|>{task} to {tgt_lang}<|reserved_special_token_15|>"
SMUGRI_INF_PROMPT_MT = SMUGRI_INF_PROMPT_LID + "{src_lang}" + _SMUGRI_INF_PROMPT_TMPMID
_SMUGRI_TRAIN_PROMPT_PREF = SMUGRI_INF_PROMPT_LID + "{src_lang}"
_SMUGRI_TRAIN_PROMPT_MID = _SMUGRI_INF_PROMPT_TMPMID + "{tgt_segm}"
_SMUGRI_TRAIN_PROMPT_SUF = "<|reserved_special_token_16|><|end_of_text|>"
SMUGRI_PROMPT_TRAIN_PARA = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_MID + _SMUGRI_TRAIN_PROMPT_SUF
SMUGRI_PROMPT_TRAIN_MONO = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_SUF
# Alpaca instructions prompt template:
ALPACA_PROMPT_INF = ("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n")
ALPACA_PROMPT_TRAIN = (ALPACA_PROMPT_INF + "{output}")
def prep_prompt(data, prompt_format, inference=False):
if prompt_format in {PF_RAW, PF_RAWLINES}:
# data is a string, return it
return data
elif prompt_format in {PF_SMUGRI_MT, PF_SMUGRI_LID}:
# data has src_segm, src_lang, tgt_lang, etc
return _prep_ljmf_entry(data, prompt_format, inference)
elif prompt_format == PF_ALPACA:
# data has instruction and input in it
return _prep_alpaca_entry(data, inference)
else:
raise NotImplementedError(f"Prompt format {prompt_format} is not implemented.")
def _prep_alpaca_entry(entry, inference=False):
fmt = ALPACA_PROMPT_INF if inference else ALPACA_PROMPT_TRAIN
prompt = fmt.format(**entry)
return prompt
def _prep_ljmf_entry(entry, fmt, inference=False):
if inference:
if fmt == PF_SMUGRI_MT:
prompt = SMUGRI_INF_PROMPT_MT.format(**entry)
elif fmt == PF_SMUGRI_LID:
prompt = SMUGRI_INF_PROMPT_LID.format(**entry)
else:
raise NotImplementedError(f"Prompt format {fmt} is not implemented.")
else:
if entry['task'] in {'translate', 'approx-translate'} and entry['tgt_segm'] and entry['tgt_lang']:
prompt = SMUGRI_PROMPT_TRAIN_PARA.format(**entry)
else:
prompt = SMUGRI_PROMPT_TRAIN_MONO.format(**entry)
return prompt