prd101-wd commited on
Commit
6a947f8
·
verified ·
1 Parent(s): 58aa69f

Update src/model_consumer.py

Browse files
Files changed (1) hide show
  1. src/model_consumer.py +25 -3
src/model_consumer.py CHANGED
@@ -7,7 +7,8 @@ model_id = "prd101-wd/phi1_5-bankingqa-merged"
7
  # Load model only once
8
  @st.cache_resource
9
  def load_model():
10
- return pipeline("question-answering", model=model_id)
 
11
 
12
  # Create a text generation pipeline
13
  pipe = load_model()
@@ -23,10 +24,31 @@ if st.button("Ask"):
23
  if user_input.strip():
24
  # Format the prompt like Alpaca-style
25
  prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
26
- output = pipe(prompt, max_new_tokens=200, do_sample=True)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Extract only the model's response (remove prompt part if included in output)
29
- answer = output.split("### Response:")[-1].strip()
 
 
 
 
 
30
  st.markdown("### HelpdeskBot Answer:")
31
  st.success(answer)
32
  else:
 
7
  # Load model only once
8
  @st.cache_resource
9
  def load_model():
10
+ #return pipeline("question-answering", model=model_id)
11
+ return pipeline("text-generation", model=model_id, trust_remote_code=True)
12
 
13
  # Create a text generation pipeline
14
  pipe = load_model()
 
24
  if user_input.strip():
25
  # Format the prompt like Alpaca-style
26
  prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
27
+ output = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
28
+
29
+ # Process output
30
+ if isinstance(output, list) and output:
31
+ answer = output[0]['generated_text']
32
+ # Extract only the response part
33
+ if "### Response:" in answer:
34
+ answer = answer.split("### Response:")[-1].strip()
35
+ else:
36
+ answer = "Unable to generate a response. Please try again."
37
+
38
+
39
+
40
+ # if isinstance(output, list) and len(output) > 0 and "generated_text" in output[0]:
41
+ # answer = output[0]["generated_text"]
42
+ # else:
43
+ # answer = "Unable to generate a response. Please try again."
44
 
45
  # Extract only the model's response (remove prompt part if included in output)
46
+ #answer = output.split("### Response:")[-1].strip()
47
+ # if isinstance(output, str):
48
+ # answer = output.split("### Response:")[-1].strip()
49
+ # else:
50
+ # answer = "Unexpected output format. Please try again."
51
+
52
  st.markdown("### HelpdeskBot Answer:")
53
  st.success(answer)
54
  else: