Spaces:
Sleeping
Sleeping
Update src/model_consumer.py
Browse files- src/model_consumer.py +25 -3
src/model_consumer.py
CHANGED
@@ -7,7 +7,8 @@ model_id = "prd101-wd/phi1_5-bankingqa-merged"
|
|
7 |
# Load model only once
|
8 |
@st.cache_resource
|
9 |
def load_model():
|
10 |
-
return pipeline("question-answering", model=model_id)
|
|
|
11 |
|
12 |
# Create a text generation pipeline
|
13 |
pipe = load_model()
|
@@ -23,10 +24,31 @@ if st.button("Ask"):
|
|
23 |
if user_input.strip():
|
24 |
# Format the prompt like Alpaca-style
|
25 |
prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
|
26 |
-
output = pipe(prompt, max_new_tokens=200, do_sample=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Extract only the model's response (remove prompt part if included in output)
|
29 |
-
answer = output.split("### Response:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
30 |
st.markdown("### HelpdeskBot Answer:")
|
31 |
st.success(answer)
|
32 |
else:
|
|
|
7 |
# Load model only once
|
8 |
@st.cache_resource
|
9 |
def load_model():
|
10 |
+
#return pipeline("question-answering", model=model_id)
|
11 |
+
return pipeline("text-generation", model=model_id, trust_remote_code=True)
|
12 |
|
13 |
# Create a text generation pipeline
|
14 |
pipe = load_model()
|
|
|
24 |
if user_input.strip():
|
25 |
# Format the prompt like Alpaca-style
|
26 |
prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
|
27 |
+
output = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
|
28 |
+
|
29 |
+
# Process output
|
30 |
+
if isinstance(output, list) and output:
|
31 |
+
answer = output[0]['generated_text']
|
32 |
+
# Extract only the response part
|
33 |
+
if "### Response:" in answer:
|
34 |
+
answer = answer.split("### Response:")[-1].strip()
|
35 |
+
else:
|
36 |
+
answer = "Unable to generate a response. Please try again."
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# if isinstance(output, list) and len(output) > 0 and "generated_text" in output[0]:
|
41 |
+
# answer = output[0]["generated_text"]
|
42 |
+
# else:
|
43 |
+
# answer = "Unable to generate a response. Please try again."
|
44 |
|
45 |
# Extract only the model's response (remove prompt part if included in output)
|
46 |
+
#answer = output.split("### Response:")[-1].strip()
|
47 |
+
# if isinstance(output, str):
|
48 |
+
# answer = output.split("### Response:")[-1].strip()
|
49 |
+
# else:
|
50 |
+
# answer = "Unexpected output format. Please try again."
|
51 |
+
|
52 |
st.markdown("### HelpdeskBot Answer:")
|
53 |
st.success(answer)
|
54 |
else:
|