diff --git a/rag/llm/rpc_server.py b/rag/llm/rpc_server.py deleted file mode 100644 index df3d6d310df..00000000000 --- a/rag/llm/rpc_server.py +++ /dev/null @@ -1,173 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import pickle -import random -import time -from copy import deepcopy -from multiprocessing.connection import Listener -from threading import Thread -from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer -from api.utils.log_utils import logger - - -def torch_gc(): - try: - import torch - if torch.cuda.is_available(): - # with torch.cuda.device(DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - elif torch.backends.mps.is_available(): - try: - from torch.mps import empty_cache - empty_cache() - except Exception: - pass - except Exception: - pass - - -class RPCHandler: - def __init__(self): - self._functions = {} - - def register_function(self, func): - self._functions[func.__name__] = func - - def handle_connection(self, connection): - try: - while True: - # Receive a message - func_name, args, kwargs = pickle.loads(connection.recv()) - # Run the RPC and send a response - try: - r = self._functions[func_name](*args, **kwargs) - connection.send(pickle.dumps(r)) - except Exception as e: - connection.send(pickle.dumps(e)) - except EOFError: - pass - - -def rpc_server(hdlr, address, authkey): - sock = Listener(address, authkey=authkey) - while True: - try: - client = sock.accept() - t = Thread(target=hdlr.handle_connection, args=(client,)) - t.daemon = True - t.start() - except Exception: - logger.exception("rpc_server got exception") - - -models = [] -tokenizer = None - - -def chat(messages, gen_conf): - global tokenizer - model = Model() - try: - torch_gc() - conf = { - "max_new_tokens": int( - gen_conf.get( - "max_tokens", 256)), "temperature": float( - gen_conf.get( - "temperature", 0.1))} - logger.debug(f"messages: {messages}, conf: {conf}") - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - - generated_ids = model.generate( - model_inputs.input_ids, - **conf - ) - generated_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) - ] - - return tokenizer.batch_decode( - generated_ids, skip_special_tokens=True)[0] - except Exception as e: - logger.exception("chat got exception") - return str(e) - - -def chat_streamly(messages, gen_conf): - global tokenizer - model = Model() - try: - torch_gc() - conf = deepcopy(gen_conf) - logger.debug(f"messages: {messages}, conf: {conf}") - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - streamer = TextStreamer(tokenizer) - conf["inputs"] = model_inputs.input_ids - conf["streamer"] = streamer - conf["max_new_tokens"] = conf["max_tokens"] - del conf["max_tokens"] - thread = Thread(target=model.generate, kwargs=conf) - thread.start() - for _, new_text in enumerate(streamer): - yield new_text - except Exception as e: - yield "**ERROR**: " + str(e) - - -def Model(): - global models - random.seed(time.time()) - return random.choice(models) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, help="Model name") - parser.add_argument( - "--port", - default=7860, - type=int, - help="RPC serving port") - args = parser.parse_args() - - handler = RPCHandler() - handler.register_function(chat) - handler.register_function(chat_streamly) - - models = [] - for _ in range(1): - m = AutoModelForCausalLM.from_pretrained(args.model_name, - device_map="auto", - torch_dtype='auto') - models.append(m) - tokenizer = AutoTokenizer.from_pretrained(args.model_name) - - # Run the server - rpc_server(handler, ('0.0.0.0', args.port), - authkey=b'infiniflow-token4kevinhu') diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 1b2bbe47104..f022ecdb904 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -70,7 +70,7 @@ def health(self) -> dict: TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables` """ inf_conn = self.connPool.get_conn() - res = infinity.show_current_node() + res = inf_conn.show_current_node() self.connPool.release_conn(inf_conn) color = "green" if res.error_code == 0 else "red" res2 = {