Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app/main.py
Browse files- app/main.py +39 -15
 
    	
        app/main.py
    CHANGED
    
    | 
         @@ -15,6 +15,7 @@ from typing import Optional 
     | 
|
| 15 | 
         
             
            print("Loading model...")
         
     | 
| 16 | 
         
             
            SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
         
     | 
| 17 | 
         
             
            FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
         
     | 
| 
         | 
|
| 18 | 
         
             
                  # n_gpu_layers=28, # Uncomment to use GPU acceleration
         
     | 
| 19 | 
         
             
                  # seed=1337, # Uncomment to set a specific seed
         
     | 
| 20 | 
         
             
                  # n_ctx=2048, # Uncomment to increase the context window
         
     | 
| 
         @@ -23,9 +24,9 @@ FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock 
     | 
|
| 23 | 
         
             
            def extract_restext(response):
         
     | 
| 24 | 
         
             
              return response['choices'][0]['text'].strip()
         
     | 
| 25 | 
         | 
| 26 | 
         
            -
            def  
     | 
| 27 | 
         
             
              prompt = f"""###User: {question}\n###Assistant:"""
         
     | 
| 28 | 
         
            -
              result = extract_restext( 
     | 
| 29 | 
         
             
              return result
         
     | 
| 30 | 
         | 
| 31 | 
         
             
            def check_sentiment(text):
         
     | 
| 
         @@ -43,7 +44,8 @@ def check_sentiment(text): 
     | 
|
| 43 | 
         
             
            # TESTING THE MODEL
         
     | 
| 44 | 
         
             
            print("Testing model...")
         
     | 
| 45 | 
         
             
            assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
         
     | 
| 46 | 
         
            -
            assert  
     | 
| 
         | 
|
| 47 | 
         
             
            print("Ready.")
         
     | 
| 48 | 
         | 
| 49 | 
         | 
| 
         @@ -70,12 +72,12 @@ class SA_Result(str, Enum): 
     | 
|
| 70 | 
         
             
              negative = "negative"
         
     | 
| 71 | 
         
             
              unknown = "unknown"
         
     | 
| 72 | 
         | 
| 73 | 
         
            -
            class  
     | 
| 74 | 
         
             
              code: int = 200
         
     | 
| 75 | 
         
             
              text: Optional[str] = None
         
     | 
| 76 | 
         
             
              result: SA_Result = None
         
     | 
| 77 | 
         | 
| 78 | 
         
            -
            class  
     | 
| 79 | 
         
             
              code: int = 200
         
     | 
| 80 | 
         
             
              question: Optional[str] = None
         
     | 
| 81 | 
         
             
              answer: str = None
         
     | 
| 
         @@ -89,18 +91,18 @@ def docs(): 
     | 
|
| 89 | 
         
             
              return responses.RedirectResponse('./docs')
         
     | 
| 90 | 
         | 
| 91 | 
         
             
            @app.post('/classifications/sentiment')
         
     | 
| 92 | 
         
            -
            async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) ->  
     | 
| 93 | 
         
             
              """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
         
     | 
| 94 | 
         
             
              if prompt:
         
     | 
| 95 | 
         
             
                try:
         
     | 
| 96 | 
         
             
                  print(f"Checking sentiment for {prompt}")
         
     | 
| 97 | 
         
             
                  result = check_sentiment(prompt)
         
     | 
| 98 | 
         
             
                  print(f"Result: {result}")
         
     | 
| 99 | 
         
            -
                  return  
     | 
| 100 | 
         
             
                except Exception as e:
         
     | 
| 101 | 
         
            -
                  return HTTPException(500,  
     | 
| 102 | 
         
             
              else:
         
     | 
| 103 | 
         
            -
                return HTTPException(400,  
     | 
| 104 | 
         | 
| 105 | 
         | 
| 106 | 
         
             
            @app.post('/questions/finance')
         
     | 
| 
         @@ -108,18 +110,40 @@ async def ask_gemmaFinanceTH( 
     | 
|
| 108 | 
         
             
                prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
         
     | 
| 109 | 
         
             
                temperature: float = Body(0.5, embed=True), 
         
     | 
| 110 | 
         
             
                max_new_tokens: int = Body(200, embed=True)
         
     | 
| 111 | 
         
            -
            ) ->  
     | 
| 112 | 
         
             
              """
         
     | 
| 113 | 
         
             
              Ask a finetuned Gemma a finance-related question, just for fun.
         
     | 
| 114 | 
         
             
              NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
         
     | 
| 115 | 
         
             
              """
         
     | 
| 116 | 
         
             
              if prompt:
         
     | 
| 117 | 
         
             
                try:
         
     | 
| 118 | 
         
            -
                  print(f'Asking  
     | 
| 119 | 
         
            -
                  result =  
     | 
| 120 | 
         
             
                  print(f"Result: {result}")
         
     | 
| 121 | 
         
            -
                  return  
     | 
| 122 | 
         
             
                except Exception as e:
         
     | 
| 123 | 
         
            -
                  return HTTPException(500,  
     | 
| 124 | 
         
             
              else:
         
     | 
| 125 | 
         
            -
                return HTTPException(400,  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 15 | 
         
             
            print("Loading model...")
         
     | 
| 16 | 
         
             
            SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
         
     | 
| 17 | 
         
             
            FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
         
     | 
| 18 | 
         
            +
            WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", mmap=False, mlock=True)
         
     | 
| 19 | 
         
             
                  # n_gpu_layers=28, # Uncomment to use GPU acceleration
         
     | 
| 20 | 
         
             
                  # seed=1337, # Uncomment to set a specific seed
         
     | 
| 21 | 
         
             
                  # n_ctx=2048, # Uncomment to increase the context window
         
     | 
| 
         | 
|
| 24 | 
         
             
            def extract_restext(response):
         
     | 
| 25 | 
         
             
              return response['choices'][0]['text'].strip()
         
     | 
| 26 | 
         | 
| 27 | 
         
            +
            def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
         
     | 
| 28 | 
         
             
              prompt = f"""###User: {question}\n###Assistant:"""
         
     | 
| 29 | 
         
            +
              result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
         
     | 
| 30 | 
         
             
              return result
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            def check_sentiment(text):
         
     | 
| 
         | 
|
| 44 | 
         
             
            # TESTING THE MODEL
         
     | 
| 45 | 
         
             
            print("Testing model...")
         
     | 
| 46 | 
         
             
            assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
         
     | 
| 47 | 
         
            +
            assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
         
     | 
| 48 | 
         
            +
            assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
         
     | 
| 49 | 
         
             
            print("Ready.")
         
     | 
| 50 | 
         | 
| 51 | 
         | 
| 
         | 
|
| 72 | 
         
             
              negative = "negative"
         
     | 
| 73 | 
         
             
              unknown = "unknown"
         
     | 
| 74 | 
         | 
| 75 | 
         
            +
            class SAResponse(BaseModel):
         
     | 
| 76 | 
         
             
              code: int = 200
         
     | 
| 77 | 
         
             
              text: Optional[str] = None
         
     | 
| 78 | 
         
             
              result: SA_Result = None
         
     | 
| 79 | 
         | 
| 80 | 
         
            +
            class QuestionResponse(BaseModel):
         
     | 
| 81 | 
         
             
              code: int = 200
         
     | 
| 82 | 
         
             
              question: Optional[str] = None
         
     | 
| 83 | 
         
             
              answer: str = None
         
     | 
| 
         | 
|
| 91 | 
         
             
              return responses.RedirectResponse('./docs')
         
     | 
| 92 | 
         | 
| 93 | 
         
             
            @app.post('/classifications/sentiment')
         
     | 
| 94 | 
         
            +
            async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
         
     | 
| 95 | 
         
             
              """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
         
     | 
| 96 | 
         
             
              if prompt:
         
     | 
| 97 | 
         
             
                try:
         
     | 
| 98 | 
         
             
                  print(f"Checking sentiment for {prompt}")
         
     | 
| 99 | 
         
             
                  result = check_sentiment(prompt)
         
     | 
| 100 | 
         
             
                  print(f"Result: {result}")
         
     | 
| 101 | 
         
            +
                  return SAResponse(result=result, text=prompt)
         
     | 
| 102 | 
         
             
                except Exception as e:
         
     | 
| 103 | 
         
            +
                  return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
         
     | 
| 104 | 
         
             
              else:
         
     | 
| 105 | 
         
            +
                return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))
         
     | 
| 106 | 
         | 
| 107 | 
         | 
| 108 | 
         
             
            @app.post('/questions/finance')
         
     | 
| 
         | 
|
| 110 | 
         
             
                prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
         
     | 
| 111 | 
         
             
                temperature: float = Body(0.5, embed=True), 
         
     | 
| 112 | 
         
             
                max_new_tokens: int = Body(200, embed=True)
         
     | 
| 113 | 
         
            +
            ) -> QuestionResponse:
         
     | 
| 114 | 
         
             
              """
         
     | 
| 115 | 
         
             
              Ask a finetuned Gemma a finance-related question, just for fun.
         
     | 
| 116 | 
         
             
              NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
         
     | 
| 117 | 
         
             
              """
         
     | 
| 118 | 
         
             
              if prompt:
         
     | 
| 119 | 
         
             
                try:
         
     | 
| 120 | 
         
            +
                  print(f'Asking GemmaFinance with the question "{prompt}"')
         
     | 
| 121 | 
         
            +
                  result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
         
     | 
| 122 | 
         
             
                  print(f"Result: {result}")
         
     | 
| 123 | 
         
            +
                  return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
         
     | 
| 124 | 
         
             
                except Exception as e:
         
     | 
| 125 | 
         
            +
                  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
         
     | 
| 126 | 
         
             
              else:
         
     | 
| 127 | 
         
            +
                return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
         
     | 
| 128 | 
         
            +
              
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            @app.post('/questions/finance')
         
     | 
| 131 | 
         
            +
            async def ask_gemmaFinanceTH(
         
     | 
| 132 | 
         
            +
                prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
         
     | 
| 133 | 
         
            +
                temperature: float = Body(0.5, embed=True), 
         
     | 
| 134 | 
         
            +
                max_new_tokens: int = Body(200, embed=True)
         
     | 
| 135 | 
         
            +
            ) -> QuestionResponse:
         
     | 
| 136 | 
         
            +
              """
         
     | 
| 137 | 
         
            +
              Ask a finetuned Gemma an open-ended question..
         
     | 
| 138 | 
         
            +
              NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
         
     | 
| 139 | 
         
            +
              """
         
     | 
| 140 | 
         
            +
              if prompt:
         
     | 
| 141 | 
         
            +
                try:
         
     | 
| 142 | 
         
            +
                  print(f'Asking GemmaWild with the question "{prompt}"')
         
     | 
| 143 | 
         
            +
                  result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
         
     | 
| 144 | 
         
            +
                  print(f"Result: {result}")
         
     | 
| 145 | 
         
            +
                  return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
         
     | 
| 146 | 
         
            +
                except Exception as e:
         
     | 
| 147 | 
         
            +
                  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
         
     | 
| 148 | 
         
            +
              else:
         
     | 
| 149 | 
         
            +
                return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
         
     |