add returned string type (const char*) for nexa-omni-audio
This commit is contained in:
parent
6a4cf0b983
commit
5edadffd88
2 changed files with 20 additions and 9 deletions
|
@ -23,6 +23,8 @@
|
||||||
// Constants
|
// Constants
|
||||||
//
|
//
|
||||||
|
|
||||||
|
void* internal_chars = nullptr;
|
||||||
|
|
||||||
static const char *AUDIO_TOKEN = "<|AUDIO|>";
|
static const char *AUDIO_TOKEN = "<|AUDIO|>";
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -570,7 +572,7 @@ static omni_params get_omni_params_from_context_params(omni_context_params ¶
|
||||||
all_params.gpt.n_gpu_layers = params.n_gpu_layers;
|
all_params.gpt.n_gpu_layers = params.n_gpu_layers;
|
||||||
all_params.gpt.model = params.model;
|
all_params.gpt.model = params.model;
|
||||||
all_params.gpt.prompt = params.prompt;
|
all_params.gpt.prompt = params.prompt;
|
||||||
|
|
||||||
// Initialize whisper params
|
// Initialize whisper params
|
||||||
all_params.whisper.model = params.mmproj;
|
all_params.whisper.model = params.mmproj;
|
||||||
all_params.whisper.fname_inp = {params.file};
|
all_params.whisper.fname_inp = {params.file};
|
||||||
|
@ -703,6 +705,10 @@ struct omni_context *omni_init_context(omni_context_params ¶ms)
|
||||||
|
|
||||||
void omni_free(struct omni_context *ctx_omni)
|
void omni_free(struct omni_context *ctx_omni)
|
||||||
{
|
{
|
||||||
|
if(internal_chars != nullptr)
|
||||||
|
{
|
||||||
|
free(internal_chars);
|
||||||
|
}
|
||||||
if (ctx_omni->ctx_whisper)
|
if (ctx_omni->ctx_whisper)
|
||||||
{
|
{
|
||||||
whisper_free(ctx_omni->ctx_whisper);
|
whisper_free(ctx_omni->ctx_whisper);
|
||||||
|
@ -792,7 +798,7 @@ ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params ¶
|
||||||
return embed_proj;
|
return embed_proj;
|
||||||
}
|
}
|
||||||
|
|
||||||
void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt)
|
const char* omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt)
|
||||||
{
|
{
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
|
@ -833,12 +839,11 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
|
||||||
for (int i = 0; i < max_tgt_len; i++)
|
for (int i = 0; i < max_tgt_len; i++)
|
||||||
{
|
{
|
||||||
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
|
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
|
||||||
response += tmp;
|
|
||||||
if (strcmp(tmp, "</s>") == 0)
|
if (strcmp(tmp, "</s>") == 0)
|
||||||
break;
|
break;
|
||||||
if (strstr(tmp, "###"))
|
if (strstr(tmp, "###"))
|
||||||
break; // Yi-VL behavior
|
break; // Yi-VL behavior
|
||||||
printf("%s", tmp);
|
// printf("%s", tmp);
|
||||||
if (strstr(response.c_str(), "<|im_end|>"))
|
if (strstr(response.c_str(), "<|im_end|>"))
|
||||||
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||||
if (strstr(response.c_str(), "<|im_start|>"))
|
if (strstr(response.c_str(), "<|im_start|>"))
|
||||||
|
@ -847,16 +852,22 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
|
||||||
break; // mistral llava-1.6
|
break; // mistral llava-1.6
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
response += tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(ctx_sampling);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
if(internal_chars != nullptr) { free(internal_chars); }
|
||||||
|
internal_chars = malloc(sizeof(char)*(response.size()+1));
|
||||||
|
strncpy((char*)(internal_chars), response.c_str(), response.size());
|
||||||
|
((char*)(internal_chars))[response.size()] = '\0';
|
||||||
|
return (const char*)(internal_chars);
|
||||||
}
|
}
|
||||||
|
|
||||||
void omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms)
|
const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms)
|
||||||
{
|
{
|
||||||
omni_params all_params = get_omni_params_from_context_params(params);
|
omni_params all_params = get_omni_params_from_context_params(params);
|
||||||
|
|
||||||
ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
|
ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
|
||||||
omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
|
return omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,11 +54,11 @@ OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params ¶m
|
||||||
|
|
||||||
OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
|
OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
|
||||||
|
|
||||||
OMNI_AUDIO_API void omni_process_full(
|
OMNI_AUDIO_API const char* omni_process_full(
|
||||||
struct omni_context *ctx_omni,
|
struct omni_context *ctx_omni,
|
||||||
omni_context_params ¶ms
|
omni_context_params ¶ms
|
||||||
);
|
);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue