Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| import os | |
| import yaml | |
| import transformers | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Adjust this as needed | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load the model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("EzekielMW/LuoKslGloss") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/LuoKslGloss") | |
| # Where should output files be stored locally | |
| # Where should output files be stored locally | |
| drive_folder = "./quadserverlogs" | |
| if not os.path.exists(drive_folder): | |
| os.makedirs(drive_folder) | |
| # Large batch sizes generally give good results for translation | |
| effective_train_batch_size = 480 | |
| train_batch_size = 6 | |
| eval_batch_size = train_batch_size | |
| gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size) | |
| # Everything in one yaml string, so that it can all be logged. | |
| yaml_config = ''' | |
| training_args: | |
| output_dir: "{drive_folder}" | |
| eval_strategy: steps | |
| eval_steps: 200 | |
| save_steps: 200 | |
| gradient_accumulation_steps: {gradient_accumulation_steps} | |
| learning_rate: 3.0e-4 # Include decimal point to parse as float | |
| # optim: adafactor | |
| per_device_train_batch_size: {train_batch_size} | |
| per_device_eval_batch_size: {eval_batch_size} | |
| weight_decay: 0.01 | |
| save_total_limit: 3 | |
| max_steps: 500 | |
| predict_with_generate: True | |
| fp16: True | |
| logging_dir: "{drive_folder}" | |
| load_best_model_at_end: True | |
| metric_for_best_model: loss | |
| seed: 123 | |
| push_to_hub: False | |
| max_input_length: 128 | |
| eval_pretrained_model: False | |
| early_stopping_patience: 4 | |
| data_dir: . | |
| # Use a 600M parameter model here, which is easier to train on a free Colab | |
| # instance. Bigger models work better, however: results will be improved | |
| # if able to train on nllb-200-1.3B instead. | |
| model_checkpoint: facebook/nllb-200-1.3B | |
| datasets: | |
| train: | |
| huggingface_load: | |
| # We will load two datasets here: English/KSL Gloss, and also SALT | |
| # Swahili/English, so that we can try out multi-way translation. | |
| - path: EzekielMW/Eksl_dataset | |
| split: train[:-1000] | |
| - path: EzekielMW/Luo_Swa | |
| split: train[:-2000] | |
| - path: sunbird/salt | |
| name: text-all | |
| split: train | |
| source: | |
| # This is a text translation only, no audio. | |
| type: text | |
| # The source text can be any of English, KSL, Swahili or Dholuo. | |
| language: [eng,ksl,swa,luo] | |
| preprocessing: | |
| # The models are case sensitive, so if the training text is all | |
| # capitals, then it will only learn to translate capital letters and | |
| # won't understand lower case. Make everything lower case for now. | |
| - lower_case | |
| # We can also augment the spelling of the input text, which makes the | |
| # model more robust to spelling errors. | |
| - augment_characters | |
| target: | |
| type: text | |
| # The target text with any of English, KSL, Swahili or Dholuo. | |
| language: [eng,ksl,swa,luo] | |
| # The models are case sensitive: make everything lower case for now. | |
| preprocessing: | |
| - lower_case | |
| shuffle: True | |
| allow_same_src_and_tgt_language: False | |
| validation: | |
| huggingface_load: | |
| # Use the last 1000 of the KSL examples for validation. | |
| - path: EzekielMW/Eksl_dataset | |
| split: train[-1000:] | |
| # Use the last 2000 of the Luo examples for validation. | |
| - path: EzekielMW/Luo_Swa | |
| split: train[-2000:] | |
| # Add some Swahili validation text. | |
| - path: sunbird/salt | |
| name: text-all | |
| split: dev | |
| source: | |
| type: text | |
| language: [swa,ksl,eng,luo] | |
| preprocessing: | |
| - lower_case | |
| target: | |
| type: text | |
| language: [swa,ksl,eng,luo] | |
| preprocessing: | |
| - lower_case | |
| allow_same_src_and_tgt_language: False | |
| ''' | |
| yaml_config = yaml_config.format( | |
| drive_folder=drive_folder, | |
| train_batch_size=train_batch_size, | |
| eval_batch_size=eval_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| ) | |
| config = yaml.safe_load(yaml_config) | |
| training_settings = transformers.Seq2SeqTrainingArguments( | |
| **config["training_args"]) | |
| # The pre-trained model that we use has support for some African languages, but | |
| # we need to adapt the tokenizer to languages that it wasn't trained with, | |
| # such as KSL. Here we reuse the token from a different language. | |
| LANGUAGE_CODES = ["eng", "swa", "ksl","luo"] | |
| code_mapping = { | |
| # Exact/close mapping | |
| 'eng': 'eng_Latn', | |
| 'swa': 'swh_Latn', | |
| # Random mapping | |
| 'ksl': 'ace_Latn', | |
| 'luo': 'luo_Latn', | |
| } | |
| tokenizer = transformers.NllbTokenizer.from_pretrained( | |
| config['model_checkpoint'], | |
| src_lang='eng_Latn', | |
| tgt_lang='eng_Latn') | |
| offset = tokenizer.sp_model_size + tokenizer.fairseq_offset | |
| for code in LANGUAGE_CODES: | |
| i = tokenizer.convert_tokens_to_ids(code_mapping[code]) | |
| tokenizer._added_tokens_encoder[code] = i | |
| transformers.generation.utils.ForcedBOSTokenLogitsProcessor = transformers.ForcedBOSTokenLogitsProcessor | |
| # Define a translation function | |
| def translate(text, source_language, target_language): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inputs = tokenizer(text.lower(), return_tensors="pt").to(device) | |
| inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language) | |
| translated_tokens = model.to(device).generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language), | |
| max_length=100, | |
| num_beams=5, | |
| ) | |
| result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| if target_language == 'ksl': | |
| result = result.upper() | |
| return result | |
| async def translate_text(request: Request): | |
| data = await request.json() | |
| text = data.get("text") | |
| source_language = data.get("source_language") | |
| target_language = data.get("target_language") | |
| translation = translate(text, source_language, target_language) | |
| return {"translation": translation} | |
| async def root(): | |
| return {"message": "Welcome to the translation API!"} | |