JenniferHJF commited on
Commit
9d12d23
·
verified ·
1 Parent(s): cfd7476

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +8 -15
agent.py CHANGED
@@ -1,25 +1,24 @@
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
 
4
- # ✅ Step 1: 加载 emoji 翻译模型(你微调后的模型)
5
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
6
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
7
  emoji_model = AutoModelForCausalLM.from_pretrained(
8
  emoji_model_id,
9
- device_map="auto",
10
- torch_dtype=torch.float16,
11
- trust_remote_code=True
12
- )
13
  emoji_model.eval()
14
 
15
- # ✅ Step 2: 加载冒犯文本分类器(你可更换为更强大的模型)
16
  classifier = pipeline("text-classification", model="unitary/toxic-bert", device=0 if torch.cuda.is_available() else -1)
17
 
18
  def classify_emoji_text(text: str):
19
  """
20
- 输入文本 -> 翻译 emoji -> 分类是否冒犯
 
21
  """
22
- # ✅ 构造翻译 prompt
23
  prompt = f"""请判断下面的文本是否具有冒犯性。
24
  这里的“冒犯性”主要指包含人身攻击、侮辱、歧视、仇恨言论或极端粗俗的内容。
25
  如果文本具有冒犯性,请仅回复冒犯;如果不具有冒犯性,请仅回复不冒犯。
@@ -27,18 +26,12 @@ def classify_emoji_text(text: str):
27
  {text}
28
  """
29
 
30
- # ✅ 生成翻译结果
31
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
32
  with torch.no_grad():
33
- output_ids = emoji_model.generate(
34
- **input_ids,
35
- max_new_tokens=50,
36
- do_sample=False
37
- )
38
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
  translated_text = decoded.strip().split("文本如下:")[-1].strip()
40
 
41
- # ✅ 送入第二阶段冒犯性识别
42
  result = classifier(translated_text)[0]
43
  label = result["label"]
44
  score = result["score"]
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
 
4
+ # ✅ Step 1: Emoji 翻译模型(你自己训练的模型)
5
  emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
6
  emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
7
  emoji_model = AutoModelForCausalLM.from_pretrained(
8
  emoji_model_id,
9
+ trust_remote_code=True,
10
+ torch_dtype=torch.float16
11
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
 
12
  emoji_model.eval()
13
 
14
+ # ✅ Step 2: 冒犯性文本识别模型
15
  classifier = pipeline("text-classification", model="unitary/toxic-bert", device=0 if torch.cuda.is_available() else -1)
16
 
17
  def classify_emoji_text(text: str):
18
  """
19
+ Step 1: 翻译文本中的 emoji
20
+ Step 2: 使用分类器判断是否冒犯
21
  """
 
22
  prompt = f"""请判断下面的文本是否具有冒犯性。
23
  这里的“冒犯性”主要指包含人身攻击、侮辱、歧视、仇恨言论或极端粗俗的内容。
24
  如果文本具有冒犯性,请仅回复冒犯;如果不具有冒犯性,请仅回复不冒犯。
 
26
  {text}
27
  """
28
 
 
29
  input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
30
  with torch.no_grad():
31
+ output_ids = emoji_model.generate(**input_ids, max_new_tokens=50, do_sample=False)
 
 
 
 
32
  decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
  translated_text = decoded.strip().split("文本如下:")[-1].strip()
34
 
 
35
  result = classifier(translated_text)[0]
36
  label = result["label"]
37
  score = result["score"]