Spaces:
Sleeping
Sleeping
File size: 2,123 Bytes
056d415 56affed 770af88 0922cef 0f78bcc 056d415 b2b1a15 d76625a 1c3803a 770af88 1c3803a 0922cef b2b1a15 0922cef d76625a 99c8e88 56affed b2b1a15 56affed b2b1a15 1c3803a 056d415 56affed 056d415 6069b6c 056d415 56affed b2b1a15 56affed b2b1a15 56affed 056d415 9e3acf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import json
import argparse
import asyncio
from tasks.data.data_loaders import TextDataLoader
from tasks.models.text_classifiers import ModelFactory
from tasks.text import evaluate_text
from tasks.utils.evaluation import TextEvaluationRequest
def load_config(config_path):
with open(config_path, 'r') as config_file:
config = json.load(config_file)
return config
async def train_model(config):
# loading data
text_request = TextEvaluationRequest()
is_light_dataset = False
data_loader = TextDataLoader(text_request, light=is_light_dataset)
# define model
model = ModelFactory.create_model(config)
# train model
train_dataset = data_loader.get_train_dataset()
if model.model is None:
model.train(train_dataset)
model.save()
print("Model training completed and saved.")
async def evaluate_model(config):
# loading data
text_request = TextEvaluationRequest()
data_loader = TextDataLoader(text_request)
# define model
model = ModelFactory.create_model(config)
# Call the evaluate_text function
results = await evaluate_text(request=text_request, model=model)
# Print the results
print(json.dumps(results, indent=2))
print(f"Achieved accuracy: {results['accuracy']}")
print(f"Energy consumed: {results['energy_consumed_wh']} Wh")
async def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train or evaluate the model.")
parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file")
args = parser.parse_args()
# Load configuration
config_path = args.config
config = load_config(config_path)
try:
mode = config["mode"]
except ValueError:
raise ValueError(f"Missing mode in configuration file: {config_path}")
if mode == "train":
await train_model(config)
elif mode == "evaluate":
await evaluate_model(config)
else:
raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'")
if __name__ == "__main__":
asyncio.run(main())
|