import dataclasses import pprint from functools import partial import re import os from threading import Lock import urllib import time from typing import List, Optional, Union from pydantic import BaseModel import absl.logging from tqdm import tqdm, trange import numpy as np import mlxu from ml_collections import ConfigDict import uvicorn from fastapi import FastAPI import gradio as gr import requests from requests.exceptions import Timeout, ConnectionError class InferenceRequest(BaseModel): prefix_text: Optional[List[str]] = None text: Optional[List[str]] = None until: Optional[Union[List[str], List[List[str]]]] = None temperature: Optional[float] = None class ChatRequest(BaseModel): prompt: str context: str = '' temperature: Optional[float] = None class LMServer(object): """ HTTP server for serving langauge models. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.host = '0.0.0.0' config.port = 5007 config.batch_size = 1 config.logging = False config.pre_compile = 'loglikelihood' config.default_temperature = 1.0 config.greedy_until_max_length = 5000 config.prepend_to_prefix = '' config.append_to_prefix = '' config.prepend_to_text = '' config.append_to_text = '' config.chat_prepend_text = '' config.chat_user_prefix = '' config.chat_user_suffix = '' config.chat_lm_prefix = '' config.chat_lm_suffix = '' config.notes = '' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config): self.config = self.get_default_config(config) self.lock = Lock() self.app = FastAPI() self.app.post('/loglikelihood')(self.serve_loglikelihood) self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling) self.app.post('/generate')(self.serve_generate) self.app.post('/greedy-until')(self.serve_greedy_until) self.app.post('/chat')(self.serve_chat) self.app.get('/ready')(self.serve_ready) self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/') @staticmethod def loglikelihood(prefix_text, text): raise NotImplementedError() @staticmethod def loglikelihood_rolling(text): raise NotImplementedError() @staticmethod def generate(text, temperature): raise NotImplementedError() @staticmethod def greedy_until(prefix_text, until, max_length): raise NotImplementedError() @staticmethod def to_list(x): if isinstance(x, np.ndarray): return x.tolist() return x def serve_ready(self): return 'Ready!\n' def serve_loglikelihood(self, data: InferenceRequest): with self.lock: if self.config.logging: absl.logging.info( '\n========= Serving Log Likelihood Request ========= \n' + pprint.pformat(data) + '\n' ) if data.prefix_text is None: data.prefix_text = ['' for _ in data.text] prefix_text = [ self.config.prepend_to_prefix + p + self.config.append_to_prefix for p in data.prefix_text ] text = [ self.config.prepend_to_text + t + self.config.append_to_text for t in data.text ] log_likelihood = [] is_greedy = [] for i in trange(0, len(text), self.config.batch_size, ncols=0): batch_prefix_text = prefix_text[i:i + self.config.batch_size] batch_text = text[i:i + self.config.batch_size] batch_size = len(batch_text) if batch_size < self.config.batch_size: extra = self.config.batch_size - batch_size batch_prefix_text.extend(['a' for _ in range(extra)]) batch_text.extend(['a' for _ in range(extra)]) batch_log_likelihood, batch_is_greedy = self.loglikelihood( batch_prefix_text, batch_text ) batch_log_likelihood = self.to_list(batch_log_likelihood) batch_is_greedy = self.to_list(batch_is_greedy) log_likelihood.extend(batch_log_likelihood[:batch_size]) is_greedy.extend(batch_is_greedy[:batch_size]) output = { 'prefix_text': data.prefix_text, 'text': data.text, 'log_likelihood': log_likelihood, 'is_greedy': is_greedy, } if self.config.logging: absl.logging.info( '\n========= Output ========= \n' + pprint.pformat(output) + '\n' ) return output def serve_loglikelihood_rolling(self, data: InferenceRequest): with self.lock: if self.config.logging: absl.logging.info( '\n========= Serving Log Likelihood Request ========= \n' + pprint.pformat(data) + '\n' ) text = [ self.config.prepend_to_text + t + self.config.append_to_text for t in data.text ] log_likelihood = [] is_greedy = [] for i in trange(0, len(text), self.config.batch_size, ncols=0): batch_text = text[i:i + self.config.batch_size] batch_size = len(batch_text) if batch_size < self.config.batch_size: extra = self.config.batch_size - batch_size batch_text.extend(['a' for _ in range(extra)]) batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling( batch_text ) batch_log_likelihood = self.to_list(batch_log_likelihood) batch_is_greedy = self.to_list(batch_is_greedy) log_likelihood.extend(batch_log_likelihood[:batch_size]) is_greedy.extend(batch_is_greedy[:batch_size]) output = { 'text': data.text, 'log_likelihood': log_likelihood, 'is_greedy': is_greedy, } if self.config.logging: absl.logging.info( '\n========= Output ========= \n' + pprint.pformat(output) + '\n' ) return output def serve_generate(self, data: InferenceRequest): with self.lock: if self.config.logging: absl.logging.info( '\n========= Serving Generate Request ========= \n' + pprint.pformat(data) + '\n' ) prefix_text = [ self.config.prepend_to_prefix + p + self.config.append_to_prefix for p in data.prefix_text ] if data.temperature is None: data.temperature = self.config.default_temperature output_text = [] for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0): batch_prefix_text = prefix_text[i:i + self.config.batch_size] batch_size = len(batch_prefix_text) if batch_size < self.config.batch_size: extra = self.config.batch_size - batch_size batch_prefix_text.extend(['a' for _ in range(extra)]) batch_output_text = self.generate( batch_prefix_text, temperature=data.temperature, ) output_text.extend(self.to_list(batch_output_text)[:batch_size]) output = { 'prefix_text': data.prefix_text, 'output_text': output_text, 'temperature': data.temperature, } if self.config.logging: absl.logging.info( '\n========= Output ========= \n' + pprint.pformat(output) + '\n' ) return output def serve_greedy_until(self, data: InferenceRequest): with self.lock: if self.config.logging: absl.logging.info( '\n========= Serving Greedy Until Request ========= \n' + pprint.pformat(data) + '\n' ) prefix_text = [ self.config.prepend_to_prefix + p + self.config.append_to_prefix for p in data.prefix_text ] until = data.until max_length = self.config.greedy_until_max_length output_text = [] for i in range(0, len(prefix_text), self.config.batch_size): batch_prefix_text = prefix_text[i:i + self.config.batch_size] batch_until = until[i:i + self.config.batch_size] batch_size = len(batch_prefix_text) batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length) output_text.extend(self.to_list(batch_output_text)[:batch_size]) output = { 'prefix_text': data.prefix_text, 'until': data.until, 'max_length': max_length, 'output_text': output_text, } if self.config.logging: absl.logging.info( '\n========= Output ========= \n' + pprint.pformat(output) + '\n' ) return output def process_chat(self, prompt, context, temperature): context = ( context + self.config.chat_user_prefix + prompt + self.config.chat_user_suffix + self.config.chat_lm_prefix ) response = self.generate( [self.config.chat_prepend_text + context], temperature=float(temperature), )[0] context = context + response + self.config.chat_lm_suffix return response, context def serve_chat(self, data: ChatRequest): if data.temperature is None: data.temperature = self.config.default_temperature response, context = self.process_chat( data.prompt, data.context, temperature=data.temperature, ) return { 'response': response, 'context': context, 'temperature': data.temperature, } def create_chat_app(self): with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot: gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)') gr.Markdown(self.config.notes) chatbot = gr.Chatbot(label='Chat history') msg = gr.Textbox( placeholder='Type your message here...', show_label=False ) with gr.Row(): send = gr.Button('Send') regenerate = gr.Button('Regenerate', interactive=False) clear = gr.Button('Reset') temp_slider = gr.Slider( label='Temperature', minimum=0, maximum=2.0, value=self.config.default_temperature ) context_state = gr.State(['', '']) def user_fn(user_message, history, context): return { msg: gr.update(value='', interactive=False), clear: gr.update(interactive=False), send: gr.update(interactive=False), regenerate: gr.update(interactive=False), chatbot: history + [[user_message, None]], context_state: [context[1], context[1]], } def model_fn(history, context, temperature): history[-1][1], new_context = self.process_chat( history[-1][0], context[0], temperature ) return { msg: gr.update(value='', interactive=True), clear: gr.update(interactive=True), send: gr.update(interactive=True), chatbot: history, context_state: [context[0], new_context], regenerate: gr.update(interactive=True), } def regenerate_fn(): return { msg: gr.update(value='', interactive=False), clear: gr.update(interactive=False), send: gr.update(interactive=False), regenerate: gr.update(interactive=False), } def clear_fn(): return { chatbot: None, msg: '', context_state: ['', ''], regenerate: gr.update(interactive=False), } msg.submit( user_fn, inputs=[msg, chatbot, context_state], outputs=[msg, clear, send, chatbot, context_state, regenerate], queue=False ).then( model_fn, inputs=[chatbot, context_state, temp_slider], outputs=[msg, clear, send, chatbot, context_state, regenerate], queue=True ) send.click( user_fn, inputs=[msg, chatbot, context_state], outputs=[msg, clear, send, chatbot, context_state, regenerate], queue=False ).then( model_fn, inputs=[chatbot, context_state, temp_slider], outputs=[msg, clear, send, chatbot, context_state, regenerate], queue=True ) regenerate.click( regenerate_fn, inputs=None, outputs=[msg, clear, send, regenerate], queue=False ).then( model_fn, inputs=[chatbot, context_state, temp_slider], outputs=[msg, clear, send, chatbot, context_state, regenerate], queue=True ) clear.click( clear_fn, inputs=None, outputs=[chatbot, msg, context_state, regenerate], queue=False ) gradio_chatbot.queue(concurrency_count=1) return gradio_chatbot def run(self): if self.config.pre_compile != '': if self.config.pre_compile == 'all': pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat'] else: pre_compile = self.config.pre_compile.split(',') pre_compile_data = ['a' for _ in range(self.config.batch_size)] for task in pre_compile: if task == 'loglikelihood': self.loglikelihood(pre_compile_data, pre_compile_data) self.loglikelihood_rolling(pre_compile_data) elif task == 'generate': self.generate(pre_compile_data, 1.0) elif task == 'greedy_until': self.greedy_until( pre_compile_data, pre_compile_data, self.config.greedy_until_max_length ) elif task == 'chat': self.process_chat('a', 'a', 1.0) else: raise ValueError(f'Invalid precompile task: {task}!') uvicorn.run(self.app, host=self.config.host, port=self.config.port) class LMClient(object): """ A simple client for the LM server. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.url = 'http://localhost:5007' config.batch_size = 1 config.wait_for_ready = True config.dummy = False if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config=None): self.config = self.get_default_config(config) if self.config.wait_for_ready: self.wait_for_ready() def wait_for_ready(self): if self.config.dummy: return while True: try: requests.get(urllib.parse.urljoin(self.config.url, 'ready')) return except (Timeout, ConnectionError) as e: time.sleep(10) @staticmethod def batched(iterator, batch_size): batch = [] for example in iterator: batch.append(example) if len(batch) == batch_size: yield batch batch = [] if len(batch) > 0: yield batch def loglikelihood(self, prefix, text): prefix, text = list(prefix), list(text) if self.config.dummy: return [-1.0 for _ in text], [False for _ in text] log_likelihood = [] is_greedy = [] batched_iterator = list(zip( self.batched(prefix, self.config.batch_size), self.batched(text, self.config.batch_size) )) for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0): response = requests.post( urllib.parse.urljoin(self.config.url, 'loglikelihood'), json={'prefix_text': batch_prefix, 'text': batch_text} ).json() log_likelihood.extend(response['log_likelihood']) is_greedy.extend(response['is_greedy']) return log_likelihood, is_greedy def loglikelihood_rolling(self, text): text = list(text) if self.config.dummy: return [-1.0 for _ in text], [False for _ in text] log_likelihood = [] is_greedy = [] batched_iterator = list(self.batched(text, self.config.batch_size)) for batch_text in tqdm(batched_iterator, ncols=0): response = requests.post( urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'), json={'text': batch_text} ).json() log_likelihood.extend(response['log_likelihood']) is_greedy.extend(response['is_greedy']) return log_likelihood, is_greedy def greedy_until(self, prefix, until): prefix, until = list(prefix), list(until) if self.config.dummy: results = [] for u in until: if isinstance(u, str): results.append('dummy text ' + u) else: results.append('dummy text ' + u[0]) return results batched_iterator = list(zip( self.batched(prefix, self.config.batch_size), self.batched(until, self.config.batch_size), )) output_text = [] for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0): response = requests.post( urllib.parse.urljoin(self.config.url, 'greedy-until'), json={'prefix_text': batch_prefix, 'until': batch_until} ).json() output_text.extend(response['output_text']) return output_text def generate(self, prefix, temperature=None): prefix = list(prefix) if self.config.dummy: return ['' for _ in prefix] output_text = [] batched_iterator = list(self.batched(prefix, self.config.batch_size)) for batch_prefix in tqdm(batched_iterator, ncols=0): response = requests.post( urllib.parse.urljoin(self.config.url, 'generate'), json={ 'prefix_text': batch_prefix, 'temperature': temperature, } ).json() output_text.extend(response['output_text']) return output_text def chat(self, prompt, context, temperature=None): if self.config.dummy: return '' response = requests.post( urllib.parse.urljoin(self.config.url, 'chat'), json={ 'prompt': prompt, 'context': context, 'temperature': temperature, } ).json() return response['response'], response['context']