m41w4r3.exe
commited on
Commit
·
18f41a5
1
Parent(s):
1abfe53
add cuda if exists
Browse files
load.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from transformers import GPT2LMHeadModel
|
2 |
from transformers import PreTrainedTokenizerFast
|
3 |
import os
|
|
|
4 |
|
5 |
|
6 |
class LoadModel:
|
@@ -31,6 +32,8 @@ class LoadModel:
|
|
31 |
self.path = path
|
32 |
self.device = device
|
33 |
self.revision = revision
|
|
|
|
|
34 |
|
35 |
def load_model_and_tokenizer(self):
|
36 |
model = self.load_model()
|
@@ -40,11 +43,11 @@ class LoadModel:
|
|
40 |
|
41 |
def load_model(self):
|
42 |
if self.revision is None:
|
43 |
-
model = GPT2LMHeadModel.from_pretrained(self.path
|
44 |
else:
|
45 |
model = GPT2LMHeadModel.from_pretrained(
|
46 |
-
self.path, revision=self.revision
|
47 |
-
)
|
48 |
|
49 |
return model
|
50 |
|
|
|
1 |
from transformers import GPT2LMHeadModel
|
2 |
from transformers import PreTrainedTokenizerFast
|
3 |
import os
|
4 |
+
import torch
|
5 |
|
6 |
|
7 |
class LoadModel:
|
|
|
32 |
self.path = path
|
33 |
self.device = device
|
34 |
self.revision = revision
|
35 |
+
if torch.cuda.is_available():
|
36 |
+
self.device = "cuda"
|
37 |
|
38 |
def load_model_and_tokenizer(self):
|
39 |
model = self.load_model()
|
|
|
43 |
|
44 |
def load_model(self):
|
45 |
if self.revision is None:
|
46 |
+
model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
|
47 |
else:
|
48 |
model = GPT2LMHeadModel.from_pretrained(
|
49 |
+
self.path, revision=self.revision
|
50 |
+
).to(self.device)
|
51 |
|
52 |
return model
|
53 |
|