add returned string (const char*) for qwen2 audio

This commit is contained in:
李为 2024-11-06 15:23:59 +08:00
parent 5574bda471
commit b24a409e22
2 changed files with 23 additions and 11 deletions

View file

@ -22,6 +22,7 @@
//
// Constants
//
void* internal_chars = nullptr;
static const char *AUDIO_TOKEN = "<|AUDIO|>";
@ -565,16 +566,16 @@ bool omni_params_parse(int argc, char **argv, omni_params &params)
static omni_params get_omni_params_from_context_params(omni_context_params &params)
{
omni_params all_params;
// Initialize gpt params
all_params.gpt.n_gpu_layers = params.n_gpu_layers;
all_params.gpt.model = params.model;
all_params.gpt.prompt = params.prompt;
// Initialize whisper params
all_params.whisper.model = params.mmproj;
all_params.whisper.fname_inp = {params.file};
if (all_params.gpt.n_threads <= 0)
{
all_params.gpt.n_threads = std::thread::hardware_concurrency();
@ -703,6 +704,11 @@ struct omni_context *omni_init_context(omni_context_params &params)
void omni_free(struct omni_context *ctx_omni)
{
if(internal_chars != nullptr)
{
free(internal_chars);
}
if (ctx_omni->ctx_whisper)
{
whisper_free(ctx_omni->ctx_whisper);
@ -792,7 +798,7 @@ ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params &para
return embed_proj;
}
void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params &params, const std::string &prompt)
const char* omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params &params, const std::string &prompt)
{
int n_past = 0;
@ -841,12 +847,11 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
for (int i = 0; i < max_tgt_len; i++)
{
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0)
break;
if (strstr(tmp, "###"))
break; // Yi-VL behavior
printf("%s", tmp);
// printf("%s", tmp);
if (strstr(response.c_str(), "<|im_end|>"))
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
if (strstr(response.c_str(), "<|im_start|>"))
@ -855,16 +860,23 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
break; // mistral llava-1.6
fflush(stdout);
response += tmp;
}
llama_sampling_free(ctx_sampling);
printf("\n");
if(internal_chars != nullptr) { free(internal_chars); }
internal_chars = malloc(sizeof(char)*(response.size()+1));
strncpy((char*)(internal_chars), response.c_str(), response.size());
((char*)(internal_chars))[response.size()] = '\0';
return (const char*)(internal_chars);
}
void omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
{
omni_params all_params = get_omni_params_from_context_params(params);
ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
}
return omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
}

View file

@ -54,11 +54,11 @@ OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params &param
OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
OMNI_AUDIO_API void omni_process_full(
OMNI_AUDIO_API const char* omni_process_full(
struct omni_context *ctx_omni,
omni_context_params &params
);
#ifdef __cplusplus
}
#endif
#endif