PEFT
Safetensors
English
retrieval
instructions
orionweller commited on
Commit
30b14e3
·
verified ·
1 Parent(s): f563dfb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -19
README.md CHANGED
@@ -13,7 +13,7 @@ datasets:
13
 
14
  # Model Summary
15
 
16
- Promptriever is a new way of using dense retriever models. This version, `promptriever-llama2-7b-v1` was instruction-trained on a corpus of 490k MSMarco samples with instructions and 490k without instructions. See the [paper]() for more details.
17
 
18
  - **Repository:** [orionw/Promptriever](https://github.com/orionw/promptriever)
19
  - **Paper:** [Promptriever: Instruction-Trained Retrievers Can Be Prompted Like Language Models](TODO)
@@ -22,7 +22,16 @@ Promptriever is a new way of using dense retriever models. This version, `prompt
22
 
23
  # Use
24
 
25
- Below is an example to compute the similarity score of a query-document pair
 
 
 
 
 
 
 
 
 
26
  ```python
27
  import torch
28
  import torch.nn.functional as F
@@ -33,7 +42,7 @@ import numpy as np
33
  class Promptriever:
34
  def __init__(self, model_name_or_path):
35
  self.model, self.tokenizer = self.get_model(model_name_or_path)
36
- self.model.eval()
37
 
38
  def get_model(self, peft_model_name):
39
  # Load the PEFT configuration to get the base model name
@@ -44,33 +53,82 @@ class Promptriever:
44
  base_model = AutoModel.from_pretrained(base_model_name)
45
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
46
  tokenizer.pad_token = tokenizer.eos_token
 
 
47
 
48
  # Load and merge the PEFT model
49
  model = PeftModel.from_pretrained(base_model, peft_model_name)
50
  model = model.merge_and_unload()
51
 
 
 
 
 
52
  return model, tokenizer
53
 
54
- def encode(self, texts):
55
- inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
56
- with torch.no_grad():
57
- outputs = self.model(**inputs)
58
- embeddings = outputs.last_hidden_state[:, 0] # Using [CLS] token
59
- return F.normalize(embeddings, p=2, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Initialize the model
62
  model = Promptriever("samaya-ai/promptriever-llama2-7b-v1")
63
 
64
  # Example query and instruction
65
  query = "What universities are in Baltimore, Maryland?"
66
- instruction = "A relevant document would describe any university in Baltimore. I am only interested in the United States, so ignore any document with a campus in Italy."
67
 
68
- # Combine query and instruction with two spaces after "query: "
69
- input_text = f"query: {query} {instruction}"
 
 
 
70
 
71
  # Example documents
72
- doc1 = "Johns Hopkins University (often abbreviated as Johns Hopkins, Hopkins, or JHU) is a private research university in Baltimore, Maryland. Founded in 1876, Johns Hopkins was the first American university based on the European research institution model. The university also has graduate campuses in Italy, China, and Washington, D.C."
73
- doc2 = "Johns Hopkins University (often abbreviated as Johns Hopkins, Hopkins, or JHU) is a private research university in Baltimore, Maryland. Founded in 1876, Johns Hopkins was the first American university based on the European research institution model. The university also has graduate campuses in China, and Washington, D.C."
 
74
 
75
  # Encode query and documents
76
  query_embedding = model.encode([input_text])
@@ -78,11 +136,17 @@ doc_embeddings = model.encode([doc1, doc2])
78
 
79
  # Calculate similarities
80
  similarities = np.dot(query_embedding, doc_embeddings.T)[0]
 
 
 
81
 
82
- # Print results
83
- print("Similarities:")
84
- print(f"Document 1: {similarities[0]:.4f}")
85
- print(f"Document 2: {similarities[1]:.4f}")
 
 
 
86
  ```
87
 
88
  # Training
@@ -124,7 +188,7 @@ deepspeed --include localhost:0,1,2,3 --master_port "60002" --module tevatron.re
124
  ```
125
 
126
  # License
127
- This model is released under the Apache-2 license, following the terms of service of the Llama licence. This model was used for research efforts and is not used in any production systems at Samaya AI.
128
 
129
  # Citation
130
 
 
13
 
14
  # Model Summary
15
 
16
+ Promptriever is a bi-encoder retrieval model that can take in natural language instructions and prompts. This version, `promptriever-llama2-7b-v1` was instruction-trained on a corpus of 490k MSMarco samples with instructions and 490k without instructions. See the [paper](todo) for more details.
17
 
18
  - **Repository:** [orionw/Promptriever](https://github.com/orionw/promptriever)
19
  - **Paper:** [Promptriever: Instruction-Trained Retrievers Can Be Prompted Like Language Models](TODO)
 
22
 
23
  # Use
24
 
25
+ You can use MTEB to load this model ([source code](https://github.com/embeddings-benchmark/mteb/blob/main/mteb/models/promptriever_models.py)):
26
+ ```python
27
+ import mteb
28
+ model = mteb.get_model("samaya-ai/promptriever-llama2-7b-v1")
29
+ tasks = mteb.get_tasks(tasks=["NFCorpus"], languages=["eng"])
30
+ evaluation = mteb.MTEB(tasks=tasks)
31
+ evaluation.run(model, batch_size=16)
32
+ ```
33
+
34
+ If you want to use a different framework, here's an example of how to batch:
35
  ```python
36
  import torch
37
  import torch.nn.functional as F
 
42
  class Promptriever:
43
  def __init__(self, model_name_or_path):
44
  self.model, self.tokenizer = self.get_model(model_name_or_path)
45
+ self.model.eval().cuda()
46
 
47
  def get_model(self, peft_model_name):
48
  # Load the PEFT configuration to get the base model name
 
53
  base_model = AutoModel.from_pretrained(base_model_name)
54
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
55
  tokenizer.pad_token = tokenizer.eos_token
56
+ tokenizer.pad_token_id = tokenizer.eos_token_id
57
+ tokenizer.padding_side = "right"
58
 
59
  # Load and merge the PEFT model
60
  model = PeftModel.from_pretrained(base_model, peft_model_name)
61
  model = model.merge_and_unload()
62
 
63
+ # can be much longer, but for the example 512 is enough
64
+ model.config.max_length = 512
65
+ tokenizer.model_max_length = 512
66
+
67
  return model, tokenizer
68
 
69
+ def create_batch_dict(self, tokenizer, input_texts):
70
+ max_length = self.model.config.max_length
71
+ batch_dict = tokenizer(
72
+ input_texts,
73
+ max_length=max_length - 1,
74
+ return_token_type_ids=False,
75
+ return_attention_mask=False,
76
+ padding=False,
77
+ truncation=True,
78
+ )
79
+ batch_dict["input_ids"] = [
80
+ input_ids + [tokenizer.eos_token_id]
81
+ for input_ids in batch_dict["input_ids"]
82
+ ]
83
+ return tokenizer.pad(
84
+ batch_dict,
85
+ padding=True,
86
+ pad_to_multiple_of=8,
87
+ return_attention_mask=True,
88
+ return_tensors="pt",
89
+ )
90
+
91
+ def encode(self, sentences, max_length: int = 2048, batch_size: int = 4):
92
+ all_embeddings = []
93
+ for i in range(0, len(sentences), batch_size):
94
+ batch_texts = sentences[i : i + batch_size]
95
+
96
+ batch_dict = self.create_batch_dict(self.tokenizer, batch_texts)
97
+ batch_dict = {
98
+ key: value.to(self.model.device) for key, value in batch_dict.items()
99
+ }
100
+
101
+ with torch.cuda.amp.autocast():
102
+ with torch.no_grad():
103
+ outputs = self.model(**batch_dict)
104
+ last_hidden_state = outputs.last_hidden_state
105
+ sequence_lengths = batch_dict["attention_mask"].sum(dim=1) - 1
106
+ batch_size = last_hidden_state.shape[0]
107
+ reps = last_hidden_state[
108
+ torch.arange(batch_size, device=last_hidden_state.device),
109
+ sequence_lengths,
110
+ ]
111
+ embeddings = F.normalize(reps, p=2, dim=-1)
112
+ all_embeddings.append(embeddings.cpu().numpy())
113
+
114
+ return np.concatenate(all_embeddings, axis=0)
115
 
116
  # Initialize the model
117
  model = Promptriever("samaya-ai/promptriever-llama2-7b-v1")
118
 
119
  # Example query and instruction
120
  query = "What universities are in Baltimore, Maryland?"
 
121
 
122
+ # add specific relevance conditions if desired (and/or/not) and any other prompts
123
+ instruction = "A relevant document would describe any university in Baltimore. I am not interested in any university that was the first American university. Think carefully about these conditions when determining relevance."
124
+
125
+ # Combine query and instruction with **two spaces** after "query: "
126
+ input_text = f"query: {query.strip()} {instruction.strip()}".strip()
127
 
128
  # Example documents
129
+ # NOTE: double space after `passage:`
130
+ doc1 = "passage: Johns Hopkins University (often abbreviated as Johns Hopkins, Hopkins, or JHU) is a private research university in Baltimore, Maryland. Founded in 1876, Johns Hopkins was the first American university based on the European research institution model."
131
+ doc2 = "passage: Johns Hopkins University (often abbreviated as Johns Hopkins, Hopkins, or JHU) is a private research university in Baltimore, Maryland. Founded in 1876, Johns Hopkins was the second American university based on the European research institution model."
132
 
133
  # Encode query and documents
134
  query_embedding = model.encode([input_text])
 
136
 
137
  # Calculate similarities
138
  similarities = np.dot(query_embedding, doc_embeddings.T)[0]
139
+ print(f"Similarities: {similarities}") # Similarities: [0.53341305 0.53451955]
140
+ assert similarities[1] > similarities[0]
141
+
142
 
143
+ # change up the instruction to the opposite, to see it works
144
+ instruction = "A relevant document would describe any university in Baltimore. I am interested in any university that was the first American university. Think carefully about these conditions when determining relevance."
145
+ input_text = f"query: {query.strip()} {instruction.strip()}".strip()
146
+ query_embedding = model.encode([input_text])
147
+ similarities = np.dot(query_embedding, doc_embeddings.T)[0]
148
+ print(f"Similarities: {similarities}") # Similarities: [0.60182875 0.5874183 ]
149
+ assert similarities[0] > similarities[1]
150
  ```
151
 
152
  # Training
 
188
  ```
189
 
190
  # License
191
+ This model was used for research efforts and is not used in any production systems at Samaya AI. Usage must follow the license of the base model as well, as this is a LoRA fine-tune.
192
 
193
  # Citation
194