initial commit
This commit is contained in:
commit
46583caabf
15 changed files with 6230 additions and 0 deletions
268
flask_server.py
Normal file
268
flask_server.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
import ctypes
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import resource
|
||||
import threading
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
from flask import Flask, request, jsonify, Response
|
||||
from rkllm import *
|
||||
|
||||
app = Flask(__name__,static_url_path='',static_folder='static')
|
||||
@app.route('/')
|
||||
def root():
|
||||
return app.send_static_file('index.html')
|
||||
|
||||
PROMPT_TEXT_PREFIX = "<|im_start|>system\nYou are a helpful assistant. You only give short answers.<|im_end|>\n<|im_start|>user\n"
|
||||
PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
MSG_START_TOKEN = "<|im_start|>"
|
||||
MSG_END_TOKEN = "<|im_end|>"
|
||||
|
||||
def msg_to_prompt(user, msg):
|
||||
return f'{MSG_START_TOKEN}{user}\n{msg}{MSG_END_TOKEN}\n'
|
||||
|
||||
def msgs_to_prompt(msgs: list[dict]):
|
||||
system = msgs[0] if msgs and msgs[0]['role'] == 'system' else ""
|
||||
if not system:
|
||||
msgs.insert(0, {'role':'system', 'content': "You are a helpful assistant. You only give short but complete answers."})
|
||||
return (''.join(msg_to_prompt(msg['role'], msg['content']) for msg in msgs) +
|
||||
f'{MSG_START_TOKEN}assistant\n')
|
||||
|
||||
|
||||
# Create a lock to control multi-user access to the server.
|
||||
lock = threading.Lock()
|
||||
|
||||
# Create a global variable to indicate whether the server is currently in a blocked state.
|
||||
is_blocking = False
|
||||
|
||||
# Define global variables to store the callback function output for displaying in the Gradio interface
|
||||
global_text = []
|
||||
global_state = -1
|
||||
split_byte_data = bytes(b"") # Used to store the segmented byte data
|
||||
global_abort = False
|
||||
|
||||
# Define the callback function
|
||||
def callback_impl(result, userdata, state):
|
||||
global global_text, global_state, split_byte_data
|
||||
if state == LLMCallState.RKLLM_RUN_FINISH:
|
||||
global_state = state
|
||||
print("\n")
|
||||
sys.stdout.flush()
|
||||
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
||||
global_state = state
|
||||
print("run error")
|
||||
sys.stdout.flush()
|
||||
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
|
||||
'''
|
||||
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
|
||||
With these three parameters, you can retrieve the data from last_hidden_layer.
|
||||
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
|
||||
'''
|
||||
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
|
||||
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
|
||||
print(f"data_size: {data_size}")
|
||||
global_text.append(f"data_size: {data_size}\n")
|
||||
output_path = os.getcwd() + "/last_hidden_layer.bin"
|
||||
with open(output_path, "wb") as outFile:
|
||||
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
|
||||
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
|
||||
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
|
||||
outFile.write(bytearray(float_array))
|
||||
print(f"Data saved to {output_path} successfully!")
|
||||
global_text.append(f"Data saved to {output_path} successfully!")
|
||||
else:
|
||||
print("Invalid hidden layer data.")
|
||||
global_text.append("Invalid hidden layer data.")
|
||||
global_state = state
|
||||
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
# Save the output token text and the RKLLM running state
|
||||
global_state = state
|
||||
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
|
||||
try:
|
||||
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
|
||||
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
|
||||
split_byte_data = bytes(b"")
|
||||
except:
|
||||
split_byte_data += result.contents.text
|
||||
sys.stdout.flush()
|
||||
|
||||
# Connect the callback function between the Python side and the C++ side
|
||||
callback = callback_type(callback_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
|
||||
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
|
||||
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
|
||||
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.rkllm_model_path):
|
||||
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
|
||||
sys.stdout.flush()
|
||||
exit()
|
||||
|
||||
if not (args.target_platform in ["rk3588", "rk3576"]):
|
||||
print("Error: Please specify the correct target platform: rk3588/rk3576.")
|
||||
sys.stdout.flush()
|
||||
exit()
|
||||
|
||||
if args.lora_model_path:
|
||||
if not os.path.exists(args.lora_model_path):
|
||||
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
|
||||
sys.stdout.flush()
|
||||
exit()
|
||||
|
||||
if args.prompt_cache_path:
|
||||
if not os.path.exists(args.prompt_cache_path):
|
||||
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
|
||||
sys.stdout.flush()
|
||||
exit()
|
||||
|
||||
# Fix frequency
|
||||
command = "[ \"$(cat /sys/class/devfreq/fdab0000.npu/governor)\" = userspace ] || sudo bash fix_freq_{}.sh".format(args.target_platform)
|
||||
subprocess.run(command, shell=True)
|
||||
|
||||
# Set resource limit
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
|
||||
|
||||
# Initialize RKLLM model
|
||||
print("=========init....===========")
|
||||
sys.stdout.flush()
|
||||
model_path = args.rkllm_model_path
|
||||
rkllm_model = RKLLM(model_path, callback, args.lora_model_path, args.prompt_cache_path)
|
||||
print("RKLLM Model has been initialized successfully!")
|
||||
print("==============================")
|
||||
sys.stdout.flush()
|
||||
|
||||
@app.route('/rkllm_abort', methods=['POST'])
|
||||
def abort():
|
||||
global global_abort
|
||||
global_abort = True
|
||||
code = rkllm_model.abort()
|
||||
return {"code":code}, 200 if code is None else 500
|
||||
|
||||
# Create a function to receive data sent by the user using a request
|
||||
@app.route('/rkllm_chat', methods=['POST'])
|
||||
def receive_message():
|
||||
# Link global variables to retrieve the output information from the callback function
|
||||
global global_text, global_state
|
||||
global is_blocking
|
||||
|
||||
# If the server is in a blocking state, return a specific response.
|
||||
if is_blocking or global_state==0:
|
||||
return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
|
||||
|
||||
lock.acquire()
|
||||
try:
|
||||
# Set the server to a blocking state.
|
||||
is_blocking = True
|
||||
|
||||
# Get JSON data from the POST request.
|
||||
data = request.json
|
||||
if data and 'messages' in data:
|
||||
# Reset global variables.
|
||||
global_text = []
|
||||
global_state = -1
|
||||
|
||||
# Define the structure for the returned response.
|
||||
rkllm_responses = {
|
||||
"id": "rkllm_chat",
|
||||
"object": "rkllm_chat",
|
||||
"created": None,
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None
|
||||
}
|
||||
}
|
||||
|
||||
if not "stream" in data.keys() or data["stream"] == False:
|
||||
# Process the received data here.
|
||||
messages = data['messages']
|
||||
print("Received messages:", messages)
|
||||
input_prompt = msgs_to_prompt(messages)
|
||||
print("generated prompt:", input_prompt)
|
||||
rkllm_output = ""
|
||||
|
||||
# Create a thread for model inference.
|
||||
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||
model_thread.start()
|
||||
|
||||
# Wait for the model to finish running and periodically check the inference thread of the model.
|
||||
model_thread_finished = False
|
||||
global global_abort
|
||||
while not model_thread_finished:
|
||||
while len(global_text) > 0:
|
||||
rkllm_output += global_text.pop(0)
|
||||
time.sleep(0.01)
|
||||
if global_abort:
|
||||
global_abort = False
|
||||
|
||||
|
||||
model_thread.join(timeout=0.005)
|
||||
model_thread_finished = not model_thread.is_alive()
|
||||
|
||||
rkllm_responses["choices"].append(
|
||||
{"index": len(messages),
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": rkllm_output,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
)
|
||||
return jsonify(rkllm_responses), 200
|
||||
else:
|
||||
messages = data['messages']
|
||||
print("Received messages:", messages)
|
||||
input_prompt = msgs_to_prompt(messages)
|
||||
print("generated prompt:", input_prompt)
|
||||
rkllm_output = ""
|
||||
|
||||
def generate():
|
||||
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||
model_thread.start()
|
||||
|
||||
model_thread_finished = False
|
||||
while not model_thread_finished:
|
||||
while len(global_text) > 0:
|
||||
rkllm_output = global_text.pop(0)
|
||||
|
||||
rkllm_responses["choices"].append(
|
||||
{"index": len(messages),
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": rkllm_output,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop" if global_state == 1 else None,
|
||||
}
|
||||
)
|
||||
yield f"{json.dumps(rkllm_responses)}\n\n"
|
||||
|
||||
model_thread.join(timeout=0.005)
|
||||
model_thread_finished = not model_thread.is_alive()
|
||||
|
||||
return Response(generate(), content_type='text/plain')
|
||||
else:
|
||||
return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
|
||||
finally:
|
||||
lock.release()
|
||||
is_blocking = False
|
||||
|
||||
# Start the Flask application.
|
||||
app.run(host='0.0.0.0', port=8080, threaded=True, use_reloader=False, debug=True, use_debugger=False) # maybe no debug?
|
||||
|
||||
print("====================")
|
||||
print("RKLLM model inference completed, releasing RKLLM model resources...")
|
||||
rkllm_model.release()
|
||||
print("====================")
|
||||
Loading…
Add table
Add a link
Reference in a new issue