habdine commited on
Commit
17a799a
·
1 Parent(s): 347fea4

Initial commit of Nile-Chat-12B space

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Nile Chat 12B
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Egyptian Chatbot
 
1
  ---
2
+ title: Nile-Chat-12B
3
+ emoji: 🏞️
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Egyptian Chatbot
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # 🏞️🏞️ JAIS Initiative: Nile-Chat-12B 🏞️🏞️
12
+
13
+ Disclaimer: This research demonstration of Nile-Chat-12B is not intended for end-user applications. The model may generate biased, offensive, or inaccurate content as it is trained on diverse internet data. The developers do not endorse any views expressed by the model and assume no responsibility for the consequences of its use. Users should critically evaluate the generated responses and use the tool at their own risk.
14
+
15
+ Note: The model is expected to take input and generate output in Egyptian with both Arabic and Latin scripts.
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "MBZUAI-Paris/Nile-Chat-12B"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ model.eval()
32
+
33
+
34
+ @spaces.GPU(duration=90)
35
+ def generate(
36
+ message: str,
37
+ chat_history: list[dict],
38
+ max_new_tokens: int = 1024,
39
+ do_sample: bool = False,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.0,
44
+ ) -> Iterator[str]:
45
+ conversation = chat_history.copy()
46
+ conversation.append({"role": "user", "content": message})
47
+
48
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
49
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
+ input_ids = input_ids.to(model.device)
53
+
54
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
55
+ generate_kwargs = dict(
56
+ {"input_ids": input_ids},
57
+ streamer=streamer,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=do_sample,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ temperature=temperature,
63
+ num_beams=1,
64
+ repetition_penalty=repetition_penalty,
65
+ )
66
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
67
+ t.start()
68
+
69
+ outputs = []
70
+ for text in streamer:
71
+ outputs.append(text)
72
+ yield "".join(outputs)
73
+
74
+
75
+ chat_interface = gr.ChatInterface(
76
+ fn=generate,
77
+ additional_inputs=[
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
+ gr.Checkbox(label="Do Sample"),
86
+ gr.Slider(
87
+ label="Temperature",
88
+ minimum=0.0,
89
+ maximum=4.0,
90
+ step=0.1,
91
+ value=0.6,
92
+ ),
93
+ gr.Slider(
94
+ label="Top-p (nucleus sampling)",
95
+ minimum=0.05,
96
+ maximum=1.0,
97
+ step=0.05,
98
+ value=0.9,
99
+ ),
100
+ gr.Slider(
101
+ label="Top-k",
102
+ minimum=1,
103
+ maximum=1000,
104
+ step=1,
105
+ value=50,
106
+ ),
107
+ gr.Slider(
108
+ label="Repetition penalty",
109
+ minimum=1.0,
110
+ maximum=2.0,
111
+ step=0.05,
112
+ value=1.0,
113
+ ),
114
+ ],
115
+ stop_btn=None,
116
+ examples=[
117
+ ["مين اللي عملك؟"],
118
+ ["اسمك ايه؟"],
119
+ ["Esmak eh?"],
120
+ ["ترجم للمصرية:\nWith a total length of about 6,650 km between the region of Lake Victoria and the Mediterranean Sea, the Nile is among the longest rivers on Earth."],
121
+ ],
122
+ cache_examples=False,
123
+ type="messages",
124
+ )
125
+
126
+ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
127
+ gr.Markdown(DESCRIPTION)
128
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
129
+ chat_interface.render()
130
+
131
+ if __name__ == "__main__":
132
+ demo.queue(max_size=20).launch()
defaults/APP_COLOR ADDED
@@ -0,0 +1 @@
 
 
1
+ blue
defaults/APP_NAME ADDED
@@ -0,0 +1 @@
 
 
1
+ Nile-Chat
defaults/MODEL_NAME ADDED
@@ -0,0 +1 @@
 
 
1
+ MBZUAI-Paris/Nile-Chat-12B
defaults/MODEL_PARAMS ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "temperature": 0.5,
3
+ "top_p": 0.95,
4
+ "repetition_penalty": 1.1,
5
+ "top_k": 50,
6
+ "truncate": 1000,
7
+ "max_new_tokens": 1024
8
+ }
defaults/MODEL_PROMPT_TEMPLATE ADDED
@@ -0,0 +1 @@
 
 
1
+ <s>{{#each messages}}{{#ifUser}}[INST] {{#if @first}}{{#if @root.preprompt}}{{@root.preprompt}}\n{{/if}}{{/if}} {{content}} [/INST]{{/ifUser}}{{#ifAssistant}}{{content}}</s> {{/ifAssistant}}{{/each}}
defaults/MONGODB_URL ADDED
@@ -0,0 +1 @@
 
 
1
+ mongodb://127.0.0.1:27017
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
requirements.txt ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.0.0
4
+ # via gemma-2-9b-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.6.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.8.30
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.3.2
20
+ # via requests
21
+ click==8.1.7
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.0
28
+ # via gradio
29
+ ffmpy==0.4.0
30
+ # via gradio
31
+ filelock==3.16.1
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2024.9.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.0.1
43
+ # via
44
+ # gemma-2-9b-it (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.4.0
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via
50
+ # httpcore
51
+ # uvicorn
52
+ hf-transfer==0.1.8
53
+ # via gemma-2-9b-it (pyproject.toml)
54
+ httpcore==1.0.5
55
+ # via httpx
56
+ httpx==0.27.2
57
+ # via
58
+ # gradio
59
+ # gradio-client
60
+ # spaces
61
+ huggingface-hub==0.25.1
62
+ # via
63
+ # accelerate
64
+ # gradio
65
+ # gradio-client
66
+ # tokenizers
67
+ # transformers
68
+ idna==3.10
69
+ # via
70
+ # anyio
71
+ # httpx
72
+ # requests
73
+ jinja2==3.1.4
74
+ # via
75
+ # gradio
76
+ # torch
77
+ markdown-it-py==3.0.0
78
+ # via rich
79
+ markupsafe==2.1.5
80
+ # via
81
+ # gradio
82
+ # jinja2
83
+ mdurl==0.1.2
84
+ # via markdown-it-py
85
+ mpmath==1.3.0
86
+ # via sympy
87
+ networkx==3.3
88
+ # via torch
89
+ numpy==2.1.1
90
+ # via
91
+ # accelerate
92
+ # gradio
93
+ # pandas
94
+ # transformers
95
+ nvidia-cublas-cu12==12.1.3.1
96
+ # via
97
+ # nvidia-cudnn-cu12
98
+ # nvidia-cusolver-cu12
99
+ # torch
100
+ nvidia-cuda-cupti-cu12==12.1.105
101
+ # via torch
102
+ nvidia-cuda-nvrtc-cu12==12.1.105
103
+ # via torch
104
+ nvidia-cuda-runtime-cu12==12.1.105
105
+ # via torch
106
+ nvidia-cudnn-cu12==9.1.0.70
107
+ # via torch
108
+ nvidia-cufft-cu12==11.0.2.54
109
+ # via torch
110
+ nvidia-curand-cu12==10.3.2.106
111
+ # via torch
112
+ nvidia-cusolver-cu12==11.4.5.107
113
+ # via torch
114
+ nvidia-cusparse-cu12==12.1.0.106
115
+ # via
116
+ # nvidia-cusolver-cu12
117
+ # torch
118
+ nvidia-nccl-cu12==2.20.5
119
+ # via torch
120
+ nvidia-nvjitlink-cu12==12.6.68
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # nvidia-cusparse-cu12
124
+ nvidia-nvtx-cu12==12.1.105
125
+ # via torch
126
+ orjson==3.10.7
127
+ # via gradio
128
+ packaging==24.1
129
+ # via
130
+ # accelerate
131
+ # gradio
132
+ # gradio-client
133
+ # huggingface-hub
134
+ # spaces
135
+ # transformers
136
+ pandas==2.2.3
137
+ # via gradio
138
+ pillow==10.4.0
139
+ # via gradio
140
+ psutil==5.9.8
141
+ # via
142
+ # accelerate
143
+ # spaces
144
+ pydantic==2.9.2
145
+ # via
146
+ # fastapi
147
+ # gradio
148
+ # spaces
149
+ pydantic-core==2.23.4
150
+ # via pydantic
151
+ pydub==0.25.1
152
+ # via gradio
153
+ pygments==2.18.0
154
+ # via rich
155
+ python-dateutil==2.9.0.post0
156
+ # via pandas
157
+ python-multipart==0.0.12
158
+ # via gradio
159
+ pytz==2024.2
160
+ # via pandas
161
+ pyyaml==6.0.2
162
+ # via
163
+ # accelerate
164
+ # gradio
165
+ # huggingface-hub
166
+ # transformers
167
+ regex==2024.9.11
168
+ # via transformers
169
+ requests==2.32.3
170
+ # via
171
+ # huggingface-hub
172
+ # spaces
173
+ # transformers
174
+ rich==13.8.1
175
+ # via typer
176
+ ruff==0.6.8
177
+ # via gradio
178
+ safetensors==0.4.5
179
+ # via
180
+ # accelerate
181
+ # transformers
182
+ semantic-version==2.10.0
183
+ # via gradio
184
+ shellingham==1.5.4
185
+ # via typer
186
+ six==1.16.0
187
+ # via python-dateutil
188
+ sniffio==1.3.1
189
+ # via
190
+ # anyio
191
+ # httpx
192
+ spaces==0.30.3
193
+ # via gemma-2-9b-it (pyproject.toml)
194
+ starlette==0.38.6
195
+ # via fastapi
196
+ sympy==1.13.3
197
+ # via torch
198
+ tokenizers==0.19
199
+ # via transformers
200
+ tomlkit==0.12.0
201
+ # via gradio
202
+ torch==2.4.0
203
+ # via
204
+ # gemma-2-9b-it (pyproject.toml)
205
+ # accelerate
206
+ tqdm==4.66.5
207
+ # via
208
+ # huggingface-hub
209
+ # transformers
210
+ transformers==4.44.2
211
+ # via gemma-2-9b-it (pyproject.toml)
212
+ triton==3.0.0
213
+ # via torch
214
+ typer==0.12.5
215
+ # via gradio
216
+ typing-extensions==4.12.2
217
+ # via
218
+ # anyio
219
+ # fastapi
220
+ # gradio
221
+ # gradio-client
222
+ # huggingface-hub
223
+ # pydantic
224
+ # pydantic-core
225
+ # spaces
226
+ # torch
227
+ # typer
228
+ # uvicorn
229
+ tzdata==2024.2
230
+ # via pandas
231
+ urllib3==2.2.3
232
+ # via requests
233
+ uvicorn==0.31.0
234
+ # via gradio
235
+ websockets==12.0
236
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }