support omni-audio

This commit is contained in:
Zack Zhiyuan Li 2024-11-03 17:58:08 +00:00
parent 4a29bca867
commit c7b912bdca
19 changed files with 20695 additions and 39 deletions

View file

@ -66,6 +66,9 @@ add_library(${TARGET} STATIC
train.cpp
ngram-cache.h
ngram-cache.cpp
common-nexa.h
common-nexa.cpp
dr_wav.h
)
if (BUILD_SHARED_LIBS)

317
common/common-nexa.cpp Normal file
View file

@ -0,0 +1,317 @@
#include "common-nexa.h"
#include <thread>
#include <vector>
#include <string.h>
#include <functional>
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include <algorithm>
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif
#include "common.h"
#include <cmath>
#include <numeric>
void print_ggml_tensor(const char *name, const struct ggml_tensor *tensor, bool use_backend, int precision) {
std::vector<float> data(ggml_nelements(tensor));
if (use_backend) {
ggml_backend_tensor_get(tensor, data.data(), 0, ggml_nbytes(tensor));
} else {
memcpy(data.data(), ggml_get_data_f32(tensor), ggml_nbytes(tensor));
}
std::vector<int64_t> shape;
for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) shape.push_back(tensor->ne[i]);
print_ggml_tensor_shape(name, tensor);
size_t offset = 0;
std::function<void(size_t, size_t &)> print_recursive = [&](size_t dim, size_t &offset) {
if (dim == shape.size()) {
printf("%.*f", precision, data[offset++]);
} else {
printf("[ ");
for (int64_t i = 0; i < shape[dim]; ++i) {
if (i > 0) printf(dim == shape.size() - 1 ? ", " : ",\n%*s", static_cast<int>(dim) + 1, "");
print_recursive(dim + 1, offset);
}
printf("]");
}
};
print_recursive(0, offset);
printf("\n");
}
void print_ggml_tensor_shape(const char *name, const struct ggml_tensor *tensor) {
printf("%s: [ ", name);
for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) printf("%d ", static_cast<int>(tensor->ne[i]));
printf("]\n");
}
bool load_hparams_and_tensors_from_gguf(const std::string &fname, NexaBaseModel &model, bool verbose)
{
// Initialize GGUF context
ggml_context *meta = nullptr;
gguf_init_params params = {true, &meta};
gguf_context *ctx_gguf = gguf_init_from_file(fname.c_str(), params);
// Check if GGUF context initialization was successful
if (!ctx_gguf)
return fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__), false;
// Get the number of tensors in the GGUF file
const int n_tensors = gguf_get_n_tensors(ctx_gguf);
const int n_tensors_in_model = model.tensor_names.size();
if (n_tensors_in_model > n_tensors)
{
fprintf(stderr, "%s: model tensor_names size (%d) is greater than the number of tensors in the GGUF file (%d)\n", __func__, n_tensors_in_model, n_tensors);
gguf_free(ctx_gguf);
return false;
}
// Load hyperparameters
for (const auto &name : model.hparam_names)
{
int key = gguf_find_key(ctx_gguf, name.c_str());
if (key != -1)
{
model.hparams[name] = gguf_get_val_i32(ctx_gguf, key);
if (verbose)
fprintf(stderr, "%s: loaded hparam '%s' = %d\n", __func__, name.c_str(), std::get<int32_t>(model.hparams[name]));
}
else
return fprintf(stderr, "%s: failed to load hparam '%s'\n", __func__, name.c_str()), gguf_free(ctx_gguf), false;
}
// Initialize GGML context for tensor data
model.ctx_data = ggml_init({(n_tensors_in_model + 1) * ggml_tensor_overhead(), nullptr, true});
if (!model.ctx_data)
return fprintf(stderr, "%s: ggml_init() failed\n", __func__), gguf_free(ctx_gguf), false;
// Open the GGUF file for reading tensor data
std::ifstream fin(fname, std::ios::binary);
if (!fin)
return fprintf(stderr, "%s: cannot open model file for loading tensors\n", __func__), gguf_free(ctx_gguf), false;
// Create tensor structures in the GGML context
for (const auto &name : model.tensor_names)
{
ggml_tensor *t = ggml_dup_tensor(model.ctx_data, ggml_get_tensor(meta, name.c_str()));
ggml_set_name(t, name.c_str());
}
// Allocate memory for tensors using the specified backend
model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx_data, model.backend);
// Load tensors from the GGUF file
for (const auto &name : model.tensor_names)
{
char *tensor_name = const_cast<char *>(name.c_str());
if (verbose)
fprintf(stderr, "%s: loading tensor '%s'\n", __func__, tensor_name);
ggml_tensor *cur = ggml_get_tensor(model.ctx_data, tensor_name);
if (!cur)
return fprintf(stderr, "%s: failed to get tensor %s\n", __func__, tensor_name), gguf_free(ctx_gguf), false;
int tensor_idx = gguf_find_tensor(ctx_gguf, tensor_name);
fin.seekg(gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, tensor_idx), std::ios::beg);
int num_bytes = ggml_nbytes(cur);
if (ggml_backend_buffer_is_host(model.buffer))
{
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
}
else
{
std::vector<uint8_t> read_buf(num_bytes);
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
// mapping tensor name to tensor pointer
model.tensors[name] = cur;
if (verbose)
{
fprintf(stderr, "%s: mapped tensor ", __func__);
print_ggml_tensor_shape(tensor_name, cur);
}
}
ggml_free(meta);
return true;
}
//
// NexaBaseModel
//
// initialize from gguf file
bool NexaBaseModel::load_from_gguf(const std::string &fname)
{
init_backend();
bool verbose = false;
#ifdef NEXA_DEBUG
verbose = true;
#endif
if (!load_hparams_and_tensors_from_gguf(fname, *this, verbose))
{
NEXA_LOG("failed to load params and tensors");
return false;
}
reserve_memory();
return true;
}
// Initialize the backend based on available hardware
void NexaBaseModel::init_backend()
{
#ifdef GGML_USE_CUDA
NEXA_LOG("using CUDA backend");
backend = ggml_backend_cuda_init(0); // Initialize CUDA on device 0
#endif
#ifdef GGML_USE_METAL
NEXA_LOG("using Metal backend");
backend = ggml_backend_metal_init(); // Initialize Metal backend
#endif
// Fallback to CPU backend if no GPU is available
if (!backend)
{
backend = ggml_backend_cpu_init();
fprintf(stderr, "%s: using CPU backend\n", __func__);
}
}
// measure mem requirement and allocate
void NexaBaseModel::reserve_memory()
{
compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
struct ggml_cgraph *gf = build_graph();
ggml_gallocr_reserve(compute_alloc, gf);
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(compute_alloc, 0);
NEXA_LOG("compute allocated memory: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0);
}
// set the number of threads
void NexaBaseModel::set_n_threads(int n_threads)
{
if (n_threads <= 0)
{
// if n_threads is not set, use the number of cores
n_threads = std::thread::hardware_concurrency();
}
// Set backend options
if (ggml_backend_is_cpu(backend))
{
ggml_backend_cpu_set_n_threads(backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(backend))
{
ggml_backend_metal_set_n_cb(backend, n_threads);
}
#endif
}
// Free allocated memory
void NexaBaseModel::free()
{
ggml_gallocr_free(compute_alloc);
ggml_free(ctx_data);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
}
void print_ggml_tensor_stats(const char *name, const struct ggml_tensor *tensor, bool use_backend) {
std::vector<float> data(ggml_nelements(tensor));
if (use_backend) {
ggml_backend_tensor_get(tensor, data.data(), 0, ggml_nbytes(tensor));
} else {
memcpy(data.data(), ggml_get_data_f32(tensor), ggml_nbytes(tensor));
}
if (data.empty()) {
printf("%s: Empty tensor\n", name);
return;
}
// Calculate mean
double sum = std::accumulate(data.begin(), data.end(), 0.0);
double mean = sum / data.size();
// Calculate variance using two-pass algorithm for better numerical stability
double sq_sum = 0.0;
for (const auto &val : data) {
double diff = val - mean;
sq_sum += diff * diff;
}
double variance = sq_sum / data.size();
// Print statistics
printf("%s:\n", name);
printf(" Shape: [");
for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) {
printf("%d%s", static_cast<int>(tensor->ne[i]), (i < GGML_MAX_DIMS - 1 && tensor->ne[i+1] > 1) ? ", " : "");
}
printf("]\n");
printf(" Mean: %.6f\n", mean);
printf(" Variance: %.6f\n", variance);
printf(" Standard Deviation: %.6f\n", std::sqrt(variance));
}
void print_all_tensor_names(struct gguf_context *ctx) {
int n_tensors = gguf_get_n_tensors(ctx);
printf("Number of tensors: %d\n", n_tensors);
const char *separator = "";
printf("Tensors: ");
for (int i = 0; i < n_tensors; ++i) {
const char *tensor_name = gguf_get_tensor_name(ctx, i);
printf("%s%s", separator, tensor_name);
separator = ", "; // Set separator after the first tensor
}
printf("\n");
}
struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name) {
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name);
// print_ggml_tensor_stats(name, tensor, false);
if (!tensor) {
fprintf(stderr, "%s: tensor '%s' not found\n", __func__, name);
throw std::runtime_error("ggml_get_tensor() failed");
}
return tensor;
}
//
// original ggml functions
//
struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {
if (i < 0) {
GGML_ASSERT(cgraph->n_nodes + i >= 0);
return cgraph->nodes[cgraph->n_nodes + i];
}
GGML_ASSERT(i < cgraph->n_nodes);
return cgraph->nodes[i];
}

80
common/common-nexa.h Normal file
View file

@ -0,0 +1,80 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <string>
#include <map>
#include <fstream>
#include <sstream>
#include <vector>
#include <variant>
#include <cmath>
#include <cxxabi.h>
#define NEXA_CLASS_NAME (abi::__cxa_demangle(typeid(*this).name(), nullptr, nullptr, nullptr))
#define NEXA_LOG(fmt, ...) fprintf(stderr, "%s::%s: " fmt "\n", NEXA_CLASS_NAME, __func__, ##__VA_ARGS__)
// Prints the content of a ggml_tensor with specified precision. Can use the backend if available.
void print_ggml_tensor(const char *name, const struct ggml_tensor *tensor, bool use_backend, int precision = 4);
// Prints the shape (dimensions) of a ggml_tensor without printing its contents.
void print_ggml_tensor_shape(const char *name, const struct ggml_tensor *tensor);
// Prints the statistics (mean, min, max, std) of a ggml_tensor. Can use the backend if available.
void print_ggml_tensor_stats(const char *name, const struct ggml_tensor *tensor, bool use_backend);
// Print all tensor names in the provided GGUF context.
void print_all_tensor_names(struct gguf_context *ctx);
// get tensor, print stats and check for null
struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name);
// Base class for all Nexa models
struct NexaBaseModel
{
std::vector<std::string> hparam_names;
std::map<std::string, std::variant<int32_t, float_t>> hparams; // hyperparameters, dict value can be either int32_t or float_t
std::vector<std::string> tensor_names;
std::map<std::string, struct ggml_tensor *> tensors; // std::variant is a type-safe union that can hold either: (1) int32_t (32-bit integer) (2) float_t (floating-point number)
struct ggml_context *ctx_data; // GGML context for tensor management
ggml_backend_buffer_t buffer; // Backend buffer to store tensor data
ggml_backend_t backend = NULL; // Backend for computation (CPU, CUDA, METAL)
ggml_gallocr_t compute_alloc = NULL; // Memory allocator for computation
// constructor & destructor
NexaBaseModel() {}
~NexaBaseModel()
{
free();
NEXA_LOG("allocated resources freed");
}
// Initialize the backend based on available hardware
void init_backend();
// measure mem requirement and allocate
void reserve_memory();
// initialize from gguf file
bool load_from_gguf(const std::string &fname);
// build the computation graph
// this is a pure virtual function that must be implemented by the derived class
virtual ggml_cgraph *build_graph() = 0;
// set the number of threads
void set_n_threads(int n_threads);
// Free allocated memory
void free();
};
bool load_hparams_and_tensors_from_gguf(const std::string &fname, NexaBaseModel &model, bool verbose = false);
struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i);

6434
common/dr_wav.h Normal file

File diff suppressed because it is too large Load diff

View file

@ -12,44 +12,45 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN)
else()
add_subdirectory(cvector-generator)
add_subdirectory(baby-llama)
add_subdirectory(batched-bench)
add_subdirectory(batched)
add_subdirectory(benchmark)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(eval-callback)
add_subdirectory(export-lora)
add_subdirectory(gbnf-validator)
add_subdirectory(gguf-hash)
add_subdirectory(gguf-split)
add_subdirectory(gguf)
add_subdirectory(gritlm)
add_subdirectory(imatrix)
add_subdirectory(infill)
add_subdirectory(llama-bench)
# add_subdirectory(cvector-generator)
# add_subdirectory(baby-llama)
# add_subdirectory(batched-bench)
# add_subdirectory(batched)
# add_subdirectory(benchmark)
# add_subdirectory(convert-llama2c-to-ggml)
# add_subdirectory(embedding)
# add_subdirectory(eval-callback)
# add_subdirectory(export-lora)
# add_subdirectory(gbnf-validator)
# add_subdirectory(gguf-hash)
# add_subdirectory(gguf-split)
# add_subdirectory(gguf)
# add_subdirectory(gritlm)
# add_subdirectory(imatrix)
# add_subdirectory(infill)
# add_subdirectory(llama-bench)
add_subdirectory(llava)
add_subdirectory(lookahead)
add_subdirectory(lookup)
add_subdirectory(main)
add_subdirectory(parallel)
add_subdirectory(passkey)
add_subdirectory(perplexity)
add_subdirectory(quantize-stats)
add_subdirectory(quantize)
add_subdirectory(retrieval)
if (GGML_RPC)
add_subdirectory(rpc)
endif()
if (LLAMA_BUILD_SERVER)
add_subdirectory(server)
endif()
if (GGML_SYCL)
add_subdirectory(sycl)
endif()
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(speculative)
add_subdirectory(tokenize)
# add_subdirectory(lookahead)
# add_subdirectory(lookup)
# add_subdirectory(main)
# add_subdirectory(parallel)
# add_subdirectory(passkey)
# add_subdirectory(perplexity)
# add_subdirectory(quantize-stats)
# add_subdirectory(quantize)
# add_subdirectory(retrieval)
# if (GGML_RPC)
# add_subdirectory(rpc)
# endif()
# if (LLAMA_BUILD_SERVER)
# add_subdirectory(server)
# endif()
# if (GGML_SYCL)
# add_subdirectory(sycl)
# endif()
# add_subdirectory(save-load-state)
# add_subdirectory(simple)
# add_subdirectory(speculative)
# add_subdirectory(tokenize)
add_subdirectory(nexa-omni-audio)
endif()

View file

@ -0,0 +1,56 @@
# whisper
# Find the Threads package
find_package(Threads REQUIRED)
# build nexa-whisper-utils
set(WHISPER_LIB nexa-whisper-utils)
add_library(${WHISPER_LIB} OBJECT
whisper.cpp
)
target_link_libraries(${WHISPER_LIB} PRIVATE ggml_llama common Threads::Threads)
# build the whisper encoder
# add_executable(whisper-encode main-encode.cpp)
# target_link_libraries(whisper-encode PRIVATE ggml_llama common Threads::Threads ${WHISPER_LIB})
# build the audio projector
# add_executable(audio-projector-cli audio-projector-cli.cpp audio-projector.cpp)
# target_link_libraries(audio-projector-cli PRIVATE ggml_llama common)
# add nexa-omni-audio-lib library
set(OMNI_AUDIO_LIB nexa-omni-audio-lib)
add_library(${OMNI_AUDIO_LIB} OBJECT
omni.cpp
omni.h
audio-projector.cpp
audio-projector.h
)
target_link_libraries(${OMNI_AUDIO_LIB} PRIVATE ggml_llama common ${WHISPER_LIB})
# build the nexa-omni-cli
add_executable(nexa-omni-cli omni-cli.cpp)
target_link_libraries(nexa-omni-cli PRIVATE ggml_llama common Threads::Threads ${WHISPER_LIB} ${OMNI_AUDIO_LIB})
# If BUILD_SHARED_LIBS is ON, also build a shared library
if(BUILD_SHARED_LIBS)
message(STATUS "Building shared libraries")
set_target_properties(${WHISPER_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(${OMNI_AUDIO_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(${OMNI_AUDIO_LIB}_shared SHARED $<TARGET_OBJECTS:${OMNI_AUDIO_LIB}>)
target_link_libraries(${OMNI_AUDIO_LIB}_shared PRIVATE ggml_llama common ${WHISPER_LIB})
set_target_properties(${OMNI_AUDIO_LIB}_shared PROPERTIES
PUBLIC_HEADER omni.h
POSITION_INDEPENDENT_CODE ON
)
# Add OMNI_AUDIO_SHARED definition when building the shared library
target_compile_definitions(${OMNI_AUDIO_LIB}_shared PRIVATE OMNI_AUDIO_SHARED WHISPER_SHARED)
# Ensure all symbols are exported on Windows
if(MSVC)
set_target_properties(${OMNI_AUDIO_LIB}_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
endif()

View file

@ -0,0 +1,63 @@
# nexa-audio-omni
## Build this example
Build the whole project from repo root dir
```shell
cmake -B build-cpu
cmake --build build-cpu --config Release -j 24
```
```shell
cmake -B build-cuda -DGGML_CUDA=ON
cmake --build build-cuda --config Release -j 24
```
## Prepare GGUF files
### mmproj (mel-filter + whisper-encoder + audio-projector)
1. Extract whisper-medium encoder (`audio_tower`) from `nexa-collaboration/nano-omini-instruct`
```shell
python convert-to-gguf-f16.py
```
2. Extract mel filters from `ggml-medium.bin` and add thems to previously extracted `nano-omni-audio-encoder-f16.gguf`
```shell
python add-mel-filters.py
```
> Run `bash download-ggml-model.sh` to donwload `ggml-medium.bin` and move it into the `models` folder
> Don't forget to modify the input and output file paths in the Python script above before running it.
### gemma2
```shell
python prepare-gemma2-hf.py
python ../../convert_hf_to_gguf.py ./models/nano-omini-instruct-gemma2 --outfile ./models/nano-omini-instruct.gemma2.gguf [--outtype bf16]
```
## Run nexa-audio-omni
From the root directory of the repo, run commands below:
```shell
./build-cpu/bin/nexa-omni-cli \
--model examples/nexa-omni-audio/models/nano-omni-instruct.gemma2.gguf \
--mmproj examples/nexa-omni-audio/models/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \
--file examples/nexa-omni-audio/samples/jfk.wav \
--prompt "this conversation talks about"
```
```shell
./build/bin/nexa-omni-cli \
--model /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/gemma2-2b.gguf \
--mmproj /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \
--file /home/azureuser/zack/ggml-project-apollo/examples/whisper/samples/jfk.wav \
--prompt "this conversation talks about" \
--n-gpu-layers 27 # offload all 27 layers of gemma2 model to GPU
```

View file

@ -0,0 +1,37 @@
#include "audio-projector.h"
#include "common-nexa.h"
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include <vector>
struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector<float> &audio_feature_data)
{
// Build the computation graph for inference
struct ggml_cgraph *gf = model.build_graph();
// Allocate the graph tensors
ggml_gallocr_alloc_graph(model.compute_alloc, gf);
// Set the input data
struct ggml_tensor *input = ggml_graph_get_tensor(gf, "input");
ggml_backend_tensor_set(input, audio_feature_data.data(), 0, audio_feature_data.size() * sizeof(float));
model.set_n_threads(0);
// Execute the graph on the backend
ggml_backend_graph_compute(model.backend, gf);
// Return the output tensor (last node in the graph)
return ggml_graph_get_tensor(gf, "output");
}
struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor)
{
// Set the input data
std::vector<float> data(ggml_nelements(audio_feature_tensor));
ggml_backend_tensor_get(audio_feature_tensor, data.data(), 0, ggml_nbytes(audio_feature_tensor));
return audio_projector_inference(model, data);
}

View file

@ -0,0 +1,67 @@
#pragma once
#include "ggml.h"
#include "common-nexa.h"
#include <vector>
//
// Audio Projector
//
struct audio_projector : public NexaBaseModel
{
audio_projector() : NexaBaseModel()
{
this->hparam_names = {
"max_source_positions",
"d_model",
};
this->tensor_names = {
"multi_modal_projector.linear.weight",
"multi_modal_projector.linear.bias",
};
}
struct ggml_cgraph *build_graph() override
{
const int MAX_NODES = 64;
size_t buf_size = ggml_tensor_overhead() * MAX_NODES + ggml_graph_overhead_custom(MAX_NODES, false);
static std::vector<uint8_t> buf(buf_size);
// Create temporary GGML context for building the graph
struct ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/buf.data(),
/*.no_alloc =*/true, // Memory will be allocated later
};
struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph *gf = ggml_new_graph_custom(ctx0, MAX_NODES, false); // Create new graph
// Create input tensor
struct ggml_tensor *input = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,
std::get<int32_t>(hparams["d_model"]),
std::get<int32_t>(hparams["max_source_positions"]) / 2);
ggml_set_name(input, "input");
ggml_set_input(input); // Mark tensor as input
// weight * input + bias
struct ggml_tensor *cur = ggml_mul_mat(ctx0, tensors["multi_modal_projector.linear.weight"], input);
cur = ggml_add(ctx0, cur, tensors["multi_modal_projector.linear.bias"]);
// Set the final output
ggml_set_name(cur, "output");
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur); // Expand graph with operations
ggml_free(ctx0); // Free temporary context
return gf;
}
};
struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector<float> &audio_feature_data);
struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor);

View file

@ -0,0 +1,614 @@
#pragma once
// GGML CPU internal header
#include "ggml.h"
#include "ggml-impl.h"
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
//#include <stddef.h>
#include <stdbool.h>
#include <string.h> // memcpy
#include <math.h> // fabsf
#ifdef __cplusplus
extern "C" {
#endif
#if defined(_MSC_VER)
#define m512bh(p) p
#define m512i(p) p
#else
#define m512bh(p) (__m512bh)(p)
#define m512i(p) (__m512i)(p)
#endif
/**
* Converts brain16 to float32.
*
* The bfloat16 floating point format has the following structure:
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b0000000000000000 brain16
*
* Since bf16 has the same number of exponent bits as a 32bit float,
* encoding and decoding numbers becomes relatively straightforward.
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b00000000000000000000000000000000 IEEE binary32
*
* For comparison, the standard fp16 format has fewer exponent bits.
*
* sign
*
* exponent
*
* mantissa
*
*
* 0b0000000000000000 IEEE binary16
*
* @see IEEE 754-2008
*/
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
/**
* Converts float32 to brain16.
*
* This is binary identical with Google Brain float conversion.
* Floats shall round to nearest even, and NANs shall be quiet.
* Subnormals aren't flushed to zero, except perhaps when used.
* This code should vectorize nicely if using modern compilers.
*/
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
ggml_bf16_t h;
union {
float f;
uint32_t i;
} u;
u.f = s;
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
h.bits = (u.i >> 16) | 64; /* force to quiet */
return h;
}
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
return h;
}
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
#ifndef __FMA__
#define __FMA__
#endif
#ifndef __F16C__
#define __F16C__
#endif
#endif
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
#ifndef __SSE3__
#define __SSE3__
#endif
#ifndef __SSSE3__
#define __SSSE3__
#endif
#endif
#if defined(__ARM_FEATURE_SVE)
#include <arm_sve.h>
#include <sys/prctl.h>
#endif
// 16-bit float
// on Arm, we use __fp16
// on x86, we use uint16_t
#if defined(__ARM_NEON)
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
//
#include <arm_neon.h>
#ifdef _MSC_VER
typedef uint16_t ggml_fp16_internal_t;
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
#else
typedef __fp16 ggml_fp16_internal_t;
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
#endif // _MSC_VER
#if !defined(__aarch64__)
// 32-bit ARM compatibility
// vaddlvq_s16
// vpaddq_s16
// vpaddq_s32
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
// vzip1_u8
// vzip2_u8
inline static int32_t vaddlvq_s16(int16x8_t v) {
int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
return vcombine_s32(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
inline static float vaddvq_f32(float32x4_t v) {
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
}
inline static float vmaxvq_f32(float32x4_t v) {
return
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
}
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
int32x4_t res;
res[0] = roundf(vgetq_lane_f32(v, 0));
res[1] = roundf(vgetq_lane_f32(v, 1));
res[2] = roundf(vgetq_lane_f32(v, 2));
res[3] = roundf(vgetq_lane_f32(v, 3));
return res;
}
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
uint8x8_t res;
res[0] = a[0]; res[1] = b[0];
res[2] = a[1]; res[3] = b[1];
res[4] = a[2]; res[5] = b[2];
res[6] = a[3]; res[7] = b[3];
return res;
}
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
uint8x8_t res;
res[0] = a[4]; res[1] = b[4];
res[2] = a[5]; res[3] = b[5];
res[4] = a[6]; res[5] = b[6];
res[6] = a[7]; res[7] = b[7];
return res;
}
// vld1q_s16_x2
// vld1q_u8_x2
// vld1q_u8_x4
// vld1q_s8_x2
// vld1q_s8_x4
// TODO: double-check these work correctly
typedef struct ggml_int16x8x2_t {
int16x8_t val[2];
} ggml_int16x8x2_t;
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
ggml_int16x8x2_t res;
res.val[0] = vld1q_s16(ptr + 0);
res.val[1] = vld1q_s16(ptr + 8);
return res;
}
typedef struct ggml_uint8x16x2_t {
uint8x16_t val[2];
} ggml_uint8x16x2_t;
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
ggml_uint8x16x2_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
return res;
}
typedef struct ggml_uint8x16x4_t {
uint8x16_t val[4];
} ggml_uint8x16x4_t;
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
ggml_uint8x16x4_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
res.val[2] = vld1q_u8(ptr + 32);
res.val[3] = vld1q_u8(ptr + 48);
return res;
}
typedef struct ggml_int8x16x2_t {
int8x16_t val[2];
} ggml_int8x16x2_t;
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
ggml_int8x16x2_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
return res;
}
typedef struct ggml_int8x16x4_t {
int8x16_t val[4];
} ggml_int8x16x4_t;
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
res.val[2] = vld1q_s8(ptr + 32);
res.val[3] = vld1q_s8(ptr + 48);
return res;
}
// NOTE: not tested
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
int8x16_t res;
res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];
return res;
}
// NOTE: not tested
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
uint8x16_t res;
res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];
return res;
}
#else
#define ggml_int16x8x2_t int16x8x2_t
#define ggml_uint8x16x2_t uint8x16x2_t
#define ggml_uint8x16x4_t uint8x16x4_t
#define ggml_int8x16x2_t int8x16x2_t
#define ggml_int8x16x4_t int8x16x4_t
#define ggml_vld1q_s16_x2 vld1q_s16_x2
#define ggml_vld1q_u8_x2 vld1q_u8_x2
#define ggml_vld1q_u8_x4 vld1q_u8_x4
#define ggml_vld1q_s8_x2 vld1q_s8_x2
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#define ggml_vqtbl1q_s8 vqtbl1q_s8
#define ggml_vqtbl1q_u8 vqtbl1q_u8
#endif // !defined(__aarch64__)
#if !defined(__ARM_FEATURE_DOTPROD)
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
}
#else
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
#endif // !defined(__ARM_FEATURE_DOTPROD)
#endif // defined(__ARM_NEON)
#if defined(__ARM_NEON) && !defined(_MSC_VER)
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
ggml_fp16_internal_t tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
ggml_fp16_internal_t tmp = f;
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
return res;
}
#else
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#ifdef __POWER9_VECTOR__
#include <altivec.h>
#undef bool
#define bool _Bool
#else
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
#endif
#if defined(__loongarch64)
#if defined(__loongarch_asx)
#include <lasxintrin.h>
#endif
#if defined(__loongarch_sx)
#include <lsxintrin.h>
#endif
#endif
#if defined(__loongarch_asx)
typedef union {
int32_t i;
float f;
} ft_union;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
}
static __m256 __lasx_xvreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
}
#endif
#ifdef __F16C__
#ifdef _MSC_VER
#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
#else
#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
#endif
#elif defined(__POWER9_VECTOR__)
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
/* the inline asm below is about 12% faster than the lookup method */
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
register float f;
register double d;
__asm__(
"mtfprd %0,%2\n"
"xscvhpdp %0,%0\n"
"frsp %1,%0\n" :
/* temp */ "=d"(d),
/* out */ "=f"(f):
/* in */ "r"(h));
return f;
}
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
register double d;
register ggml_fp16_t r;
__asm__( /* xscvdphp can work on double or single precision */
"xscvdphp %0,%2\n"
"mffprd %1,%0\n" :
/* temp */ "=d"(d),
/* out */ "=r"(r):
/* in */ "f"(f));
return r;
}
#else
// FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16
static inline float fp32_from_bits(uint32_t w) {
union {
uint32_t as_bits;
float as_value;
} fp32;
fp32.as_bits = w;
return fp32.as_value;
}
static inline uint32_t fp32_to_bits(float f) {
union {
float as_value;
uint32_t as_bits;
} fp32;
fp32.as_value = f;
return fp32.as_bits;
}
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
const uint32_t w = (uint32_t) h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t two_w = w + w;
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float exp_scale = 0x1.0p-112f;
#else
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
#endif
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
const uint32_t magic_mask = UINT32_C(126) << 23;
const float magic_bias = 0.5f;
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
return fp32_from_bits(result);
}
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;
#else
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
#endif
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#endif // __F16C__
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h>
#endif // __ARM_FEATURE_SVE
// precomputed f32 table for f16 (256 KB)
// defined in ggml.c, initialized in ggml_init()
extern float ggml_table_f32_f16[1 << 16];
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
#if !defined(GGML_FP16_TO_FP32)
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
uint16_t s;
memcpy(&s, &f, sizeof(uint16_t));
return ggml_table_f32_f16[s];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
#if !defined(GGML_FP32_TO_FP16)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
#endif
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,916 @@
#include "common.h"
#include "common-nexa.h"
#include "whisper.h"
#include "grammar-parser.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable : 4244 4267) // possible loss of data
#endif
// helper function to replace substrings
static void replace_all(std::string &s, const std::string &search, const std::string &replace)
{
for (size_t pos = 0;; pos += replace.length())
{
pos = s.find(search, pos);
if (pos == std::string::npos)
break;
s.erase(pos, search.length());
s.insert(pos, replace);
}
}
// command-line parameters
struct whisper_params
{
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t audio_ctx = 0;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float grammar_penalty = 100.0f;
float temperature = 0.0f;
float temperature_inc = 0.2f;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool no_prints = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
bool flash_attn = false;
std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string grammar;
std::string grammar_rule;
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
// A regular expression that matches tokens to suppress
std::string suppress_regex;
std::string openvino_encode_device = "CPU";
std::string dtw = "";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
grammar_parser::parse_state grammar_parsed;
};
static void whisper_print_usage(int argc, char **argv, const whisper_params &params);
static char *whisper_param_turn_lowercase(char *in)
{
int string_len = strlen(in);
for (int i = 0; i < string_len; i++)
{
*(in + i) = tolower((unsigned char)*(in + i));
}
return in;
}
static bool whisper_params_parse(int argc, char **argv, whisper_params &params)
{
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];
if (arg == "-")
{
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-')
{
params.fname_inp.push_back(arg);
continue;
}
if (arg == "-h" || arg == "--help")
{
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads")
{
params.n_threads = std::stoi(argv[++i]);
}
else if (arg == "-p" || arg == "--processors")
{
params.n_processors = std::stoi(argv[++i]);
}
else if (arg == "-ot" || arg == "--offset-t")
{
params.offset_t_ms = std::stoi(argv[++i]);
}
else if (arg == "-on" || arg == "--offset-n")
{
params.offset_n = std::stoi(argv[++i]);
}
else if (arg == "-d" || arg == "--duration")
{
params.duration_ms = std::stoi(argv[++i]);
}
else if (arg == "-mc" || arg == "--max-context")
{
params.max_context = std::stoi(argv[++i]);
}
else if (arg == "-ml" || arg == "--max-len")
{
params.max_len = std::stoi(argv[++i]);
}
else if (arg == "-bo" || arg == "--best-of")
{
params.best_of = std::stoi(argv[++i]);
}
else if (arg == "-bs" || arg == "--beam-size")
{
params.beam_size = std::stoi(argv[++i]);
}
else if (arg == "-ac" || arg == "--audio-ctx")
{
params.audio_ctx = std::stoi(argv[++i]);
}
else if (arg == "-wt" || arg == "--word-thold")
{
params.word_thold = std::stof(argv[++i]);
}
else if (arg == "-et" || arg == "--entropy-thold")
{
params.entropy_thold = std::stof(argv[++i]);
}
else if (arg == "-lpt" || arg == "--logprob-thold")
{
params.logprob_thold = std::stof(argv[++i]);
}
else if (arg == "-tp" || arg == "--temperature")
{
params.temperature = std::stof(argv[++i]);
}
else if (arg == "-tpi" || arg == "--temperature-inc")
{
params.temperature_inc = std::stof(argv[++i]);
}
else if (arg == "-debug" || arg == "--debug-mode")
{
params.debug_mode = true;
}
else if (arg == "-tr" || arg == "--translate")
{
params.translate = true;
}
else if (arg == "-di" || arg == "--diarize")
{
params.diarize = true;
}
else if (arg == "-tdrz" || arg == "--tinydiarize")
{
params.tinydiarize = true;
}
else if (arg == "-sow" || arg == "--split-on-word")
{
params.split_on_word = true;
}
else if (arg == "-nf" || arg == "--no-fallback")
{
params.no_fallback = true;
}
else if (arg == "-fp" || arg == "--font-path")
{
params.font_path = argv[++i];
}
else if (arg == "-np" || arg == "--no-prints")
{
params.no_prints = true;
}
else if (arg == "-ps" || arg == "--print-special")
{
params.print_special = true;
}
else if (arg == "-pc" || arg == "--print-colors")
{
params.print_colors = true;
}
else if (arg == "-pp" || arg == "--print-progress")
{
params.print_progress = true;
}
else if (arg == "-nt" || arg == "--no-timestamps")
{
params.no_timestamps = true;
}
else if (arg == "-l" || arg == "--language")
{
params.language = whisper_param_turn_lowercase(argv[++i]);
}
else if (arg == "-dl" || arg == "--detect-language")
{
params.detect_language = true;
}
else if (arg == "--prompt")
{
params.prompt = argv[++i];
}
else if (arg == "-m" || arg == "--model")
{
params.model = argv[++i];
}
else if (arg == "-f" || arg == "--file")
{
params.fname_inp.emplace_back(argv[++i]);
}
else if (arg == "-oved" || arg == "--ov-e-device")
{
params.openvino_encode_device = argv[++i];
}
else if (arg == "-dtw" || arg == "--dtw")
{
params.dtw = argv[++i];
}
else if (arg == "-ls" || arg == "--log-score")
{
params.log_score = true;
}
else if (arg == "-ng" || arg == "--no-gpu")
{
params.use_gpu = false;
}
else if (arg == "-fa" || arg == "--flash-attn")
{
params.flash_attn = true;
}
else if (arg == "--suppress-regex")
{
params.suppress_regex = argv[++i];
}
else if (arg == "--grammar")
{
params.grammar = argv[++i];
}
else if (arg == "--grammar-rule")
{
params.grammar_rule = argv[++i];
}
else if (arg == "--grammar-penalty")
{
params.grammar_penalty = std::stof(argv[++i]);
}
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params &params)
{
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc);
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
struct whisper_print_user_data
{
const whisper_params *params;
const std::vector<std::vector<float>> *pcmf32s;
int progress_prev;
};
static std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false)
{
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE);
const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++)
{
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1 * energy1)
{
speaker = "0";
}
else if (energy1 > 1.1 * energy0)
{
speaker = "1";
}
else
{
speaker = "?";
}
// printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only)
{
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
static void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data)
{
int progress_step = ((whisper_print_user_data *)user_data)->params->progress_step;
int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev);
if (progress >= *progress_prev + progress_step)
{
*progress_prev += progress_step;
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
}
}
static void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data)
{
const auto &params = *((whisper_print_user_data *)user_data)->params;
const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0 = 0;
int64_t t1 = 0;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0)
{
printf("\n");
}
for (int i = s0; i < n_segments; i++)
{
if (!params.no_timestamps || params.diarize)
{
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps)
{
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
if (params.print_colors)
{
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j)
{
if (params.print_special == false)
{
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx))
{
continue;
}
}
const char *text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p(ctx, i, j);
const int col = std::max(0, std::min((int)k_colors.size() - 1, (int)(std::pow(p, 3) * float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
}
else
{
const char *text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize)
{
if (whisper_full_get_segment_speaker_turn_next(ctx, i))
{
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize)
{
printf("\n");
}
fflush(stdout);
}
}
static char *escape_double_quotes_and_backslashes(const char *str)
{
if (str == NULL)
{
return NULL;
}
size_t escaped_length = strlen(str) + 1;
for (size_t i = 0; str[i] != '\0'; i++)
{
if (str[i] == '"' || str[i] == '\\')
{
escaped_length++;
}
}
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL)
{
return NULL;
}
size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++)
{
if (str[i] == '"' || str[i] == '\\')
{
escaped[pos++] = '\\';
}
escaped[pos++] = str[i];
}
// no need to set zero due to calloc() being used prior
return escaped;
}
// double quote should be escaped by another double quote. (rfc4180)
static char *escape_double_quotes_in_csv(const char *str)
{
if (str == NULL)
{
return NULL;
}
size_t escaped_length = strlen(str) + 1;
for (size_t i = 0; str[i] != '\0'; i++)
{
if (str[i] == '"')
{
escaped_length++;
}
}
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
if (escaped == NULL)
{
return NULL;
}
size_t pos = 0;
for (size_t i = 0; str[i] != '\0'; i++)
{
if (str[i] == '"')
{
escaped[pos++] = '"';
}
escaped[pos++] = str[i];
}
// no need to set zero due to calloc() being used prior
return escaped;
}
static void cb_log_disable(enum ggml_log_level, const char *, void *) {}
int main(int argc, char **argv)
{
whisper_params params;
// If the only argument starts with "@", read arguments line-by-line
// from the given file.
std::vector<std::string> vec_args;
if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@')
{
// Save the name of the executable.
vec_args.push_back(argv[0]);
// Open the response file.
char const *rspfile = argv[1] + sizeof(char);
std::ifstream fin(rspfile);
if (fin.is_open() == false)
{
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
return 1;
}
// Read the entire response file.
std::string line;
while (std::getline(fin, line))
{
vec_args.push_back(line);
}
// Use the contents of the response file as the command-line arguments.
argc = static_cast<int>(vec_args.size());
argv = static_cast<char **>(alloca(argc * sizeof(char *)));
for (int i = 0; i < argc; ++i)
{
argv[i] = const_cast<char *>(vec_args[i].c_str());
}
}
if (whisper_params_parse(argc, argv, params) == false)
{
whisper_print_usage(argc, argv, params);
return 1;
}
// remove non-existent files
for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();)
{
const auto fname_inp = it->c_str();
if (*it != "-" && !is_file_exist(fname_inp))
{
fprintf(stderr, "error: input file not found '%s'\n", fname_inp);
it = params.fname_inp.erase(it);
continue;
}
it++;
}
if (params.fname_inp.empty())
{
fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params);
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1)
{
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
if (params.diarize && params.tinydiarize)
{
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
exit(0);
}
if (params.no_prints)
{
whisper_log_set(cb_log_disable, NULL);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
if (!params.dtw.empty())
{
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
if (params.dtw == "tiny")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
if (params.dtw == "tiny.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
if (params.dtw == "base")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
if (params.dtw == "base.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
if (params.dtw == "small")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
if (params.dtw == "small.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
if (params.dtw == "medium")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
if (params.dtw == "medium.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
if (params.dtw == "large.v1")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
if (params.dtw == "large.v2")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
if (params.dtw == "large.v3")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE)
{
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
return 3;
}
}
struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr)
{
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
if (!params.grammar.empty())
{
auto &grammar = params.grammar_parsed;
if (is_file_exist(params.grammar.c_str()))
{
// read grammar from file
std::ifstream ifs(params.grammar.c_str());
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
grammar = grammar_parser::parse(txt.c_str());
}
else
{
// read grammar from string
grammar = grammar_parser::parse(params.grammar.c_str());
}
// will be empty (default) if there are parse errors
if (grammar.rules.empty())
{
fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
return 4;
}
else
{
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, grammar);
fprintf(stderr, "\n");
}
}
for (int f = 0; f < (int)params.fname_inp.size(); ++f)
{
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize))
{
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
}
if (!whisper_is_multilingual(ctx)) // TODO: something off here
{
if (params.language != "en" || params.translate)
{
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
if (params.detect_language)
{
params.language = "auto";
}
if (!params.no_prints)
{
// print system information
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads * params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
// print some info about the processing
fprintf(stderr, "\n");
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size()) / WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors, params.beam_size, params.best_of,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
wparams.temperature = params.temperature;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.no_timestamps = params.no_timestamps;
whisper_print_user_data user_data = {&params, &pcmf32s, 0};
const auto &grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();
if (use_grammar)
{
if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end())
{
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
}
else
{
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
wparams.grammar_penalty = params.grammar_penalty;
}
}
// this callback is called on each new segment
if (!wparams.print_realtime)
{
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
if (wparams.print_progress)
{
wparams.progress_callback = whisper_print_progress_callback;
wparams.progress_callback_user_data = &user_data;
}
// examples for abort mechanism
// in examples below, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void *user_data)
{
bool is_aborted = *(bool *)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void *user_data)
{
bool is_aborted = *(bool *)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_encode_wo_cross_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
}
ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx);
// print_ggml_tensor("embd_enc", embd_enc, true);
print_ggml_tensor_shape("embd_enc", embd_enc);
}
}
if (!params.no_prints)
{
whisper_print_timings(ctx);
}
whisper_free(ctx);
return 0;
}

View file

@ -0,0 +1,19 @@
#include "omni.h"
int main(int argc, char **argv)
{
omni_context_params ctx_params = omni_context_default_params();
if (!omni_context_params_parse(argc, argv, ctx_params))
{
return 1;
}
omni_context *ctx_omni = omni_init_context(ctx_params);
omni_process_full(ctx_omni, ctx_params);
omni_free(ctx_omni);
return 0;
}

View file

@ -0,0 +1,864 @@
#include "omni.h"
#include "audio-projector.h"
#include "common-nexa.h"
#include "whisper.h"
#include "llama.h"
#include "common.h"
#include "log.h"
// #include "arg.h"
#include "sampling.h"
#include "llama-impl.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
//
// Constants
//
static const char *AUDIO_TOKEN = "<|AUDIO|>";
//
// Whisper
//
struct whisper_params
{
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t audio_ctx = 0;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float grammar_penalty = 100.0f;
float temperature = 0.0f;
float temperature_inc = 0.2f;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool no_prints = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
bool flash_attn = false;
std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string grammar;
std::string grammar_rule;
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
// A regular expression that matches tokens to suppress
std::string suppress_regex;
std::string openvino_encode_device = "CPU";
std::string dtw = "";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
grammar_parser::parse_state grammar_parsed;
};
static void whisper_print_usage(int argc, char **argv, const whisper_params &params);
static char *whisper_param_turn_lowercase(char *in)
{
int string_len = strlen(in);
for (int i = 0; i < string_len; i++)
{
*(in + i) = tolower((unsigned char)*(in + i));
}
return in;
}
static bool whisper_params_parse(int argc, char **argv, whisper_params &params)
{
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];
if (arg == "-")
{
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-')
{
// params.fname_inp.push_back(arg);
continue;
}
if (arg == "-h" || arg == "--help")
{
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads")
{
params.n_threads = std::stoi(argv[++i]);
}
else if (arg == "-p" || arg == "--processors")
{
params.n_processors = std::stoi(argv[++i]);
}
else if (arg == "-ot" || arg == "--offset-t")
{
params.offset_t_ms = std::stoi(argv[++i]);
}
else if (arg == "-on" || arg == "--offset-n")
{
params.offset_n = std::stoi(argv[++i]);
}
else if (arg == "-d" || arg == "--duration")
{
params.duration_ms = std::stoi(argv[++i]);
}
else if (arg == "-mc" || arg == "--max-context")
{
params.max_context = std::stoi(argv[++i]);
}
else if (arg == "-ml" || arg == "--max-len")
{
params.max_len = std::stoi(argv[++i]);
}
else if (arg == "-bo" || arg == "--best-of")
{
params.best_of = std::stoi(argv[++i]);
}
else if (arg == "-bs" || arg == "--beam-size")
{
params.beam_size = std::stoi(argv[++i]);
}
else if (arg == "-ac" || arg == "--audio-ctx")
{
params.audio_ctx = std::stoi(argv[++i]);
}
else if (arg == "-wt" || arg == "--word-thold")
{
params.word_thold = std::stof(argv[++i]);
}
else if (arg == "-et" || arg == "--entropy-thold")
{
params.entropy_thold = std::stof(argv[++i]);
}
else if (arg == "-lpt" || arg == "--logprob-thold")
{
params.logprob_thold = std::stof(argv[++i]);
}
else if (arg == "-tp" || arg == "--temperature")
{
params.temperature = std::stof(argv[++i]);
}
else if (arg == "-tpi" || arg == "--temperature-inc")
{
params.temperature_inc = std::stof(argv[++i]);
}
else if (arg == "-debug" || arg == "--debug-mode")
{
params.debug_mode = true;
}
else if (arg == "-tr" || arg == "--translate")
{
params.translate = true;
}
else if (arg == "-di" || arg == "--diarize")
{
params.diarize = true;
}
else if (arg == "-tdrz" || arg == "--tinydiarize")
{
params.tinydiarize = true;
}
else if (arg == "-sow" || arg == "--split-on-word")
{
params.split_on_word = true;
}
else if (arg == "-nf" || arg == "--no-fallback")
{
params.no_fallback = true;
}
else if (arg == "-fp" || arg == "--font-path")
{
params.font_path = argv[++i];
}
else if (arg == "-np" || arg == "--no-prints")
{
params.no_prints = true;
}
else if (arg == "-ps" || arg == "--print-special")
{
params.print_special = true;
}
else if (arg == "-pc" || arg == "--print-colors")
{
params.print_colors = true;
}
else if (arg == "-pp" || arg == "--print-progress")
{
params.print_progress = true;
}
else if (arg == "-nt" || arg == "--no-timestamps")
{
params.no_timestamps = true;
}
else if (arg == "-l" || arg == "--language")
{
params.language = whisper_param_turn_lowercase(argv[++i]);
}
else if (arg == "-dl" || arg == "--detect-language")
{
params.detect_language = true;
}
else if (arg == "--prompt")
{
params.prompt = argv[++i];
}
else if (arg == "-m" || arg == "--model")
{
params.model = argv[++i];
}
else if (arg == "-f" || arg == "--file")
{
params.fname_inp.emplace_back(argv[++i]);
}
else if (arg == "-oved" || arg == "--ov-e-device")
{
params.openvino_encode_device = argv[++i];
}
else if (arg == "-dtw" || arg == "--dtw")
{
params.dtw = argv[++i];
}
else if (arg == "-ls" || arg == "--log-score")
{
params.log_score = true;
}
else if (arg == "-ng" || arg == "--no-gpu")
{
params.use_gpu = false;
}
else if (arg == "-fa" || arg == "--flash-attn")
{
params.flash_attn = true;
}
else if (arg == "--suppress-regex")
{
params.suppress_regex = argv[++i];
}
else if (arg == "--grammar")
{
params.grammar = argv[++i];
}
else if (arg == "--grammar-rule")
{
params.grammar_rule = argv[++i];
}
else if (arg == "--grammar-penalty")
{
params.grammar_penalty = std::stof(argv[++i]);
}
else if (arg == "--mmproj")
{
continue;
// params.mmproj = argv[++i];
}
else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers")
{
continue;
// params.n_gpu_layers = std::stoi(argv[++i]);
}
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params &params)
{
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc);
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
struct whisper_context *whisper_init_context(whisper_params &params)
{
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
if (!params.dtw.empty())
{
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
if (params.dtw == "tiny")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
if (params.dtw == "tiny.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
if (params.dtw == "base")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
if (params.dtw == "base.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
if (params.dtw == "small")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
if (params.dtw == "small.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
if (params.dtw == "medium")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
if (params.dtw == "medium.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
if (params.dtw == "large.v1")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
if (params.dtw == "large.v2")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
if (params.dtw == "large.v3")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE)
{
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
return NULL;
}
}
struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr)
{
fprintf(stderr, "error: failed to initialize whisper context\n");
return NULL;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
return ctx;
}
static struct whisper_full_params get_whisper_inference_params_from_whisper_params(whisper_params &params)
{
struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
wparams.temperature = params.temperature;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.no_timestamps = params.no_timestamps;
return wparams;
}
//
// Omni
//
static void omni_print_usage(int, char **argv)
{
LOG("\n example usage:\n");
LOG("\n %s --model <omni/ggml-model.gguf> --mmproj <whisper/model-f16.gguf> --file <path/to/an/audio.wav> [-p \"describe the audio in detail.\"]\n", argv[0]);
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
bool omni_context_params_parse(int argc, char **argv, omni_context_params &params)
{
for (int i = 1; i < argc; i++)
{
std::string arg = argv[i];
if (arg[0] != '-')
{
continue;
}
if (arg == "-h" || arg == "--help")
{
omni_print_usage(argc, argv);
exit(0);
}
if (arg == "--prompt")
{
params.prompt = argv[++i];
}
else if (arg == "-m" || arg == "--model")
{
params.model = argv[++i];
}
else if (arg == "-f" || arg == "--file")
{
params.file = argv[++i];
}
else if (arg == "--mmproj")
{
params.mmproj = argv[++i];
}
else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers")
{
params.n_gpu_layers = std::stoi(argv[++i]);
}
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
omni_print_usage(argc, argv);
exit(0);
}
}
return true;
}
omni_context_params omni_context_default_params()
{
omni_context_params params = {
.model = "",
.mmproj = "",
.file = "",
.prompt = "this conversation talks about",
.n_gpu_layers = -1,
};
return params;
}
struct omni_params
{
gpt_params gpt;
whisper_params whisper;
};
bool omni_params_parse(int argc, char **argv, omni_params &params)
{
if (!gpt_params_parse(argc, argv, params.gpt))
{
return false;
}
if (!whisper_params_parse(argc, argv, params.whisper))
{
whisper_print_usage(argc, argv, params.whisper);
return false;
}
if (params.gpt.model.empty() || params.gpt.mmproj.empty() || params.whisper.fname_inp.empty())
{
omni_print_usage(argc, argv);
return false;
}
params.whisper.model = params.gpt.mmproj;
return true;
}
static omni_params get_omni_params_from_context_params(omni_context_params &params)
{
omni_params all_params = {
.gpt = {
.n_gpu_layers = params.n_gpu_layers,
.model = params.model,
.prompt = params.prompt,
},
.whisper = {
.model = params.mmproj,
.fname_inp = {params.file},
},
};
if (all_params.gpt.n_threads <= 0)
{
all_params.gpt.n_threads = std::thread::hardware_concurrency();
}
return all_params;
}
static bool eval_tokens(struct llama_context *ctx_llama, std::vector<llama_token> tokens, int n_batch, int *n_past)
{
int N = (int)tokens.size();
for (int i = 0; i < N; i += n_batch)
{
int n_eval = (int)tokens.size() - i;
if (n_eval > n_batch)
{
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0)))
{
LLAMA_LOG_ERROR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}
*n_past += n_eval;
}
return true;
}
static bool eval_id(struct llama_context *ctx_llama, int id, int *n_past)
{
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past);
}
static bool eval_string(struct llama_context *ctx_llama, const char *str, int n_batch, int *n_past, bool add_bos)
{
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
static const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
}
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}
static size_t find_audio_token(const std::string &prompt)
{
return prompt.find(AUDIO_TOKEN);
}
struct omni_context *omni_init_context(omni_context_params &params)
{
omni_params all_params = get_omni_params_from_context_params(params);
// llama
LLAMA_LOG_INFO("------- llama --------\n");
std::string prompt = all_params.gpt.prompt;
if (prompt.empty())
{
prompt = "this conversation talks about";
}
llama_backend_init();
llama_numa_init(all_params.gpt.numa);
llama_model_params model_params = llama_model_params_from_gpt_params(all_params.gpt);
llama_model *model = llama_load_model_from_file(all_params.gpt.model.c_str(), model_params);
if (model == NULL)
{
LLAMA_LOG_ERROR("%s: unable to load model\n", __func__);
return NULL;
}
llama_context_params ctx_params = llama_context_params_from_gpt_params(all_params.gpt);
ctx_params.n_ctx = all_params.gpt.n_ctx < 2048 ? 2048 : all_params.gpt.n_ctx; // we need a longer context size to process image embeddings
llama_context *ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL)
{
LLAMA_LOG_ERROR("%s: failed to create the llama_context\n", __func__);
return NULL;
}
// whisper
LLAMA_LOG_INFO("------- whisper --------\n");
whisper_context *ctx_whisper = whisper_init_context(all_params.whisper);
// projector
LLAMA_LOG_INFO("------- projector --------\n");
audio_projector *projector = new audio_projector();
if (!projector->load_from_gguf(all_params.whisper.model))
{
fprintf(stderr, "Failed to load model.\n");
return NULL;
}
auto *ctx_omni = (struct omni_context *)malloc(sizeof(omni_context));
ctx_omni->ctx_llama = ctx_llama;
ctx_omni->ctx_whisper = ctx_whisper;
ctx_omni->model = model;
ctx_omni->projector = projector;
LLAMA_LOG_INFO("------- omni context initialized --------\n");
return ctx_omni;
}
void omni_free(struct omni_context *ctx_omni)
{
if (ctx_omni->ctx_whisper)
{
whisper_free(ctx_omni->ctx_whisper);
ctx_omni->ctx_whisper = NULL;
}
if (ctx_omni->projector)
{
ctx_omni->projector->free();
}
llama_free(ctx_omni->ctx_llama);
llama_free_model(ctx_omni->model);
llama_backend_free();
}
static bool omni_eval_audio_embed(llama_context *ctx_llama, ggml_tensor *audio_embed, int n_batch, int *n_past)
{
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
int n_audio_embed = audio_embed->ne[1];
GGML_ASSERT(audio_embed->ne[0] == n_embd);
size_t audio_embed_size = ggml_nbytes(audio_embed);
float *audio_embed_data = (float *)malloc(audio_embed_size);
ggml_backend_tensor_get(audio_embed, audio_embed_data, 0, audio_embed_size);
for (int i = 0; i < n_audio_embed; i += n_batch)
{
int n_eval = n_audio_embed - i;
if (n_eval > n_batch)
{
n_eval = n_batch;
}
llama_batch batch = {
/* n_tokens */ int32_t(n_eval),
/* token */ nullptr,
/* embd */ (audio_embed_data + i * n_embd),
/* pos */ nullptr,
/* n_seq_id */ nullptr,
/* seq_id */ nullptr,
/* logits */ nullptr,
/* all_pos_0 */ *n_past,
/* all_pos_1 */ 1,
/* all_seq_id */ 0,
};
if (llama_decode(ctx_llama, batch))
{
LLAMA_LOG_ERROR("%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params &params)
{
auto fname_inp = params.whisper.fname_inp[0];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.whisper.diarize))
{
LLAMA_LOG_ERROR("error: failed to read WAV file '%s'\n", fname_inp.c_str());
return NULL;
}
whisper_full_params wparams = get_whisper_inference_params_from_whisper_params(params.whisper);
if (whisper_encode_wo_cross_parallel(ctx_omni->ctx_whisper, wparams, pcmf32.data(), pcmf32.size(), params.whisper.n_processors) != 0)
{
LLAMA_LOG_ERROR("%s: failed to process audio\n", __func__);
return NULL;
}
ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx_omni->ctx_whisper);
#ifdef NEXA_DEBUG
print_ggml_tensor_shape("embd_enc", embd_enc);
#endif
ggml_tensor *embed_proj = audio_projector_inference(*ctx_omni->projector, embd_enc);
#ifdef NEXA_DEBUG
print_ggml_tensor_shape("embed_proj", embed_proj);
#endif
return embed_proj;
}
void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params &params, const std::string &prompt)
{
int n_past = 0;
int n_audio_embed = audio_embed->ne[1];
GGML_ASSERT(params.gpt.n_predict < 0 || params.gpt.n_predict > n_audio_embed);
const int max_tgt_len = params.gpt.n_predict < 0 ? 256 + n_audio_embed : params.gpt.n_predict;
std::string system_prompt, user_prompt;
size_t audio_pos = find_audio_token(prompt);
if (audio_pos != std::string::npos)
{
system_prompt = prompt.substr(0, audio_pos);
user_prompt = prompt.substr(audio_pos + std::string(AUDIO_TOKEN).length());
// LLAMA_LOG_INFO("system_prompt: %s\n", system_prompt.c_str());
// LLAMA_LOG_INFO("user_prompt: %s\n", user_prompt.c_str());
}
else
{
system_prompt = "<start_of_turn>user\nAudio 1: <|audio_bos|>";
user_prompt = "<|audio_eos|>\n" + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
eval_string(ctx_omni->ctx_llama, system_prompt.c_str(), params.gpt.n_batch, &n_past, true);
omni_eval_audio_embed(ctx_omni->ctx_llama, audio_embed, params.gpt.n_batch, &n_past);
eval_string(ctx_omni->ctx_llama, user_prompt.c_str(), params.gpt.n_batch, &n_past, false);
// generate the response
LOG("\n");
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.gpt.sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
std::string response = "";
for (int i = 0; i < max_tgt_len; i++)
{
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0)
break;
if (strstr(tmp, "###"))
break; // Yi-VL behavior
printf("%s", tmp);
if (strstr(response.c_str(), "<|im_end|>"))
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
if (strstr(response.c_str(), "<|im_start|>"))
break; // Yi-34B llava-1.6
if (strstr(response.c_str(), "USER:"))
break; // mistral llava-1.6
fflush(stdout);
}
llama_sampling_free(ctx_sampling);
printf("\n");
}
void omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
{
omni_params all_params = get_omni_params_from_context_params(params);
ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
}

View file

@ -0,0 +1,64 @@
#pragma once
#include "whisper.h"
#include "llama.h"
#include "grammar-parser.h"
#include "common.h"
#include "common-nexa.h"
#include <string>
#include <thread>
#include "audio-projector.h"
#ifdef OMNI_AUDIO_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef OMNI_AUDIO_BUILD
# define OMNI_AUDIO_API __declspec(dllexport)
# else
# define OMNI_AUDIO_API __declspec(dllimport)
# endif
# else
# define OMNI_AUDIO_API __attribute__ ((visibility ("default")))
# endif
#else
# define OMNI_AUDIO_API
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct omni_context_params
{
const char *model;
const char *mmproj;
const char *file;
const char *prompt;
int32_t n_gpu_layers;
};
struct omni_context
{
struct whisper_context *ctx_whisper;
struct audio_projector *projector;
struct llama_context *ctx_llama;
struct llama_model *model;
};
OMNI_AUDIO_API bool omni_context_params_parse(int argc, char **argv, omni_context_params &params);
OMNI_AUDIO_API omni_context_params omni_context_default_params();
OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params &params);
OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
OMNI_AUDIO_API void omni_process_full(
struct omni_context *ctx_omni,
omni_context_params &params
);
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,364 @@
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
#include "whisper-mel-cuda.hpp"
#include "whisper.h"
#include "common.cuh"
#include <ggml-backend.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cufft.h>
#include <cublas_v2.h>
#include <cuComplex.h>
#include <cub/device/device_reduce.cuh>
#include <device_launch_parameters.h>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4324) // added padding
#endif
namespace {
static const char* cufftGetErrorString(cufftResult_t res) {
switch (res) {
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
case CUFFT_INVALID_TYPE: return "No longer used";
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
case CUFFT_UNALIGNED_DATA: return "No longer used";
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
case CUFFT_PARSE_ERROR: return "Internal plan database error";
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
default: return "Unknown error";
}
}
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
__global__ void k_fill_stft_input(
const float * padded_samples,
const int n_frames,
const float * hann_window,
float * stft_in
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT) return;
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
auto outLine = stft_in + y * WHISPER_N_FFT;
outLine[x] = line[x] * hann_window[x];
}
__global__ void k_calc_magnitudes(
const cuComplex * stft_out,
const int n_frames,
float * magnitudes
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT_HALF) return;
auto idx = y * WHISPER_N_FFT_HALF + x;
auto r = stft_out[idx].x;
auto i = stft_out[idx].y;
magnitudes[idx] = r * r + i * i;
}
__global__ void k_calc_log_mel(
const float * mel_data,
const int n_mel,
const float * max_val,
float * log_mel
) {
auto x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= n_mel) return;
float val = mel_data[x];
constexpr float e = 1e-10f;
if (val < e) val = e;
val = log10(val);
const float max = log10(*max_val) - 8.f;
if (val < max) val = max;
log_mel[x] = (val + 4) / 4;
}
static void fill_stft_input(
const float * padded_samples,
int n_frames,
const float * hann_window,
float * stft_in,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT, 1);
dim3 grid(1, n_frames);
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
}
static void calc_magnitudes(
const cuComplex * stft_out,
int n_frames,
float * magnitudes,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT_HALF, 1);
dim3 grid(1, n_frames);
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
}
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
static void calc_log_mel(
const float * mel_data,
int n_mel,
void * tempStorage,
int tempStorageSize,
float * log_mel,
cudaStream_t stream
) {
float * max_val = reinterpret_cast<float *>(tempStorage);
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
int block = 256;
int grid = (n_mel + block - 1) / block;
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
}
class mel_calc_cuda : public whisper_mel_calc {
const int m_n_mel;
ggml_backend_t m_backend = nullptr;
int m_device = -1;
cudaStream_t m_stream = nullptr;
cublasHandle_t m_cublas_handle = nullptr;
float * m_hann_window = nullptr;
float * m_filters = nullptr;
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
int m_n_max_samples = 0;
size_t m_cufft_workspace_size = 0;
void * m_cufft_workspace = nullptr;
size_t m_log_mel_temp_storage_size = 0;
void * m_log_mel_temp_storage = nullptr;
public:
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
: m_n_mel(filters.n_mel)
, m_backend(backend)
{
ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*) m_backend->context;
m_device = cuda_ctx->device;
if (ggml_cuda_info().devices[m_device].cc < 600) {
// we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
// https://github.com/ggerganov/whisper.cpp/issues/2230
// to be safe forbid anything below 6.0
throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
}
ggml_cuda_set_device(m_device);
if (filters.n_fft != WHISPER_N_FFT_HALF) {
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
}
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
CUDA_CHECK(cudaStreamCreate(&m_stream));
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
// create Hann window
{
auto hw = whisper_mel_calc::hann_window();
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// fill filters
{
auto& f = filters.data;
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// preallocate working areas enough for the most common cases (<= 30s)
ensure_working_areas(WHISPER_N_SAMPLES);
}
~mel_calc_cuda() {
ggml_cuda_set_device(m_device);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
CUDA_CHECK(cudaStreamDestroy(m_stream));
CUDA_CHECK(cudaFree(m_hann_window));
CUDA_CHECK(cudaFree(m_cufft_workspace));
CUDA_CHECK(cudaFree(m_filters));
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
}
void ensure_working_areas(int n_samples) {
if (n_samples <= m_n_max_samples) {
return;
}
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
// cufft workspace
{
if (m_cufft_workspace) {
CUDA_CHECK(cudaFree(m_cufft_workspace));
m_cufft_workspace_size = 0;
m_cufft_workspace = nullptr;
}
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// device reduce working area
{
if (m_log_mel_temp_storage) {
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
m_log_mel_temp_storage_size = 0;
m_log_mel_temp_storage = nullptr;
}
const auto max_mels = 160;
size_t nbytes = 0;
float* temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
m_n_max_samples = n_samples;
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
ggml_cuda_set_device(m_device);
ensure_working_areas(samples.len);
const size_t mirror_pad = WHISPER_N_FFT / 2;
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
// pad
std::vector<float> padded_samples(padded_size);
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
// fill the rest of the data
// it should canonically be mirrored at the end as well,
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
float * cu_padded_samples = nullptr;
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
float * stft_in = nullptr; // contiguous buffer for stft input
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
cufftComplex* stft_out;
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
cufftHandle plan;
CUFFT_CHECK(cufftCreate(&plan));
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
{
size_t waSize;
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
assert(waSize <= m_cufft_workspace_size);
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
CUFFT_CHECK(cufftSetStream(plan, m_stream));
}
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
const auto n_mag_frames = n_frames - 1; // drop last frame
float * magnitudes;
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
float * mel_data = nullptr;
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
const float fone = 1.0f, fzero = 0.0f;
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
&fone,
magnitudes, WHISPER_N_FFT_HALF,
m_filters, WHISPER_N_FFT_HALF,
&fzero,
mel_data, int(n_mag_frames)));
whisper_mel ret;
// Calculate semi-padded sample length to ensure compatibility
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
calc_log_mel(
mel_data, int(m_n_mel * n_mag_frames),
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
log_mels, m_stream);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
// cleanup
CUFFT_CHECK(cufftDestroy(plan));
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
return ret;
}
};
}
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
try {
return new mel_calc_cuda(backend, filters);
}
catch (...) {
// TODO: log error (but for this we would have to expose the log state to be accessible here)
return nullptr;
}
}

