dasomaru commited on
Commit
a88d56c
ยท
verified ยท
1 Parent(s): 6fb6387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -79
app.py CHANGED
@@ -1,79 +1,80 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- # from retriever.vectordb_rerank import search_documents # ๐Ÿง  RAG ๊ฒ€์ƒ‰๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
6
- from services.rag_pipeline import rag_pipeline
7
-
8
- model_name = "dasomaru/gemma-3-4bit-it-demo"
9
-
10
-
11
- # 1. ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € 1ํšŒ ๋กœ๋”ฉ
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
- # ๐Ÿš€ model์€ CPU๋กœ๋งŒ ๋จผ์ € ์˜ฌ๋ฆผ (GPU ์•„์ง ์—†์Œ)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
17
- device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
18
- trust_remote_code=True,
19
- )
20
-
21
- # 2. ์บ์‹œ ๊ด€๋ฆฌ
22
- search_cache = {}
23
-
24
- @spaces.GPU(duration=300)
25
- def generate_response(query: str):
26
- tokenizer = AutoTokenizer.from_pretrained(
27
- "dasomaru/gemma-3-4bit-it-demo",
28
- trust_remote_code=True,
29
- )
30
- model = AutoModelForCausalLM.from_pretrained(
31
- "dasomaru/gemma-3-4bit-it-demo",
32
- torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
33
- device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
34
- trust_remote_code=True,
35
-
36
- )
37
- model.to("cuda")
38
-
39
- if query in search_cache:
40
- print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'")
41
- return search_cache[query]
42
-
43
- # ๐Ÿ”ฅ rag_pipeline์„ ํ˜ธ์ถœํ•ด์„œ ๊ฒ€์ƒ‰ + ์ƒ์„ฑ
44
- # ๊ฒ€์ƒ‰
45
- top_k = 5
46
- results = rag_pipeline(query, top_k=top_k)
47
-
48
- # ๊ฒฐ๊ณผ๊ฐ€ list์ผ ๊ฒฝ์šฐ ํ•ฉ์น˜๊ธฐ
49
- if isinstance(results, list):
50
- results = "\n\n".join(results)
51
-
52
- search_cache[query] = results
53
- # return results
54
-
55
- inputs = tokenizer(results, return_tensors="pt").to(model.device) # โœ… model.device
56
- outputs = model.generate(
57
- **inputs,
58
- max_new_tokens=512,
59
- temperature=0.7,
60
- top_p=0.9,
61
- top_k=50,
62
- do_sample=True,
63
- )
64
-
65
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
66
-
67
-
68
- # 3. Gradio ์ธํ„ฐํŽ˜์ด์Šค
69
- demo = gr.Interface(
70
- fn=generate_response,
71
- # inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”"),
72
- inputs="text",
73
- outputs="text",
74
- title="Law RAG Assistant",
75
- description="๋ฒ•๋ น ๊ธฐ๋ฐ˜ RAG ํŒŒ์ดํ”„๋ผ์ธ ํ…Œ์ŠคํŠธ",
76
- )
77
-
78
- # demo.launch(server_name="0.0.0.0", server_port=7860) # ๐Ÿš€ API ๋ฐฐํฌ ์ค€๋น„ ๊ฐ€๋Šฅ
79
- demo.launch()
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ # from retriever.vectordb_rerank import search_documents # ๐Ÿง  RAG ๊ฒ€์ƒ‰๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
6
+ from services.rag_pipeline import rag_pipeline
7
+
8
+ model_name = "dasomaru/gemma-3-4bit-it-demo"
9
+
10
+
11
+ # 1. ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € 1ํšŒ ๋กœ๋”ฉ
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ # ๐Ÿš€ model์€ CPU๋กœ๋งŒ ๋จผ์ € ์˜ฌ๋ฆผ (GPU ์•„์ง ์—†์Œ)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
17
+ device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
18
+ trust_remote_code=True,
19
+ )
20
+
21
+ # 2. ์บ์‹œ ๊ด€๋ฆฌ
22
+ search_cache = {}
23
+
24
+ @spaces.GPU(duration=300)
25
+ def generate_response(query: str):
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ "dasomaru/gemma-3-4bit-it-demo",
28
+ trust_remote_code=True,
29
+ )
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ "dasomaru/gemma-3-4bit-it-demo",
32
+ torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
33
+ device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
34
+ trust_remote_code=True,
35
+
36
+ )
37
+ model.to("cuda")
38
+
39
+ if query in search_cache:
40
+ print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'")
41
+ return search_cache[query]
42
+
43
+ # ๐Ÿ”ฅ rag_pipeline์„ ํ˜ธ์ถœํ•ด์„œ ๊ฒ€์ƒ‰ + ์ƒ์„ฑ
44
+ # ๊ฒ€์ƒ‰
45
+ top_k = 5
46
+ results = rag_pipeline(query, top_k=top_k)
47
+
48
+ # ๊ฒฐ๊ณผ๊ฐ€ list์ผ ๊ฒฝ์šฐ ํ•ฉ์น˜๊ธฐ
49
+ if isinstance(results, list):
50
+ results = "\n\n".join(results)
51
+
52
+ search_cache[query] = results
53
+ # return results
54
+
55
+ inputs = tokenizer(results, return_tensors="pt").to(model.device) # โœ… model.device
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=512,
59
+ temperature=0.7,
60
+ top_p=0.9,
61
+ top_k=50,
62
+ do_sample=True,
63
+ )
64
+
65
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+
68
+ # 3. Gradio ์ธํ„ฐํŽ˜์ด์Šค
69
+ demo = gr.Interface(
70
+ fn=generate_response,
71
+ # inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”"),
72
+ inputs="text",
73
+ outputs="text",
74
+ title="Law RAG Assistant",
75
+ description="๋ฒ•๋ น ๊ธฐ๋ฐ˜ RAG ํŒŒ์ดํ”„๋ผ์ธ ํ…Œ์ŠคํŠธ",
76
+ )
77
+
78
+ # demo.launch(server_name="0.0.0.0", server_port=7860) # ๐Ÿš€ API ๋ฐฐํฌ ์ค€๋น„ ๊ฐ€๋Šฅ
79
+ # demo.launch()
80
+ demo.launch(debug=True)