tahsinhasem commited on
Commit
8c27ab6
·
verified ·
1 Parent(s): 7a79370

Update generate endpoint

Browse files
Files changed (1) hide show
  1. main.py +18 -33
main.py CHANGED
@@ -11,48 +11,33 @@ app = FastAPI()
11
 
12
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
 
14
- class Item(BaseModel):
15
- prompt: str
16
- history: list
17
- system_prompt: str
18
- temperature: float = 0.0
19
- max_new_tokens: int = 1048
20
- top_p: float = 0.15
21
- repetition_penalty: float = 1.0
22
 
23
  class EchoMessage(BaseModel):
24
  message: str
25
 
26
- def format_prompt(message, history):
27
- prompt = "<s>"
28
- for user_prompt, bot_response in history:
29
- prompt += f"[INST] {user_prompt} [/INST]"
30
- prompt += f" {bot_response}</s> "
31
- prompt += f"[INST] {message} [/INST]"
32
- return prompt
33
 
34
  def generate(item: Item):
35
- temperature = float(item.temperature)
36
- if temperature < 1e-2:
37
- temperature = 1e-2
38
- top_p = float(item.top_p)
39
-
40
- generate_kwargs = dict(
41
- temperature=temperature,
42
- max_new_tokens=item.max_new_tokens,
43
- top_p=top_p,
44
- repetition_penalty=item.repetition_penalty,
45
- do_sample=True,
46
- seed=42,
 
 
47
  )
48
 
49
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
50
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
51
- output = ""
52
 
53
- for response in stream:
54
- output += response.token.text
55
- return output
56
 
57
  @app.post("/generate/")
58
  async def generate_text(item: Item):
 
11
 
12
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
13
 
 
 
 
 
 
 
 
 
14
 
15
  class EchoMessage(BaseModel):
16
  message: str
17
 
18
+ class Item(BaseModel):
19
+ prompt: str
20
+
 
 
 
 
21
 
22
  def generate(item: Item):
23
+ generator = pipeline("text-generation", model=model_name)
24
+
25
+ # Your input prompt
26
+ prompt = item.prompt
27
+
28
+ # Generate text
29
+ generated_texts = generator(
30
+ prompt,
31
+ max_length=50, # Maximum length of the generated text
32
+ num_return_sequences=1, # Number of different sequences to generate
33
+ temperature=0.8, # Controls the randomness of the output
34
+ top_k=50, # Limits the number of top tokens to consider
35
+ top_p=0.95, # Nucleus sampling parameter
36
+ do_sample=True # Enable sampling for non-deterministic output
37
  )
38
 
39
+ return generated_texts
 
 
40
 
 
 
 
41
 
42
  @app.post("/generate/")
43
  async def generate_text(item: Item):