File size: 2,123 Bytes
056d415
56affed
 
770af88
0922cef
0f78bcc
056d415
 
 
b2b1a15
 
 
 
 
 
d76625a
1c3803a
770af88
 
1c3803a
0922cef
b2b1a15
0922cef
 
d76625a
99c8e88
 
 
56affed
 
b2b1a15
56affed
 
 
 
 
b2b1a15
1c3803a
056d415
56affed
 
056d415
 
6069b6c
 
056d415
56affed
 
 
b2b1a15
56affed
 
b2b1a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56affed
056d415
9e3acf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
import argparse
import asyncio

from tasks.data.data_loaders import TextDataLoader
from tasks.models.text_classifiers import ModelFactory
from tasks.text import evaluate_text
from tasks.utils.evaluation import TextEvaluationRequest

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

async def train_model(config):
    # loading data
    text_request = TextEvaluationRequest()
    is_light_dataset = False
    data_loader = TextDataLoader(text_request, light=is_light_dataset)

    # define model
    model = ModelFactory.create_model(config)

    # train model
    train_dataset = data_loader.get_train_dataset()
    if model.model is None:
        model.train(train_dataset)
        model.save()
    print("Model training completed and saved.")

async def evaluate_model(config):
    # loading data
    text_request = TextEvaluationRequest()
    data_loader = TextDataLoader(text_request)

    # define model
    model = ModelFactory.create_model(config)

    # Call the evaluate_text function
    results = await evaluate_text(request=text_request, model=model)

    # Print the results
    print(json.dumps(results, indent=2))
    print(f"Achieved accuracy: {results['accuracy']}")
    print(f"Energy consumed: {results['energy_consumed_wh']} Wh")

async def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Train or evaluate the model.")
    parser.add_argument("--config", type=str, default="config.json", help="Path to the configuration file")
    args = parser.parse_args()

    # Load configuration
    config_path = args.config
    config = load_config(config_path)

    try:
        mode = config["mode"]
    except ValueError:
        raise ValueError(f"Missing mode in configuration file: {config_path}")

    if mode == "train":
        await train_model(config)
    elif mode == "evaluate":
        await evaluate_model(config)
    else:
        raise ValueError(f"Invalid mode in file '{config_path}': '{mode}'")

if __name__ == "__main__":
    asyncio.run(main())