Tomohiro commited on
Commit
fd9e32a
·
verified ·
1 Parent(s): b8e7134

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -13
README.md CHANGED
@@ -32,19 +32,13 @@ tags:
32
  import torch
33
  from transformers import AutoTokenizer, AutoModelForTokenClassification
34
 
35
- # 1) チェックポイントディレクトリを指定
36
- checkpoint_dir = "Tomohiro/MedTXTNER"
37
-
38
- # 2) モデルとトークナイザーをロード
39
- model = AutoModelForTokenClassification.from_pretrained(checkpoint_dir)
40
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, use_fast=True)
41
-
42
- # 3) デバイス設定
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  model.to(device)
45
  model.eval()
46
 
47
- # 4) 推論用
48
  def predict_text(text: str):
49
  enc = tokenizer(
50
  text,
@@ -59,14 +53,10 @@ def predict_text(text: str):
59
  outputs = model(**enc)
60
  logits = outputs.logits
61
 
62
- # 各トークンごとの予測ラベルIDを取得
63
  pred_ids = torch.argmax(logits, dim=-1)[0].cpu().tolist()
64
-
65
- # トークン列と IOB ラベル列に変換
66
  tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
67
  id2label = model.config.id2label
68
 
69
- # special tokens を除いて結果を整形
70
  result = []
71
  for tok, pid in zip(tokens, pred_ids):
72
  if tok in tokenizer.all_special_tokens:
@@ -74,7 +64,6 @@ def predict_text(text: str):
74
  result.append((tok, id2label[pid]))
75
  return result
76
 
77
- # 5) 実際に試す
78
  sample = "症例】53歳女性。発熱と嘔気を認め、プレドニゾロンを中断しました。"
79
  for tok, lab in predict_text(sample):
80
  print(f"{tok}\t{lab}")
 
32
  import torch
33
  from transformers import AutoTokenizer, AutoModelForTokenClassification
34
 
35
+ model_dir = "Tomohiro/MedTXTNER"
36
+ model = AutoModelForTokenClassification.from_pretrained(model_dir)
 
 
 
37
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, use_fast=True)
 
 
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  model.to(device)
40
  model.eval()
41
 
 
42
  def predict_text(text: str):
43
  enc = tokenizer(
44
  text,
 
53
  outputs = model(**enc)
54
  logits = outputs.logits
55
 
 
56
  pred_ids = torch.argmax(logits, dim=-1)[0].cpu().tolist()
 
 
57
  tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
58
  id2label = model.config.id2label
59
 
 
60
  result = []
61
  for tok, pid in zip(tokens, pred_ids):
62
  if tok in tokenizer.all_special_tokens:
 
64
  result.append((tok, id2label[pid]))
65
  return result
66
 
 
67
  sample = "症例】53歳女性。発熱と嘔気を認め、プレドニゾロンを中断しました。"
68
  for tok, lab in predict_text(sample):
69
  print(f"{tok}\t{lab}")