Upload folder using huggingface_hub
Browse files- app.py +79 -53
- app_v2.py +61 -0
- data/index/exam_db/chroma.sqlite3 +2 -2
- data/index/law_db/chroma.sqlite3 +2 -2
- generator/llm_inference.py +29 -30
- generator/prompt_builder.py +13 -13
- generator/prompt_builder_v1.py +19 -0
- requirements.txt +6 -7
- retriever/vectordb_rerank.py +43 -24
- retriever/vectordb_rerank_exam.py +55 -0
- retriever/vectordb_rerank_law.py +68 -0
- services/rag_pipeline_v2.py +33 -0
app.py
CHANGED
@@ -1,53 +1,79 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import spaces
|
3 |
-
import torch
|
4 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
-
# from retriever.vectordb_rerank import search_documents # ๐ง RAG ๊ฒ์๊ธฐ ๋ถ๋ฌ์ค๊ธฐ
|
6 |
-
from services.rag_pipeline import rag_pipeline
|
7 |
-
|
8 |
-
model_name = "dasomaru/gemma-3-4bit-it-demo"
|
9 |
-
|
10 |
-
|
11 |
-
# 1. ๋ชจ๋ธ/ํ ํฌ๋์ด์ 1ํ ๋ก๋ฉ
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
13 |
-
# ๐ model์ CPU๋ก๋ง ๋จผ์ ์ฌ๋ฆผ (GPU ์์ง ์์)
|
14 |
-
model = AutoModelForCausalLM.from_pretrained(
|
15 |
-
model_name,
|
16 |
-
torch_dtype=torch.float16, # 4bit model์ด๋๊น
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
)
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
# from retriever.vectordb_rerank import search_documents # ๐ง RAG ๊ฒ์๊ธฐ ๋ถ๋ฌ์ค๊ธฐ
|
6 |
+
from services.rag_pipeline import rag_pipeline
|
7 |
+
|
8 |
+
model_name = "dasomaru/gemma-3-4bit-it-demo"
|
9 |
+
|
10 |
+
|
11 |
+
# 1. ๋ชจ๋ธ/ํ ํฌ๋์ด์ 1ํ ๋ก๋ฉ
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
13 |
+
# ๐ model์ CPU๋ก๋ง ๋จผ์ ์ฌ๋ฆผ (GPU ์์ง ์์)
|
14 |
+
model = AutoModelForCausalLM.from_pretrained(
|
15 |
+
model_name,
|
16 |
+
torch_dtype=torch.float16, # 4bit model์ด๋๊น
|
17 |
+
device_map="auto", # โ
์ค์: ์๋์ผ๋ก GPU ํ ๋น
|
18 |
+
trust_remote_code=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
# 2. ์บ์ ๊ด๋ฆฌ
|
22 |
+
search_cache = {}
|
23 |
+
|
24 |
+
@spaces.GPU(duration=300)
|
25 |
+
def generate_response(query: str):
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
27 |
+
"dasomaru/gemma-3-4bit-it-demo",
|
28 |
+
trust_remote_code=True,
|
29 |
+
)
|
30 |
+
model = AutoModelForCausalLM.from_pretrained(
|
31 |
+
"dasomaru/gemma-3-4bit-it-demo",
|
32 |
+
torch_dtype=torch.float16, # 4bit model์ด๋๊น
|
33 |
+
device_map="auto", # โ
์ค์: ์๋์ผ๋ก GPU ํ ๋น
|
34 |
+
trust_remote_code=True,
|
35 |
+
|
36 |
+
)
|
37 |
+
model.to("cuda")
|
38 |
+
|
39 |
+
if query in search_cache:
|
40 |
+
print(f"โก ์บ์ ์ฌ์ฉ: '{query}'")
|
41 |
+
return search_cache[query]
|
42 |
+
|
43 |
+
# ๐ฅ rag_pipeline์ ํธ์ถํด์ ๊ฒ์ + ์์ฑ
|
44 |
+
# ๊ฒ์
|
45 |
+
top_k = 5
|
46 |
+
results = rag_pipeline(query, top_k=top_k)
|
47 |
+
|
48 |
+
# ๊ฒฐ๊ณผ๊ฐ list์ผ ๊ฒฝ์ฐ ํฉ์น๊ธฐ
|
49 |
+
if isinstance(results, list):
|
50 |
+
results = "\n\n".join(results)
|
51 |
+
|
52 |
+
search_cache[query] = results
|
53 |
+
# return results
|
54 |
+
|
55 |
+
inputs = tokenizer(results, return_tensors="pt").to(model.device) # โ
model.device
|
56 |
+
outputs = model.generate(
|
57 |
+
**inputs,
|
58 |
+
max_new_tokens=512,
|
59 |
+
temperature=0.7,
|
60 |
+
top_p=0.9,
|
61 |
+
top_k=50,
|
62 |
+
do_sample=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
66 |
+
|
67 |
+
|
68 |
+
# 3. Gradio ์ธํฐํ์ด์ค
|
69 |
+
demo = gr.Interface(
|
70 |
+
fn=generate_response,
|
71 |
+
# inputs=gr.Textbox(lines=2, placeholder="์ง๋ฌธ์ ์
๋ ฅํ์ธ์"),
|
72 |
+
inputs="text",
|
73 |
+
outputs="text",
|
74 |
+
title="Law RAG Assistant",
|
75 |
+
description="๋ฒ๋ น ๊ธฐ๋ฐ RAG ํ์ดํ๋ผ์ธ ํ
์คํธ",
|
76 |
+
)
|
77 |
+
|
78 |
+
# demo.launch(server_name="0.0.0.0", server_port=7860) # ๐ API ๋ฐฐํฌ ์ค๋น ๊ฐ๋ฅ
|
79 |
+
demo.launch()
|
app_v2.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from retriever.vectordb import search_documents # ๐ง RAG ๊ฒ์๊ธฐ ๋ถ๋ฌ์ค๊ธฐ
|
6 |
+
|
7 |
+
model_name = "dasomaru/gemma-3-4bit-it-demo"
|
8 |
+
|
9 |
+
|
10 |
+
# ๐ tokenizer๋ CPU์์๋ ๋ฏธ๋ฆฌ ๋ถ๋ฌ์ฌ ์ ์์
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
12 |
+
# ๐ model์ CPU๋ก๋ง ๋จผ์ ์ฌ๋ฆผ (GPU ์์ง ์์)
|
13 |
+
model = AutoModelForCausalLM.from_pretrained(
|
14 |
+
model_name,
|
15 |
+
torch_dtype=torch.float16, # 4bit model์ด๋๊น
|
16 |
+
trust_remote_code=True,
|
17 |
+
)
|
18 |
+
|
19 |
+
@spaces.GPU(duration=300)
|
20 |
+
def generate_response(query):
|
21 |
+
# ๐ generate_response ํจ์ ์์์ ๋งค๋ฒ ๋ก๋
|
22 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
23 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
# model_name,
|
25 |
+
# torch_dtype=torch.float16,
|
26 |
+
# device_map="auto", # โ
์ค์: ์๋์ผ๋ก GPU ํ ๋น
|
27 |
+
# trust_remote_code=True,
|
28 |
+
# )
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("dasomaru/gemma-3-4bit-it-demo")
|
30 |
+
model = AutoModelForCausalLM.from_pretrained("dasomaru/gemma-3-4bit-it-demo")
|
31 |
+
model.to("cuda")
|
32 |
+
|
33 |
+
# 1. ๊ฒ์
|
34 |
+
top_k = 5
|
35 |
+
retrieved_docs = search_documents(query, top_k=top_k)
|
36 |
+
|
37 |
+
# 2. ํ๋กฌํํธ ์กฐ๋ฆฝ
|
38 |
+
prompt = (
|
39 |
+
"๋น์ ์ ๊ณต์ธ์ค๊ฐ์ฌ ์ํ ๋ฌธ์ ์ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.\n\n"
|
40 |
+
"๋ค์์ ๊ธฐ์ถ ๋ฌธ์ ๋ฐ ๊ด๋ จ ๋ฒ๋ น ์ ๋ณด์
๋๋ค:\n"
|
41 |
+
)
|
42 |
+
for idx, doc in enumerate(retrieved_docs, 1):
|
43 |
+
prompt += f"- {doc}\n"
|
44 |
+
prompt += f"\n์ด ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉ์์ ์์ฒญ์ ๋ต๋ณํด ์ฃผ์ธ์.\n\n"
|
45 |
+
prompt += f"[์ง๋ฌธ]\n{query}\n\n[๋ต๋ณ]\n"
|
46 |
+
|
47 |
+
# 3. ๋ต๋ณ ์์ฑ
|
48 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # โ
model.device
|
49 |
+
outputs = model.generate(
|
50 |
+
**inputs,
|
51 |
+
max_new_tokens=512,
|
52 |
+
temperature=0.7,
|
53 |
+
top_p=0.9,
|
54 |
+
top_k=50,
|
55 |
+
do_sample=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
59 |
+
|
60 |
+
demo = gr.Interface(fn=generate_response, inputs="text", outputs="text")
|
61 |
+
demo.launch()
|
data/index/exam_db/chroma.sqlite3
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab1270442e19db5a1c0ec0217101b32e3d5ce379d9cf0a4278f7b4edac2489fb
|
3 |
+
size 14610432
|
data/index/law_db/chroma.sqlite3
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6dbbf1eed4fb2a85649ef2d22fdce84b1c10a268a59279dbb4a9e0d8141e1e55
|
3 |
+
size 38465536
|
generator/llm_inference.py
CHANGED
@@ -1,30 +1,29 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
import spaces
|
3 |
-
|
4 |
-
# 1. ๋ชจ๋ธ ๋ก๋ (์ต์ด 1๋ฒ๋ง ๋ก๋๋จ)
|
5 |
-
generator = pipeline(
|
6 |
-
"text-generation",
|
7 |
-
model="dasomaru/gemma-3-4bit-it-demo", # ๋ค๊ฐ ์
๋ก๋ํ ๋ชจ๋ธ ์ด๋ฆ
|
8 |
-
tokenizer="dasomaru/gemma-3-4bit-it-demo",
|
9 |
-
device=0, # CUDA:0 ์ฌ์ฉ (GPU). CPU๋ง ์์ผ๋ฉด device=-1
|
10 |
-
max_new_tokens=
|
11 |
-
temperature=0.7,
|
12 |
-
top_p=0.9,
|
13 |
-
repetition_penalty=1.1
|
14 |
-
)
|
15 |
-
|
16 |
-
# 2. ๋ต๋ณ ์์ฑ ํจ์
|
17 |
-
@spaces.GPU(duration=300)
|
18 |
-
def generate_answer(prompt: str) -> str:
|
19 |
-
"""
|
20 |
-
์
๋ ฅ๋ฐ์ ํ๋กฌํํธ๋ก๋ถํฐ ๋ชจ๋ธ์ด ๋ต๋ณ์ ์์ฑํ๋ค.
|
21 |
-
"""
|
22 |
-
print(f"๐ต Prompt Length: {len(prompt)} characters") # ์ถ๊ฐ!
|
23 |
-
outputs = generator(
|
24 |
-
prompt,
|
25 |
-
do_sample=True,
|
26 |
-
top_k=50,
|
27 |
-
num_return_sequences=1
|
28 |
-
)
|
29 |
-
return outputs[0]["generated_text"].strip()
|
30 |
-
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
# 1. ๋ชจ๋ธ ๋ก๋ (์ต์ด 1๋ฒ๋ง ๋ก๋๋จ)
|
5 |
+
generator = pipeline(
|
6 |
+
"text-generation",
|
7 |
+
model="dasomaru/gemma-3-4bit-it-demo", # ๋ค๊ฐ ์
๋ก๋ํ ๋ชจ๋ธ ์ด๋ฆ
|
8 |
+
tokenizer="dasomaru/gemma-3-4bit-it-demo",
|
9 |
+
device=0, # CUDA:0 ์ฌ์ฉ (GPU). CPU๋ง ์์ผ๋ฉด device=-1
|
10 |
+
max_new_tokens=512,
|
11 |
+
temperature=0.7,
|
12 |
+
top_p=0.9,
|
13 |
+
repetition_penalty=1.1
|
14 |
+
)
|
15 |
+
|
16 |
+
# 2. ๋ต๋ณ ์์ฑ ํจ์
|
17 |
+
@spaces.GPU(duration=300)
|
18 |
+
def generate_answer(prompt: str) -> str:
|
19 |
+
"""
|
20 |
+
์
๋ ฅ๋ฐ์ ํ๋กฌํํธ๋ก๋ถํฐ ๋ชจ๋ธ์ด ๋ต๋ณ์ ์์ฑํ๋ค.
|
21 |
+
"""
|
22 |
+
print(f"๐ต Prompt Length: {len(prompt)} characters") # ์ถ๊ฐ!
|
23 |
+
outputs = generator(
|
24 |
+
prompt,
|
25 |
+
do_sample=True,
|
26 |
+
top_k=50,
|
27 |
+
num_return_sequences=1
|
28 |
+
)
|
29 |
+
return outputs[0]["generated_text"].strip()
|
|
generator/prompt_builder.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1 |
-
def build_prompt(query: str,
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
6 |
|
7 |
-
|
|
|
8 |
|
9 |
-
|
10 |
-
{context_text}
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
-
|
15 |
-
{query}
|
16 |
|
17 |
-
[๋ต๋ณ]
|
18 |
-
"""
|
19 |
return prompt
|
|
|
1 |
+
def build_prompt(query: str, law_docs: list, exam_docs: list) -> str:
|
2 |
+
prompt = (
|
3 |
+
"๋น์ ์ ๊ณต์ธ์ค๊ฐ์ฌ ์ํ ๋ฌธ์ ์ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.\n\n"
|
4 |
+
"์๋๋ ๊ด๋ จ ๋ฒ๋ น๊ณผ ๊ธฐ์ถ๋ฌธ์ ์
๋๋ค:\n\n"
|
5 |
+
"[๋ฒ๋ น ์ ๋ณด]\n"
|
6 |
+
)
|
7 |
|
8 |
+
for doc in law_docs:
|
9 |
+
prompt += f"- {doc}\n"
|
10 |
|
11 |
+
prompt += "\n[๊ธฐ์ถ๋ฌธ์ ์ ๋ณด]\n"
|
|
|
12 |
|
13 |
+
for doc in exam_docs:
|
14 |
+
prompt += f"- {doc}\n"
|
15 |
|
16 |
+
prompt += f"\n์ ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉ์์ ์์ฒญ์ ์ ํํ๊ณ ๋ช
ํํ๊ฒ ๋ต๋ณํ์ธ์.\n\n"
|
17 |
+
prompt += f"[์ง๋ฌธ]\n{query}\n\n[๋ต๋ณ]\n"
|
18 |
|
|
|
|
|
19 |
return prompt
|
generator/prompt_builder_v1.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def build_prompt(query: str, context_docs: list) -> str:
|
2 |
+
"""
|
3 |
+
์ฌ์ฉ์ ์ง๋ฌธ๊ณผ ๊ฒ์๋ ๋ฌธ์๋ค์ ์กฐํฉํด LLM ์
๋ ฅ์ฉ ํ๋กฌํํธ๋ฅผ ๋ง๋ ๋ค.
|
4 |
+
"""
|
5 |
+
context_text = "\n".join([f"- {doc}" for doc in context_docs])
|
6 |
+
|
7 |
+
prompt = f"""๋น์ ์ ๊ณต์ธ์ค๊ฐ์ฌ ์ํ ๋ฌธ์ ์ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.
|
8 |
+
|
9 |
+
๋ค์์ ๊ธฐ์ถ ๋ฌธ์ ๋ฐ ๊ด๋ จ ๋ฒ๋ น ์ ๋ณด์
๋๋ค:
|
10 |
+
{context_text}
|
11 |
+
|
12 |
+
์ด ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉ์์ ์์ฒญ์ ๋ต๋ณํด ์ฃผ์ธ์.
|
13 |
+
|
14 |
+
[์ง๋ฌธ]
|
15 |
+
{query}
|
16 |
+
|
17 |
+
[๋ต๋ณ]
|
18 |
+
"""
|
19 |
+
return prompt
|
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
gradio
|
2 |
-
torch
|
3 |
-
transformers
|
4 |
-
sentence-transformers
|
5 |
-
faiss-cpu
|
6 |
-
tqdm
|
7 |
-
accelerate
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
sentence-transformers
|
5 |
+
faiss-cpu
|
6 |
+
tqdm
|
|
retriever/vectordb_rerank.py
CHANGED
@@ -1,37 +1,56 @@
|
|
|
|
1 |
import faiss
|
2 |
import numpy as np
|
3 |
import os
|
|
|
|
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
from retriever.reranker import rerank_documents
|
6 |
|
7 |
-
#
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# 2. ๋ฒกํฐDB (FAISS Index) ์ด๊ธฐํ
|
11 |
-
INDEX_PATH = "data/index/index.faiss"
|
12 |
-
DOCS_PATH = "data/index/docs.npy"
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
else:
|
18 |
-
index = None
|
19 |
-
documents = None
|
20 |
-
print("No FAISS index or docs found. Please build the index first.")
|
21 |
|
22 |
-
# 3. ๊ฒ์ ํจ์
|
23 |
-
def search_documents(query: str, top_k: int = 5):
|
24 |
-
if index is None or documents is None:
|
25 |
-
raise ValueError("Index or documents not loaded. Build the FAISS index first.")
|
26 |
-
|
27 |
-
# 1. FAISS rough ๊ฒ์
|
28 |
-
query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().detach().numpy()
|
29 |
-
distances, indices = index.search(query_embedding, top_k)
|
30 |
-
results = [documents[idx] for idx in indices[0] if idx != -1]
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
return reranked_results
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vectordb_relank_law.py
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
import os
|
5 |
+
from chromadb import PersistentClient
|
6 |
+
from chromadb.utils import embedding_functions
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from retriever.reranker import rerank_documents
|
9 |
|
10 |
+
# chroma vector config v2
|
11 |
+
embedding_models = [
|
12 |
+
"upskyy/bge-m3-korean",
|
13 |
+
"jhgan/ko-sbert-sts",
|
14 |
+
"BM-K/KoSimCSE-roberta",
|
15 |
+
"BM-K/KoSimCSE-v2-multitask",
|
16 |
+
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
|
17 |
+
"beomi/KcELECTRA-small-v2022",
|
18 |
+
]
|
19 |
+
# law_db config v2
|
20 |
+
CHROMA_PATH = os.path.abspath("data/index/law_db")
|
21 |
+
COLLECTION_NAME = "law_all"
|
22 |
+
EMBEDDING_MODEL_NAME = embedding_models[4] # ์ฌ์ฉํ๊ณ ์ ํ๋ ๋ชจ๋ธ ์ ํ
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
+
# 1. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ v2
|
26 |
+
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
27 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
# 2. ์๋ฒ ๋ฉ ํจ์ ์ค์
|
31 |
+
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
|
32 |
+
|
33 |
+
# 3. Chroma ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
|
34 |
+
client = PersistentClient(path=CHROMA_PATH)
|
35 |
+
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
|
36 |
|
|
|
37 |
|
38 |
|
39 |
+
# 4. ๊ฒ์ ํจ์
|
40 |
+
def search_documents(query: str, top_k: int = 5):
|
41 |
+
print(f"\n๐ ๊ฒ์์ด: '{query}'")
|
42 |
+
results = collection.query(
|
43 |
+
query_texts=[query],
|
44 |
+
n_results=top_k,
|
45 |
+
include=["documents", "metadatas", "distances"]
|
46 |
+
)
|
47 |
+
|
48 |
+
for i, (doc, meta, dist) in enumerate(zip(
|
49 |
+
results['documents'][0],
|
50 |
+
results['metadatas'][0],
|
51 |
+
results['distances'][0]
|
52 |
+
)):
|
53 |
+
print(f"\n๐ ๊ฒฐ๊ณผ {i+1} (์ ์ฌ๋: {1 - dist:.2f})")
|
54 |
+
print(f"๋ฌธ์: {doc[:150]}...")
|
55 |
+
print("๋ฉํ๋ฐ์ดํฐ:")
|
56 |
+
print(meta)
|
retriever/vectordb_rerank_exam.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vectordb_relank_law.py
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from chromadb import PersistentClient
|
6 |
+
from chromadb.utils import embedding_functions
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from retriever.reranker import rerank_documents
|
9 |
+
|
10 |
+
# chroma vector config v2
|
11 |
+
embedding_models = [
|
12 |
+
"upskyy/bge-m3-korean",
|
13 |
+
"jhgan/ko-sbert-sts",
|
14 |
+
"BM-K/KoSimCSE-roberta",
|
15 |
+
"BM-K/KoSimCSE-v2-multitask",
|
16 |
+
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
|
17 |
+
"beomi/KcELECTRA-small-v2022",
|
18 |
+
]
|
19 |
+
# law_db config v2
|
20 |
+
CHROMA_PATH = os.path.abspath("data/index/exam_db")
|
21 |
+
COLLECTION_NAME = "exam_all"
|
22 |
+
EMBEDDING_MODEL_NAME = embedding_models[4] # ์ฌ์ฉํ๊ณ ์ ํ๋ ๋ชจ๋ธ ์ ํ
|
23 |
+
|
24 |
+
# 1. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ v2
|
25 |
+
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
26 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
27 |
+
|
28 |
+
# 2. ์๋ฒ ๋ฉ ํจ์ ์ค์
|
29 |
+
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
|
30 |
+
|
31 |
+
# 3. Chroma ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
|
32 |
+
client = PersistentClient(path=CHROMA_PATH)
|
33 |
+
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
|
34 |
+
|
35 |
+
# 4. ๊ฒ์ ํจ์
|
36 |
+
def search_documents(query: str, top_k: int = 5):
|
37 |
+
print(f"\n๐ ๊ฒ์์ด: '{query}'")
|
38 |
+
results = collection.query(
|
39 |
+
query_texts=[query],
|
40 |
+
n_results=top_k,
|
41 |
+
include=["documents", "metadatas", "distances"]
|
42 |
+
)
|
43 |
+
|
44 |
+
# rerank documents
|
45 |
+
# reranked_results = rerank_documents(query, results, top_k=top_k)
|
46 |
+
|
47 |
+
for i, (doc, meta, dist) in enumerate(zip(
|
48 |
+
results['documents'][0],
|
49 |
+
results['metadatas'][0],
|
50 |
+
results['distances'][0]
|
51 |
+
)):
|
52 |
+
print(f"\n๐ ๊ฒฐ๊ณผ {i+1} (์ ์ฌ๋: {1 - dist:.2f})")
|
53 |
+
print(f"๋ฌธ์: {doc[:150]}...")
|
54 |
+
print("๋ฉํ๋ฐ์ดํฐ:")
|
55 |
+
print(meta)
|
retriever/vectordb_rerank_law.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vectordb_relank_law.py
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from chromadb import PersistentClient
|
6 |
+
from chromadb.utils import embedding_functions
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from retriever.reranker import rerank_documents
|
9 |
+
|
10 |
+
# chroma vector config v2
|
11 |
+
embedding_models = [
|
12 |
+
"upskyy/bge-m3-korean",
|
13 |
+
"jhgan/ko-sbert-sts",
|
14 |
+
"BM-K/KoSimCSE-roberta",
|
15 |
+
"BM-K/KoSimCSE-v2-multitask",
|
16 |
+
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
|
17 |
+
"beomi/KcELECTRA-small-v2022",
|
18 |
+
]
|
19 |
+
# law_db config v2
|
20 |
+
CHROMA_PATH = os.path.abspath("data/index/law_db")
|
21 |
+
COLLECTION_NAME = "law_all"
|
22 |
+
EMBEDDING_MODEL_NAME = embedding_models[4] # ์ฌ์ฉํ๊ณ ์ ํ๋ ๋ชจ๋ธ ์ ํ
|
23 |
+
|
24 |
+
|
25 |
+
# 1. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ v2
|
26 |
+
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
27 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
28 |
+
|
29 |
+
|
30 |
+
# 2. ์๋ฒ ๋ฉ ํจ์ ์ค์
|
31 |
+
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
|
32 |
+
|
33 |
+
# 3. Chroma ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
|
34 |
+
client = PersistentClient(path=CHROMA_PATH)
|
35 |
+
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
# 4. ๊ฒ์ ํจ์
|
40 |
+
def search_documents(query: str, top_k: int = 5):
|
41 |
+
print(f"\n๐ ๊ฒ์์ด: '{query}'")
|
42 |
+
results = collection.query(
|
43 |
+
query_texts=[query],
|
44 |
+
n_results=top_k,
|
45 |
+
include=["documents", "metadatas", "distances"]
|
46 |
+
)
|
47 |
+
|
48 |
+
# ๋ฌธ์ ๋ฆฌ์คํธ๋ง ์ถ์ถ
|
49 |
+
docs = results['documents'][0]
|
50 |
+
metadatas = results['metadatas'][0]
|
51 |
+
distances = results['distances'][0]
|
52 |
+
|
53 |
+
# Rerank ๋ฌธ์
|
54 |
+
reranked_docs = rerank_documents(query, docs, top_k=top_k)
|
55 |
+
|
56 |
+
# Rerank๋ ๋ฌธ์์ ๋ง์ถฐ metadata, distance ๋ค์ ์ ๋ ฌ
|
57 |
+
reranked_data = []
|
58 |
+
for doc in reranked_docs:
|
59 |
+
idx = docs.index(doc)
|
60 |
+
reranked_data.append((doc, metadatas[idx], distances[idx]))
|
61 |
+
|
62 |
+
for i, (doc, meta, dist) in enumerate(reranked_data):
|
63 |
+
print(f"\n๐ ๊ฒฐ๊ณผ {i+1} (์ ์ฌ๋: {1 - dist:.2f})")
|
64 |
+
print(f"๋ฌธ์: {doc[:150]}...")
|
65 |
+
print("๋ฉํ๋ฐ์ดํฐ:")
|
66 |
+
print(meta)
|
67 |
+
|
68 |
+
return reranked_data # ํ์ํ๋ฉด ๋ฆฌํด
|
services/rag_pipeline_v2.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from retriever.vectordb import search_documents
|
2 |
+
# from retriever.vectordb_rerank import search_documents
|
3 |
+
from retriever.vectordb_rerank_law import search_documents as search_law
|
4 |
+
from retriever.vectordb_rerank_exam import search_documents as search_exam
|
5 |
+
from generator.prompt_builder import build_prompt
|
6 |
+
from generator.llm_inference import generate_answer
|
7 |
+
|
8 |
+
def rag_pipeline(query: str, top_k: int = 5) -> str:
|
9 |
+
"""
|
10 |
+
1. ์ฌ์ฉ์ ์ง๋ฌธ์ผ๋ก ๊ด๋ จ ๋ฌธ์๋ฅผ ๊ฒ์
|
11 |
+
2. ๊ฒ์๋ ๋ฌธ์์ ํจ๊ป ํ๋กฌํํธ ๊ตฌ์ฑ
|
12 |
+
3. ํ๋กฌํํธ๋ก๋ถํฐ ๋ต๋ณ ์์ฑ
|
13 |
+
"""
|
14 |
+
# 1. ๋ฒ๋ น๊ณผ ๋ฌธ์ ๋ฅผ ๊ฐ๊ฐ ๊ฒ์
|
15 |
+
# context_docs = search_documents(query, top_k=top_k)
|
16 |
+
laws_docs = search_law(query, top_k=top_k)
|
17 |
+
exam_docs = search_exam(query, top_k=top_k)
|
18 |
+
|
19 |
+
# 2. ํ๋กฌํํธ ๊ตฌ์ฑ
|
20 |
+
# prompt = build_prompt(query, context_docs)
|
21 |
+
prompt = build_prompt(query, laws_docs, exam_docs)
|
22 |
+
|
23 |
+
# 3. LLM์ผ๋ก ๋ฌธ์ ์์ฑ
|
24 |
+
# output = generate_answer(prompt)
|
25 |
+
questions = generate_answer(prompt)
|
26 |
+
|
27 |
+
# 4. ๊ฒฐ๊ณผ ์ ์ฅ
|
28 |
+
# save_to_exam_vector_db(questions)
|
29 |
+
|
30 |
+
return questions
|
31 |
+
|
32 |
+
|
33 |
+
|