|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
--- |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
Query Rewriting in Retrieval-Augmented Large Language Models |
|
|
|
Arxiv : https://arxiv.org/abs/2305.14283 |
|
|
|
Large Language Models (LLMs) play powerful, black-box readers in the retrieve-then-read pipeline, making remarkable progress in knowledge-intensive tasks. This work introduces a new framework, Rewrite-Retrieve-Read instead of the previous retrieve-then-read for the retrieval-augmented LLMs from the perspective of the query rewriting. We first prompt an LLM to generate the query, then use a web search engine to retrieve contexts. Furthermore, to better align the query to the frozen modules, we propose a trainable scheme for our pipeline. A small language model is adopted as a trainable rewriter to cater to the black-box LLM reader. The rewriter is trained using the feedback of the LLM reader by reinforcement learning. |
|
- **Developed by:** https://github.com/xbmxb/RAG-query-rewriting |
|
- **Model type:** google/t5-large |
|
- **Checkpoint:** checkpoint_20 |
|
|
|
### Inference |
|
|
|
``` |
|
from transformers import T5Tokenizer,T5ForConditionalGeneration,BitsAndBytesConfig |
|
import torch |
|
|
|
# 8 bit Quantization |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=True) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = T5ForConditionalGeneration.from_pretrained('catyung/t5l-turbo-hotpot-0331', |
|
quantization_config=quantization_config) |
|
|
|
tokenizer = T5Tokenizer.from_pretrained('catyung/t5l-turbo-hotpot-0331') |
|
|
|
rewrite_prompt = f"""rewrite a better search query: {user_query} |
|
answer:""" |
|
|
|
# Inference |
|
user_query = "What profession does Nicholas Ray and Elia Kazan have in common?" |
|
|
|
input_ids = tokenizer(rewrite_prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
outputs = model.generate(input_ids,max_new_tokens=50) |
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
print(result) |
|
``` |
|
|