import ctypes import sys import os import subprocess import resource import threading import time import argparse import json from flask import Flask, request, jsonify, Response from rkllm import * app = Flask(__name__,static_url_path='',static_folder='static') @app.route('/') def root(): return app.send_static_file('index.html') PROMPT_TEXT_PREFIX = "<|im_start|>system\nYou are a helpful assistant. You only give short answers.<|im_end|>\n<|im_start|>user\n" PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant\n" MSG_START_TOKEN = "<|im_start|>" MSG_END_TOKEN = "<|im_end|>" def msg_to_prompt(user, msg): return f'{MSG_START_TOKEN}{user}\n{msg}{MSG_END_TOKEN}\n' def msgs_to_prompt(msgs: list[dict]): system = msgs[0] if msgs and msgs[0]['role'] == 'system' else "" if not system: msgs.insert(0, {'role':'system', 'content': "You are a helpful assistant. You only give short but complete answers."}) return (''.join(msg_to_prompt(msg['role'], msg['content']) for msg in msgs) + f'{MSG_START_TOKEN}assistant\n') # Create a lock to control multi-user access to the server. lock = threading.Lock() # Create a global variable to indicate whether the server is currently in a blocked state. is_blocking = False # Define global variables to store the callback function output for displaying in the Gradio interface global_text = [] global_state = -1 split_byte_data = bytes(b"") # Used to store the segmented byte data global_abort = False # Define the callback function def callback_impl(result, userdata, state): global global_text, global_state, split_byte_data if state == LLMCallState.RKLLM_RUN_FINISH: global_state = state print("\n") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_ERROR: global_state = state print("run error") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER: ''' If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size. With these three parameters, you can retrieve the data from last_hidden_layer. Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback. ''' if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0: data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float) print(f"data_size: {data_size}") global_text.append(f"data_size: {data_size}\n") output_path = os.getcwd() + "/last_hidden_layer.bin" with open(output_path, "wb") as outFile: data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float)) float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float)) float_array = float_array_type.from_address(ctypes.addressof(data.contents)) outFile.write(bytearray(float_array)) print(f"Data saved to {output_path} successfully!") global_text.append(f"Data saved to {output_path} successfully!") else: print("Invalid hidden layer data.") global_text.append("Invalid hidden layer data.") global_state = state time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result sys.stdout.flush() else: # Save the output token text and the RKLLM running state global_state = state # Monitor if the current byte data is complete; if incomplete, record it for later parsing try: global_text.append((split_byte_data + result.contents.text).decode('utf-8')) print((split_byte_data + result.contents.text).decode('utf-8'), end='') split_byte_data = bytes(b"") except: split_byte_data += result.contents.text sys.stdout.flush() # Connect the callback function between the Python side and the C++ side callback = callback_type(callback_impl) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;') parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;') parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;') parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;') args = parser.parse_args() if not os.path.exists(args.rkllm_model_path): print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.") sys.stdout.flush() exit() if not (args.target_platform in ["rk3588", "rk3576"]): print("Error: Please specify the correct target platform: rk3588/rk3576.") sys.stdout.flush() exit() if args.lora_model_path: if not os.path.exists(args.lora_model_path): print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.") sys.stdout.flush() exit() if args.prompt_cache_path: if not os.path.exists(args.prompt_cache_path): print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.") sys.stdout.flush() exit() # Fix frequency command = "[ \"$(cat /sys/class/devfreq/fdab0000.npu/governor)\" = userspace ] || sudo bash fix_freq_{}.sh".format(args.target_platform) subprocess.run(command, shell=True) # Set resource limit resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400)) # Initialize RKLLM model print("=========init....===========") sys.stdout.flush() model_path = args.rkllm_model_path rkllm_model = RKLLM(model_path, callback, args.lora_model_path, args.prompt_cache_path) print("RKLLM Model has been initialized successfully!") print("==============================") sys.stdout.flush() @app.route('/rkllm_abort', methods=['POST']) def abort(): global global_abort global_abort = True code = rkllm_model.abort() return {"code":code}, 200 if code is None else 500 # Create a function to receive data sent by the user using a request @app.route('/rkllm_chat', methods=['POST']) def receive_message(): # Link global variables to retrieve the output information from the callback function global global_text, global_state global is_blocking # If the server is in a blocking state, return a specific response. if is_blocking or global_state==0: return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503 lock.acquire() try: # Set the server to a blocking state. is_blocking = True # Get JSON data from the POST request. data = request.json if data and 'messages' in data: # Reset global variables. global_text = [] global_state = -1 # Define the structure for the returned response. rkllm_responses = { "id": "rkllm_chat", "object": "rkllm_chat", "created": None, "choices": [], "usage": { "prompt_tokens": None, "completion_tokens": None, "total_tokens": None } } if not "stream" in data.keys() or data["stream"] == False: # Process the received data here. messages = data['messages'] print("Received messages:", messages) input_prompt = msgs_to_prompt(messages) print("generated prompt:", input_prompt) rkllm_output = "" # Create a thread for model inference. model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,)) model_thread.start() # Wait for the model to finish running and periodically check the inference thread of the model. model_thread_finished = False global global_abort while not model_thread_finished: while len(global_text) > 0: rkllm_output += global_text.pop(0) time.sleep(0.01) if global_abort: global_abort = False model_thread.join(timeout=0.005) model_thread_finished = not model_thread.is_alive() rkllm_responses["choices"].append( {"index": len(messages), "message": { "role": "assistant", "content": rkllm_output, }, "logprobs": None, "finish_reason": "stop" } ) return jsonify(rkllm_responses), 200 else: messages = data['messages'] print("Received messages:", messages) input_prompt = msgs_to_prompt(messages) print("generated prompt:", input_prompt) rkllm_output = "" def generate(): model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,)) model_thread.start() model_thread_finished = False while not model_thread_finished: while len(global_text) > 0: rkllm_output = global_text.pop(0) rkllm_responses["choices"].append( {"index": len(messages), "delta": { "role": "assistant", "content": rkllm_output, }, "logprobs": None, "finish_reason": "stop" if global_state == 1 else None, } ) yield f"{json.dumps(rkllm_responses)}\n\n" model_thread.join(timeout=0.005) model_thread_finished = not model_thread.is_alive() return Response(generate(), content_type='text/plain') else: return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400 finally: lock.release() is_blocking = False # Start the Flask application. app.run(host='0.0.0.0', port=8080, threaded=True, use_reloader=False, debug=True, use_debugger=False) # maybe no debug? print("====================") print("RKLLM model inference completed, releasing RKLLM model resources...") rkllm_model.release() print("====================")