remove reference interface from extern C in qwen2audio examples

This commit is contained in:
李为 2024-11-21 20:10:27 +08:00
parent fe792d62b1
commit fd2c58286a
3 changed files with 44 additions and 37 deletions

View file

@ -1,9 +1,13 @@
#include "qwen2.h"
#include <iostream>
using std::cout;
using std::endl;
int main(int argc, char **argv)
{
omni_context_params ctx_params = omni_context_default_params();
omni_context_params * ctx_params = omni_context_default_params();
if (!omni_context_params_parse(argc, argv, ctx_params))
{
return 1;
@ -11,7 +15,9 @@ int main(int argc, char **argv)
omni_context *ctx_omni = omni_init_context(ctx_params);
omni_process_full(ctx_omni, ctx_params);
auto* ret_str = omni_process_full(ctx_omni, ctx_params);
cout << "RET: " << ret_str << endl;
omni_free(ctx_omni);

View file

@ -27,10 +27,26 @@ void* internal_chars = nullptr;
static const char *AUDIO_TOKEN = "<|AUDIO|>";
struct omni_context_params
{
const char *model;
const char *mmproj;
const char *file;
const char *prompt;
int32_t n_gpu_layers;
};
struct omni_context
{
struct whisper_context *ctx_whisper;
struct audio_projector *projector;
struct llama_context *ctx_llama;
struct llama_model *model;
};
//
// Whisper
//
struct whisper_params
{
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
@ -476,8 +492,9 @@ static void omni_print_usage(int, char **argv)
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
bool omni_context_params_parse(int argc, char **argv, omni_context_params &params)
bool omni_context_params_parse(int argc, char **argv, omni_context_params * in_params)
{
auto& params = *in_params;
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];
@ -523,15 +540,15 @@ bool omni_context_params_parse(int argc, char **argv, omni_context_params &param
return true;
}
omni_context_params omni_context_default_params()
omni_context_params * omni_context_default_params()
{
omni_context_params params;
static omni_context_params params;
params.model = "";
params.mmproj = "";
params.file = "";
params.prompt = "this conversation talks about";
params.n_gpu_layers = -1;
return params;
return &params;
}
struct omni_params
@ -540,8 +557,9 @@ struct omni_params
whisper_params whisper;
};
bool omni_params_parse(int argc, char **argv, omni_params &params)
bool omni_params_parse(int argc, char **argv, omni_params * in_params)
{
auto& params = *in_params;
if (!gpt_params_parse(argc, argv, params.gpt))
{
return false;
@ -564,8 +582,9 @@ bool omni_params_parse(int argc, char **argv, omni_params &params)
return true;
}
static omni_params get_omni_params_from_context_params(omni_context_params &params)
static omni_params get_omni_params_from_context_params(omni_context_params * in_params)
{
auto& params = *in_params;
omni_params all_params;
// Initialize gpt params
@ -639,10 +658,9 @@ static size_t find_audio_token(const std::string &prompt)
return prompt.find(AUDIO_TOKEN);
}
struct omni_context *omni_init_context(omni_context_params &params)
struct omni_context *omni_init_context(omni_context_params * in_params)
{
omni_params all_params = get_omni_params_from_context_params(params);
omni_params all_params = get_omni_params_from_context_params(in_params);
// llama
LLAMA_LOG_INFO("------- llama --------\n");
@ -877,7 +895,7 @@ const char* omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audi
return (const char*)(internal_chars);
}
const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params * params)
{
omni_params all_params = get_omni_params_from_context_params(params);

View file

@ -29,34 +29,17 @@
extern "C" {
#endif
struct omni_context_params
{
const char *model;
const char *mmproj;
const char *file;
const char *prompt;
int32_t n_gpu_layers;
};
OMNI_AUDIO_API bool omni_context_params_parse(int argc, char **argv, struct omni_context_params * params);
struct omni_context
{
struct whisper_context *ctx_whisper;
struct audio_projector *projector;
struct llama_context *ctx_llama;
struct llama_model *model;
};
OMNI_AUDIO_API struct omni_context_params * omni_context_default_params();
OMNI_AUDIO_API bool omni_context_params_parse(int argc, char **argv, omni_context_params &params);
OMNI_AUDIO_API omni_context_params omni_context_default_params();
OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params &params);
OMNI_AUDIO_API struct omni_context * omni_init_context(struct omni_context_params * params);
OMNI_AUDIO_API void omni_free(struct omni_context * ctx_omni);
OMNI_AUDIO_API const char* omni_process_full(
struct omni_context * ctx_omni,
omni_context_params &params
struct omni_context_params * params
);
#ifdef __cplusplus