initial commit
This commit is contained in:
commit
46583caabf
15 changed files with 6230 additions and 0 deletions
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
*venv*
|
||||||
|
__pycache__
|
||||||
|
models/*
|
||||||
|
launch.json
|
||||||
|
*-bak
|
||||||
25
README.md
Normal file
25
README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
# rkllm server
|
||||||
|
|
||||||
|
Hosts a simple flask-based chat interface to a rkllm-model at localhost:8080.
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
On a r3588 system:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone "<repo>/rkllm_server"
|
||||||
|
cd rkllm_server
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
deactivate
|
||||||
|
```
|
||||||
|
|
||||||
|
you can now start the server with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bash ./start_server.sh "/path/to/model.rkllm"
|
||||||
|
```
|
||||||
|
|
||||||
|
The first time on each boot it will ask for a sudo password to fix the npu speed (see [fix_freq_rk3588.sh]).
|
||||||
9
fix_freq_rk3576.sh
Normal file
9
fix_freq_rk3576.sh
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
#!/system/bin/sh
|
||||||
|
|
||||||
|
echo userspace > /sys/class/devfreq/27700000.npu/governor
|
||||||
|
echo 1000000000 > /sys/class/devfreq/27700000.npu/userspace/set_freq
|
||||||
|
|
||||||
|
echo userspace > /sys/devices/system/cpu/cpufreq/policy0/scaling_governor
|
||||||
|
echo 2208000 > /sys/devices/system/cpu/cpufreq/policy0/scaling_setspeed
|
||||||
|
echo userspace > /sys/devices/system/cpu/cpufreq/policy4/scaling_governor
|
||||||
|
echo 2304000 > /sys/devices/system/cpu/cpufreq/policy4/scaling_setspeed
|
||||||
44
fix_freq_rk3588.sh
Normal file
44
fix_freq_rk3588.sh
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu0/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu1/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu2/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu3/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu4/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu5/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu6/cpuidle/state1/disable
|
||||||
|
echo 1 > /sys/devices/system/cpu/cpu7/cpuidle/state1/disable
|
||||||
|
|
||||||
|
echo "NPU available frequencies:"
|
||||||
|
cat /sys/class/devfreq/fdab0000.npu/available_frequencies
|
||||||
|
echo "Fix NPU max frequency:"
|
||||||
|
echo userspace > /sys/class/devfreq/fdab0000.npu/governor
|
||||||
|
echo 1000000000 > /sys/class/devfreq/fdab0000.npu/userspace/set_freq
|
||||||
|
cat /sys/class/devfreq/fdab0000.npu/cur_freq
|
||||||
|
|
||||||
|
echo "CPU available frequencies:"
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy0/scaling_available_frequencies
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy4/scaling_available_frequencies
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy6/scaling_available_frequencies
|
||||||
|
echo "Fix CPU max frequency:"
|
||||||
|
echo userspace > /sys/devices/system/cpu/cpufreq/policy0/scaling_governor
|
||||||
|
echo 1800000 > /sys/devices/system/cpu/cpufreq/policy0/scaling_setspeed
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy0/scaling_cur_freq
|
||||||
|
echo userspace > /sys/devices/system/cpu/cpufreq/policy4/scaling_governor
|
||||||
|
echo 2352000 > /sys/devices/system/cpu/cpufreq/policy4/scaling_setspeed
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy4/scaling_cur_freq
|
||||||
|
echo userspace > /sys/devices/system/cpu/cpufreq/policy6/scaling_governor
|
||||||
|
echo 2352000 > /sys/devices/system/cpu/cpufreq/policy6/scaling_setspeed
|
||||||
|
cat /sys/devices/system/cpu/cpufreq/policy6/scaling_cur_freq
|
||||||
|
|
||||||
|
echo "GPU available frequencies:"
|
||||||
|
cat /sys/class/devfreq/fb000000.gpu-panthor/available_frequencies
|
||||||
|
echo "Fix GPU max frequency:"
|
||||||
|
echo userspace > /sys/class/devfreq/fb000000.gpu-panthor/governor
|
||||||
|
echo 1000000000 > /sys/class/devfreq/fb000000.gpu-panthor/userspace/set_freq
|
||||||
|
cat //sys/class/devfreq/fb000000.gpu-panthor/cur_freq
|
||||||
|
|
||||||
|
echo "DDR available frequencies:"
|
||||||
|
cat /sys/class/devfreq/dmc/available_frequencies
|
||||||
|
echo "Fix DDR max frequency:"
|
||||||
|
echo userspace > /sys/class/devfreq/dmc/governor
|
||||||
|
echo 2112000000 > /sys/class/devfreq/dmc/userspace/set_freq
|
||||||
|
cat /sys/class/devfreq/dmc/cur_freq
|
||||||
268
flask_server.py
Normal file
268
flask_server.py
Normal file
|
|
@ -0,0 +1,268 @@
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import resource
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from flask import Flask, request, jsonify, Response
|
||||||
|
from rkllm import *
|
||||||
|
|
||||||
|
app = Flask(__name__,static_url_path='',static_folder='static')
|
||||||
|
@app.route('/')
|
||||||
|
def root():
|
||||||
|
return app.send_static_file('index.html')
|
||||||
|
|
||||||
|
PROMPT_TEXT_PREFIX = "<|im_start|>system\nYou are a helpful assistant. You only give short answers.<|im_end|>\n<|im_start|>user\n"
|
||||||
|
PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
MSG_START_TOKEN = "<|im_start|>"
|
||||||
|
MSG_END_TOKEN = "<|im_end|>"
|
||||||
|
|
||||||
|
def msg_to_prompt(user, msg):
|
||||||
|
return f'{MSG_START_TOKEN}{user}\n{msg}{MSG_END_TOKEN}\n'
|
||||||
|
|
||||||
|
def msgs_to_prompt(msgs: list[dict]):
|
||||||
|
system = msgs[0] if msgs and msgs[0]['role'] == 'system' else ""
|
||||||
|
if not system:
|
||||||
|
msgs.insert(0, {'role':'system', 'content': "You are a helpful assistant. You only give short but complete answers."})
|
||||||
|
return (''.join(msg_to_prompt(msg['role'], msg['content']) for msg in msgs) +
|
||||||
|
f'{MSG_START_TOKEN}assistant\n')
|
||||||
|
|
||||||
|
|
||||||
|
# Create a lock to control multi-user access to the server.
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
# Create a global variable to indicate whether the server is currently in a blocked state.
|
||||||
|
is_blocking = False
|
||||||
|
|
||||||
|
# Define global variables to store the callback function output for displaying in the Gradio interface
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
split_byte_data = bytes(b"") # Used to store the segmented byte data
|
||||||
|
global_abort = False
|
||||||
|
|
||||||
|
# Define the callback function
|
||||||
|
def callback_impl(result, userdata, state):
|
||||||
|
global global_text, global_state, split_byte_data
|
||||||
|
if state == LLMCallState.RKLLM_RUN_FINISH:
|
||||||
|
global_state = state
|
||||||
|
print("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
||||||
|
global_state = state
|
||||||
|
print("run error")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
|
||||||
|
'''
|
||||||
|
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
|
||||||
|
With these three parameters, you can retrieve the data from last_hidden_layer.
|
||||||
|
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
|
||||||
|
'''
|
||||||
|
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
|
||||||
|
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
|
||||||
|
print(f"data_size: {data_size}")
|
||||||
|
global_text.append(f"data_size: {data_size}\n")
|
||||||
|
output_path = os.getcwd() + "/last_hidden_layer.bin"
|
||||||
|
with open(output_path, "wb") as outFile:
|
||||||
|
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
|
||||||
|
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
|
||||||
|
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
|
||||||
|
outFile.write(bytearray(float_array))
|
||||||
|
print(f"Data saved to {output_path} successfully!")
|
||||||
|
global_text.append(f"Data saved to {output_path} successfully!")
|
||||||
|
else:
|
||||||
|
print("Invalid hidden layer data.")
|
||||||
|
global_text.append("Invalid hidden layer data.")
|
||||||
|
global_state = state
|
||||||
|
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
# Save the output token text and the RKLLM running state
|
||||||
|
global_state = state
|
||||||
|
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
|
||||||
|
try:
|
||||||
|
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
|
||||||
|
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
|
||||||
|
split_byte_data = bytes(b"")
|
||||||
|
except:
|
||||||
|
split_byte_data += result.contents.text
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Connect the callback function between the Python side and the C++ side
|
||||||
|
callback = callback_type(callback_impl)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
|
||||||
|
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
|
||||||
|
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
|
||||||
|
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.rkllm_model_path):
|
||||||
|
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not (args.target_platform in ["rk3588", "rk3576"]):
|
||||||
|
print("Error: Please specify the correct target platform: rk3588/rk3576.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.lora_model_path:
|
||||||
|
if not os.path.exists(args.lora_model_path):
|
||||||
|
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.prompt_cache_path:
|
||||||
|
if not os.path.exists(args.prompt_cache_path):
|
||||||
|
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Fix frequency
|
||||||
|
command = "[ \"$(cat /sys/class/devfreq/fdab0000.npu/governor)\" = userspace ] || sudo bash fix_freq_{}.sh".format(args.target_platform)
|
||||||
|
subprocess.run(command, shell=True)
|
||||||
|
|
||||||
|
# Set resource limit
|
||||||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
|
||||||
|
|
||||||
|
# Initialize RKLLM model
|
||||||
|
print("=========init....===========")
|
||||||
|
sys.stdout.flush()
|
||||||
|
model_path = args.rkllm_model_path
|
||||||
|
rkllm_model = RKLLM(model_path, callback, args.lora_model_path, args.prompt_cache_path)
|
||||||
|
print("RKLLM Model has been initialized successfully!")
|
||||||
|
print("==============================")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
@app.route('/rkllm_abort', methods=['POST'])
|
||||||
|
def abort():
|
||||||
|
global global_abort
|
||||||
|
global_abort = True
|
||||||
|
code = rkllm_model.abort()
|
||||||
|
return {"code":code}, 200 if code is None else 500
|
||||||
|
|
||||||
|
# Create a function to receive data sent by the user using a request
|
||||||
|
@app.route('/rkllm_chat', methods=['POST'])
|
||||||
|
def receive_message():
|
||||||
|
# Link global variables to retrieve the output information from the callback function
|
||||||
|
global global_text, global_state
|
||||||
|
global is_blocking
|
||||||
|
|
||||||
|
# If the server is in a blocking state, return a specific response.
|
||||||
|
if is_blocking or global_state==0:
|
||||||
|
return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
|
||||||
|
|
||||||
|
lock.acquire()
|
||||||
|
try:
|
||||||
|
# Set the server to a blocking state.
|
||||||
|
is_blocking = True
|
||||||
|
|
||||||
|
# Get JSON data from the POST request.
|
||||||
|
data = request.json
|
||||||
|
if data and 'messages' in data:
|
||||||
|
# Reset global variables.
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
|
||||||
|
# Define the structure for the returned response.
|
||||||
|
rkllm_responses = {
|
||||||
|
"id": "rkllm_chat",
|
||||||
|
"object": "rkllm_chat",
|
||||||
|
"created": None,
|
||||||
|
"choices": [],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": None,
|
||||||
|
"completion_tokens": None,
|
||||||
|
"total_tokens": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if not "stream" in data.keys() or data["stream"] == False:
|
||||||
|
# Process the received data here.
|
||||||
|
messages = data['messages']
|
||||||
|
print("Received messages:", messages)
|
||||||
|
input_prompt = msgs_to_prompt(messages)
|
||||||
|
print("generated prompt:", input_prompt)
|
||||||
|
rkllm_output = ""
|
||||||
|
|
||||||
|
# Create a thread for model inference.
|
||||||
|
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||||
|
model_thread.start()
|
||||||
|
|
||||||
|
# Wait for the model to finish running and periodically check the inference thread of the model.
|
||||||
|
model_thread_finished = False
|
||||||
|
global global_abort
|
||||||
|
while not model_thread_finished:
|
||||||
|
while len(global_text) > 0:
|
||||||
|
rkllm_output += global_text.pop(0)
|
||||||
|
time.sleep(0.01)
|
||||||
|
if global_abort:
|
||||||
|
global_abort = False
|
||||||
|
|
||||||
|
|
||||||
|
model_thread.join(timeout=0.005)
|
||||||
|
model_thread_finished = not model_thread.is_alive()
|
||||||
|
|
||||||
|
rkllm_responses["choices"].append(
|
||||||
|
{"index": len(messages),
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": rkllm_output,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return jsonify(rkllm_responses), 200
|
||||||
|
else:
|
||||||
|
messages = data['messages']
|
||||||
|
print("Received messages:", messages)
|
||||||
|
input_prompt = msgs_to_prompt(messages)
|
||||||
|
print("generated prompt:", input_prompt)
|
||||||
|
rkllm_output = ""
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||||
|
model_thread.start()
|
||||||
|
|
||||||
|
model_thread_finished = False
|
||||||
|
while not model_thread_finished:
|
||||||
|
while len(global_text) > 0:
|
||||||
|
rkllm_output = global_text.pop(0)
|
||||||
|
|
||||||
|
rkllm_responses["choices"].append(
|
||||||
|
{"index": len(messages),
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": rkllm_output,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop" if global_state == 1 else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield f"{json.dumps(rkllm_responses)}\n\n"
|
||||||
|
|
||||||
|
model_thread.join(timeout=0.005)
|
||||||
|
model_thread_finished = not model_thread.is_alive()
|
||||||
|
|
||||||
|
return Response(generate(), content_type='text/plain')
|
||||||
|
else:
|
||||||
|
return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
is_blocking = False
|
||||||
|
|
||||||
|
# Start the Flask application.
|
||||||
|
app.run(host='0.0.0.0', port=8080, threaded=True, use_reloader=False, debug=True, use_debugger=False) # maybe no debug?
|
||||||
|
|
||||||
|
print("====================")
|
||||||
|
print("RKLLM model inference completed, releasing RKLLM model resources...")
|
||||||
|
rkllm_model.release()
|
||||||
|
print("====================")
|
||||||
BIN
lib/librkllmrt.so
Executable file
BIN
lib/librkllmrt.so
Executable file
Binary file not shown.
2
requirements.txt
Normal file
2
requirements.txt
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
flask==2.2.2
|
||||||
|
Werkzeug==2.2.2
|
||||||
280
rkllm.py
Normal file
280
rkllm.py
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
|
||||||
|
# Set the dynamic library path
|
||||||
|
rkllm_lib = ctypes.CDLL("lib/librkllmrt.so")
|
||||||
|
|
||||||
|
# Define the structures from the library
|
||||||
|
RKLLM_Handle_t = ctypes.c_void_p
|
||||||
|
userdata = ctypes.c_void_p(None)
|
||||||
|
|
||||||
|
LLMCallState = ctypes.c_int
|
||||||
|
LLMCallState.RKLLM_RUN_NORMAL = 0
|
||||||
|
LLMCallState.RKLLM_RUN_WAITING = 1
|
||||||
|
LLMCallState.RKLLM_RUN_FINISH = 2
|
||||||
|
LLMCallState.RKLLM_RUN_ERROR = 3
|
||||||
|
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
|
||||||
|
|
||||||
|
RKLLMInputMode = ctypes.c_int
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
|
||||||
|
|
||||||
|
RKLLMInferMode = ctypes.c_int
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMExtendParam(ctypes.Structure):
|
||||||
|
_fields_ = [("base_domain_id", ctypes.c_int32), ("reserved", ctypes.c_uint8 * 112)]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("model_path", ctypes.c_char_p),
|
||||||
|
("max_context_len", ctypes.c_int32),
|
||||||
|
("max_new_tokens", ctypes.c_int32),
|
||||||
|
("top_k", ctypes.c_int32),
|
||||||
|
("top_p", ctypes.c_float),
|
||||||
|
("temperature", ctypes.c_float),
|
||||||
|
("repeat_penalty", ctypes.c_float),
|
||||||
|
("frequency_penalty", ctypes.c_float),
|
||||||
|
("presence_penalty", ctypes.c_float),
|
||||||
|
("mirostat", ctypes.c_int32),
|
||||||
|
("mirostat_tau", ctypes.c_float),
|
||||||
|
("mirostat_eta", ctypes.c_float),
|
||||||
|
("skip_special_token", ctypes.c_bool),
|
||||||
|
("is_async", ctypes.c_bool),
|
||||||
|
("img_start", ctypes.c_char_p),
|
||||||
|
("img_end", ctypes.c_char_p),
|
||||||
|
("img_content", ctypes.c_char_p),
|
||||||
|
("extend_param", RKLLMExtendParam),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMLoraAdapter(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("lora_adapter_path", ctypes.c_char_p),
|
||||||
|
("lora_adapter_name", ctypes.c_char_p),
|
||||||
|
("scale", ctypes.c_float),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMEmbedInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_tokens", ctypes.c_size_t),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMTokenInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
||||||
|
("n_tokens", ctypes.c_size_t),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMMultiModelInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt", ctypes.c_char_p),
|
||||||
|
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_image_tokens", ctypes.c_size_t),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMInputUnion(ctypes.Union):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt_input", ctypes.c_char_p),
|
||||||
|
("embed_input", RKLLMEmbedInput),
|
||||||
|
("token_input", RKLLMTokenInput),
|
||||||
|
("multimodal_input", RKLLMMultiModelInput),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMInput(ctypes.Structure):
|
||||||
|
_fields_ = [("input_mode", ctypes.c_int), ("input_data", RKLLMInputUnion)]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMLoraParam(ctypes.Structure):
|
||||||
|
_fields_ = [("lora_adapter_name", ctypes.c_char_p)]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMPromptCacheParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("save_prompt_cache", ctypes.c_int),
|
||||||
|
("prompt_cache_path", ctypes.c_char_p),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMInferParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("mode", RKLLMInferMode),
|
||||||
|
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
||||||
|
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("embd_size", ctypes.c_int),
|
||||||
|
("num_tokens", ctypes.c_int),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RKLLMResult(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("text", ctypes.c_char_p),
|
||||||
|
("size", ctypes.c_int),
|
||||||
|
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
callback_type = ctypes.CFUNCTYPE(
|
||||||
|
None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
|
||||||
|
class RKLLM(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path,
|
||||||
|
callback,
|
||||||
|
lora_model_path=None,
|
||||||
|
prompt_cache_path=None,
|
||||||
|
):
|
||||||
|
rkllm_param = RKLLMParam()
|
||||||
|
rkllm_param.model_path = bytes(model_path, "utf-8")
|
||||||
|
|
||||||
|
rkllm_param.max_context_len = 10000
|
||||||
|
rkllm_param.max_new_tokens = -1
|
||||||
|
rkllm_param.skip_special_token = True
|
||||||
|
|
||||||
|
rkllm_param.top_k = 20
|
||||||
|
rkllm_param.top_p = 0.8
|
||||||
|
rkllm_param.temperature = 0.7
|
||||||
|
rkllm_param.repeat_penalty = 1.1
|
||||||
|
rkllm_param.frequency_penalty = 0.0
|
||||||
|
rkllm_param.presence_penalty = 0.0
|
||||||
|
|
||||||
|
rkllm_param.mirostat = 0
|
||||||
|
rkllm_param.mirostat_tau = 5.0
|
||||||
|
rkllm_param.mirostat_eta = 0.1
|
||||||
|
|
||||||
|
rkllm_param.is_async = False
|
||||||
|
|
||||||
|
rkllm_param.img_start = "".encode("utf-8")
|
||||||
|
rkllm_param.img_end = "".encode("utf-8")
|
||||||
|
rkllm_param.img_content = "".encode("utf-8")
|
||||||
|
|
||||||
|
rkllm_param.extend_param.base_domain_id = 0
|
||||||
|
|
||||||
|
self.handle = RKLLM_Handle_t()
|
||||||
|
|
||||||
|
self.rkllm_init = rkllm_lib.rkllm_init
|
||||||
|
self.rkllm_init.argtypes = [
|
||||||
|
ctypes.POINTER(RKLLM_Handle_t),
|
||||||
|
ctypes.POINTER(RKLLMParam),
|
||||||
|
callback_type,
|
||||||
|
]
|
||||||
|
self.rkllm_init.restype = ctypes.c_int
|
||||||
|
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
|
||||||
|
|
||||||
|
self.rkllm_run = rkllm_lib.rkllm_run
|
||||||
|
self.rkllm_run.argtypes = [
|
||||||
|
RKLLM_Handle_t,
|
||||||
|
ctypes.POINTER(RKLLMInput),
|
||||||
|
ctypes.POINTER(RKLLMInferParam),
|
||||||
|
ctypes.c_void_p,
|
||||||
|
]
|
||||||
|
self.rkllm_run.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.rkllm_run_async = rkllm_lib.rkllm_run_async
|
||||||
|
self.rkllm_run_async.argtypes = [
|
||||||
|
RKLLM_Handle_t,
|
||||||
|
ctypes.POINTER(RKLLMInput),
|
||||||
|
ctypes.POINTER(RKLLMInferParam),
|
||||||
|
ctypes.c_void_p,
|
||||||
|
]
|
||||||
|
self.rkllm_run_async.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.rkllm_abort = rkllm_lib.rkllm_abort
|
||||||
|
self.rkllm_abort.argtypes = [RKLLM_Handle_t]
|
||||||
|
self.rkllm_abort.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.rkllm_destroy = rkllm_lib.rkllm_destroy
|
||||||
|
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
|
||||||
|
self.rkllm_destroy.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.lora_adapter_path = None
|
||||||
|
self.lora_model_name = None
|
||||||
|
if lora_model_path:
|
||||||
|
self.lora_adapter_path = lora_model_path
|
||||||
|
self.lora_adapter_name = "test"
|
||||||
|
|
||||||
|
lora_adapter = RKLLMLoraAdapter()
|
||||||
|
ctypes.memset(
|
||||||
|
ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter)
|
||||||
|
)
|
||||||
|
lora_adapter.lora_adapter_path = ctypes.c_char_p(
|
||||||
|
(self.lora_adapter_path).encode("utf-8")
|
||||||
|
)
|
||||||
|
lora_adapter.lora_adapter_name = ctypes.c_char_p(
|
||||||
|
(self.lora_adapter_name).encode("utf-8")
|
||||||
|
)
|
||||||
|
lora_adapter.scale = 1.0
|
||||||
|
|
||||||
|
rkllm_load_lora = rkllm_lib.rkllm_load_lora
|
||||||
|
rkllm_load_lora.argtypes = [
|
||||||
|
RKLLM_Handle_t,
|
||||||
|
ctypes.POINTER(RKLLMLoraAdapter),
|
||||||
|
]
|
||||||
|
rkllm_load_lora.restype = ctypes.c_int
|
||||||
|
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
|
||||||
|
|
||||||
|
self.prompt_cache_path = None
|
||||||
|
if prompt_cache_path:
|
||||||
|
self.prompt_cache_path = prompt_cache_path
|
||||||
|
|
||||||
|
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
|
||||||
|
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
|
||||||
|
rkllm_load_prompt_cache.restype = ctypes.c_int
|
||||||
|
rkllm_load_prompt_cache(
|
||||||
|
self.handle, ctypes.c_char_p((prompt_cache_path).encode("utf-8"))
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, prompt):
|
||||||
|
rkllm_lora_params = None
|
||||||
|
if self.lora_model_name:
|
||||||
|
rkllm_lora_params = RKLLMLoraParam()
|
||||||
|
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p(
|
||||||
|
(self.lora_model_name).encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
rkllm_infer_params = RKLLMInferParam()
|
||||||
|
ctypes.memset(
|
||||||
|
ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)
|
||||||
|
)
|
||||||
|
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
||||||
|
rkllm_infer_params.lora_params = (
|
||||||
|
ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
|
||||||
|
)
|
||||||
|
|
||||||
|
rkllm_input = RKLLMInput()
|
||||||
|
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
|
||||||
|
rkllm_input.input_data.prompt_input = ctypes.c_char_p((prompt).encode("utf-8"))
|
||||||
|
self.rkllm_run(
|
||||||
|
self.handle,
|
||||||
|
ctypes.byref(rkllm_input),
|
||||||
|
ctypes.byref(rkllm_infer_params),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def abort(self):
|
||||||
|
return self.rkllm_abort(self.handle)
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.rkllm_destroy(self.handle)
|
||||||
18
start_server.sh
Normal file
18
start_server.sh
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
MODEL="$(realpath "$1")"
|
||||||
|
SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||||
|
echo "Serverdir: $SCRIPT_DIR"
|
||||||
|
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "USAGE: start.sh <path-to-model>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Modelpath: $MODEL"
|
||||||
|
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
source "./venv/bin/activate"
|
||||||
|
python3 -m debugpy --listen 0.0.0.0:5679 flask_server.py --rkllm_model_path "$MODEL" --target_platform rk3588
|
||||||
|
|
||||||
|
|
||||||
1
static/atom-one-dark.min.css
vendored
Normal file
1
static/atom-one-dark.min.css
vendored
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{color:#abb2bf;background:#282c34}.hljs-comment,.hljs-quote{color:#5c6370;font-style:italic}.hljs-doctag,.hljs-formula,.hljs-keyword{color:#c678dd}.hljs-deletion,.hljs-name,.hljs-section,.hljs-selector-tag,.hljs-subst{color:#e06c75}.hljs-literal{color:#56b6c2}.hljs-addition,.hljs-attribute,.hljs-meta .hljs-string,.hljs-regexp,.hljs-string{color:#98c379}.hljs-attr,.hljs-number,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-pseudo,.hljs-template-variable,.hljs-type,.hljs-variable{color:#d19a66}.hljs-bullet,.hljs-link,.hljs-meta,.hljs-selector-id,.hljs-symbol,.hljs-title{color:#61aeee}.hljs-built_in,.hljs-class .hljs-title,.hljs-title.class_{color:#e6c07b}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}.hljs-link{text-decoration:underline}
|
||||||
BIN
static/favicon.ico
Normal file
BIN
static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 KiB |
3861
static/highlight.min.js
vendored
Normal file
3861
static/highlight.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
38
static/index.html
Normal file
38
static/index.html
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>RKLLM Chat</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: Arial, sans-serif; background-color: #333; color: #fff; margin: 0; padding: 0; display: flex; justify-content: center; align-items: center; height: 100vh; }
|
||||||
|
.chat-container { width: 95%; height: 95%; background-color: #444; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
|
||||||
|
.messages { flex: 1; padding: 20px; flex-direction: column; overflow-y: auto; }
|
||||||
|
.message { flex: 1; padding: 10px; max-width: 60%; margin: 5px; border-radius: 5px; background-color: #333; }
|
||||||
|
.selected { background-color: #343; }
|
||||||
|
.message:hover { background-color: #353; }
|
||||||
|
.clearfix { clear: both; display: table; margin:5px 0; }
|
||||||
|
.input-area { display: flex; }
|
||||||
|
.input-box { flex: 1; padding: 10px; background-color: #555; border: none; resize: vertical; color: #fff; height:fit-content;}
|
||||||
|
.send-btn { padding: 10px 20px; background-color: #666; border: none; cursor: pointer; transition: background-color 0.3s ease;}
|
||||||
|
.send-btn:hover { background-color: #777; }
|
||||||
|
.abort-btn { padding: 10px 20px; background-color: #511; border: none; cursor: pointer; transition: background-color 0.3s ease; }
|
||||||
|
.abort-btn:hover { background-color: #711; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="chat-container">
|
||||||
|
<textarea class="input-box" id="systemMessage" style="max-height:20px;" placeholder="Change the system directions here."></textarea>
|
||||||
|
<div class="messages" id="messages"></div>
|
||||||
|
<div class="input-area">
|
||||||
|
<textarea class="input-box" id="messageInput" placeholder="Type your message here..."></textarea>
|
||||||
|
<button class="send-btn" onclick="sendMessage()">Send</button>
|
||||||
|
<button class="abort-btn" onclick="sendAbort()">Abort</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script src="stmd.js"></script>
|
||||||
|
<script src="highlight.min.js"></script>
|
||||||
|
<script src="v2.js"></script>
|
||||||
|
<link rel="stylesheet" href="atom-one-dark.min.css">
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
1547
static/stmd.js
Normal file
1547
static/stmd.js
Normal file
File diff suppressed because it is too large
Load diff
132
static/v2.js
Normal file
132
static/v2.js
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
//var stmd = require('stmd');
|
||||||
|
|
||||||
|
|
||||||
|
const messagesContainer = document.getElementById('messages');
|
||||||
|
const systemInput = document.getElementById('systemMessage');
|
||||||
|
const messageInput = document.getElementById('messageInput');
|
||||||
|
var parser = new stmd.DocParser();
|
||||||
|
var renderer = new stmd.HtmlRenderer();
|
||||||
|
var cur_idx = 0;
|
||||||
|
const questions = [];
|
||||||
|
const answers = [];
|
||||||
|
var selected_msg = [];
|
||||||
|
|
||||||
|
messageInput.addEventListener('keydown', (event) => {
|
||||||
|
if (event.key === "Enter" && (event.metaKey || event.ctrlKey)) {
|
||||||
|
sendMessage();
|
||||||
|
} else if (event.key === "Escape" && (event.metaKey || event.ctrlKey)) {
|
||||||
|
sendAbort();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function messageListener(elem, idx, type) {
|
||||||
|
return () => {
|
||||||
|
elem.classList.toggle('selected');
|
||||||
|
const e = { id: idx, type: type };
|
||||||
|
const i = selected_msg.indexOf(e);
|
||||||
|
if (i === -1)
|
||||||
|
selected_msg.push(e);
|
||||||
|
else selected_msg.splice(i, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendAbort() {
|
||||||
|
fetch('/rkllm_abort', {
|
||||||
|
method: 'POST',
|
||||||
|
body: true
|
||||||
|
}).catch(error => alert(error))
|
||||||
|
}
|
||||||
|
|
||||||
|
function _try_highlightAll() {
|
||||||
|
try {
|
||||||
|
hljs.highlightAll();
|
||||||
|
} catch { /*e*/ }
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendMessage() {
|
||||||
|
const messageText = messageInput.value.trim();
|
||||||
|
if (!messageText) return;
|
||||||
|
messageHTML = renderer.render(parser.parse(`**You:**\n\n${messageText}`))
|
||||||
|
mess = document.createElement('div');
|
||||||
|
mess.classList.add('message', 'clearfix');
|
||||||
|
mess.style.float = "right"
|
||||||
|
mess.innerHTML = messageHTML
|
||||||
|
mess.addEventListener('click', messageListener(mess, cur_idx, 'q'));
|
||||||
|
questions.push(messageText);
|
||||||
|
messagesContainer.appendChild(mess);
|
||||||
|
Array.from(mess.getElementsByTagName('code')).forEach(e => hljs.highlightElement(e));
|
||||||
|
messageInput.value = '';
|
||||||
|
messagesContainer.scrollTop = messagesContainer.scrollHeight; // Auto-scroll to bottom
|
||||||
|
|
||||||
|
const messagesToSend = selected_msg.sort((a,b) => {
|
||||||
|
const sub = a.id-b.id;
|
||||||
|
if (sub === 0) return b.type === 'q' ? 1 : -1;
|
||||||
|
return sub;
|
||||||
|
}).map(m => {
|
||||||
|
const q = m.type === 'q';
|
||||||
|
return {
|
||||||
|
role: q ? 'user' : 'assistant',
|
||||||
|
content: (q ? questions : answers)[m.id]
|
||||||
|
}});
|
||||||
|
messagesToSend.push({role:'user', content: messageText})
|
||||||
|
const sysText = systemInput.value.trim();
|
||||||
|
|
||||||
|
const withSysMsg = (sysText ? [{role: 'system', 'content': sysText}] : []).concat(messagesToSend);
|
||||||
|
|
||||||
|
const response = await fetch('/rkllm_chat', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: "some model",
|
||||||
|
messages: withSysMsg,
|
||||||
|
stream: true
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to connect to chat server');
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
let decoder = new TextDecoder('utf-8');
|
||||||
|
//const name = (Math.random() + 1).toString(36).substring(7);
|
||||||
|
const message = document.createElement('div');
|
||||||
|
message.classList.add('message', 'clearfix');
|
||||||
|
message.style.float = 'left';
|
||||||
|
messagesContainer.appendChild(message);
|
||||||
|
let chunks = '';
|
||||||
|
function double_try() {
|
||||||
|
function display(n) {
|
||||||
|
const ccs = chunks.split(/\n/).filter(e => e !== "");
|
||||||
|
const json = JSON.parse(ccs[ccs.length - n]);
|
||||||
|
const text = `${json.choices.map((c) => c.delta.content).join('')}`;
|
||||||
|
while (cur_idx >= answers.length) answers.push("");
|
||||||
|
answers[cur_idx] = text;
|
||||||
|
message.innerHTML = renderer.render(parser.parse(`**RKLLM**:\n\n${text}`));
|
||||||
|
Array.from(message.getElementsByTagName('code')).forEach(e => hljs.highlightElement(e));
|
||||||
|
}
|
||||||
|
const at_bottom = Math.abs(messagesContainer.scrollHeight - messagesContainer.clientHeight - messagesContainer.scrollTop) <= 1;;
|
||||||
|
console.log(messagesContainer.scrollHeight, messagesContainer.scrollTop, messagesContainer.clientHeight);
|
||||||
|
try {
|
||||||
|
display(1);
|
||||||
|
} catch {
|
||||||
|
try {
|
||||||
|
display(2);
|
||||||
|
} catch (e) { console.error(e, chunks); }
|
||||||
|
} finally {
|
||||||
|
if (at_bottom) messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
chunks += chunk.replace('\r', '');
|
||||||
|
double_try();
|
||||||
|
}
|
||||||
|
double_try();
|
||||||
|
|
||||||
|
message.addEventListener('click', messageListener(message, cur_idx++, 'a'));
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue