initial commit
This commit is contained in:
commit
46583caabf
15 changed files with 6230 additions and 0 deletions
280
rkllm.py
Normal file
280
rkllm.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
import ctypes
|
||||
|
||||
|
||||
# Set the dynamic library path
|
||||
rkllm_lib = ctypes.CDLL("lib/librkllmrt.so")
|
||||
|
||||
# Define the structures from the library
|
||||
RKLLM_Handle_t = ctypes.c_void_p
|
||||
userdata = ctypes.c_void_p(None)
|
||||
|
||||
LLMCallState = ctypes.c_int
|
||||
LLMCallState.RKLLM_RUN_NORMAL = 0
|
||||
LLMCallState.RKLLM_RUN_WAITING = 1
|
||||
LLMCallState.RKLLM_RUN_FINISH = 2
|
||||
LLMCallState.RKLLM_RUN_ERROR = 3
|
||||
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
|
||||
|
||||
RKLLMInputMode = ctypes.c_int
|
||||
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
|
||||
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
|
||||
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
|
||||
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
|
||||
|
||||
RKLLMInferMode = ctypes.c_int
|
||||
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
|
||||
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
||||
|
||||
|
||||
class RKLLMExtendParam(ctypes.Structure):
|
||||
_fields_ = [("base_domain_id", ctypes.c_int32), ("reserved", ctypes.c_uint8 * 112)]
|
||||
|
||||
|
||||
class RKLLMParam(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("model_path", ctypes.c_char_p),
|
||||
("max_context_len", ctypes.c_int32),
|
||||
("max_new_tokens", ctypes.c_int32),
|
||||
("top_k", ctypes.c_int32),
|
||||
("top_p", ctypes.c_float),
|
||||
("temperature", ctypes.c_float),
|
||||
("repeat_penalty", ctypes.c_float),
|
||||
("frequency_penalty", ctypes.c_float),
|
||||
("presence_penalty", ctypes.c_float),
|
||||
("mirostat", ctypes.c_int32),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("skip_special_token", ctypes.c_bool),
|
||||
("is_async", ctypes.c_bool),
|
||||
("img_start", ctypes.c_char_p),
|
||||
("img_end", ctypes.c_char_p),
|
||||
("img_content", ctypes.c_char_p),
|
||||
("extend_param", RKLLMExtendParam),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMLoraAdapter(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("lora_adapter_path", ctypes.c_char_p),
|
||||
("lora_adapter_name", ctypes.c_char_p),
|
||||
("scale", ctypes.c_float),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMEmbedInput(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("embed", ctypes.POINTER(ctypes.c_float)),
|
||||
("n_tokens", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMTokenInput(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
||||
("n_tokens", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMMultiModelInput(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("prompt", ctypes.c_char_p),
|
||||
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
||||
("n_image_tokens", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMInputUnion(ctypes.Union):
|
||||
_fields_ = [
|
||||
("prompt_input", ctypes.c_char_p),
|
||||
("embed_input", RKLLMEmbedInput),
|
||||
("token_input", RKLLMTokenInput),
|
||||
("multimodal_input", RKLLMMultiModelInput),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMInput(ctypes.Structure):
|
||||
_fields_ = [("input_mode", ctypes.c_int), ("input_data", RKLLMInputUnion)]
|
||||
|
||||
|
||||
class RKLLMLoraParam(ctypes.Structure):
|
||||
_fields_ = [("lora_adapter_name", ctypes.c_char_p)]
|
||||
|
||||
|
||||
class RKLLMPromptCacheParam(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("save_prompt_cache", ctypes.c_int),
|
||||
("prompt_cache_path", ctypes.c_char_p),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMInferParam(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("mode", RKLLMInferMode),
|
||||
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
||||
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
||||
("embd_size", ctypes.c_int),
|
||||
("num_tokens", ctypes.c_int),
|
||||
]
|
||||
|
||||
|
||||
class RKLLMResult(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("text", ctypes.c_char_p),
|
||||
("size", ctypes.c_int),
|
||||
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
||||
]
|
||||
|
||||
|
||||
callback_type = ctypes.CFUNCTYPE(
|
||||
None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int
|
||||
)
|
||||
|
||||
|
||||
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
|
||||
class RKLLM(object):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
callback,
|
||||
lora_model_path=None,
|
||||
prompt_cache_path=None,
|
||||
):
|
||||
rkllm_param = RKLLMParam()
|
||||
rkllm_param.model_path = bytes(model_path, "utf-8")
|
||||
|
||||
rkllm_param.max_context_len = 10000
|
||||
rkllm_param.max_new_tokens = -1
|
||||
rkllm_param.skip_special_token = True
|
||||
|
||||
rkllm_param.top_k = 20
|
||||
rkllm_param.top_p = 0.8
|
||||
rkllm_param.temperature = 0.7
|
||||
rkllm_param.repeat_penalty = 1.1
|
||||
rkllm_param.frequency_penalty = 0.0
|
||||
rkllm_param.presence_penalty = 0.0
|
||||
|
||||
rkllm_param.mirostat = 0
|
||||
rkllm_param.mirostat_tau = 5.0
|
||||
rkllm_param.mirostat_eta = 0.1
|
||||
|
||||
rkllm_param.is_async = False
|
||||
|
||||
rkllm_param.img_start = "".encode("utf-8")
|
||||
rkllm_param.img_end = "".encode("utf-8")
|
||||
rkllm_param.img_content = "".encode("utf-8")
|
||||
|
||||
rkllm_param.extend_param.base_domain_id = 0
|
||||
|
||||
self.handle = RKLLM_Handle_t()
|
||||
|
||||
self.rkllm_init = rkllm_lib.rkllm_init
|
||||
self.rkllm_init.argtypes = [
|
||||
ctypes.POINTER(RKLLM_Handle_t),
|
||||
ctypes.POINTER(RKLLMParam),
|
||||
callback_type,
|
||||
]
|
||||
self.rkllm_init.restype = ctypes.c_int
|
||||
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
|
||||
|
||||
self.rkllm_run = rkllm_lib.rkllm_run
|
||||
self.rkllm_run.argtypes = [
|
||||
RKLLM_Handle_t,
|
||||
ctypes.POINTER(RKLLMInput),
|
||||
ctypes.POINTER(RKLLMInferParam),
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
self.rkllm_run.restype = ctypes.c_int
|
||||
|
||||
self.rkllm_run_async = rkllm_lib.rkllm_run_async
|
||||
self.rkllm_run_async.argtypes = [
|
||||
RKLLM_Handle_t,
|
||||
ctypes.POINTER(RKLLMInput),
|
||||
ctypes.POINTER(RKLLMInferParam),
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
self.rkllm_run_async.restype = ctypes.c_int
|
||||
|
||||
self.rkllm_abort = rkllm_lib.rkllm_abort
|
||||
self.rkllm_abort.argtypes = [RKLLM_Handle_t]
|
||||
self.rkllm_abort.restype = ctypes.c_int
|
||||
|
||||
self.rkllm_destroy = rkllm_lib.rkllm_destroy
|
||||
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
|
||||
self.rkllm_destroy.restype = ctypes.c_int
|
||||
|
||||
self.lora_adapter_path = None
|
||||
self.lora_model_name = None
|
||||
if lora_model_path:
|
||||
self.lora_adapter_path = lora_model_path
|
||||
self.lora_adapter_name = "test"
|
||||
|
||||
lora_adapter = RKLLMLoraAdapter()
|
||||
ctypes.memset(
|
||||
ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter)
|
||||
)
|
||||
lora_adapter.lora_adapter_path = ctypes.c_char_p(
|
||||
(self.lora_adapter_path).encode("utf-8")
|
||||
)
|
||||
lora_adapter.lora_adapter_name = ctypes.c_char_p(
|
||||
(self.lora_adapter_name).encode("utf-8")
|
||||
)
|
||||
lora_adapter.scale = 1.0
|
||||
|
||||
rkllm_load_lora = rkllm_lib.rkllm_load_lora
|
||||
rkllm_load_lora.argtypes = [
|
||||
RKLLM_Handle_t,
|
||||
ctypes.POINTER(RKLLMLoraAdapter),
|
||||
]
|
||||
rkllm_load_lora.restype = ctypes.c_int
|
||||
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
|
||||
|
||||
self.prompt_cache_path = None
|
||||
if prompt_cache_path:
|
||||
self.prompt_cache_path = prompt_cache_path
|
||||
|
||||
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
|
||||
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
|
||||
rkllm_load_prompt_cache.restype = ctypes.c_int
|
||||
rkllm_load_prompt_cache(
|
||||
self.handle, ctypes.c_char_p((prompt_cache_path).encode("utf-8"))
|
||||
)
|
||||
|
||||
def run(self, prompt):
|
||||
rkllm_lora_params = None
|
||||
if self.lora_model_name:
|
||||
rkllm_lora_params = RKLLMLoraParam()
|
||||
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p(
|
||||
(self.lora_model_name).encode("utf-8")
|
||||
)
|
||||
|
||||
rkllm_infer_params = RKLLMInferParam()
|
||||
ctypes.memset(
|
||||
ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)
|
||||
)
|
||||
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
||||
rkllm_infer_params.lora_params = (
|
||||
ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
|
||||
)
|
||||
|
||||
rkllm_input = RKLLMInput()
|
||||
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
|
||||
rkllm_input.input_data.prompt_input = ctypes.c_char_p((prompt).encode("utf-8"))
|
||||
self.rkllm_run(
|
||||
self.handle,
|
||||
ctypes.byref(rkllm_input),
|
||||
ctypes.byref(rkllm_infer_params),
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
def abort(self):
|
||||
return self.rkllm_abort(self.handle)
|
||||
|
||||
def release(self):
|
||||
self.rkllm_destroy(self.handle)
|
||||
Loading…
Add table
Add a link
Reference in a new issue