View file

@ -0,0 +1,3 @@
#include "whisper-mel.hpp"
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);

View file

@ -0,0 +1,34 @@
#pragma once
#include "ggml-backend.h"
#include <vector>
struct whisper_mel {
int n_len_org = 0;
ggml_context * ctx = nullptr;
ggml_tensor * tensor = nullptr;
ggml_backend_buffer_t buffer = nullptr;
};
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
void whisper_mel_free(whisper_mel & mel);
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
template <typename T>
struct whisper_span {
T * data;
int len;
};
struct whisper_mel_calc {
virtual ~whisper_mel_calc();
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
static whisper_span<const float> hann_window();
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,686 @@
#ifndef WHISPER_H
#define WHISPER_H
#include "ggml.h"
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#include <string>
#include <vector>
#ifdef __GNUC__
# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
#else
# define WHISPER_DEPRECATED(func, hint) func
#endif
#ifdef WHISPER_SHARED
# ifdef _WIN32
# ifdef WHISPER_BUILD
# define WHISPER_API __declspec(dllexport)
# else
# define WHISPER_API __declspec(dllimport)
# endif
# else
# define WHISPER_API __attribute__ ((visibility ("default")))
# endif
#else
# define WHISPER_API
#endif
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
#define COMMON_SAMPLE_RATE 16000 // Common sample rate for audio processing (16kHz)
#ifdef __cplusplus
extern "C" {
#endif
//
// C interface
//
// The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
// concurrently.
//
// Basic usage:
//
// #include "whisper.h"
//
// ...
//
// whisper_context_params cparams = whisper_context_default_params();
//
// struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams);
//
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
// fprintf(stderr, "failed to process audio\n");
// return 7;
// }
//
// const int n_segments = whisper_full_n_segments(ctx);
// for (int i = 0; i < n_segments; ++i) {
// const char * text = whisper_full_get_segment_text(ctx, i);
// printf("%s", text);
// }
//
// whisper_free(ctx);
//
// ...
//
// This is a demonstration of the most straightforward usage of the library.
// "pcmf32" contains the RAW audio data in 32-bit floating point format.
//
// The interface also allows for more fine-grained control over the computation, but it requires a deeper
// understanding of how the model works.
//
struct whisper_context;
struct whisper_state;
struct whisper_full_params;
typedef int32_t whisper_pos;
typedef int32_t whisper_token;
typedef int32_t whisper_seq_id;
enum whisper_alignment_heads_preset {
WHISPER_AHEADS_NONE,
WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
WHISPER_AHEADS_CUSTOM,
WHISPER_AHEADS_TINY_EN,
WHISPER_AHEADS_TINY,
WHISPER_AHEADS_BASE_EN,
WHISPER_AHEADS_BASE,
WHISPER_AHEADS_SMALL_EN,
WHISPER_AHEADS_SMALL,
WHISPER_AHEADS_MEDIUM_EN,
WHISPER_AHEADS_MEDIUM,
WHISPER_AHEADS_LARGE_V1,
WHISPER_AHEADS_LARGE_V2,
WHISPER_AHEADS_LARGE_V3,
};
typedef struct whisper_ahead {
int n_text_layer;
int n_head;
} whisper_ahead;
typedef struct whisper_aheads {
size_t n_heads;
const whisper_ahead * heads;
} whisper_aheads;
struct whisper_context_params {
bool use_gpu;
bool flash_attn;
int gpu_device; // CUDA device
// [EXPERIMENTAL] Token-level timestamps with DTW
bool dtw_token_timestamps;
enum whisper_alignment_heads_preset dtw_aheads_preset;
int dtw_n_top;
struct whisper_aheads dtw_aheads;
size_t dtw_mem_size; // TODO: remove
};
typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
// [EXPERIMENTAL] Token-level timestamps with DTW
// do not use if you haven't computed token-level timestamps with dtw
// Roughly corresponds to the moment in audio in which the token was output
int64_t t_dtw;
float vlen; // voice length of the token
} whisper_token_data;
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
void (*seek)(void * ctx, size_t offset);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// grammar element type
enum whisper_gretype {
// end of rule definition
WHISPER_GRETYPE_END = 0,
// start of alternate definition for rule
WHISPER_GRETYPE_ALT = 1,
// non-terminal element: reference to rule
WHISPER_GRETYPE_RULE_REF = 2,
// terminal element: character (code point)
WHISPER_GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
WHISPER_GRETYPE_CHAR_NOT = 4,
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z])
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding WHISPER_GRETYPE_CHAR or
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
WHISPER_GRETYPE_CHAR_ALT = 6,
};
typedef struct whisper_grammar_element {
enum whisper_gretype type;
uint32_t value; // Unicode code point or rule ID
} whisper_grammar_element;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
// These are the same as the above, but the internal state of the context is not allocated automatically
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
"use whisper_init_from_file_with_params instead"
);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size),
"use whisper_init_from_buffer_with_params instead"
);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader),
"use whisper_init_with_params instead"
);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model),
"use whisper_init_from_file_with_params_no_state instead"
);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size),
"use whisper_init_from_buffer_with_params_no_state instead"
);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader),
"use whisper_init_with_params_no_state instead"
);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Given a context, enable use of OpenVINO for encode inference.
// model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
// the path will be generated from the ggml model path that was passed
// in to whisper_init_from_file. For example, if 'path_model' was
// "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
// assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
// device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
// cache_dir: Optional cache directory that can speed up init time, especially for
// GPU, by caching compiled 'blobs' there.
// Set to nullptr if not used.
// Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
WHISPER_API int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
const char * device,
const char * cache_dir);
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
WHISPER_API int whisper_set_mel(
struct whisper_context * ctx,
const float * data,
int n_len,
int n_mel);
WHISPER_API int whisper_set_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
WHISPER_API int whisper_encode(
struct whisper_context * ctx,
int offset,
int n_threads);
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
// Return the number of tokens in the provided text
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
int whisper_token_count(struct whisper_context * ctx, const char * text);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id(void);
// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str_full(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whisper_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx);
WHISPER_API int whisper_model_ftype (struct whisper_context * ctx);
WHISPER_API int whisper_model_type (struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
// Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
// Performance information from the default state.
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
// Print system information
WHISPER_API const char * whisper_print_system_info(void);
////////////////////////////////////////////////////////////////////////////
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
};
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function
// If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
struct whisper_full_params {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
bool no_timestamps; // do not generate timestamps
bool single_segment; // force single segment output (useful for streaming)
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default)
// [EXPERIMENTAL] [TDRZ] tinydiarize
bool tdrz_enable; // enable tinydiarize speaker turn detection
// A regular expression that matches tokens to suppress
const char * suppress_regex;
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
// use whisper_tokenize() to convert text to tokens
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
const char * initial_prompt;
const whisper_token * prompt_tokens;
int prompt_n_tokens;
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
bool detect_language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy;
struct {
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called each time before ggml computation starts
ggml_abort_callback abort_callback;
void * abort_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
const whisper_grammar_element ** grammar_rules;
size_t n_grammar_rules;
size_t i_start_rule;
float grammar_penalty;
};
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
WHISPER_API struct whisper_context_params whisper_context_default_params (void);
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Not thread safe for same context
// Uses the specified decoding strategy to obtain the text.
WHISPER_API int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples);
WHISPER_API int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
WHISPER_API int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
int n_processors);
// Number of generated text segments
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the context's default state
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Language id associated with the provided state
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
// Get the embedding tensor
WHISPER_API ggml_tensor * whisper_full_get_embd_conv(struct whisper_context * ctx);
WHISPER_API ggml_tensor * whisper_full_get_embd_enc(struct whisper_context * ctx);
// Get the start and end time of the specified segment
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Get number of tokens in the specified segment
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
// Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy (int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str (int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
// Control logging output; default behavior is to print to stderr
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
/* Whisper Encode without cross-attention */
WHISPER_API struct whisper_context * whisper_encoder_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_state * whisper_encoder_init_state(struct whisper_context * ctx);
WHISPER_API int whisper_encode_wo_cross(
struct whisper_context * ctx,
int offset,
int n_threads);
WHISPER_API int whisper_encode_wo_cross_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
int n_processors);
WHISPER_API bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo);
#ifdef __cplusplus
}
#endif
#endif