Spaces:
Sleeping
Sleeping
| import json | |
| import argparse | |
| import asyncio | |
| from tasks.data.data_loaders import TextDataLoader | |
| from tasks.models.text_classifiers import ModelFactory | |
| from tasks.text import evaluate_text | |
| from tasks.utils.evaluation import TextEvaluationRequest | |
| def load_config(config_path): | |
| with open(config_path, 'r') as config_file: | |
| config = json.load(config_file) | |
| return config | |
| async def train_model(config): | |
| # loading data | |
| text_request = TextEvaluationRequest() | |
| is_light_dataset = False | |
| data_loader = TextDataLoader(text_request, light=is_light_dataset) | |
| # define model | |
| model = ModelFactory.create_model(config) | |
| # train model | |
| train_dataset = data_loader.get_train_dataset() | |
| if model.model is None: | |
| model.train(train_dataset) | |
| model.save() | |
| print("Model training completed and saved.") | |
| async def evaluate_model(config): | |
| # loading data | |
| text_request = TextEvaluationRequest() | |
| data_loader = TextDataLoader(text_request) | |
| # define model | |
| model = ModelFactory.create_model(config) | |
| # Call the evaluate_text function | |
| results = await evaluate_text(request=text_request, model=model) | |
| # Print the results | |
| print(json.dumps(results, indent=2)) | |
| print(f"Achieved accuracy: {results['accuracy']}") | |
| print(f"Energy consumed: {results['energy_consumed_wh']} Wh") | |
| async def main(): | |
| # Parse command-line arguments | |
| parser = argparse.ArgumentParser(description="Train or evaluate the model.") | |
| parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file") | |
| args = parser.parse_args() | |
| # Load configuration | |
| config_path = args.config | |
| config = load_config(config_path) | |
| try: | |
| mode = config["mode"] | |
| except ValueError: | |
| raise ValueError(f"Missing mode in configuration file: {config_path}") | |
| if mode == "train": | |
| await train_model(config) | |
| elif mode == "evaluate": | |
| await evaluate_model(config) | |
| else: | |
| raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |