Sergiu2404 commited on
Commit
92b3bd3
·
1 Parent(s): feb2463

refactored inference.py

Browse files
Files changed (1) hide show
  1. inference.py +49 -20
inference.py CHANGED
@@ -1,8 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer
3
- from fin_tinybert_pytorch import TinyFinBERTRegressor # You may need to rename or include this class here
4
 
5
- # Load model
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
  model = TinyFinBERTRegressor()
8
  model.load_state_dict(torch.load("./saved_model/pytorch_model.bin", map_location=device))
@@ -11,19 +38,24 @@ model.eval()
11
 
12
  tokenizer = AutoTokenizer.from_pretrained("./saved_model")
13
 
14
- def predict(texts):
15
- if isinstance(texts, str):
16
- texts = [texts]
17
-
18
- results = []
19
- for text in texts:
20
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
21
- inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
22
- with torch.no_grad():
23
- score = model(**inputs)["score"].item()
24
- sentiment = "positive" if score > 0.3 else "negative" if score < -0.3 else "neutral"
25
- results.append({"text": text, "score": score, "sentiment": sentiment})
26
- return results
 
 
 
 
 
27
  #
28
  # if __name__ == "__main__":
29
  # texts = [
@@ -32,8 +64,5 @@ def predict(texts):
32
  # "There was no noticeable change in performance."
33
  # ]
34
  #
35
- # predictions = predict(texts)
36
- # for pred in predictions:
37
- # print(f"Text: {pred['text']}")
38
- # print(f"Score: {pred['score']:.3f}")
39
- # print(f"Sentiment: {pred['sentiment']}\n")
 
1
+ # import torch
2
+ # from transformers import AutoTokenizer
3
+ # from fin_tinybert_pytorch import TinyFinBERTRegressor # You may need to rename or include this class here
4
+ #
5
+ # # Load model
6
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ # model = TinyFinBERTRegressor()
8
+ # model.load_state_dict(torch.load("./saved_model/pytorch_model.bin", map_location=device))
9
+ # model.to(device)
10
+ # model.eval()
11
+ #
12
+ # tokenizer = AutoTokenizer.from_pretrained("./saved_model")
13
+ #
14
+ # def predict(texts):
15
+ # if isinstance(texts, str):
16
+ # texts = [texts]
17
+ #
18
+ # results = []
19
+ # for text in texts:
20
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
21
+ # inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
22
+ # with torch.no_grad():
23
+ # score = model(**inputs)["score"].item()
24
+ # sentiment = "positive" if score > 0.3 else "negative" if score < -0.3 else "neutral"
25
+ # results.append({"text": text, "score": score, "sentiment": sentiment})
26
+ # return results
27
+
28
+
29
  import torch
30
  from transformers import AutoTokenizer
31
+ from fin_tinybert_pytorch import TinyFinBERTRegressor
32
 
 
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  model = TinyFinBERTRegressor()
35
  model.load_state_dict(torch.load("./saved_model/pytorch_model.bin", map_location=device))
 
38
 
39
  tokenizer = AutoTokenizer.from_pretrained("./saved_model")
40
 
41
+
42
+ def pipeline(text):
43
+ if not isinstance(text, str):
44
+ raise ValueError("Input must be a string")
45
+
46
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
47
+ inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"}
48
+
49
+ with torch.no_grad():
50
+ score = model(**inputs)["score"].item()
51
+
52
+ sentiment = "positive" if score > 0.3 else "negative" if score < -0.3 else "neutral"
53
+
54
+ return [{
55
+ "label": sentiment,
56
+ "score": round(score, 4)
57
+ }]
58
+
59
  #
60
  # if __name__ == "__main__":
61
  # texts = [
 
64
  # "There was no noticeable change in performance."
65
  # ]
66
  #
67
+ # predictions = pipeline("The stock price soared after the earnings report.")[0]
68
+ # print(f"sentiment: {predictions['label']}, score: {predictions['score']}")