Wang Baoling
		
		KevinHuSh
		
	commited on
		
		
					Commit 
							
							·
						
						04d3b7e
	
1
								Parent(s):
							
							ac8a9f7
								
Fix: bug #991 (#1013)
Browse files### What problem does this PR solve?
issue #991
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: KevinHuSh <[email protected]>
- api/db/init_data.py +1 -1
 - rag/llm/rerank_model.py +14 -0
 
    	
        api/db/init_data.py
    CHANGED
    
    | 
         @@ -386,7 +386,7 @@ def init_llm_factory(): 
     | 
|
| 386 | 
         
             
                        "fid": factory_infos[7]["name"],
         
     | 
| 387 | 
         
             
                        "llm_name": "maidalun1020/bce-reranker-base_v1",
         
     | 
| 388 | 
         
             
                        "tags": "RE-RANK, 8K",
         
     | 
| 389 | 
         
            -
                        "max_tokens":  
     | 
| 390 | 
         
             
                        "model_type": LLMType.RERANK.value
         
     | 
| 391 | 
         
             
                    },
         
     | 
| 392 | 
         
             
                    # ------------------------ DeepSeek -----------------------
         
     | 
| 
         | 
|
| 386 | 
         
             
                        "fid": factory_infos[7]["name"],
         
     | 
| 387 | 
         
             
                        "llm_name": "maidalun1020/bce-reranker-base_v1",
         
     | 
| 388 | 
         
             
                        "tags": "RE-RANK, 8K",
         
     | 
| 389 | 
         
            +
                        "max_tokens": 512,
         
     | 
| 390 | 
         
             
                        "model_type": LLMType.RERANK.value
         
     | 
| 391 | 
         
             
                    },
         
     | 
| 392 | 
         
             
                    # ------------------------ DeepSeek -----------------------
         
     | 
    	
        rag/llm/rerank_model.py
    CHANGED
    
    | 
         @@ -113,4 +113,18 @@ class YoudaoRerank(DefaultRerank): 
     | 
|
| 113 | 
         
             
                            YoudaoRerank._model = RerankerModel(
         
     | 
| 114 | 
         
             
                                model_name_or_path=model_name.replace(
         
     | 
| 115 | 
         
             
                                    "maidalun1020", "InfiniFlow"))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 116 | 
         | 
| 
         | 
|
| 113 | 
         
             
                            YoudaoRerank._model = RerankerModel(
         
     | 
| 114 | 
         
             
                                model_name_or_path=model_name.replace(
         
     | 
| 115 | 
         
             
                                    "maidalun1020", "InfiniFlow"))
         
     | 
| 116 | 
         
            +
                
         
     | 
| 117 | 
         
            +
                def similarity(self, query: str, texts: list):
         
     | 
| 118 | 
         
            +
                    pairs = [(query,truncate(t, self._model.max_length)) for t in texts]
         
     | 
| 119 | 
         
            +
                    token_count = 0
         
     | 
| 120 | 
         
            +
                    for _, t in pairs:
         
     | 
| 121 | 
         
            +
                        token_count += num_tokens_from_string(t)
         
     | 
| 122 | 
         
            +
                    batch_size = 32
         
     | 
| 123 | 
         
            +
                    res = []
         
     | 
| 124 | 
         
            +
                    for i in range(0, len(pairs), batch_size):
         
     | 
| 125 | 
         
            +
                        scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
         
     | 
| 126 | 
         
            +
                        scores = sigmoid(np.array(scores)).tolist()
         
     | 
| 127 | 
         
            +
                        res.extend(scores)
         
     | 
| 128 | 
         
            +
                    return np.array(res), token_count
         
     | 
| 129 | 
         
            +
                
         
     | 
| 130 | 
         |