rkllm-server/flask_server.py

268 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import sys
import os
import subprocess
import resource
import threading
import time
import argparse
import json
from flask import Flask, request, jsonify, Response
from rkllm import *
app = Flask(__name__,static_url_path='',static_folder='static')
@app.route('/')
def root():
return app.send_static_file('index.html')
PROMPT_TEXT_PREFIX = "<|im_start|>system\nYou are a helpful assistant. You only give short answers.<|im_end|>\n<|im_start|>user\n"
PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant\n"
MSG_START_TOKEN = "<|im_start|>" # there work for Qwen, miniCPM and deepseek, but not for chatglm3
MSG_END_TOKEN = "<|im_end|>"
def msg_to_prompt(user, msg):
return f'{MSG_START_TOKEN}{user}\n{msg}{MSG_END_TOKEN}\n'
def msgs_to_prompt(msgs: list[dict]):
system = msgs[0] if msgs and msgs[0]['role'] == 'system' else ""
if not system:
msgs.insert(0, {'role':'system', 'content': "You are a helpful assistant. You only give short but complete answers."})
return (''.join(msg_to_prompt(msg['role'], msg['content']) for msg in msgs) +
f'{MSG_START_TOKEN}assistant\n')
# Create a lock to control multi-user access to the server.
lock = threading.Lock()
# Create a global variable to indicate whether the server is currently in a blocked state.
is_blocking = False
# Define global variables to store the callback function output for displaying in the Gradio interface
global_text = []
global_state = -1
split_byte_data = bytes(b"") # Used to store the segmented byte data
global_abort = False
# Define the callback function
def callback_impl(result, userdata, state):
global global_text, global_state, split_byte_data
if state == LLMCallState.RKLLM_RUN_FINISH:
global_state = state
print("\n")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_ERROR:
global_state = state
print("run error")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
'''
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
With these three parameters, you can retrieve the data from last_hidden_layer.
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
'''
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
print(f"data_size: {data_size}")
global_text.append(f"data_size: {data_size}\n")
output_path = os.getcwd() + "/last_hidden_layer.bin"
with open(output_path, "wb") as outFile:
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
outFile.write(bytearray(float_array))
print(f"Data saved to {output_path} successfully!")
global_text.append(f"Data saved to {output_path} successfully!")
else:
print("Invalid hidden layer data.")
global_text.append("Invalid hidden layer data.")
global_state = state
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
sys.stdout.flush()
else:
# Save the output token text and the RKLLM running state
global_state = state
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
try:
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
split_byte_data = bytes(b"")
except:
split_byte_data += result.contents.text
sys.stdout.flush()
# Connect the callback function between the Python side and the C++ side
callback = callback_type(callback_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
args = parser.parse_args()
if not os.path.exists(args.rkllm_model_path):
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
sys.stdout.flush()
exit()
if not (args.target_platform in ["rk3588", "rk3576"]):
print("Error: Please specify the correct target platform: rk3588/rk3576.")
sys.stdout.flush()
exit()
if args.lora_model_path:
if not os.path.exists(args.lora_model_path):
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
if args.prompt_cache_path:
if not os.path.exists(args.prompt_cache_path):
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
# Fix frequency
command = "[ \"$(cat /sys/class/devfreq/fdab0000.npu/governor)\" = userspace ] || sudo bash fix_freq_{}.sh".format(args.target_platform)
subprocess.run(command, shell=True)
# Set resource limit
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
# Initialize RKLLM model
print("=========init....===========")
sys.stdout.flush()
model_path = args.rkllm_model_path
rkllm_model = RKLLM(model_path, callback, args.lora_model_path, args.prompt_cache_path)
print("RKLLM Model has been initialized successfully")
print("==============================")
sys.stdout.flush()
@app.route('/rkllm_abort', methods=['POST'])
def abort():
global global_abort
global_abort = True
code = rkllm_model.abort()
return {"code":code}, 200
# Create a function to receive data sent by the user using a request
@app.route('/rkllm_chat', methods=['POST'])
def receive_message():
# Link global variables to retrieve the output information from the callback function
global global_text, global_state
global is_blocking
# If the server is in a blocking state, return a specific response.
if is_blocking or global_state==0:
return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
lock.acquire()
try:
# Set the server to a blocking state.
is_blocking = True
# Get JSON data from the POST request.
data = request.json
if data and 'messages' in data:
# Reset global variables.
global_text = []
global_state = -1
# Define the structure for the returned response.
rkllm_responses = {
"id": "rkllm_chat",
"object": "rkllm_chat",
"created": None,
"choices": [],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None
}
}
if not "stream" in data.keys() or data["stream"] == False:
# Process the received data here.
messages = data['messages']
print("Received messages:", messages)
input_prompt = msgs_to_prompt(messages)
print("generated prompt:", input_prompt)
rkllm_output = ""
# Create a thread for model inference.
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
model_thread.start()
# Wait for the model to finish running and periodically check the inference thread of the model.
model_thread_finished = False
global global_abort
while not model_thread_finished:
while len(global_text) > 0:
rkllm_output += global_text.pop(0)
time.sleep(0.01)
if global_abort:
global_abort = False
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()
rkllm_responses["choices"].append(
{"index": len(messages),
"message": {
"role": "assistant",
"content": rkllm_output,
},
"logprobs": None,
"finish_reason": "stop"
}
)
return jsonify(rkllm_responses), 200
else:
messages = data['messages']
print("Received messages:", messages)
input_prompt = msgs_to_prompt(messages)
print("generated prompt:", input_prompt)
rkllm_output = ""
def generate():
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
model_thread.start()
model_thread_finished = False
while not model_thread_finished:
while len(global_text) > 0:
rkllm_output = global_text.pop(0)
rkllm_responses["choices"].append(
{"index": len(messages),
"delta": {
"role": "assistant",
"content": rkllm_output,
},
"logprobs": None,
"finish_reason": "stop" if global_state == 1 else None,
}
)
yield f"{json.dumps(rkllm_responses)}\n\n"
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()
return Response(generate(), content_type='text/plain')
else:
return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
finally:
lock.release()
is_blocking = False
# Start the Flask application.
app.run(host='0.0.0.0', port=8080, threaded=True, use_reloader=False, debug=True, use_debugger=False) # maybe no debug?
print("====================")
print("RKLLM model inference completed, releasing RKLLM model resources...")
rkllm_model.release()
print("====================")