abancp commited on
Commit
419d496
·
verified ·
1 Parent(s): 9cc709f

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +7 -6
inference_fine_tune.py CHANGED
@@ -32,19 +32,18 @@ state = torch.load(model_path,map_location=torch.device('cpu'))
32
  model.load_state_dict(state['model_state_dict'])
33
 
34
  def generate_response(prompt: str):
35
- print("Prompt:", prompt)
36
  input_tokens = tokenizer.encode(prompt).ids
37
  input_tokens = [user_token_id] + input_tokens + [ai_token_id]
38
 
39
  if len(input_tokens) > config['seq_len']:
40
- print(f"Exceeding max length of input: {config['seq_len']}")
41
  return
42
 
43
- input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device) # (1, seq_len)
44
-
45
  temperature = 0.7
46
  top_k = 50
47
  i = 0
 
48
 
49
  while input_tokens.shape[1] < 2000:
50
  out = model.decode(input_tokens)
@@ -55,8 +54,10 @@ def generate_response(prompt: str):
55
  next_token = torch.multinomial(probs, num_samples=1)
56
  next_token = top_k_indices.gather(-1, next_token)
57
 
58
- decoded_word = tokenizer.decode([next_token.item()])
59
- yield decoded_word # Streaming output token-by-token
 
 
60
 
61
  input_tokens = torch.cat([input_tokens, next_token], dim=1)
62
  if input_tokens.shape[1] > config['seq_len']:
 
32
  model.load_state_dict(state['model_state_dict'])
33
 
34
  def generate_response(prompt: str):
 
35
  input_tokens = tokenizer.encode(prompt).ids
36
  input_tokens = [user_token_id] + input_tokens + [ai_token_id]
37
 
38
  if len(input_tokens) > config['seq_len']:
39
+ yield gr.Textbox.update(value="Prompt too long.")
40
  return
41
 
42
+ input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
 
43
  temperature = 0.7
44
  top_k = 50
45
  i = 0
46
+ generated_text = ""
47
 
48
  while input_tokens.shape[1] < 2000:
49
  out = model.decode(input_tokens)
 
54
  next_token = torch.multinomial(probs, num_samples=1)
55
  next_token = top_k_indices.gather(-1, next_token)
56
 
57
+ word = tokenizer.decode([next_token.item()])
58
+ generated_text += word
59
+
60
+ yield gr.Textbox.update(value=generated_text)
61
 
62
  input_tokens = torch.cat([input_tokens, next_token], dim=1)
63
  if input_tokens.shape[1] > config['seq_len']: