From b24a409e22dc49daa7f7cb422492281403dfb239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Wed, 6 Nov 2024 15:23:59 +0800 Subject: [PATCH] add returned string (const char*) for qwen2 audio --- examples/qwen2-audio/qwen2.cpp | 30 +++++++++++++++++++++--------- examples/qwen2-audio/qwen2.h | 4 ++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/examples/qwen2-audio/qwen2.cpp b/examples/qwen2-audio/qwen2.cpp index be7d74d6d..71b2a4a12 100644 --- a/examples/qwen2-audio/qwen2.cpp +++ b/examples/qwen2-audio/qwen2.cpp @@ -22,6 +22,7 @@ // // Constants // +void* internal_chars = nullptr; static const char *AUDIO_TOKEN = "<|AUDIO|>"; @@ -565,16 +566,16 @@ bool omni_params_parse(int argc, char **argv, omni_params ¶ms) static omni_params get_omni_params_from_context_params(omni_context_params ¶ms) { omni_params all_params; - + // Initialize gpt params all_params.gpt.n_gpu_layers = params.n_gpu_layers; all_params.gpt.model = params.model; all_params.gpt.prompt = params.prompt; - + // Initialize whisper params all_params.whisper.model = params.mmproj; all_params.whisper.fname_inp = {params.file}; - + if (all_params.gpt.n_threads <= 0) { all_params.gpt.n_threads = std::thread::hardware_concurrency(); @@ -703,6 +704,11 @@ struct omni_context *omni_init_context(omni_context_params ¶ms) void omni_free(struct omni_context *ctx_omni) { + + if(internal_chars != nullptr) + { + free(internal_chars); + } if (ctx_omni->ctx_whisper) { whisper_free(ctx_omni->ctx_whisper); @@ -792,7 +798,7 @@ ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params ¶ return embed_proj; } -void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt) +const char* omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt) { int n_past = 0; @@ -841,12 +847,11 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed for (int i = 0; i < max_tgt_len; i++) { const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past); - response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior - printf("%s", tmp); + // printf("%s", tmp); if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) if (strstr(response.c_str(), "<|im_start|>")) @@ -855,16 +860,23 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed break; // mistral llava-1.6 fflush(stdout); + response += tmp; } llama_sampling_free(ctx_sampling); printf("\n"); + + if(internal_chars != nullptr) { free(internal_chars); } + internal_chars = malloc(sizeof(char)*(response.size()+1)); + strncpy((char*)(internal_chars), response.c_str(), response.size()); + ((char*)(internal_chars))[response.size()] = '\0'; + return (const char*)(internal_chars); } -void omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms) +const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms) { omni_params all_params = get_omni_params_from_context_params(params); ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params); - omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt); -} \ No newline at end of file + return omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt); +} diff --git a/examples/qwen2-audio/qwen2.h b/examples/qwen2-audio/qwen2.h index 5cbbd52ed..dcadb4288 100644 --- a/examples/qwen2-audio/qwen2.h +++ b/examples/qwen2-audio/qwen2.h @@ -54,11 +54,11 @@ OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params ¶m OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni); -OMNI_AUDIO_API void omni_process_full( +OMNI_AUDIO_API const char* omni_process_full( struct omni_context *ctx_omni, omni_context_params ¶ms ); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif