import ctypes # Set the dynamic library path rkllm_lib = ctypes.CDLL("lib/librkllmrt.so") # Define the structures from the library RKLLM_Handle_t = ctypes.c_void_p userdata = ctypes.c_void_p(None) LLMCallState = ctypes.c_int LLMCallState.RKLLM_RUN_NORMAL = 0 LLMCallState.RKLLM_RUN_WAITING = 1 LLMCallState.RKLLM_RUN_FINISH = 2 LLMCallState.RKLLM_RUN_ERROR = 3 LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4 RKLLMInputMode = ctypes.c_int RKLLMInputMode.RKLLM_INPUT_PROMPT = 0 RKLLMInputMode.RKLLM_INPUT_TOKEN = 1 RKLLMInputMode.RKLLM_INPUT_EMBED = 2 RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3 RKLLMInferMode = ctypes.c_int RKLLMInferMode.RKLLM_INFER_GENERATE = 0 RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 class RKLLMExtendParam(ctypes.Structure): _fields_ = [("base_domain_id", ctypes.c_int32), ("reserved", ctypes.c_uint8 * 112)] class RKLLMParam(ctypes.Structure): _fields_ = [ ("model_path", ctypes.c_char_p), ("max_context_len", ctypes.c_int32), ("max_new_tokens", ctypes.c_int32), ("top_k", ctypes.c_int32), ("top_p", ctypes.c_float), ("temperature", ctypes.c_float), ("repeat_penalty", ctypes.c_float), ("frequency_penalty", ctypes.c_float), ("presence_penalty", ctypes.c_float), ("mirostat", ctypes.c_int32), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), ("skip_special_token", ctypes.c_bool), ("is_async", ctypes.c_bool), ("img_start", ctypes.c_char_p), ("img_end", ctypes.c_char_p), ("img_content", ctypes.c_char_p), ("extend_param", RKLLMExtendParam), ] class RKLLMLoraAdapter(ctypes.Structure): _fields_ = [ ("lora_adapter_path", ctypes.c_char_p), ("lora_adapter_name", ctypes.c_char_p), ("scale", ctypes.c_float), ] class RKLLMEmbedInput(ctypes.Structure): _fields_ = [ ("embed", ctypes.POINTER(ctypes.c_float)), ("n_tokens", ctypes.c_size_t), ] class RKLLMTokenInput(ctypes.Structure): _fields_ = [ ("input_ids", ctypes.POINTER(ctypes.c_int32)), ("n_tokens", ctypes.c_size_t), ] class RKLLMMultiModelInput(ctypes.Structure): _fields_ = [ ("prompt", ctypes.c_char_p), ("image_embed", ctypes.POINTER(ctypes.c_float)), ("n_image_tokens", ctypes.c_size_t), ] class RKLLMInputUnion(ctypes.Union): _fields_ = [ ("prompt_input", ctypes.c_char_p), ("embed_input", RKLLMEmbedInput), ("token_input", RKLLMTokenInput), ("multimodal_input", RKLLMMultiModelInput), ] class RKLLMInput(ctypes.Structure): _fields_ = [("input_mode", ctypes.c_int), ("input_data", RKLLMInputUnion)] class RKLLMLoraParam(ctypes.Structure): _fields_ = [("lora_adapter_name", ctypes.c_char_p)] class RKLLMPromptCacheParam(ctypes.Structure): _fields_ = [ ("save_prompt_cache", ctypes.c_int), ("prompt_cache_path", ctypes.c_char_p), ] class RKLLMInferParam(ctypes.Structure): _fields_ = [ ("mode", RKLLMInferMode), ("lora_params", ctypes.POINTER(RKLLMLoraParam)), ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), ] class RKLLMResultLastHiddenLayer(ctypes.Structure): _fields_ = [ ("hidden_states", ctypes.POINTER(ctypes.c_float)), ("embd_size", ctypes.c_int), ("num_tokens", ctypes.c_int), ] class RKLLMResult(ctypes.Structure): _fields_ = [ ("text", ctypes.c_char_p), ("size", ctypes.c_int), ("last_hidden_layer", RKLLMResultLastHiddenLayer), ] callback_type = ctypes.CFUNCTYPE( None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int ) # Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library class RKLLM(object): def __init__( self, model_path, callback, lora_model_path=None, prompt_cache_path=None, ): rkllm_param = RKLLMParam() rkllm_param.model_path = bytes(model_path, "utf-8") rkllm_param.max_context_len = 10000 rkllm_param.max_new_tokens = -1 rkllm_param.skip_special_token = True rkllm_param.top_k = 20 rkllm_param.top_p = 0.8 rkllm_param.temperature = 0.7 rkllm_param.repeat_penalty = 1.1 rkllm_param.frequency_penalty = 0.0 rkllm_param.presence_penalty = 0.0 rkllm_param.mirostat = 0 rkllm_param.mirostat_tau = 5.0 rkllm_param.mirostat_eta = 0.1 rkllm_param.is_async = False rkllm_param.img_start = "".encode("utf-8") rkllm_param.img_end = "".encode("utf-8") rkllm_param.img_content = "".encode("utf-8") rkllm_param.extend_param.base_domain_id = 0 self.handle = RKLLM_Handle_t() self.rkllm_init = rkllm_lib.rkllm_init self.rkllm_init.argtypes = [ ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type, ] self.rkllm_init.restype = ctypes.c_int self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback) self.rkllm_run = rkllm_lib.rkllm_run self.rkllm_run.argtypes = [ RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p, ] self.rkllm_run.restype = ctypes.c_int self.rkllm_run_async = rkllm_lib.rkllm_run_async self.rkllm_run_async.argtypes = [ RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p, ] self.rkllm_run_async.restype = ctypes.c_int self.rkllm_abort = rkllm_lib.rkllm_abort self.rkllm_abort.argtypes = [RKLLM_Handle_t] self.rkllm_abort.restype = ctypes.c_int self.rkllm_destroy = rkllm_lib.rkllm_destroy self.rkllm_destroy.argtypes = [RKLLM_Handle_t] self.rkllm_destroy.restype = ctypes.c_int self.lora_adapter_path = None self.lora_model_name = None if lora_model_path: self.lora_adapter_path = lora_model_path self.lora_adapter_name = "test" lora_adapter = RKLLMLoraAdapter() ctypes.memset( ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter) ) lora_adapter.lora_adapter_path = ctypes.c_char_p( (self.lora_adapter_path).encode("utf-8") ) lora_adapter.lora_adapter_name = ctypes.c_char_p( (self.lora_adapter_name).encode("utf-8") ) lora_adapter.scale = 1.0 rkllm_load_lora = rkllm_lib.rkllm_load_lora rkllm_load_lora.argtypes = [ RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter), ] rkllm_load_lora.restype = ctypes.c_int rkllm_load_lora(self.handle, ctypes.byref(lora_adapter)) self.prompt_cache_path = None if prompt_cache_path: self.prompt_cache_path = prompt_cache_path rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p] rkllm_load_prompt_cache.restype = ctypes.c_int rkllm_load_prompt_cache( self.handle, ctypes.c_char_p((prompt_cache_path).encode("utf-8")) ) def run(self, prompt): rkllm_lora_params = None if self.lora_model_name: rkllm_lora_params = RKLLMLoraParam() rkllm_lora_params.lora_adapter_name = ctypes.c_char_p( (self.lora_model_name).encode("utf-8") ) rkllm_infer_params = RKLLMInferParam() ctypes.memset( ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam) ) rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE rkllm_infer_params.lora_params = ( ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None ) rkllm_input = RKLLMInput() rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT rkllm_input.input_data.prompt_input = ctypes.c_char_p((prompt).encode("utf-8")) self.rkllm_run( self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None, ) return def abort(self): return self.rkllm_abort(self.handle) def release(self): self.rkllm_destroy(self.handle)