Spaces:
Sleeping
Sleeping
File size: 2,689 Bytes
76b1ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# first, keyword identifiers for selecting prompt templates in scripts:
PF_RAW = "raw"
PF_RAWLINES = "rawlines"
PF_SMUGRI_MT = "smugri_mt"
PF_SMUGRI_LID = "smugri_lid"
PF_ALPACA = "alpaca"
# now the prompt templates themselves, SMUGRI LID / MT template:
SMUGRI_INF_PROMPT_LID = "<|reserved_special_token_12|>{src_segm}<|reserved_special_token_13|>"
_SMUGRI_INF_PROMPT_TMPMID = "<|reserved_special_token_14|>{task} to {tgt_lang}<|reserved_special_token_15|>"
SMUGRI_INF_PROMPT_MT = SMUGRI_INF_PROMPT_LID + "{src_lang}" + _SMUGRI_INF_PROMPT_TMPMID
_SMUGRI_TRAIN_PROMPT_PREF = SMUGRI_INF_PROMPT_LID + "{src_lang}"
_SMUGRI_TRAIN_PROMPT_MID = _SMUGRI_INF_PROMPT_TMPMID + "{tgt_segm}"
_SMUGRI_TRAIN_PROMPT_SUF = "<|reserved_special_token_16|><|end_of_text|>"
SMUGRI_PROMPT_TRAIN_PARA = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_MID + _SMUGRI_TRAIN_PROMPT_SUF
SMUGRI_PROMPT_TRAIN_MONO = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_SUF
# Alpaca instructions prompt template:
ALPACA_PROMPT_INF = ("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n")
ALPACA_PROMPT_TRAIN = (ALPACA_PROMPT_INF + "{output}")
def prep_prompt(data, prompt_format, inference=False):
if prompt_format in {PF_RAW, PF_RAWLINES}:
# data is a string, return it
return data
elif prompt_format in {PF_SMUGRI_MT, PF_SMUGRI_LID}:
# data has src_segm, src_lang, tgt_lang, etc
return _prep_ljmf_entry(data, prompt_format, inference)
elif prompt_format == PF_ALPACA:
# data has instruction and input in it
return _prep_alpaca_entry(data, inference)
else:
raise NotImplementedError(f"Prompt format {prompt_format} is not implemented.")
def _prep_alpaca_entry(entry, inference=False):
fmt = ALPACA_PROMPT_INF if inference else ALPACA_PROMPT_TRAIN
prompt = fmt.format(**entry)
return prompt
def _prep_ljmf_entry(entry, fmt, inference=False):
if inference:
if fmt == PF_SMUGRI_MT:
prompt = SMUGRI_INF_PROMPT_MT.format(**entry)
elif fmt == PF_SMUGRI_LID:
prompt = SMUGRI_INF_PROMPT_LID.format(**entry)
else:
raise NotImplementedError(f"Prompt format {fmt} is not implemented.")
else:
if entry['task'] in {'translate', 'approx-translate'} and entry['tgt_segm'] and entry['tgt_lang']:
prompt = SMUGRI_PROMPT_TRAIN_PARA.format(**entry)
else:
prompt = SMUGRI_PROMPT_TRAIN_MONO.format(**entry)
return prompt
|