From c7b912bdca66ad5cc3edef66117d35837006ff2e Mon Sep 17 00:00:00 2001 From: Zack Zhiyuan Li Date: Sun, 3 Nov 2024 17:58:08 +0000 Subject: [PATCH] support omni-audio --- common/CMakeLists.txt | 3 + common/common-nexa.cpp | 317 + common/common-nexa.h | 80 + common/dr_wav.h | 6434 ++++++++++ examples/CMakeLists.txt | 79 +- examples/nexa-omni-audio/CMakeLists.txt | 56 + examples/nexa-omni-audio/README.md | 63 + examples/nexa-omni-audio/audio-projector.cpp | 37 + examples/nexa-omni-audio/audio-projector.h | 67 + examples/nexa-omni-audio/ggml-cpu-impl.h | 614 + examples/nexa-omni-audio/main-encode.cpp | 916 ++ examples/nexa-omni-audio/omni-cli.cpp | 19 + examples/nexa-omni-audio/omni.cpp | 864 ++ examples/nexa-omni-audio/omni.h | 64 + examples/nexa-omni-audio/whisper-mel-cuda.cu | 364 + examples/nexa-omni-audio/whisper-mel-cuda.hpp | 3 + examples/nexa-omni-audio/whisper-mel.hpp | 34 + examples/nexa-omni-audio/whisper.cpp | 10034 ++++++++++++++++ examples/nexa-omni-audio/whisper.h | 686 ++ 19 files changed, 20695 insertions(+), 39 deletions(-) create mode 100644 common/common-nexa.cpp create mode 100644 common/common-nexa.h create mode 100644 common/dr_wav.h create mode 100644 examples/nexa-omni-audio/CMakeLists.txt create mode 100644 examples/nexa-omni-audio/README.md create mode 100644 examples/nexa-omni-audio/audio-projector.cpp create mode 100644 examples/nexa-omni-audio/audio-projector.h create mode 100644 examples/nexa-omni-audio/ggml-cpu-impl.h create mode 100644 examples/nexa-omni-audio/main-encode.cpp create mode 100644 examples/nexa-omni-audio/omni-cli.cpp create mode 100644 examples/nexa-omni-audio/omni.cpp create mode 100644 examples/nexa-omni-audio/omni.h create mode 100644 examples/nexa-omni-audio/whisper-mel-cuda.cu create mode 100644 examples/nexa-omni-audio/whisper-mel-cuda.hpp create mode 100644 examples/nexa-omni-audio/whisper-mel.hpp create mode 100644 examples/nexa-omni-audio/whisper.cpp create mode 100644 examples/nexa-omni-audio/whisper.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 761971d68..14d91bc0a 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -66,6 +66,9 @@ add_library(${TARGET} STATIC train.cpp ngram-cache.h ngram-cache.cpp + common-nexa.h + common-nexa.cpp + dr_wav.h ) if (BUILD_SHARED_LIBS) diff --git a/common/common-nexa.cpp b/common/common-nexa.cpp new file mode 100644 index 000000000..e8a54ba04 --- /dev/null +++ b/common/common-nexa.cpp @@ -0,0 +1,317 @@ +#include "common-nexa.h" + +#include +#include +#include +#include + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include "common.h" +#include +#include + +void print_ggml_tensor(const char *name, const struct ggml_tensor *tensor, bool use_backend, int precision) { + std::vector data(ggml_nelements(tensor)); + if (use_backend) { + ggml_backend_tensor_get(tensor, data.data(), 0, ggml_nbytes(tensor)); + } else { + memcpy(data.data(), ggml_get_data_f32(tensor), ggml_nbytes(tensor)); + } + + std::vector shape; + for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) shape.push_back(tensor->ne[i]); + + print_ggml_tensor_shape(name, tensor); + + size_t offset = 0; + std::function print_recursive = [&](size_t dim, size_t &offset) { + if (dim == shape.size()) { + printf("%.*f", precision, data[offset++]); + } else { + printf("[ "); + for (int64_t i = 0; i < shape[dim]; ++i) { + if (i > 0) printf(dim == shape.size() - 1 ? ", " : ",\n%*s", static_cast(dim) + 1, ""); + print_recursive(dim + 1, offset); + } + printf("]"); + } + }; + print_recursive(0, offset); + printf("\n"); +} + +void print_ggml_tensor_shape(const char *name, const struct ggml_tensor *tensor) { + printf("%s: [ ", name); + for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) printf("%d ", static_cast(tensor->ne[i])); + printf("]\n"); +} + +bool load_hparams_and_tensors_from_gguf(const std::string &fname, NexaBaseModel &model, bool verbose) +{ + + // Initialize GGUF context + ggml_context *meta = nullptr; + gguf_init_params params = {true, &meta}; + gguf_context *ctx_gguf = gguf_init_from_file(fname.c_str(), params); + + // Check if GGUF context initialization was successful + if (!ctx_gguf) + return fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__), false; + + // Get the number of tensors in the GGUF file + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + const int n_tensors_in_model = model.tensor_names.size(); + if (n_tensors_in_model > n_tensors) + { + fprintf(stderr, "%s: model tensor_names size (%d) is greater than the number of tensors in the GGUF file (%d)\n", __func__, n_tensors_in_model, n_tensors); + gguf_free(ctx_gguf); + return false; + } + + // Load hyperparameters + for (const auto &name : model.hparam_names) + { + int key = gguf_find_key(ctx_gguf, name.c_str()); + if (key != -1) + { + model.hparams[name] = gguf_get_val_i32(ctx_gguf, key); + if (verbose) + fprintf(stderr, "%s: loaded hparam '%s' = %d\n", __func__, name.c_str(), std::get(model.hparams[name])); + } + else + return fprintf(stderr, "%s: failed to load hparam '%s'\n", __func__, name.c_str()), gguf_free(ctx_gguf), false; + } + + // Initialize GGML context for tensor data + model.ctx_data = ggml_init({(n_tensors_in_model + 1) * ggml_tensor_overhead(), nullptr, true}); + if (!model.ctx_data) + return fprintf(stderr, "%s: ggml_init() failed\n", __func__), gguf_free(ctx_gguf), false; + + // Open the GGUF file for reading tensor data + std::ifstream fin(fname, std::ios::binary); + if (!fin) + return fprintf(stderr, "%s: cannot open model file for loading tensors\n", __func__), gguf_free(ctx_gguf), false; + + // Create tensor structures in the GGML context + for (const auto &name : model.tensor_names) + { + ggml_tensor *t = ggml_dup_tensor(model.ctx_data, ggml_get_tensor(meta, name.c_str())); + ggml_set_name(t, name.c_str()); + } + + // Allocate memory for tensors using the specified backend + model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx_data, model.backend); + + // Load tensors from the GGUF file + for (const auto &name : model.tensor_names) + { + char *tensor_name = const_cast(name.c_str()); + if (verbose) + fprintf(stderr, "%s: loading tensor '%s'\n", __func__, tensor_name); + + ggml_tensor *cur = ggml_get_tensor(model.ctx_data, tensor_name); + if (!cur) + return fprintf(stderr, "%s: failed to get tensor %s\n", __func__, tensor_name), gguf_free(ctx_gguf), false; + + int tensor_idx = gguf_find_tensor(ctx_gguf, tensor_name); + fin.seekg(gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, tensor_idx), std::ios::beg); + + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buffer_is_host(model.buffer)) + { + fin.read(reinterpret_cast(cur->data), num_bytes); + } + else + { + std::vector read_buf(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + + // mapping tensor name to tensor pointer + model.tensors[name] = cur; + if (verbose) + { + fprintf(stderr, "%s: mapped tensor ", __func__); + print_ggml_tensor_shape(tensor_name, cur); + } + } + + ggml_free(meta); + return true; +} + +// +// NexaBaseModel +// + +// initialize from gguf file +bool NexaBaseModel::load_from_gguf(const std::string &fname) +{ + init_backend(); + + bool verbose = false; +#ifdef NEXA_DEBUG + verbose = true; +#endif + + if (!load_hparams_and_tensors_from_gguf(fname, *this, verbose)) + { + NEXA_LOG("failed to load params and tensors"); + return false; + } + + reserve_memory(); + + return true; +} + +// Initialize the backend based on available hardware +void NexaBaseModel::init_backend() +{ +#ifdef GGML_USE_CUDA + NEXA_LOG("using CUDA backend"); + backend = ggml_backend_cuda_init(0); // Initialize CUDA on device 0 +#endif + +#ifdef GGML_USE_METAL + NEXA_LOG("using Metal backend"); + backend = ggml_backend_metal_init(); // Initialize Metal backend +#endif + + // Fallback to CPU backend if no GPU is available + if (!backend) + { + backend = ggml_backend_cpu_init(); + fprintf(stderr, "%s: using CPU backend\n", __func__); + } +} + +// measure mem requirement and allocate +void NexaBaseModel::reserve_memory() +{ + compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + struct ggml_cgraph *gf = build_graph(); + ggml_gallocr_reserve(compute_alloc, gf); + size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(compute_alloc, 0); + NEXA_LOG("compute allocated memory: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); +} + +// set the number of threads +void NexaBaseModel::set_n_threads(int n_threads) +{ + if (n_threads <= 0) + { + // if n_threads is not set, use the number of cores + n_threads = std::thread::hardware_concurrency(); + } + + // Set backend options + if (ggml_backend_is_cpu(backend)) + { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(backend)) + { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif +} + +// Free allocated memory +void NexaBaseModel::free() +{ + ggml_gallocr_free(compute_alloc); + ggml_free(ctx_data); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); +} + +void print_ggml_tensor_stats(const char *name, const struct ggml_tensor *tensor, bool use_backend) { + std::vector data(ggml_nelements(tensor)); + if (use_backend) { + ggml_backend_tensor_get(tensor, data.data(), 0, ggml_nbytes(tensor)); + } else { + memcpy(data.data(), ggml_get_data_f32(tensor), ggml_nbytes(tensor)); + } + + if (data.empty()) { + printf("%s: Empty tensor\n", name); + return; + } + + // Calculate mean + double sum = std::accumulate(data.begin(), data.end(), 0.0); + double mean = sum / data.size(); + + // Calculate variance using two-pass algorithm for better numerical stability + double sq_sum = 0.0; + for (const auto &val : data) { + double diff = val - mean; + sq_sum += diff * diff; + } + double variance = sq_sum / data.size(); + + // Print statistics + printf("%s:\n", name); + printf(" Shape: ["); + for (int i = 0; i < GGML_MAX_DIMS && tensor->ne[i] > 1; ++i) { + printf("%d%s", static_cast(tensor->ne[i]), (i < GGML_MAX_DIMS - 1 && tensor->ne[i+1] > 1) ? ", " : ""); + } + printf("]\n"); + printf(" Mean: %.6f\n", mean); + printf(" Variance: %.6f\n", variance); + printf(" Standard Deviation: %.6f\n", std::sqrt(variance)); +} + +void print_all_tensor_names(struct gguf_context *ctx) { + int n_tensors = gguf_get_n_tensors(ctx); + printf("Number of tensors: %d\n", n_tensors); + + const char *separator = ""; + printf("Tensors: "); + for (int i = 0; i < n_tensors; ++i) { + const char *tensor_name = gguf_get_tensor_name(ctx, i); + printf("%s%s", separator, tensor_name); + separator = ", "; // Set separator after the first tensor + } + printf("\n"); +} + +struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name) { + struct ggml_tensor * tensor = ggml_get_tensor(ctx, name); + // print_ggml_tensor_stats(name, tensor, false); + if (!tensor) { + fprintf(stderr, "%s: tensor '%s' not found\n", __func__, name); + throw std::runtime_error("ggml_get_tensor() failed"); + } + return tensor; +} + +// +// original ggml functions +// + +struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) { + if (i < 0) { + GGML_ASSERT(cgraph->n_nodes + i >= 0); + return cgraph->nodes[cgraph->n_nodes + i]; + } + + GGML_ASSERT(i < cgraph->n_nodes); + return cgraph->nodes[i]; +} \ No newline at end of file diff --git a/common/common-nexa.h b/common/common-nexa.h new file mode 100644 index 000000000..1135eae57 --- /dev/null +++ b/common/common-nexa.h @@ -0,0 +1,80 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#define NEXA_CLASS_NAME (abi::__cxa_demangle(typeid(*this).name(), nullptr, nullptr, nullptr)) +#define NEXA_LOG(fmt, ...) fprintf(stderr, "%s::%s: " fmt "\n", NEXA_CLASS_NAME, __func__, ##__VA_ARGS__) + +// Prints the content of a ggml_tensor with specified precision. Can use the backend if available. +void print_ggml_tensor(const char *name, const struct ggml_tensor *tensor, bool use_backend, int precision = 4); + +// Prints the shape (dimensions) of a ggml_tensor without printing its contents. +void print_ggml_tensor_shape(const char *name, const struct ggml_tensor *tensor); + +// Prints the statistics (mean, min, max, std) of a ggml_tensor. Can use the backend if available. +void print_ggml_tensor_stats(const char *name, const struct ggml_tensor *tensor, bool use_backend); + +// Print all tensor names in the provided GGUF context. +void print_all_tensor_names(struct gguf_context *ctx); + +// get tensor, print stats and check for null +struct ggml_tensor * checked_get_tensor(struct ggml_context * ctx, const char * name); + +// Base class for all Nexa models + +struct NexaBaseModel +{ + std::vector hparam_names; + std::map> hparams; // hyperparameters, dict value can be either int32_t or float_t + + std::vector tensor_names; + std::map tensors; // std::variant is a type-safe union that can hold either: (1) int32_t (32-bit integer) (2) float_t (floating-point number) + + struct ggml_context *ctx_data; // GGML context for tensor management + + ggml_backend_buffer_t buffer; // Backend buffer to store tensor data + + ggml_backend_t backend = NULL; // Backend for computation (CPU, CUDA, METAL) + ggml_gallocr_t compute_alloc = NULL; // Memory allocator for computation + + // constructor & destructor + NexaBaseModel() {} + ~NexaBaseModel() + { + free(); + NEXA_LOG("allocated resources freed"); + } + + // Initialize the backend based on available hardware + void init_backend(); + + // measure mem requirement and allocate + void reserve_memory(); + + // initialize from gguf file + bool load_from_gguf(const std::string &fname); + + // build the computation graph + // this is a pure virtual function that must be implemented by the derived class + virtual ggml_cgraph *build_graph() = 0; + + // set the number of threads + void set_n_threads(int n_threads); + + // Free allocated memory + void free(); +}; + +bool load_hparams_and_tensors_from_gguf(const std::string &fname, NexaBaseModel &model, bool verbose = false); + +struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i); diff --git a/common/dr_wav.h b/common/dr_wav.h new file mode 100644 index 000000000..fd3e95b34 --- /dev/null +++ b/common/dr_wav.h @@ -0,0 +1,6434 @@ +/* +WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file. +dr_wav - v0.12.16 - 2020-12-02 + +David Reid - mackron@gmail.com + +GitHub: https://github.com/mackron/dr_libs +*/ + +/* +RELEASE NOTES - VERSION 0.12 +============================ +Version 0.12 includes breaking changes to custom chunk handling. + + +Changes to Chunk Callback +------------------------- +dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the +container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file. + +Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and +Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use. + +Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk +callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the +`DR_WAVE_FORMAT_*` tokens. +*/ + +/* +Introduction +============ +This is a single file library. To use it, do something like the following in one .c file. + + ```c + #define DR_WAV_IMPLEMENTATION + #include "dr_wav.h" + ``` + +You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data: + + ```c + drwav wav; + if (!drwav_init_file(&wav, "my_song.wav", NULL)) { + // Error opening WAV file. + } + + drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32)); + size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + + ... + + drwav_uninit(&wav); + ``` + +If you just want to quickly open and read the audio data in a single operation you can do something like this: + + ```c + unsigned int channels; + unsigned int sampleRate; + drwav_uint64 totalPCMFrameCount; + float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL); + if (pSampleData == NULL) { + // Error opening and reading WAV file. + } + + ... + + drwav_free(pSampleData); + ``` + +The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the +audio data in its internal format (see notes below for supported formats): + + ```c + size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + ``` + +You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format: + + ```c + size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer); + ``` + +dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`, +`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk. + + ```c + drwav_data_format format; + format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64. + format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes. + format.channels = 2; + format.sampleRate = 44100; + format.bitsPerSample = 16; + drwav_init_file_write(&wav, "data/recording.wav", &format, NULL); + + ... + + drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples); + ``` + +dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention. + + +Build Options +============= +#define these options before including this file. + +#define DR_WAV_NO_CONVERSION_API + Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`. + +#define DR_WAV_NO_STDIO + Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc. + + + +Notes +===== +- Samples are always interleaved. +- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()` + to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal + formats include the following: + - Unsigned 8-bit PCM + - Signed 12-bit PCM + - Signed 16-bit PCM + - Signed 24-bit PCM + - Signed 32-bit PCM + - IEEE 32-bit floating point + - IEEE 64-bit floating point + - A-law and u-law + - Microsoft ADPCM + - IMA ADPCM (DVI, format code 0x11) +- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format. +*/ + +#ifndef dr_wav_h +#define dr_wav_h + +#ifdef __cplusplus +extern "C" { +#endif + +#define DRWAV_STRINGIFY(x) #x +#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x) + +#define DRWAV_VERSION_MAJOR 0 +#define DRWAV_VERSION_MINOR 12 +#define DRWAV_VERSION_REVISION 16 +#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION) + +#include /* For size_t. */ + +/* Sized types. */ +typedef signed char drwav_int8; +typedef unsigned char drwav_uint8; +typedef signed short drwav_int16; +typedef unsigned short drwav_uint16; +typedef signed int drwav_int32; +typedef unsigned int drwav_uint32; +#if defined(_MSC_VER) + typedef signed __int64 drwav_int64; + typedef unsigned __int64 drwav_uint64; +#else + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wlong-long" + #if defined(__clang__) + #pragma GCC diagnostic ignored "-Wc++11-long-long" + #endif + #endif + typedef signed long long drwav_int64; + typedef unsigned long long drwav_uint64; + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic pop + #endif +#endif +#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + typedef drwav_uint64 drwav_uintptr; +#else + typedef drwav_uint32 drwav_uintptr; +#endif +typedef drwav_uint8 drwav_bool8; +typedef drwav_uint32 drwav_bool32; +#define DRWAV_TRUE 1 +#define DRWAV_FALSE 0 + +#if !defined(DRWAV_API) + #if defined(DRWAV_DLL) + #if defined(_WIN32) + #define DRWAV_DLL_IMPORT __declspec(dllimport) + #define DRWAV_DLL_EXPORT __declspec(dllexport) + #define DRWAV_DLL_PRIVATE static + #else + #if defined(__GNUC__) && __GNUC__ >= 4 + #define DRWAV_DLL_IMPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_EXPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden"))) + #else + #define DRWAV_DLL_IMPORT + #define DRWAV_DLL_EXPORT + #define DRWAV_DLL_PRIVATE static + #endif + #endif + + #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) + #define DRWAV_API DRWAV_DLL_EXPORT + #else + #define DRWAV_API DRWAV_DLL_IMPORT + #endif + #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE + #else + #define DRWAV_API extern + #define DRWAV_PRIVATE static + #endif +#endif + +typedef drwav_int32 drwav_result; +#define DRWAV_SUCCESS 0 +#define DRWAV_ERROR -1 /* A generic error. */ +#define DRWAV_INVALID_ARGS -2 +#define DRWAV_INVALID_OPERATION -3 +#define DRWAV_OUT_OF_MEMORY -4 +#define DRWAV_OUT_OF_RANGE -5 +#define DRWAV_ACCESS_DENIED -6 +#define DRWAV_DOES_NOT_EXIST -7 +#define DRWAV_ALREADY_EXISTS -8 +#define DRWAV_TOO_MANY_OPEN_FILES -9 +#define DRWAV_INVALID_FILE -10 +#define DRWAV_TOO_BIG -11 +#define DRWAV_PATH_TOO_LONG -12 +#define DRWAV_NAME_TOO_LONG -13 +#define DRWAV_NOT_DIRECTORY -14 +#define DRWAV_IS_DIRECTORY -15 +#define DRWAV_DIRECTORY_NOT_EMPTY -16 +#define DRWAV_END_OF_FILE -17 +#define DRWAV_NO_SPACE -18 +#define DRWAV_BUSY -19 +#define DRWAV_IO_ERROR -20 +#define DRWAV_INTERRUPT -21 +#define DRWAV_UNAVAILABLE -22 +#define DRWAV_ALREADY_IN_USE -23 +#define DRWAV_BAD_ADDRESS -24 +#define DRWAV_BAD_SEEK -25 +#define DRWAV_BAD_PIPE -26 +#define DRWAV_DEADLOCK -27 +#define DRWAV_TOO_MANY_LINKS -28 +#define DRWAV_NOT_IMPLEMENTED -29 +#define DRWAV_NO_MESSAGE -30 +#define DRWAV_BAD_MESSAGE -31 +#define DRWAV_NO_DATA_AVAILABLE -32 +#define DRWAV_INVALID_DATA -33 +#define DRWAV_TIMEOUT -34 +#define DRWAV_NO_NETWORK -35 +#define DRWAV_NOT_UNIQUE -36 +#define DRWAV_NOT_SOCKET -37 +#define DRWAV_NO_ADDRESS -38 +#define DRWAV_BAD_PROTOCOL -39 +#define DRWAV_PROTOCOL_UNAVAILABLE -40 +#define DRWAV_PROTOCOL_NOT_SUPPORTED -41 +#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42 +#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43 +#define DRWAV_SOCKET_NOT_SUPPORTED -44 +#define DRWAV_CONNECTION_RESET -45 +#define DRWAV_ALREADY_CONNECTED -46 +#define DRWAV_NOT_CONNECTED -47 +#define DRWAV_CONNECTION_REFUSED -48 +#define DRWAV_NO_HOST -49 +#define DRWAV_IN_PROGRESS -50 +#define DRWAV_CANCELLED -51 +#define DRWAV_MEMORY_ALREADY_MAPPED -52 +#define DRWAV_AT_END -53 + +/* Common data formats. */ +#define DR_WAVE_FORMAT_PCM 0x1 +#define DR_WAVE_FORMAT_ADPCM 0x2 +#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3 +#define DR_WAVE_FORMAT_ALAW 0x6 +#define DR_WAVE_FORMAT_MULAW 0x7 +#define DR_WAVE_FORMAT_DVI_ADPCM 0x11 +#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE + +/* Constants. */ +#ifndef DRWAV_MAX_SMPL_LOOPS +#define DRWAV_MAX_SMPL_LOOPS 1 +#endif + +/* Flags to pass into drwav_init_ex(), etc. */ +#define DRWAV_SEQUENTIAL 0x00000001 + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision); +DRWAV_API const char* drwav_version_string(void); + +typedef enum +{ + drwav_seek_origin_start, + drwav_seek_origin_current +} drwav_seek_origin; + +typedef enum +{ + drwav_container_riff, + drwav_container_w64, + drwav_container_rf64 +} drwav_container; + +typedef struct +{ + union + { + drwav_uint8 fourcc[4]; + drwav_uint8 guid[16]; + } id; + + /* The size in bytes of the chunk. */ + drwav_uint64 sizeInBytes; + + /* + RIFF = 2 byte alignment. + W64 = 8 byte alignment. + */ + unsigned int paddingSize; +} drwav_chunk_header; + +typedef struct +{ + /* + The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications + that require support for data formats not natively supported by dr_wav. + */ + drwav_uint16 formatTag; + + /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */ + drwav_uint16 channels; + + /* The sample rate. Usually set to something like 44100. */ + drwav_uint32 sampleRate; + + /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */ + drwav_uint32 avgBytesPerSec; + + /* Block align. This is equal to the number of channels * bytes per sample. */ + drwav_uint16 blockAlign; + + /* Bits per sample. */ + drwav_uint16 bitsPerSample; + + /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */ + drwav_uint16 extendedSize; + + /* + The number of valid bits per sample. When is equal to WAVE_FORMAT_EXTENSIBLE, + is always rounded up to the nearest multiple of 8. This variable contains information about exactly how + many bits are valid per sample. Mainly used for informational purposes. + */ + drwav_uint16 validBitsPerSample; + + /* The channel mask. Not used at the moment. */ + drwav_uint32 channelMask; + + /* The sub-format, exactly as specified by the wave file. */ + drwav_uint8 subFormat[16]; +} drwav_fmt; + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT); + + +/* +Callback for when data is read. Return value is the number of bytes actually read. + +pUserData [in] The user data that was passed to drwav_init() and family. +pBufferOut [out] The output buffer. +bytesToRead [in] The number of bytes to read. + +Returns the number of bytes actually read. + +A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until +either the entire bytesToRead is filled or you have reached the end of the stream. +*/ +typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); + +/* +Callback for when data is written. Returns value is the number of bytes actually written. + +pUserData [in] The user data that was passed to drwav_init_write() and family. +pData [out] A pointer to the data to write. +bytesToWrite [in] The number of bytes to write. + +Returns the number of bytes actually written. + +If the return value differs from bytesToWrite, it indicates an error. +*/ +typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); + +/* +Callback for when data needs to be seeked. + +pUserData [in] The user data that was passed to drwav_init() and family. +offset [in] The number of bytes to move, relative to the origin. Will never be negative. +origin [in] The origin of the seek - the current position or the start of the stream. + +Returns whether or not the seek was successful. + +Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or +drwav_seek_origin_current. +*/ +typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin); + +/* +Callback for when drwav_init_ex() finds a chunk. + +pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family. +onRead [in] A pointer to the function to call when reading. +onSeek [in] A pointer to the function to call when seeking. +pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family. +pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk. +container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF. +pFMT [in] A pointer to the object containing the contents of the "fmt" chunk. + +Returns the number of bytes read + seeked. + +To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must +be the total number of bytes you have read _plus_ seeked. + +Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should +use `id.fourcc`, otherwise you should use `id.guid`. + +The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the +`DR_WAVE_FORMAT_*` identifiers. + +The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk. +*/ +typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT); + +typedef struct +{ + void* pUserData; + void* (* onMalloc)(size_t sz, void* pUserData); + void* (* onRealloc)(void* p, size_t sz, void* pUserData); + void (* onFree)(void* p, void* pUserData); +} drwav_allocation_callbacks; + +/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */ +typedef struct +{ + const drwav_uint8* data; + size_t dataSize; + size_t currentReadPos; +} drwav__memory_stream; + +/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */ +typedef struct +{ + void** ppData; + size_t* pDataSize; + size_t dataSize; + size_t dataCapacity; + size_t currentWritePos; +} drwav__memory_stream_write; + +typedef struct +{ + drwav_container container; /* RIFF, W64. */ + drwav_uint32 format; /* DR_WAVE_FORMAT_* */ + drwav_uint32 channels; + drwav_uint32 sampleRate; + drwav_uint32 bitsPerSample; +} drwav_data_format; + + +/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */ +typedef struct +{ + drwav_uint32 cuePointId; + drwav_uint32 type; + drwav_uint32 start; + drwav_uint32 end; + drwav_uint32 fraction; + drwav_uint32 playCount; +} drwav_smpl_loop; + + typedef struct +{ + drwav_uint32 manufacturer; + drwav_uint32 product; + drwav_uint32 samplePeriod; + drwav_uint32 midiUnityNotes; + drwav_uint32 midiPitchFraction; + drwav_uint32 smpteFormat; + drwav_uint32 smpteOffset; + drwav_uint32 numSampleLoops; + drwav_uint32 samplerData; + drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS]; +} drwav_smpl; + +typedef struct +{ + /* A pointer to the function to call when more data is needed. */ + drwav_read_proc onRead; + + /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */ + drwav_write_proc onWrite; + + /* A pointer to the function to call when the wav file needs to be seeked. */ + drwav_seek_proc onSeek; + + /* The user data to pass to callbacks. */ + void* pUserData; + + /* Allocation callbacks. */ + drwav_allocation_callbacks allocationCallbacks; + + + /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */ + drwav_container container; + + + /* Structure containing format information exactly as specified by the wav file. */ + drwav_fmt fmt; + + /* The sample rate. Will be set to something like 44100. */ + drwav_uint32 sampleRate; + + /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */ + drwav_uint16 channels; + + /* The bits per sample. Will be set to something like 16, 24, etc. */ + drwav_uint16 bitsPerSample; + + /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */ + drwav_uint16 translatedFormatTag; + + /* The total number of PCM frames making up the audio data. */ + drwav_uint64 totalPCMFrameCount; + + + /* The size in bytes of the data chunk. */ + drwav_uint64 dataChunkDataSize; + + /* The position in the stream of the first byte of the data chunk. This is used for seeking. */ + drwav_uint64 dataChunkDataPos; + + /* The number of bytes remaining in the data chunk. */ + drwav_uint64 bytesRemaining; + + + /* + Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always + set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation. + */ + drwav_uint64 dataChunkDataSizeTargetWrite; + + /* Keeps track of whether or not the wav writer was initialized in sequential mode. */ + drwav_bool32 isSequentialWrite; + + + /* smpl chunk. */ + drwav_smpl smpl; + + + /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */ + drwav__memory_stream memoryStream; + drwav__memory_stream_write memoryStreamWrite; + + /* Generic data for compressed formats. This data is shared across all block-compressed formats. */ + struct + { + drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */ + } compressed; + + /* Microsoft ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_uint16 predictor[2]; + drwav_int32 delta[2]; + drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */ + } msadpcm; + + /* IMA ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_int32 predictor[2]; + drwav_int32 stepIndex[2]; + drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + } ima; +} drwav; + + +/* +Initializes a pre-allocated drwav object for reading. + +pWav [out] A pointer to the drwav object being initialized. +onRead [in] The function to call when data needs to be read from the client. +onSeek [in] The function to call when the read position of the client data needs to move. +onChunk [in, optional] The function to call when a chunk is enumerated at initialized time. +pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek. +pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk. +flags [in, optional] A set of flags for controlling how things are loaded. + +Returns true if successful; false otherwise. + +Close the loader with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory() +to open the stream from a file or from a block of memory respectively. + +Possible values for flags: + DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function + to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored. + +drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);". + +The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt +after the function returns. + +See also: drwav_init_file(), drwav_init_memory(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Initializes a pre-allocated drwav object for writing. + +onWrite [in] The function to call when data needs to be written. +onSeek [in] The function to call when the write position needs to move. +pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek. + +Returns true if successful; false otherwise. + +Close the writer with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write() +to open the stream from a file or from a block of memory respectively. + +If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform +a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek. + +See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Utility function to determine the target size of the entire data to be written (including all headers and chunks). + +Returns the target size in bytes. + +Useful if the application needs to know the size to allocate. + +Only writing to the RIFF chunk and one data chunk is currently supported. + +See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write() +*/ +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +/* +Uninitializes the given drwav object. + +Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()). +*/ +DRWAV_API drwav_result drwav_uninit(drwav* pWav); + + +/* +Reads raw audio data. + +This is the lowest level function for reading audio data. It simply reads the given number of +bytes of the raw internal sample data. + +Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for +reading sample data in a consistent format. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of bytes actually read. +*/ +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut); + +/* +Reads up to the specified number of PCM frames from the WAV file. + +The output data will be in the file's internal format, converted to native-endian byte order. Use +drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format. + +If the return value is less than it means the end of the file has been reached or +you have requested more PCM frames than can possibly fit in the output buffer. + +This function will only work when sample data is of a fixed size and uncompressed. If you are +using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32(). + +pBufferOut can be NULL in which case a seek will be performed. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); + +/* +Seeks to the given PCM frame. + +Returns true if successful; false otherwise. +*/ +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex); + + +/* +Writes raw audio data. + +Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error. +*/ +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData); + +/* +Writes PCM frames. + +Returns the number of PCM frames written. + +Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to +little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion. +*/ +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); + + +/* Conversion Utilities */ +#ifndef DR_WAV_NO_CONVERSION_API + +/* +Reads a chunk of audio data and converts it to signed 16-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to signed 32-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +#endif /* DR_WAV_NO_CONVERSION_API */ + + +/* High-Level Convenience Helpers */ + +#ifndef DR_WAV_NO_STDIO +/* +Helper for initializing a wave file for reading using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a wave file for writing using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif /* DR_WAV_NO_STDIO */ + +/* +Helper for initializing a loader from a pre-allocated memory buffer. + +This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for +the lifetime of the drwav object. + +The buffer should contain the contents of the entire wave file, not just the sample data. +*/ +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a writer which outputs data to a memory buffer. + +dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free(). + +The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid +until after drwav_uninit() has been called. +*/ +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); + + +#ifndef DR_WAV_NO_CONVERSION_API +/* +Opens and reads an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#ifndef DR_WAV_NO_STDIO +/* +Opens and decodes an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif +/* +Opens and decodes an entire wav file from a block of memory in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif + +/* Frees data that was allocated internally by dr_wav. */ +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* Converts bytes from a wav stream to a sized type of native endian. */ +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data); +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data); +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data); +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data); +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data); +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data); + +/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */ +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]); + +/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */ +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b); + +#ifdef __cplusplus +} +#endif +#endif /* dr_wav_h */ + + +/************************************************************************************************************************************************************ + ************************************************************************************************************************************************************ + + IMPLEMENTATION + + ************************************************************************************************************************************************************ + ************************************************************************************************************************************************************/ +#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) +#ifndef dr_wav_c +#define dr_wav_c + +#include +#include /* For memcpy(), memset() */ +#include /* For INT_MAX */ + +#ifndef DR_WAV_NO_STDIO +#include +#include +#endif + +/* Standard library stuff. */ +#ifndef DRWAV_ASSERT +#include +#define DRWAV_ASSERT(expression) assert(expression) +#endif +#ifndef DRWAV_MALLOC +#define DRWAV_MALLOC(sz) malloc((sz)) +#endif +#ifndef DRWAV_REALLOC +#define DRWAV_REALLOC(p, sz) realloc((p), (sz)) +#endif +#ifndef DRWAV_FREE +#define DRWAV_FREE(p) free((p)) +#endif +#ifndef DRWAV_COPY_MEMORY +#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz)) +#endif +#ifndef DRWAV_ZERO_MEMORY +#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz)) +#endif +#ifndef DRWAV_ZERO_OBJECT +#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p)) +#endif + +#define drwav_countof(x) (sizeof(x) / sizeof(x[0])) +#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a)) +#define drwav_min(a, b) (((a) < (b)) ? (a) : (b)) +#define drwav_max(a, b) (((a) > (b)) ? (a) : (b)) +#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x)))) + +#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */ + +/* CPU architecture. */ +#if defined(__x86_64__) || defined(_M_X64) + #define DRWAV_X64 +#elif defined(__i386) || defined(_M_IX86) + #define DRWAV_X86 +#elif defined(__arm__) || defined(_M_ARM) + #define DRWAV_ARM +#endif + +#ifdef _MSC_VER + #define DRWAV_INLINE __forceinline +#elif defined(__GNUC__) + /* + I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when + the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some + case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the + command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue + I am using "__inline__" only when we're compiling in strict ANSI mode. + */ + #if defined(__STRICT_ANSI__) + #define DRWAV_INLINE __inline__ __attribute__((always_inline)) + #else + #define DRWAV_INLINE inline __attribute__((always_inline)) + #endif +#elif defined(__WATCOMC__) + #define DRWAV_INLINE __inline +#else + #define DRWAV_INLINE +#endif + +#if defined(SIZE_MAX) + #define DRWAV_SIZE_MAX SIZE_MAX +#else + #if defined(_WIN64) || defined(_LP64) || defined(__LP64__) + #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF) + #else + #define DRWAV_SIZE_MAX 0xFFFFFFFF + #endif +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC +#elif defined(__clang__) + #if defined(__has_builtin) + #if __has_builtin(__builtin_bswap16) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap32) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap64) + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #endif +#elif defined(__GNUC__) + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif +#endif + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision) +{ + if (pMajor) { + *pMajor = DRWAV_VERSION_MAJOR; + } + + if (pMinor) { + *pMinor = DRWAV_VERSION_MINOR; + } + + if (pRevision) { + *pRevision = DRWAV_VERSION_REVISION; + } +} + +DRWAV_API const char* drwav_version_string(void) +{ + return DRWAV_VERSION_STRING; +} + +/* +These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are +you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation. +*/ +#ifndef DRWAV_MAX_SAMPLE_RATE +#define DRWAV_MAX_SAMPLE_RATE 384000 +#endif +#ifndef DRWAV_MAX_CHANNELS +#define DRWAV_MAX_CHANNELS 256 +#endif +#ifndef DRWAV_MAX_BITS_PER_SAMPLE +#define DRWAV_MAX_BITS_PER_SAMPLE 64 +#endif + +static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */ +static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */ +/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */ + +static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + int i; + for (i = 0; i < 16; i += 1) { + if (a[i] != b[i]) { + return DRWAV_FALSE; + } + } + + return DRWAV_TRUE; +} + +static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b) +{ + return + a[0] == b[0] && + a[1] == b[1] && + a[2] == b[2] && + a[3] == b[3]; +} + + + +static DRWAV_INLINE int drwav__is_little_endian(void) +{ +#if defined(DRWAV_X86) || defined(DRWAV_X64) + return DRWAV_TRUE; +#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN + return DRWAV_TRUE; +#else + int n = 1; + return (*(char*)&n) == 1; +#endif +} + +static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8); +} + +static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data) +{ + return (short)drwav__bytes_to_u16(data); +} + +static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24); +} + +static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data) +{ + return (drwav_int32)drwav__bytes_to_u32(data); +} + +static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data) +{ + return + ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) | + ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56); +} + +static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data) +{ + return (drwav_int64)drwav__bytes_to_u64(data); +} + +static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid) +{ + int i; + for (i = 0; i < 16; ++i) { + guid[i] = data[i]; + } +} + + +static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n) +{ +#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ushort(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap16(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF00) >> 8) | + ((n & 0x00FF) << 8); +#endif +} + +static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n) +{ +#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ulong(n); + #elif defined(__GNUC__) || defined(__clang__) + #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */ + /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */ + drwav_uint32 r; + __asm__ __volatile__ ( + #if defined(DRWAV_64BIT) + "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */ + #else + "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n) + #endif + ); + return r; + #else + return __builtin_bswap32(n); + #endif + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF000000) >> 24) | + ((n & 0x00FF0000) >> 8) | + ((n & 0x0000FF00) << 8) | + ((n & 0x000000FF) << 24); +#endif +} + +static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n) +{ +#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_uint64(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap64(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */ + return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) | + ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) | + ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) | + ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) | + ((n & ((drwav_uint64)0xFF000000 )) << 8) | + ((n & ((drwav_uint64)0x00FF0000 )) << 24) | + ((n & ((drwav_uint64)0x0000FF00 )) << 40) | + ((n & ((drwav_uint64)0x000000FF )) << 56); +#endif +} + + +static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n) +{ + return (drwav_int16)drwav__bswap16((drwav_uint16)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p) +{ + drwav_uint8 t; + t = p[0]; + p[0] = p[2]; + p[2] = t; +} + +static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + drwav_uint8* pSample = pSamples + (iSample*3); + drwav__bswap_s24(pSample); + } +} + + +static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n) +{ + return (drwav_int32)drwav__bswap32((drwav_uint32)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE float drwav__bswap_f32(float n) +{ + union { + drwav_uint32 i; + float f; + } x; + x.f = n; + x.i = drwav__bswap32(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE double drwav__bswap_f64(double n) +{ + union { + drwav_uint64 i; + double f; + } x; + x.f = n; + x.i = drwav__bswap64(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */ + switch (bytesPerSample) + { + case 2: /* s16, s12 (loosely packed) */ + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + case 3: /* s24 */ + { + drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount); + } break; + case 4: /* s32 */ + { + drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + switch (bytesPerSample) + { + #if 0 /* Contributions welcome for f16 support. */ + case 2: /* f16 */ + { + drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount); + } break; + #endif + case 4: /* f32 */ + { + drwav__bswap_samples_f32((float*)pSamples, sampleCount); + } break; + case 8: /* f64 */ + { + drwav__bswap_samples_f64((double*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format) +{ + switch (format) + { + case DR_WAVE_FORMAT_PCM: + { + drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_IEEE_FLOAT: + { + drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_ALAW: + case DR_WAVE_FORMAT_MULAW: + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + + case DR_WAVE_FORMAT_ADPCM: + case DR_WAVE_FORMAT_DVI_ADPCM: + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + + +static void* drwav__malloc_default(size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_MALLOC(sz); +} + +static void* drwav__realloc_default(void* p, size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_REALLOC(p, sz); +} + +static void drwav__free_default(void* p, void* pUserData) +{ + (void)pUserData; + DRWAV_FREE(p); +} + + +static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onMalloc != NULL) { + return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData); + } + + /* Try using realloc(). */ + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData); + } + + return NULL; +} + +static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData); + } + + /* Try emulating realloc() in terms of malloc()/free(). */ + if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) { + void* p2; + + p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData); + if (p2 == NULL) { + return NULL; + } + + if (p != NULL) { + DRWAV_COPY_MEMORY(p2, p, szOld); + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } + + return p2; + } + + return NULL; +} + +static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (p == NULL || pAllocationCallbacks == NULL) { + return; + } + + if (pAllocationCallbacks->onFree != NULL) { + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } +} + + +static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + /* Copy. */ + return *pAllocationCallbacks; + } else { + /* Defaults. */ + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = NULL; + allocationCallbacks.onMalloc = drwav__malloc_default; + allocationCallbacks.onRealloc = drwav__realloc_default; + allocationCallbacks.onFree = drwav__free_default; + return allocationCallbacks; + } +} + + +static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag) +{ + return + formatTag == DR_WAVE_FORMAT_ADPCM || + formatTag == DR_WAVE_FORMAT_DVI_ADPCM; +} + +static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 2); +} + +static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 8); +} + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut) +{ + if (container == drwav_container_riff || container == drwav_container_rf64) { + drwav_uint8 sizeInBytes[4]; + + if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 4) != 4) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes); + pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 8; + } else { + drwav_uint8 sizeInBytes[8]; + + if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 8) != 8) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */ + pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 24; + } + + return DRWAV_SUCCESS; +} + +static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + drwav_uint64 bytesRemainingToSeek = offset; + while (bytesRemainingToSeek > 0) { + if (bytesRemainingToSeek > 0x7FFFFFFF) { + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek -= 0x7FFFFFFF; + } else { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek = 0; + } + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_start); + } + + /* Larger than 32-bit seek. */ + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + + for (;;) { + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_current); + } + + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + } + + /* Should never get here. */ + /*return DRWAV_TRUE; */ +} + + +static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut) +{ + drwav_chunk_header header; + drwav_uint8 fmt[16]; + + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + + /* Skip non-fmt chunks. */ + while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) { + if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize; + + /* Try the next header. */ + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + } + + + /* Validation. */ + if (container == drwav_container_riff || container == drwav_container_rf64) { + if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) { + return DRWAV_FALSE; + } + } else { + if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) { + return DRWAV_FALSE; + } + } + + + if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += sizeof(fmt); + + fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0); + fmtOut->channels = drwav__bytes_to_u16(fmt + 2); + fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4); + fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8); + fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12); + fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14); + + fmtOut->extendedSize = 0; + fmtOut->validBitsPerSample = 0; + fmtOut->channelMask = 0; + memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat)); + + if (header.sizeInBytes > 16) { + drwav_uint8 fmt_cbSize[2]; + int bytesReadSoFar = 0; + + if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) { + return DRWAV_FALSE; /* Expecting more data. */ + } + *pRunningBytesReadOut += sizeof(fmt_cbSize); + + bytesReadSoFar = 18; + + fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize); + if (fmtOut->extendedSize > 0) { + /* Simple validation. */ + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + if (fmtOut->extendedSize != 22) { + return DRWAV_FALSE; + } + } + + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + drwav_uint8 fmtext[22]; + if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) { + return DRWAV_FALSE; /* Expecting more data. */ + } + + fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0); + fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2); + drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat); + } else { + if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + } + *pRunningBytesReadOut += fmtOut->extendedSize; + + bytesReadSoFar += fmtOut->extendedSize; + } + + /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */ + if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar); + } + + if (header.paddingSize > 0) { + if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.paddingSize; + } + + return DRWAV_TRUE; +} + + +static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor) +{ + size_t bytesRead; + + DRWAV_ASSERT(onRead != NULL); + DRWAV_ASSERT(pCursor != NULL); + + bytesRead = onRead(pUserData, pBufferOut, bytesToRead); + *pCursor += bytesRead; + return bytesRead; +} + +#if 0 +static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor) +{ + DRWAV_ASSERT(onSeek != NULL); + DRWAV_ASSERT(pCursor != NULL); + + if (!onSeek(pUserData, offset, origin)) { + return DRWAV_FALSE; + } + + if (origin == drwav_seek_origin_start) { + *pCursor = offset; + } else { + *pCursor += offset; + } + + return DRWAV_TRUE; +} +#endif + + + +static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav) +{ + /* + The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here + is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align. + */ + if ((pWav->bitsPerSample & 0x7) == 0) { + /* Bits per sample is a multiple of 8. */ + return (pWav->bitsPerSample * pWav->fmt.channels) >> 3; + } else { + return pWav->fmt.blockAlign; + } +} + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT) +{ + if (pFMT == NULL) { + return 0; + } + + if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) { + return pFMT->formatTag; + } else { + return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */ + } +} + +static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onRead == NULL || onSeek == NULL) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onRead = onRead; + pWav->onSeek = onSeek; + pWav->pUserData = pReadSeekUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags) +{ + /* This function assumes drwav_preinit() has been called beforehand. */ + + drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */ + drwav_bool32 sequential; + drwav_uint8 riff[4]; + drwav_fmt fmt; + unsigned short translatedFormatTag; + drwav_bool32 foundDataChunk; + drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */ + drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */ + drwav_uint64 chunkSize; + + cursor = 0; + sequential = (flags & DRWAV_SEQUENTIAL) != 0; + + /* The first 4 bytes should be the RIFF identifier. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) { + return DRWAV_FALSE; + } + + /* + The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for + w64 it will start with "riff". + */ + if (drwav__fourcc_equal(riff, "RIFF")) { + pWav->container = drwav_container_riff; + } else if (drwav__fourcc_equal(riff, "riff")) { + int i; + drwav_uint8 riff2[12]; + + pWav->container = drwav_container_w64; + + /* Check the rest of the GUID for validity. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) { + return DRWAV_FALSE; + } + + for (i = 0; i < 12; ++i) { + if (riff2[i] != drwavGUID_W64_RIFF[i+4]) { + return DRWAV_FALSE; + } + } + } else if (drwav__fourcc_equal(riff, "RF64")) { + pWav->container = drwav_container_rf64; + } else { + return DRWAV_FALSE; /* Unknown or unsupported container. */ + } + + + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + drwav_uint8 chunkSizeBytes[4]; + drwav_uint8 wave[4]; + + /* RIFF/WAVE */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (pWav->container == drwav_container_riff) { + if (drwav__bytes_to_u32(chunkSizeBytes) < 36) { + return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */ + } + } else { + if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) { + return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */ + } + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(wave, "WAVE")) { + return DRWAV_FALSE; /* Expecting "WAVE". */ + } + } else { + drwav_uint8 chunkSizeBytes[8]; + drwav_uint8 wave[16]; + + /* W64 */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (drwav__bytes_to_u64(chunkSizeBytes) < 80) { + return DRWAV_FALSE; + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) { + return DRWAV_FALSE; + } + } + + + /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */ + if (pWav->container == drwav_container_rf64) { + drwav_uint8 sizeBytes[8]; + drwav_uint64 bytesRemainingInChunk; + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) { + return DRWAV_FALSE; /* Expecting "ds64". */ + } + + bytesRemainingInChunk = header.sizeInBytes + header.paddingSize; + + /* We don't care about the size of the RIFF chunk - skip it. */ + if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + cursor += 8; + + + /* Next 8 bytes is the size of the "data" chunk. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + dataChunkSize = drwav__bytes_to_u64(sizeBytes); + + + /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes); + + + /* Skip over everything else. */ + if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor += bytesRemainingInChunk; + } + + + /* The next bytes should be the "fmt " chunk. */ + if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) { + return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */ + } + + /* Basic validation. */ + if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) || + (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) || + (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) || + fmt.blockAlign == 0) { + return DRWAV_FALSE; /* Probably an invalid WAV file. */ + } + + + /* Translate the internal format. */ + translatedFormatTag = fmt.formatTag; + if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0); + } + + + /* + We need to enumerate over each chunk for two reasons: + 1) The "data" chunk may not be the next one + 2) We may want to report each chunk back to the client + + In order to correctly report each chunk back to the client we will need to keep looping until the end of the file. + */ + foundDataChunk = DRWAV_FALSE; + + /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */ + for (;;) + { + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + if (!foundDataChunk) { + return DRWAV_FALSE; + } else { + break; /* Probably at the end of the file. Get out of the loop. */ + } + } + + /* Tell the client about this chunk. */ + if (!sequential && onChunk != NULL) { + drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt); + + /* + dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before + we called the callback. + */ + if (callbackBytesRead > 0) { + if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) { + return DRWAV_FALSE; + } + } + } + + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + chunkSize = header.sizeInBytes; + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "data")) { + foundDataChunk = DRWAV_TRUE; + if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */ + dataChunkSize = chunkSize; + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) { + foundDataChunk = DRWAV_TRUE; + dataChunkSize = chunkSize; + } + } + + /* + If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for + this is that we would otherwise require a backwards seek which sequential mode forbids. + */ + if (foundDataChunk && sequential) { + break; + } + + /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */ + if (pWav->container == drwav_container_riff) { + if (drwav__fourcc_equal(header.id.fourcc, "fact")) { + drwav_uint32 sampleCount; + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) { + return DRWAV_FALSE; + } + chunkSize -= 4; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + /* + The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this + for Microsoft ADPCM formats. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + sampleCountFromFactChunk = sampleCount; + } else { + sampleCountFromFactChunk = 0; + } + } + } else if (pWav->container == drwav_container_w64) { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) { + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) { + return DRWAV_FALSE; + } + chunkSize -= 8; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + } else if (pWav->container == drwav_container_rf64) { + /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */ + } + + /* "smpl" chunk. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "smpl")) { + drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */ + if (chunkSize >= sizeof(smplHeaderData)) { + drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplHeaderData)) { + drwav_uint32 iLoop; + + pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0); + pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4); + pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8); + pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12); + pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16); + pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20); + pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24); + pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28); + pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32); + + for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) { + drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */ + bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplLoopData)) { + pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0); + pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4); + pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8); + pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12); + pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16); + pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20); + } else { + break; /* Break from the smpl loop for loop. */ + } + } + } + } else { + /* Looks like invalid data. Ignore the chunk. */ + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) { + /* + This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution + is welcome to add support for this. + */ + } + } + + /* Make sure we seek past the padding. */ + chunkSize += header.paddingSize; + if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) { + break; + } + cursor += chunkSize; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + + /* If we haven't found a data chunk, return an error. */ + if (!foundDataChunk) { + return DRWAV_FALSE; + } + + /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */ + if (!sequential) { + if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor = pWav->dataChunkDataPos; + } + + + /* At this point we should be sitting on the first byte of the raw audio data. */ + + pWav->fmt = fmt; + pWav->sampleRate = fmt.sampleRate; + pWav->channels = fmt.channels; + pWav->bitsPerSample = fmt.bitsPerSample; + pWav->bytesRemaining = dataChunkSize; + pWav->translatedFormatTag = translatedFormatTag; + pWav->dataChunkDataSize = dataChunkSize; + + if (sampleCountFromFactChunk != 0) { + pWav->totalPCMFrameCount = sampleCountFromFactChunk; + } else { + pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav); + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + + /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */ + pWav->totalPCMFrameCount += blockCount; + } + } + + /* Some formats only support a certain number of channels. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + if (pWav->channels > 2) { + return DRWAV_FALSE; + } + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website), + it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count + from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct + way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should + always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my + correctness tests against libsndfile, and is disabled by default. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */ + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels; + } +#endif + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */ +} + +static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + if (dataChunkSize <= 0xFFFFFFFFUL) { + return (drwav_uint32)dataChunkSize; + } else { + return 0xFFFFFFFFUL; + } +} + +static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize); + + return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return chunkSize; +} + +static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + return dataChunkSize; +} + + +static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + /* Generic write. Assumes no byte reordering required. */ + return pWav->onWrite(pWav->pUserData, pData, dataSize); +} + +static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap16(value); + } + + return drwav__write(pWav, &value, 2); +} + +static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap32(value); + } + + return drwav__write(pWav, &value, 4); +} + +static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap64(value); + } + + return drwav__write(pWav, &value, 8); +} + + +static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onWrite == NULL) { + return DRWAV_FALSE; + } + + if (!isSequential && onSeek == NULL) { + return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */ + } + + /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */ + if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) { + return DRWAV_FALSE; + } + if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onWrite = onWrite; + pWav->onSeek = onSeek; + pWav->pUserData = pUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + pWav->fmt.formatTag = (drwav_uint16)pFormat->format; + pWav->fmt.channels = (drwav_uint16)pFormat->channels; + pWav->fmt.sampleRate = pFormat->sampleRate; + pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8); + pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8); + pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->fmt.extendedSize = 0; + pWav->isSequentialWrite = isSequential; + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* The function assumes drwav_preinit_write() was called beforehand. */ + + size_t runningPos = 0; + drwav_uint64 initialDataChunkSize = 0; + drwav_uint64 chunkSizeFMT; + + /* + The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In + sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non- + sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek. + */ + if (pWav->isSequentialWrite) { + initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8; + + /* + The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64 + so for the sake of simplicity I'm not doing any validation for that. + */ + if (pFormat->container == drwav_container_riff) { + if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) { + return DRWAV_FALSE; /* Not enough room to store every sample. */ + } + } + } + + pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; + + + /* "RIFF" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */ + runningPos += drwav__write(pWav, "RIFF", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, "WAVE", 4); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "RF64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */ + runningPos += drwav__write(pWav, "WAVE", 4); + } + + + /* "ds64" chunk (RF64 only). */ + if (pFormat->container == drwav_container_rf64) { + drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */ + drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */ + + runningPos += drwav__write(pWav, "ds64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */ + runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */ + } + + + /* "fmt " chunk. */ + if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) { + chunkSizeFMT = 16; + runningPos += drwav__write(pWav, "fmt ", 4); + runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT); + } else if (pFormat->container == drwav_container_w64) { + chunkSizeFMT = 40; + runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT); + } + + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample); + + pWav->dataChunkDataPos = runningPos; + + /* "data" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize; + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */ + } + + /* + The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it + as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to + keep track of the running position for whatever reason. The line below should silence the static analysis tools. + */ + (void)runningPos; + + /* Set some properties for the client's convenience. */ + pWav->container = pFormat->container; + pWav->channels = (drwav_uint16)pFormat->channels; + pWav->sampleRate = pFormat->sampleRate; + pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->translatedFormatTag = (drwav_uint16)pFormat->format; + + return DRWAV_TRUE; +} + + +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks); +} + +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */ + drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0); + drwav_uint64 riffChunkSizeBytes; + drwav_uint64 fileSizeBytes = 0; + + if (pFormat->container == drwav_container_riff) { + riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } else if (pFormat->container == drwav_container_w64) { + riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes); + fileSizeBytes = riffChunkSizeBytes; + } else if (pFormat->container == drwav_container_rf64) { + riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } + + return fileSizeBytes; +} + + +#ifndef DR_WAV_NO_STDIO + +/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */ +#include +static drwav_result drwav_result_from_errno(int e) +{ + switch (e) + { + case 0: return DRWAV_SUCCESS; + #ifdef EPERM + case EPERM: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ENOENT + case ENOENT: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ESRCH + case ESRCH: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EINTR + case EINTR: return DRWAV_INTERRUPT; + #endif + #ifdef EIO + case EIO: return DRWAV_IO_ERROR; + #endif + #ifdef ENXIO + case ENXIO: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef E2BIG + case E2BIG: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENOEXEC + case ENOEXEC: return DRWAV_INVALID_FILE; + #endif + #ifdef EBADF + case EBADF: return DRWAV_INVALID_FILE; + #endif + #ifdef ECHILD + case ECHILD: return DRWAV_ERROR; + #endif + #ifdef EAGAIN + case EAGAIN: return DRWAV_UNAVAILABLE; + #endif + #ifdef ENOMEM + case ENOMEM: return DRWAV_OUT_OF_MEMORY; + #endif + #ifdef EACCES + case EACCES: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EFAULT + case EFAULT: return DRWAV_BAD_ADDRESS; + #endif + #ifdef ENOTBLK + case ENOTBLK: return DRWAV_ERROR; + #endif + #ifdef EBUSY + case EBUSY: return DRWAV_BUSY; + #endif + #ifdef EEXIST + case EEXIST: return DRWAV_ALREADY_EXISTS; + #endif + #ifdef EXDEV + case EXDEV: return DRWAV_ERROR; + #endif + #ifdef ENODEV + case ENODEV: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ENOTDIR + case ENOTDIR: return DRWAV_NOT_DIRECTORY; + #endif + #ifdef EISDIR + case EISDIR: return DRWAV_IS_DIRECTORY; + #endif + #ifdef EINVAL + case EINVAL: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENFILE + case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef EMFILE + case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef ENOTTY + case ENOTTY: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ETXTBSY + case ETXTBSY: return DRWAV_BUSY; + #endif + #ifdef EFBIG + case EFBIG: return DRWAV_TOO_BIG; + #endif + #ifdef ENOSPC + case ENOSPC: return DRWAV_NO_SPACE; + #endif + #ifdef ESPIPE + case ESPIPE: return DRWAV_BAD_SEEK; + #endif + #ifdef EROFS + case EROFS: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EMLINK + case EMLINK: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef EPIPE + case EPIPE: return DRWAV_BAD_PIPE; + #endif + #ifdef EDOM + case EDOM: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef ERANGE + case ERANGE: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EDEADLK + case EDEADLK: return DRWAV_DEADLOCK; + #endif + #ifdef ENAMETOOLONG + case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG; + #endif + #ifdef ENOLCK + case ENOLCK: return DRWAV_ERROR; + #endif + #ifdef ENOSYS + case ENOSYS: return DRWAV_NOT_IMPLEMENTED; + #endif + #ifdef ENOTEMPTY + case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY; + #endif + #ifdef ELOOP + case ELOOP: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef ENOMSG + case ENOMSG: return DRWAV_NO_MESSAGE; + #endif + #ifdef EIDRM + case EIDRM: return DRWAV_ERROR; + #endif + #ifdef ECHRNG + case ECHRNG: return DRWAV_ERROR; + #endif + #ifdef EL2NSYNC + case EL2NSYNC: return DRWAV_ERROR; + #endif + #ifdef EL3HLT + case EL3HLT: return DRWAV_ERROR; + #endif + #ifdef EL3RST + case EL3RST: return DRWAV_ERROR; + #endif + #ifdef ELNRNG + case ELNRNG: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EUNATCH + case EUNATCH: return DRWAV_ERROR; + #endif + #ifdef ENOCSI + case ENOCSI: return DRWAV_ERROR; + #endif + #ifdef EL2HLT + case EL2HLT: return DRWAV_ERROR; + #endif + #ifdef EBADE + case EBADE: return DRWAV_ERROR; + #endif + #ifdef EBADR + case EBADR: return DRWAV_ERROR; + #endif + #ifdef EXFULL + case EXFULL: return DRWAV_ERROR; + #endif + #ifdef ENOANO + case ENOANO: return DRWAV_ERROR; + #endif + #ifdef EBADRQC + case EBADRQC: return DRWAV_ERROR; + #endif + #ifdef EBADSLT + case EBADSLT: return DRWAV_ERROR; + #endif + #ifdef EBFONT + case EBFONT: return DRWAV_INVALID_FILE; + #endif + #ifdef ENOSTR + case ENOSTR: return DRWAV_ERROR; + #endif + #ifdef ENODATA + case ENODATA: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ETIME + case ETIME: return DRWAV_TIMEOUT; + #endif + #ifdef ENOSR + case ENOSR: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ENONET + case ENONET: return DRWAV_NO_NETWORK; + #endif + #ifdef ENOPKG + case ENOPKG: return DRWAV_ERROR; + #endif + #ifdef EREMOTE + case EREMOTE: return DRWAV_ERROR; + #endif + #ifdef ENOLINK + case ENOLINK: return DRWAV_ERROR; + #endif + #ifdef EADV + case EADV: return DRWAV_ERROR; + #endif + #ifdef ESRMNT + case ESRMNT: return DRWAV_ERROR; + #endif + #ifdef ECOMM + case ECOMM: return DRWAV_ERROR; + #endif + #ifdef EPROTO + case EPROTO: return DRWAV_ERROR; + #endif + #ifdef EMULTIHOP + case EMULTIHOP: return DRWAV_ERROR; + #endif + #ifdef EDOTDOT + case EDOTDOT: return DRWAV_ERROR; + #endif + #ifdef EBADMSG + case EBADMSG: return DRWAV_BAD_MESSAGE; + #endif + #ifdef EOVERFLOW + case EOVERFLOW: return DRWAV_TOO_BIG; + #endif + #ifdef ENOTUNIQ + case ENOTUNIQ: return DRWAV_NOT_UNIQUE; + #endif + #ifdef EBADFD + case EBADFD: return DRWAV_ERROR; + #endif + #ifdef EREMCHG + case EREMCHG: return DRWAV_ERROR; + #endif + #ifdef ELIBACC + case ELIBACC: return DRWAV_ACCESS_DENIED; + #endif + #ifdef ELIBBAD + case ELIBBAD: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBSCN + case ELIBSCN: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBMAX + case ELIBMAX: return DRWAV_ERROR; + #endif + #ifdef ELIBEXEC + case ELIBEXEC: return DRWAV_ERROR; + #endif + #ifdef EILSEQ + case EILSEQ: return DRWAV_INVALID_DATA; + #endif + #ifdef ERESTART + case ERESTART: return DRWAV_ERROR; + #endif + #ifdef ESTRPIPE + case ESTRPIPE: return DRWAV_ERROR; + #endif + #ifdef EUSERS + case EUSERS: return DRWAV_ERROR; + #endif + #ifdef ENOTSOCK + case ENOTSOCK: return DRWAV_NOT_SOCKET; + #endif + #ifdef EDESTADDRREQ + case EDESTADDRREQ: return DRWAV_NO_ADDRESS; + #endif + #ifdef EMSGSIZE + case EMSGSIZE: return DRWAV_TOO_BIG; + #endif + #ifdef EPROTOTYPE + case EPROTOTYPE: return DRWAV_BAD_PROTOCOL; + #endif + #ifdef ENOPROTOOPT + case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE; + #endif + #ifdef EPROTONOSUPPORT + case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED; + #endif + #ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED; + #endif + #ifdef EOPNOTSUPP + case EOPNOTSUPP: return DRWAV_INVALID_OPERATION; + #endif + #ifdef EPFNOSUPPORT + case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EAFNOSUPPORT + case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EADDRINUSE + case EADDRINUSE: return DRWAV_ALREADY_IN_USE; + #endif + #ifdef EADDRNOTAVAIL + case EADDRNOTAVAIL: return DRWAV_ERROR; + #endif + #ifdef ENETDOWN + case ENETDOWN: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETUNREACH + case ENETUNREACH: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETRESET + case ENETRESET: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNABORTED + case ECONNABORTED: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNRESET + case ECONNRESET: return DRWAV_CONNECTION_RESET; + #endif + #ifdef ENOBUFS + case ENOBUFS: return DRWAV_NO_SPACE; + #endif + #ifdef EISCONN + case EISCONN: return DRWAV_ALREADY_CONNECTED; + #endif + #ifdef ENOTCONN + case ENOTCONN: return DRWAV_NOT_CONNECTED; + #endif + #ifdef ESHUTDOWN + case ESHUTDOWN: return DRWAV_ERROR; + #endif + #ifdef ETOOMANYREFS + case ETOOMANYREFS: return DRWAV_ERROR; + #endif + #ifdef ETIMEDOUT + case ETIMEDOUT: return DRWAV_TIMEOUT; + #endif + #ifdef ECONNREFUSED + case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED; + #endif + #ifdef EHOSTDOWN + case EHOSTDOWN: return DRWAV_NO_HOST; + #endif + #ifdef EHOSTUNREACH + case EHOSTUNREACH: return DRWAV_NO_HOST; + #endif + #ifdef EALREADY + case EALREADY: return DRWAV_IN_PROGRESS; + #endif + #ifdef EINPROGRESS + case EINPROGRESS: return DRWAV_IN_PROGRESS; + #endif + #ifdef ESTALE + case ESTALE: return DRWAV_INVALID_FILE; + #endif + #ifdef EUCLEAN + case EUCLEAN: return DRWAV_ERROR; + #endif + #ifdef ENOTNAM + case ENOTNAM: return DRWAV_ERROR; + #endif + #ifdef ENAVAIL + case ENAVAIL: return DRWAV_ERROR; + #endif + #ifdef EISNAM + case EISNAM: return DRWAV_ERROR; + #endif + #ifdef EREMOTEIO + case EREMOTEIO: return DRWAV_IO_ERROR; + #endif + #ifdef EDQUOT + case EDQUOT: return DRWAV_NO_SPACE; + #endif + #ifdef ENOMEDIUM + case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EMEDIUMTYPE + case EMEDIUMTYPE: return DRWAV_ERROR; + #endif + #ifdef ECANCELED + case ECANCELED: return DRWAV_CANCELLED; + #endif + #ifdef ENOKEY + case ENOKEY: return DRWAV_ERROR; + #endif + #ifdef EKEYEXPIRED + case EKEYEXPIRED: return DRWAV_ERROR; + #endif + #ifdef EKEYREVOKED + case EKEYREVOKED: return DRWAV_ERROR; + #endif + #ifdef EKEYREJECTED + case EKEYREJECTED: return DRWAV_ERROR; + #endif + #ifdef EOWNERDEAD + case EOWNERDEAD: return DRWAV_ERROR; + #endif + #ifdef ENOTRECOVERABLE + case ENOTRECOVERABLE: return DRWAV_ERROR; + #endif + #ifdef ERFKILL + case ERFKILL: return DRWAV_ERROR; + #endif + #ifdef EHWPOISON + case EHWPOISON: return DRWAV_ERROR; + #endif + default: return DRWAV_ERROR; + } +} + +static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode) +{ +#if _MSC_VER && _MSC_VER >= 1400 + errno_t err; +#endif + + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if _MSC_VER && _MSC_VER >= 1400 + err = fopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } +#else +#if defined(_WIN32) || defined(__APPLE__) + *ppFile = fopen(pFilePath, pOpenMode); +#else + #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE) + *ppFile = fopen64(pFilePath, pOpenMode); + #else + *ppFile = fopen(pFilePath, pOpenMode); + #endif +#endif + if (*ppFile == NULL) { + drwav_result result = drwav_result_from_errno(errno); + if (result == DRWAV_SUCCESS) { + result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */ + } + + return result; + } +#endif + + return DRWAV_SUCCESS; +} + +/* +_wfopen() isn't always available in all compilation environments. + + * Windows only. + * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back). + * MinGW-64 (both 32- and 64-bit) seems to support it. + * MinGW wraps it in !defined(__STRICT_ANSI__). + * OpenWatcom wraps it in !defined(_NO_EXT_KEYS). + +This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() +fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. +*/ +#if defined(_WIN32) + #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) + #define DRWAV_HAS_WFOPEN + #endif +#endif + +static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if defined(DRWAV_HAS_WFOPEN) + { + /* Use _wfopen() on Windows. */ + #if defined(_MSC_VER) && _MSC_VER >= 1400 + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } + #else + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return drwav_result_from_errno(errno); + } + #endif + (void)pAllocationCallbacks; + } +#else + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ + { + mbstate_t mbs; + size_t lenMB; + const wchar_t* pFilePathTemp = pFilePath; + char* pFilePathMB = NULL; + char pOpenModeMB[32] = {0}; + + /* Get the length first. */ + DRWAV_ZERO_OBJECT(&mbs); + lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs); + if (lenMB == (size_t)-1) { + return drwav_result_from_errno(errno); + } + + pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks); + if (pFilePathMB == NULL) { + return DRWAV_OUT_OF_MEMORY; + } + + pFilePathTemp = pFilePath; + DRWAV_ZERO_OBJECT(&mbs); + wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs); + + /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */ + { + size_t i = 0; + for (;;) { + if (pOpenMode[i] == 0) { + pOpenModeMB[i] = '\0'; + break; + } + + pOpenModeMB[i] = (char)pOpenMode[i]; + i += 1; + } + } + + *ppFile = fopen(pFilePathMB, pOpenModeMB); + + drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks); + } + + if (*ppFile == NULL) { + return DRWAV_ERROR; + } +#endif + + return DRWAV_SUCCESS; +} + + +static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData); +} + +static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite) +{ + return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData); +} + +static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin) +{ + return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; +} + +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init_write__internal(pWav, pFormat, totalSampleCount); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} +#endif /* DR_WAV_NO_STDIO */ + + +static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos); + + bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos; + if (bytesToRead > bytesRemaining) { + bytesToRead = bytesRemaining; + } + + if (bytesToRead > 0) { + DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead); + pWav->memoryStream.currentReadPos += bytesToRead; + } + + return bytesToRead; +} + +static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStream.currentReadPos < (size_t)-offset) { + return DRWAV_FALSE; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStream.currentReadPos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) { + pWav->memoryStream.currentReadPos = offset; + } else { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos); + + bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos; + if (bytesRemaining < bytesToWrite) { + /* Need to reallocate. */ + void* pNewData; + size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2; + + /* If doubling wasn't enough, just make it the minimum required size to write the data. */ + if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) { + newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite; + } + + pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks); + if (pNewData == NULL) { + return 0; + } + + *pWav->memoryStreamWrite.ppData = pNewData; + pWav->memoryStreamWrite.dataCapacity = newDataCapacity; + } + + DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite); + + pWav->memoryStreamWrite.currentWritePos += bytesToWrite; + if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) { + pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos; + } + + *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize; + + return bytesToWrite; +} + +static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { + offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { + offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStreamWrite.currentWritePos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) { + pWav->memoryStreamWrite.currentWritePos = offset; + } else { + pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (data == NULL || dataSize == 0) { + return DRWAV_FALSE; + } + + if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStream.data = (const drwav_uint8*)data; + pWav->memoryStream.dataSize = dataSize; + pWav->memoryStream.currentReadPos = 0; + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppData == NULL || pDataSize == NULL) { + return DRWAV_FALSE; + } + + *ppData = NULL; /* Important because we're using realloc()! */ + *pDataSize = 0; + + if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStreamWrite.ppData = ppData; + pWav->memoryStreamWrite.pDataSize = pDataSize; + pWav->memoryStreamWrite.dataSize = 0; + pWav->memoryStreamWrite.dataCapacity = 0; + pWav->memoryStreamWrite.currentWritePos = 0; + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + + + +DRWAV_API drwav_result drwav_uninit(drwav* pWav) +{ + drwav_result result = DRWAV_SUCCESS; + + if (pWav == NULL) { + return DRWAV_INVALID_ARGS; + } + + /* + If the drwav object was opened in write mode we'll need to finalize a few things: + - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers. + - Set the size of the "data" chunk. + */ + if (pWav->onWrite != NULL) { + drwav_uint32 paddingSize = 0; + + /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize); + } else { + paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize); + } + + if (paddingSize > 0) { + drwav_uint64 paddingData = 0; + drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */ + } + + /* + Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need + to do this when using non-sequential mode. + */ + if (pWav->onSeek && !pWav->isSequentialWrite) { + if (pWav->container == drwav_container_riff) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) { + drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, riffChunkSize); + } + + /* the "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) { + drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_w64) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_rf64) { + /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */ + int ds64BodyPos = 12 + 8; + + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } + } + + /* Validation for sequential mode. */ + if (pWav->isSequentialWrite) { + if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) { + result = DRWAV_INVALID_FILE; + } + } + } + +#ifndef DR_WAV_NO_STDIO + /* + If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file() + was used by looking at the onRead and onSeek callbacks. + */ + if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) { + fclose((FILE*)pWav->pUserData); + } +#endif + + return result; +} + + + +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut) +{ + size_t bytesRead; + + if (pWav == NULL || bytesToRead == 0) { + return 0; + } + + if (bytesToRead > pWav->bytesRemaining) { + bytesToRead = (size_t)pWav->bytesRemaining; + } + + if (pBufferOut != NULL) { + bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead); + } else { + /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */ + bytesRead = 0; + while (bytesRead < bytesToRead) { + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > 0x7FFFFFFF) { + bytesToSeek = 0x7FFFFFFF; + } + + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) { + break; + } + + bytesRead += bytesToSeek; + } + + /* When we get here we may need to read-and-discard some data. */ + while (bytesRead < bytesToRead) { + drwav_uint8 buffer[4096]; + size_t bytesSeeked; + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > sizeof(buffer)) { + bytesToSeek = sizeof(buffer); + } + + bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek); + bytesRead += bytesSeeked; + + if (bytesSeeked < bytesToSeek) { + break; /* Reached the end. */ + } + } + } + + pWav->bytesRemaining -= bytesRead; + return bytesRead; +} + + + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */ + + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + /* Cannot use this function for compressed formats. */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + return 0; + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + bytesToRead = framesToRead * bytesPerFrame; + if (bytesToRead > DRWAV_SIZE_MAX) { + bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */ + } + + /* + Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There + *could* be a time where it evaluates to 0 due to overflowing. + */ + if (bytesToRead == 0) { + return 0; + } + + return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + + if (pBufferOut != NULL) { + drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + if (drwav__is_little_endian()) { + return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + } else { + return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut); + } +} + + + +DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav) +{ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; /* No seeking in write mode. */ + } + + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + pWav->compressed.iCurrentPCMFrame = 0; + + /* Cached data needs to be cleared for compressed formats. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->msadpcm); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->ima); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + } + + pWav->bytesRemaining = pWav->dataChunkDataSize; + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex) +{ + /* Seeking should be compatible with wave files > 2GB. */ + + if (pWav == NULL || pWav->onSeek == NULL) { + return DRWAV_FALSE; + } + + /* No seeking in write mode. */ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; + } + + /* If there are no samples, just return DRWAV_TRUE without doing anything. */ + if (pWav->totalPCMFrameCount == 0) { + return DRWAV_TRUE; + } + + /* Make sure the sample is clamped. */ + if (targetFrameIndex >= pWav->totalPCMFrameCount) { + targetFrameIndex = pWav->totalPCMFrameCount - 1; + } + + /* + For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need + to seek back to the start. + */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + /* TODO: This can be optimized. */ + + /* + If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards, + we first need to seek back to the start and then just do the same thing as a forward seek. + */ + if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) { + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + } + + if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) { + drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame; + + drwav_int16 devnull[2048]; + while (offsetInFrames > 0) { + drwav_uint64 framesRead = 0; + drwav_uint64 framesToRead = offsetInFrames; + if (framesToRead > drwav_countof(devnull)/pWav->channels) { + framesToRead = drwav_countof(devnull)/pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + + if (framesRead != framesToRead) { + return DRWAV_FALSE; + } + + offsetInFrames -= framesRead; + } + } + } else { + drwav_uint64 totalSizeInBytes; + drwav_uint64 currentBytePos; + drwav_uint64 targetBytePos; + drwav_uint64 offset; + + totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav); + DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining); + + currentBytePos = totalSizeInBytes - pWav->bytesRemaining; + targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav); + + if (currentBytePos < targetBytePos) { + /* Offset forwards. */ + offset = (targetBytePos - currentBytePos); + } else { + /* Offset backwards. */ + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + offset = targetBytePos; + } + + while (offset > 0) { + int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); + if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + + pWav->bytesRemaining -= offset32; + offset -= offset32; + } + } + + return DRWAV_TRUE; +} + + +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData) +{ + size_t bytesWritten; + + if (pWav == NULL || bytesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite); + pWav->dataChunkDataSize += bytesWritten; + + return bytesWritten; +} + + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + while (bytesToWrite > 0) { + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + drwav_uint32 bytesPerSample; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels; + + while (bytesToWrite > 0) { + drwav_uint8 temp[4096]; + drwav_uint32 sampleCount; + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + /* + WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need + to use an intermediary buffer for the conversion. + */ + sampleCount = sizeof(temp)/bytesPerSample; + + if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) { + bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample; + } + + DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration); + drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag); + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + if (drwav__is_little_endian()) { + return drwav_write_pcm_frames_le(pWav, framesToWrite, pData); + } else { + return drwav_write_pcm_frames_be(pWav, framesToWrite, pData); + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached frames we need to load a new block. */ + if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[7]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5); + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_uint8 header[14]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.predictor[1] = header[1]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2); + pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6); + pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10); + pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12); + + pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0]; + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.cachedFrameCount = 2; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample = 0; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample]; + } + + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->msadpcm.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->msadpcm.cachedFrameCount == 0) { + if (pWav->msadpcm.bytesRemainingInBlock == 0) { + continue; + } else { + static drwav_int32 adaptationTable[] = { + 230, 230, 230, 230, 307, 409, 512, 614, + 768, 614, 512, 409, 307, 230, 230, 230 + }; + static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + + drwav_uint8 nibbles; + drwav_int32 nibble0; + drwav_int32 nibble1; + + if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock -= 1; + + /* TODO: Optimize away these if statements. */ + nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; } + nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; } + + if (pWav->channels == 1) { + /* Mono. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[0]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample1; + + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + /* Left. */ + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + /* Right. */ + newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[1]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8; + if (pWav->msadpcm.delta[1] < 16) { + pWav->msadpcm.delta[1] = 16; + } + + pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.prevFrames[1][1] = newSample1; + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 1; + } + } + } + } + + return totalFramesRead; +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + drwav_uint32 iChannel; + + static drwav_int32 indexTable[16] = { + -1, -1, -1, -1, 2, 4, 6, 8, + -1, -1, -1, -1, 2, 4, 6, 8 + }; + + static drwav_int32 stepTable[89] = { + 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, + 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, + 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, + 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, + 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066, + 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358, + 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899, + 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767 + }; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached samples we need to load a new block. */ + if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[4]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0]; + pWav->ima.cachedFrameCount = 1; + } else { + /* Stereo. */ + drwav_uint8 header[8]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4); + pWav->ima.stepIndex[1] = header[6]; + + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1]; + pWav->ima.cachedFrameCount = 1; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample]; + } + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->ima.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->ima.cachedFrameCount == 0) { + if (pWav->ima.bytesRemainingInBlock == 0) { + continue; + } else { + /* + From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the + left channel, 4 bytes for the right channel. + */ + pWav->ima.cachedFrameCount = 8; + for (iChannel = 0; iChannel < pWav->channels; ++iChannel) { + drwav_uint32 iByte; + drwav_uint8 nibbles[4]; + if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) { + pWav->ima.cachedFrameCount = 0; + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock -= 4; + + for (iByte = 0; iByte < 4; ++iByte) { + drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0); + drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4); + + drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]]; + drwav_int32 predictor = pWav->ima.predictor[iChannel]; + + drwav_int32 diff = step >> 3; + if (nibble0 & 1) diff += step >> 2; + if (nibble0 & 2) diff += step >> 1; + if (nibble0 & 4) diff += step; + if (nibble0 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor; + + + step = stepTable[pWav->ima.stepIndex[iChannel]]; + predictor = pWav->ima.predictor[iChannel]; + + diff = step >> 3; + if (nibble1 & 1) diff += step >> 2; + if (nibble1 & 2) diff += step >> 1; + if (nibble1 & 4) diff += step; + if (nibble1 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor; + } + } + } + } + } + + return totalFramesRead; +} + + +#ifndef DR_WAV_NO_CONVERSION_API +static unsigned short g_drwavAlawTable[256] = { + 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, + 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, + 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, + 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00, + 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58, + 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58, + 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960, + 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0, + 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80, + 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40, + 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00, + 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500, + 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8, + 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8, + 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, + 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 +}; + +static unsigned short g_drwavMulawTable[256] = { + 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, + 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, + 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, + 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844, + 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64, + 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74, + 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C, + 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000, + 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C, + 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C, + 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC, + 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC, + 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C, + 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C, + 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084, + 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000 +}; + +static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavAlawTable[sampleIn]; +} + +static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavMulawTable[sampleIn]; +} + + + +static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s16(pOut, pIn, totalSampleCount); + return; + } + + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int16*)pIn)[i]; + } + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s16(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int16)((drwav_int64)sample >> 48); + } +} + +static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + +static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + /* Fast path. */ + if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x << 8; + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8; + r = x >> 8; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x >> 16; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + float c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5f); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + double x = pIn[i]; + double c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__alaw_to_s16(pIn[i]); + } +} + +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__mulaw_to_s16(pIn[i]); + } +} + + + +static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_f32(pOut, pIn, sampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_f32(pOut, pIn, sampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < sampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0); + } +} + +static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + unsigned int i; + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((const float*)pIn)[i]; + } + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears + libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note + the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated + correctness testing. This is disabled by default. + */ + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (pIn[i] / 256.0f) * 2 - 1; + } +#else + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */ + x = x - 1; /* 0..2 to -1..1 */ + + *pOut++ = x; + } +#endif +} + +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] * 0.000030517578125f; + } +} + +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + double x; + drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8); + drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16); + drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24); + + x = (double)((drwav_int32)(a | b | c) >> 8); + *pOut++ = (float)(x * 0.00000011920928955078125); + } +} + +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + size_t i; + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)(pIn[i] / 2147483648.0); + } +} + +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)pIn[i]; + } +} + +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f; + } +} + +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f; + } +} + + + +static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s32(pOut, pIn, totalSampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s32(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int32*)pIn)[i]; + } + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int32)((drwav_int64)sample >> 32); + } +} + +static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((int)pIn[i] - 128) << 24; + } +} + +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] << 16; + } +} + +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + unsigned int s0 = pIn[i*3 + 0]; + unsigned int s1 = pIn[i*3 + 1]; + unsigned int s2 = pIn[i*3 + 2]; + + drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24)); + *pOut++ = sample32; + } +} + +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16; + } +} + +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i= 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16; + } +} + + + +static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int16* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + float* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int32* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + + + +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +#ifndef DR_WAV_NO_STDIO +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + + +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif + +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif /* DR_WAV_NO_CONVERSION_API */ + + +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + drwav__free_from_callbacks(p, pAllocationCallbacks); + } else { + drwav__free_default(p, NULL); + } +} + +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data) +{ + return drwav__bytes_to_u16(data); +} + +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data) +{ + return drwav__bytes_to_s16(data); +} + +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data) +{ + return drwav__bytes_to_u32(data); +} + +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data) +{ + return drwav__bytes_to_s32(data); +} + +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data) +{ + return drwav__bytes_to_u64(data); +} + +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data) +{ + return drwav__bytes_to_s64(data); +} + + +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + return drwav__guid_equal(a, b); +} + +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b) +{ + return drwav__fourcc_equal(a, b); +} + +#endif /* dr_wav_c */ +#endif /* DR_WAV_IMPLEMENTATION */ + +/* +RELEASE NOTES - v0.11.0 +======================= +Version 0.11.0 has breaking API changes. + +Improved Client-Defined Memory Allocation +----------------------------------------- +The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The +existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom +allocation callbacks are specified. + +To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this: + + void* my_malloc(size_t sz, void* pUserData) + { + return malloc(sz); + } + void* my_realloc(void* p, size_t sz, void* pUserData) + { + return realloc(p, sz); + } + void my_free(void* p, void* pUserData) + { + free(p); + } + + ... + + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = &myData; + allocationCallbacks.onMalloc = my_malloc; + allocationCallbacks.onRealloc = my_realloc; + allocationCallbacks.onFree = my_free; + drwav_init_file(&wav, "my_file.wav", &allocationCallbacks); + +The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines. + +Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC, +DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions. + +Every API that opens a drwav object now takes this extra parameter. These include the following: + + drwav_init() + drwav_init_ex() + drwav_init_file() + drwav_init_file_ex() + drwav_init_file_w() + drwav_init_file_w_ex() + drwav_init_memory() + drwav_init_memory_ex() + drwav_init_write() + drwav_init_write_sequential() + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write() + drwav_init_file_write_sequential() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write() + drwav_init_memory_write_sequential() + drwav_init_memory_write_sequential_pcm_frames() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_pcm_frames_s32() + +Endian Improvements +------------------- +Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility +on big-endian architectures. + + drwav_read_pcm_frames() + drwav_read_pcm_frames_s16() + drwav_read_pcm_frames_s32() + drwav_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_s32() + drwav_open_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_s32() + drwav_open_memory_and_read_pcm_frames_f32() + +APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte +order: + + drwav_read_pcm_frames_le() + drwav_read_pcm_frames_be() + drwav_read_pcm_frames_s16le() + drwav_read_pcm_frames_s16be() + drwav_read_pcm_frames_f32le() + drwav_read_pcm_frames_f32be() + drwav_read_pcm_frames_s32le() + drwav_read_pcm_frames_s32be() + drwav_write_pcm_frames_le() + drwav_write_pcm_frames_be() + +Removed APIs +------------ +The following APIs were deprecated in version 0.10.0 and have now been removed: + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + + + +RELEASE NOTES - v0.10.0 +======================= +Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do +not need to upgrade. + +Removed APIs +------------ +The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0: + + drwav_read() + drwav_read_s16() + drwav_read_f32() + drwav_read_s32() + drwav_seek_to_sample() + drwav_write() + drwav_open_and_read_s16() + drwav_open_and_read_f32() + drwav_open_and_read_s32() + drwav_open_file_and_read_s16() + drwav_open_file_and_read_f32() + drwav_open_file_and_read_s32() + drwav_open_memory_and_read_s16() + drwav_open_memory_and_read_f32() + drwav_open_memory_and_read_s32() + drwav::totalSampleCount + +See release notes for version 0.9.0 at the bottom of this file for replacement APIs. + +Deprecated APIs +--------------- +The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and +drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a +drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre- +allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by +a called to one of the drwav_init*() APIs. + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + +These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the +two different ways to initialize a drwav object. +*/ + +/* +REVISION HISTORY +================ +v0.12.16 - 2020-12-02 + - Fix a bug when trying to read more bytes than can fit in a size_t. + +v0.12.15 - 2020-11-21 + - Fix compilation with OpenWatcom. + +v0.12.14 - 2020-11-13 + - Minor code clean up. + +v0.12.13 - 2020-11-01 + - Improve compiler support for older versions of GCC. + +v0.12.12 - 2020-09-28 + - Add support for RF64. + - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section. + +v0.12.11 - 2020-09-08 + - Fix a compilation error on older compilers. + +v0.12.10 - 2020-08-24 + - Fix a bug when seeking with ADPCM formats. + +v0.12.9 - 2020-08-02 + - Simplify sized types. + +v0.12.8 - 2020-07-25 + - Fix a compilation warning. + +v0.12.7 - 2020-07-15 + - Fix some bugs on big-endian architectures. + - Fix an error in s24 to f32 conversion. + +v0.12.6 - 2020-06-23 + - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek. + - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files. + - Add include guard for the implementation section. + +v0.12.5 - 2020-05-27 + - Minor documentation fix. + +v0.12.4 - 2020-05-16 + - Replace assert() with DRWAV_ASSERT(). + - Add compile-time and run-time version querying. + - DRWAV_VERSION_MINOR + - DRWAV_VERSION_MAJOR + - DRWAV_VERSION_REVISION + - DRWAV_VERSION_STRING + - drwav_version() + - drwav_version_string() + +v0.12.3 - 2020-04-30 + - Fix compilation errors with VC6. + +v0.12.2 - 2020-04-21 + - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file. + +v0.12.1 - 2020-04-13 + - Fix some pedantic warnings. + +v0.12.0 - 2020-04-04 + - API CHANGE: Add container and format parameters to the chunk callback. + - Minor documentation updates. + +v0.11.5 - 2020-03-07 + - Fix compilation error with Visual Studio .NET 2003. + +v0.11.4 - 2020-01-29 + - Fix some static analysis warnings. + - Fix a bug when reading f32 samples from an A-law encoded stream. + +v0.11.3 - 2020-01-12 + - Minor changes to some f32 format conversion routines. + - Minor bug fix for ADPCM conversion when end of file is reached. + +v0.11.2 - 2019-12-02 + - Fix a possible crash when using custom memory allocators without a custom realloc() implementation. + - Fix an integer overflow bug. + - Fix a null pointer dereference bug. + - Add limits to sample rate, channels and bits per sample to tighten up some validation. + +v0.11.1 - 2019-10-07 + - Internal code clean up. + +v0.11.0 - 2019-10-06 + - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation + routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs: + - drwav_init() + - drwav_init_ex() + - drwav_init_file() + - drwav_init_file_ex() + - drwav_init_file_w() + - drwav_init_file_w_ex() + - drwav_init_memory() + - drwav_init_memory_ex() + - drwav_init_write() + - drwav_init_write_sequential() + - drwav_init_write_sequential_pcm_frames() + - drwav_init_file_write() + - drwav_init_file_write_sequential() + - drwav_init_file_write_sequential_pcm_frames() + - drwav_init_file_write_w() + - drwav_init_file_write_sequential_w() + - drwav_init_file_write_sequential_pcm_frames_w() + - drwav_init_memory_write() + - drwav_init_memory_write_sequential() + - drwav_init_memory_write_sequential_pcm_frames() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_f32() + - drwav_open_memory_and_read_pcm_frames_s32() + Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use + DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE. + - Add support for reading and writing PCM frames in an explicit endianness. New APIs: + - drwav_read_pcm_frames_le() + - drwav_read_pcm_frames_be() + - drwav_read_pcm_frames_s16le() + - drwav_read_pcm_frames_s16be() + - drwav_read_pcm_frames_f32le() + - drwav_read_pcm_frames_f32be() + - drwav_read_pcm_frames_s32le() + - drwav_read_pcm_frames_s32be() + - drwav_write_pcm_frames_le() + - drwav_write_pcm_frames_be() + - Remove deprecated APIs. + - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data. + - drwav_read_pcm_frames() + - drwav_read_pcm_frames_s16() + - drwav_read_pcm_frames_s32() + - drwav_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_s32() + - drwav_open_memory_and_read_pcm_frames_f32() + +v0.10.1 - 2019-08-31 + - Correctly handle partial trailing ADPCM blocks. + +v0.10.0 - 2019-08-04 + - Remove deprecated APIs. + - Add wchar_t variants for file loading APIs: + drwav_init_file_w() + drwav_init_file_ex_w() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count. + - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode: + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write_sequential_pcm_frames() + - Deprecate drwav_open*() and drwav_close(): + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + - Minor documentation updates. + +v0.9.2 - 2019-05-21 + - Fix warnings. + +v0.9.1 - 2019-05-05 + - Add support for C89. + - Change license to choice of public domain or MIT-0. + +v0.9.0 - 2018-12-16 + - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and + will be removed in v0.10.0. Deprecated APIs and their replacements: + drwav_read() -> drwav_read_pcm_frames() + drwav_read_s16() -> drwav_read_pcm_frames_s16() + drwav_read_f32() -> drwav_read_pcm_frames_f32() + drwav_read_s32() -> drwav_read_pcm_frames_s32() + drwav_seek_to_sample() -> drwav_seek_to_pcm_frame() + drwav_write() -> drwav_write_pcm_frames() + drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32() + drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32() + drwav::totalSampleCount -> drwav::totalPCMFrameCount + - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*(). + - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*(). + - Add built-in support for smpl chunks. + - Add support for firing a callback for each chunk in the file at initialization time. + - This is enabled through the drwav_init_ex(), etc. family of APIs. + - Handle invalid FMT chunks more robustly. + +v0.8.5 - 2018-09-11 + - Const correctness. + - Fix a potential stack overflow. + +v0.8.4 - 2018-08-07 + - Improve 64-bit detection. + +v0.8.3 - 2018-08-05 + - Fix C++ build on older versions of GCC. + +v0.8.2 - 2018-08-02 + - Fix some big-endian bugs. + +v0.8.1 - 2018-06-29 + - Add support for sequential writing APIs. + - Disable seeking in write mode. + - Fix bugs with Wave64. + - Fix typos. + +v0.8 - 2018-04-27 + - Bug fix. + - Start using major.minor.revision versioning. + +v0.7f - 2018-02-05 + - Restrict ADPCM formats to a maximum of 2 channels. + +v0.7e - 2018-02-02 + - Fix a crash. + +v0.7d - 2018-02-01 + - Fix a crash. + +v0.7c - 2018-02-01 + - Set drwav.bytesPerSample to 0 for all compressed formats. + - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for + all format conversion reading APIs (*_s16, *_s32, *_f32 APIs). + - Fix some divide-by-zero errors. + +v0.7b - 2018-01-22 + - Fix errors with seeking of compressed formats. + - Fix compilation error when DR_WAV_NO_CONVERSION_API + +v0.7a - 2017-11-17 + - Fix some GCC warnings. + +v0.7 - 2017-11-04 + - Add writing APIs. + +v0.6 - 2017-08-16 + - API CHANGE: Rename dr_* types to drwav_*. + - Add support for custom implementations of malloc(), realloc(), etc. + - Add support for Microsoft ADPCM. + - Add support for IMA ADPCM (DVI, format code 0x11). + - Optimizations to drwav_read_s16(). + - Bug fixes. + +v0.5g - 2017-07-16 + - Change underlying type for booleans to unsigned. + +v0.5f - 2017-04-04 + - Fix a minor bug with drwav_open_and_read_s16() and family. + +v0.5e - 2016-12-29 + - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this. + - Minor fixes to documentation. + +v0.5d - 2016-12-28 + - Use drwav_int* and drwav_uint* sized types to improve compiler support. + +v0.5c - 2016-11-11 + - Properly handle JUNK chunks that come before the FMT chunk. + +v0.5b - 2016-10-23 + - A minor change to drwav_bool8 and drwav_bool32 types. + +v0.5a - 2016-10-11 + - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering. + - Improve A-law and mu-law efficiency. + +v0.5 - 2016-09-29 + - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to + keep it consistent with dr_audio and dr_flac. + +v0.4b - 2016-09-18 + - Fixed a typo in documentation. + +v0.4a - 2016-09-18 + - Fixed a typo. + - Change date format to ISO 8601 (YYYY-MM-DD) + +v0.4 - 2016-07-13 + - API CHANGE. Make onSeek consistent with dr_flac. + - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac. + - Added support for Sony Wave64. + +v0.3a - 2016-05-28 + - API CHANGE. Return drwav_bool32 instead of int in onSeek callback. + - Fixed a memory leak. + +v0.3 - 2016-05-22 + - Lots of API changes for consistency. + +v0.2a - 2016-05-16 + - Fixed Linux/GCC build. + +v0.2 - 2016-05-11 + - Added support for reading data as signed 32-bit PCM for consistency with dr_flac. + +v0.1a - 2016-05-07 + - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize. + +v0.1 - 2016-05-04 + - Initial versioned release. +*/ + +/* +This software is available as a choice of the following licenses. Choose +whichever you prefer. + +=============================================================================== +ALTERNATIVE 1 - Public Domain (www.unlicense.org) +=============================================================================== +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. + +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + +=============================================================================== +ALTERNATIVE 2 - MIT No Attribution +=============================================================================== +Copyright 2020 David Reid + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 67b3d2774..ac54dae9b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,44 +12,45 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) if (EMSCRIPTEN) else() - add_subdirectory(cvector-generator) - add_subdirectory(baby-llama) - add_subdirectory(batched-bench) - add_subdirectory(batched) - add_subdirectory(benchmark) - add_subdirectory(convert-llama2c-to-ggml) - add_subdirectory(embedding) - add_subdirectory(eval-callback) - add_subdirectory(export-lora) - add_subdirectory(gbnf-validator) - add_subdirectory(gguf-hash) - add_subdirectory(gguf-split) - add_subdirectory(gguf) - add_subdirectory(gritlm) - add_subdirectory(imatrix) - add_subdirectory(infill) - add_subdirectory(llama-bench) + # add_subdirectory(cvector-generator) + # add_subdirectory(baby-llama) + # add_subdirectory(batched-bench) + # add_subdirectory(batched) + # add_subdirectory(benchmark) + # add_subdirectory(convert-llama2c-to-ggml) + # add_subdirectory(embedding) + # add_subdirectory(eval-callback) + # add_subdirectory(export-lora) + # add_subdirectory(gbnf-validator) + # add_subdirectory(gguf-hash) + # add_subdirectory(gguf-split) + # add_subdirectory(gguf) + # add_subdirectory(gritlm) + # add_subdirectory(imatrix) + # add_subdirectory(infill) + # add_subdirectory(llama-bench) add_subdirectory(llava) - add_subdirectory(lookahead) - add_subdirectory(lookup) - add_subdirectory(main) - add_subdirectory(parallel) - add_subdirectory(passkey) - add_subdirectory(perplexity) - add_subdirectory(quantize-stats) - add_subdirectory(quantize) - add_subdirectory(retrieval) - if (GGML_RPC) - add_subdirectory(rpc) - endif() - if (LLAMA_BUILD_SERVER) - add_subdirectory(server) - endif() - if (GGML_SYCL) - add_subdirectory(sycl) - endif() - add_subdirectory(save-load-state) - add_subdirectory(simple) - add_subdirectory(speculative) - add_subdirectory(tokenize) + # add_subdirectory(lookahead) + # add_subdirectory(lookup) + # add_subdirectory(main) + # add_subdirectory(parallel) + # add_subdirectory(passkey) + # add_subdirectory(perplexity) + # add_subdirectory(quantize-stats) + # add_subdirectory(quantize) + # add_subdirectory(retrieval) + # if (GGML_RPC) + # add_subdirectory(rpc) + # endif() + # if (LLAMA_BUILD_SERVER) + # add_subdirectory(server) + # endif() + # if (GGML_SYCL) + # add_subdirectory(sycl) + # endif() + # add_subdirectory(save-load-state) + # add_subdirectory(simple) + # add_subdirectory(speculative) + # add_subdirectory(tokenize) + add_subdirectory(nexa-omni-audio) endif() diff --git a/examples/nexa-omni-audio/CMakeLists.txt b/examples/nexa-omni-audio/CMakeLists.txt new file mode 100644 index 000000000..4e4fce175 --- /dev/null +++ b/examples/nexa-omni-audio/CMakeLists.txt @@ -0,0 +1,56 @@ +# whisper + +# Find the Threads package +find_package(Threads REQUIRED) + +# build nexa-whisper-utils +set(WHISPER_LIB nexa-whisper-utils) +add_library(${WHISPER_LIB} OBJECT + whisper.cpp + ) +target_link_libraries(${WHISPER_LIB} PRIVATE ggml_llama common Threads::Threads) + +# build the whisper encoder +# add_executable(whisper-encode main-encode.cpp) +# target_link_libraries(whisper-encode PRIVATE ggml_llama common Threads::Threads ${WHISPER_LIB}) + +# build the audio projector +# add_executable(audio-projector-cli audio-projector-cli.cpp audio-projector.cpp) +# target_link_libraries(audio-projector-cli PRIVATE ggml_llama common) + +# add nexa-omni-audio-lib library +set(OMNI_AUDIO_LIB nexa-omni-audio-lib) +add_library(${OMNI_AUDIO_LIB} OBJECT + omni.cpp + omni.h + audio-projector.cpp + audio-projector.h + ) +target_link_libraries(${OMNI_AUDIO_LIB} PRIVATE ggml_llama common ${WHISPER_LIB}) + +# build the nexa-omni-cli +add_executable(nexa-omni-cli omni-cli.cpp) +target_link_libraries(nexa-omni-cli PRIVATE ggml_llama common Threads::Threads ${WHISPER_LIB} ${OMNI_AUDIO_LIB}) + +# If BUILD_SHARED_LIBS is ON, also build a shared library +if(BUILD_SHARED_LIBS) + message(STATUS "Building shared libraries") + + set_target_properties(${WHISPER_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(${OMNI_AUDIO_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON) + + add_library(${OMNI_AUDIO_LIB}_shared SHARED $) + target_link_libraries(${OMNI_AUDIO_LIB}_shared PRIVATE ggml_llama common ${WHISPER_LIB}) + set_target_properties(${OMNI_AUDIO_LIB}_shared PROPERTIES + PUBLIC_HEADER omni.h + POSITION_INDEPENDENT_CODE ON + ) + + # Add OMNI_AUDIO_SHARED definition when building the shared library + target_compile_definitions(${OMNI_AUDIO_LIB}_shared PRIVATE OMNI_AUDIO_SHARED WHISPER_SHARED) + + # Ensure all symbols are exported on Windows + if(MSVC) + set_target_properties(${OMNI_AUDIO_LIB}_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() \ No newline at end of file diff --git a/examples/nexa-omni-audio/README.md b/examples/nexa-omni-audio/README.md new file mode 100644 index 000000000..2f319ac65 --- /dev/null +++ b/examples/nexa-omni-audio/README.md @@ -0,0 +1,63 @@ +# nexa-audio-omni + +## Build this example + +Build the whole project from repo root dir + +```shell +cmake -B build-cpu +cmake --build build-cpu --config Release -j 24 +``` + + +```shell +cmake -B build-cuda -DGGML_CUDA=ON +cmake --build build-cuda --config Release -j 24 +``` + +## Prepare GGUF files + +### mmproj (mel-filter + whisper-encoder + audio-projector) + +1. Extract whisper-medium encoder (`audio_tower`) from `nexa-collaboration/nano-omini-instruct` + + ```shell + python convert-to-gguf-f16.py + ``` + +2. Extract mel filters from `ggml-medium.bin` and add thems to previously extracted `nano-omni-audio-encoder-f16.gguf` + + ```shell + python add-mel-filters.py + ``` + + > Run `bash download-ggml-model.sh` to donwload `ggml-medium.bin` and move it into the `models` folder + > Don't forget to modify the input and output file paths in the Python script above before running it. + +### gemma2 + +```shell +python prepare-gemma2-hf.py +python ../../convert_hf_to_gguf.py ./models/nano-omini-instruct-gemma2 --outfile ./models/nano-omini-instruct.gemma2.gguf [--outtype bf16] +``` + +## Run nexa-audio-omni + +From the root directory of the repo, run commands below: + +```shell +./build-cpu/bin/nexa-omni-cli \ + --model examples/nexa-omni-audio/models/nano-omni-instruct.gemma2.gguf \ + --mmproj examples/nexa-omni-audio/models/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \ + --file examples/nexa-omni-audio/samples/jfk.wav \ + --prompt "this conversation talks about" +``` + +```shell +./build/bin/nexa-omni-cli \ + --model /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/gemma2-2b.gguf \ + --mmproj /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \ + --file /home/azureuser/zack/ggml-project-apollo/examples/whisper/samples/jfk.wav \ + --prompt "this conversation talks about" \ + --n-gpu-layers 27 # offload all 27 layers of gemma2 model to GPU +``` diff --git a/examples/nexa-omni-audio/audio-projector.cpp b/examples/nexa-omni-audio/audio-projector.cpp new file mode 100644 index 000000000..87d4136c2 --- /dev/null +++ b/examples/nexa-omni-audio/audio-projector.cpp @@ -0,0 +1,37 @@ +#include "audio-projector.h" +#include "common-nexa.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include + +struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector &audio_feature_data) +{ + // Build the computation graph for inference + struct ggml_cgraph *gf = model.build_graph(); + // Allocate the graph tensors + ggml_gallocr_alloc_graph(model.compute_alloc, gf); + + // Set the input data + struct ggml_tensor *input = ggml_graph_get_tensor(gf, "input"); + ggml_backend_tensor_set(input, audio_feature_data.data(), 0, audio_feature_data.size() * sizeof(float)); + + model.set_n_threads(0); + + // Execute the graph on the backend + ggml_backend_graph_compute(model.backend, gf); + + // Return the output tensor (last node in the graph) + return ggml_graph_get_tensor(gf, "output"); +} + +struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor) +{ + // Set the input data + std::vector data(ggml_nelements(audio_feature_tensor)); + ggml_backend_tensor_get(audio_feature_tensor, data.data(), 0, ggml_nbytes(audio_feature_tensor)); + + return audio_projector_inference(model, data); +} diff --git a/examples/nexa-omni-audio/audio-projector.h b/examples/nexa-omni-audio/audio-projector.h new file mode 100644 index 000000000..d7252b756 --- /dev/null +++ b/examples/nexa-omni-audio/audio-projector.h @@ -0,0 +1,67 @@ +#pragma once + +#include "ggml.h" +#include "common-nexa.h" + +#include + +// +// Audio Projector +// + +struct audio_projector : public NexaBaseModel +{ + + audio_projector() : NexaBaseModel() + { + this->hparam_names = { + "max_source_positions", + "d_model", + }; + this->tensor_names = { + "multi_modal_projector.linear.weight", + "multi_modal_projector.linear.bias", + }; + } + + struct ggml_cgraph *build_graph() override + { + const int MAX_NODES = 64; + size_t buf_size = ggml_tensor_overhead() * MAX_NODES + ggml_graph_overhead_custom(MAX_NODES, false); + static std::vector buf(buf_size); + + // Create temporary GGML context for building the graph + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // Memory will be allocated later + }; + struct ggml_context *ctx0 = ggml_init(params); + struct ggml_cgraph *gf = ggml_new_graph_custom(ctx0, MAX_NODES, false); // Create new graph + + // Create input tensor + struct ggml_tensor *input = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, + std::get(hparams["d_model"]), + std::get(hparams["max_source_positions"]) / 2); + ggml_set_name(input, "input"); + ggml_set_input(input); // Mark tensor as input + + // weight * input + bias + struct ggml_tensor *cur = ggml_mul_mat(ctx0, tensors["multi_modal_projector.linear.weight"], input); + cur = ggml_add(ctx0, cur, tensors["multi_modal_projector.linear.bias"]); + + // Set the final output + ggml_set_name(cur, "output"); + ggml_set_output(cur); + + ggml_build_forward_expand(gf, cur); // Expand graph with operations + + ggml_free(ctx0); // Free temporary context + + return gf; + } +}; + +struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector &audio_feature_data); + +struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor); \ No newline at end of file diff --git a/examples/nexa-omni-audio/ggml-cpu-impl.h b/examples/nexa-omni-audio/ggml-cpu-impl.h new file mode 100644 index 000000000..5b45155b0 --- /dev/null +++ b/examples/nexa-omni-audio/ggml-cpu-impl.h @@ -0,0 +1,614 @@ +#pragma once + +// GGML CPU internal header + +#include "ggml.h" +#include "ggml-impl.h" +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ +//#include +#include +#include // memcpy +#include // fabsf + + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_MSC_VER) + +#define m512bh(p) p +#define m512i(p) p + +#else + +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) + +#endif + +/** + * Converts brain16 to float32. + * + * The bfloat16 floating point format has the following structure: + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───┐ + * 0b0000000000000000 brain16 + * + * Since bf16 has the same number of exponent bits as a 32bit float, + * encoding and decoding numbers becomes relatively straightforward. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───────────────────┐ + * 0b00000000000000000000000000000000 IEEE binary32 + * + * For comparison, the standard fp16 format has fewer exponent bits. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌─┴─┐┌─┴──────┐ + * 0b0000000000000000 IEEE binary16 + * + * @see IEEE 754-2008 + */ +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +/** + * Converts float32 to brain16. + * + * This is binary identical with Google Brain float conversion. + * Floats shall round to nearest even, and NANs shall be quiet. + * Subnormals aren't flushed to zero, except perhaps when used. + * This code should vectorize nicely if using modern compilers. + */ +static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { + ggml_bf16_t h; + union { + float f; + uint32_t i; + } u; + u.f = s; + if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ + h.bits = (u.i >> 16) | 64; /* force to quiet */ + return h; + } + h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; + return h; +} + +#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) +#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __SSE3__ +#define __SSE3__ +#endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif +#endif + +#if defined(__ARM_FEATURE_SVE) +#include +#include +#endif + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#ifdef _MSC_VER + +typedef uint16_t ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +typedef __fp16 ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif // _MSC_VER + +#if !defined(__aarch64__) + +// 32-bit ARM compatibility + +// vaddlvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddlvq_s16(int16x8_t v) { + int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v))); + return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif // !defined(__aarch64__) + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif // !defined(__ARM_FEATURE_DOTPROD) + +#endif // defined(__ARM_NEON) + +#if defined(__ARM_NEON) && !defined(_MSC_VER) + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(ggml_fp16_t)); + return (float)tmp; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(ggml_fp16_t)); + return res; +} + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#if defined(__loongarch64) +#if defined(__loongarch_asx) +#include +#endif +#if defined(__loongarch_sx) +#include +#endif +#endif + +#if defined(__loongarch_asx) + +typedef union { + int32_t i; + float f; +} ft_union; + +/* float type data load instructions */ +static __m128 __lsx_vreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +} + +static __m256 __lasx_xvreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +} +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + register double d; + register ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) + +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in ggml_init() +extern float ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(GGML_FP16_TO_FP32) +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return ggml_table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#endif + +#if !defined(GGML_FP32_TO_FP16) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +#endif + +#ifdef __cplusplus +} +#endif diff --git a/examples/nexa-omni-audio/main-encode.cpp b/examples/nexa-omni-audio/main-encode.cpp new file mode 100644 index 000000000..e9a28cf3e --- /dev/null +++ b/examples/nexa-omni-audio/main-encode.cpp @@ -0,0 +1,916 @@ +#include "common.h" +#include "common-nexa.h" + +#include "whisper.h" +#include "grammar-parser.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4267) // possible loss of data +#endif + +// helper function to replace substrings +static void replace_all(std::string &s, const std::string &search, const std::string &replace) +{ + for (size_t pos = 0;; pos += replace.length()) + { + pos = s.find(search, pos); + if (pos == std::string::npos) + break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + +// command-line parameters +struct whisper_params +{ + int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t audio_ctx = 0; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; + + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool no_prints = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + bool log_score = false; + bool use_gpu = true; + bool flash_attn = false; + + std::string language = "en"; + std::string prompt; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string grammar; + std::string grammar_rule; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + + // A regular expression that matches tokens to suppress + std::string suppress_regex; + + std::string openvino_encode_device = "CPU"; + + std::string dtw = ""; + + std::vector fname_inp = {}; + std::vector fname_out = {}; + + grammar_parser::parse_state grammar_parsed; +}; + +static void whisper_print_usage(int argc, char **argv, const whisper_params ¶ms); + +static char *whisper_param_turn_lowercase(char *in) +{ + int string_len = strlen(in); + for (int i = 0; i < string_len; i++) + { + *(in + i) = tolower((unsigned char)*(in + i)); + } + return in; +} + +static bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) +{ + for (int i = 1; i < argc; i++) + { + std::string arg = argv[i]; + + if (arg == "-") + { + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') + { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") + { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") + { + params.n_threads = std::stoi(argv[++i]); + } + else if (arg == "-p" || arg == "--processors") + { + params.n_processors = std::stoi(argv[++i]); + } + else if (arg == "-ot" || arg == "--offset-t") + { + params.offset_t_ms = std::stoi(argv[++i]); + } + else if (arg == "-on" || arg == "--offset-n") + { + params.offset_n = std::stoi(argv[++i]); + } + else if (arg == "-d" || arg == "--duration") + { + params.duration_ms = std::stoi(argv[++i]); + } + else if (arg == "-mc" || arg == "--max-context") + { + params.max_context = std::stoi(argv[++i]); + } + else if (arg == "-ml" || arg == "--max-len") + { + params.max_len = std::stoi(argv[++i]); + } + else if (arg == "-bo" || arg == "--best-of") + { + params.best_of = std::stoi(argv[++i]); + } + else if (arg == "-bs" || arg == "--beam-size") + { + params.beam_size = std::stoi(argv[++i]); + } + else if (arg == "-ac" || arg == "--audio-ctx") + { + params.audio_ctx = std::stoi(argv[++i]); + } + else if (arg == "-wt" || arg == "--word-thold") + { + params.word_thold = std::stof(argv[++i]); + } + else if (arg == "-et" || arg == "--entropy-thold") + { + params.entropy_thold = std::stof(argv[++i]); + } + else if (arg == "-lpt" || arg == "--logprob-thold") + { + params.logprob_thold = std::stof(argv[++i]); + } + else if (arg == "-tp" || arg == "--temperature") + { + params.temperature = std::stof(argv[++i]); + } + else if (arg == "-tpi" || arg == "--temperature-inc") + { + params.temperature_inc = std::stof(argv[++i]); + } + else if (arg == "-debug" || arg == "--debug-mode") + { + params.debug_mode = true; + } + else if (arg == "-tr" || arg == "--translate") + { + params.translate = true; + } + else if (arg == "-di" || arg == "--diarize") + { + params.diarize = true; + } + else if (arg == "-tdrz" || arg == "--tinydiarize") + { + params.tinydiarize = true; + } + else if (arg == "-sow" || arg == "--split-on-word") + { + params.split_on_word = true; + } + else if (arg == "-nf" || arg == "--no-fallback") + { + params.no_fallback = true; + } + else if (arg == "-fp" || arg == "--font-path") + { + params.font_path = argv[++i]; + } + else if (arg == "-np" || arg == "--no-prints") + { + params.no_prints = true; + } + else if (arg == "-ps" || arg == "--print-special") + { + params.print_special = true; + } + else if (arg == "-pc" || arg == "--print-colors") + { + params.print_colors = true; + } + else if (arg == "-pp" || arg == "--print-progress") + { + params.print_progress = true; + } + else if (arg == "-nt" || arg == "--no-timestamps") + { + params.no_timestamps = true; + } + else if (arg == "-l" || arg == "--language") + { + params.language = whisper_param_turn_lowercase(argv[++i]); + } + else if (arg == "-dl" || arg == "--detect-language") + { + params.detect_language = true; + } + else if (arg == "--prompt") + { + params.prompt = argv[++i]; + } + else if (arg == "-m" || arg == "--model") + { + params.model = argv[++i]; + } + else if (arg == "-f" || arg == "--file") + { + params.fname_inp.emplace_back(argv[++i]); + } + else if (arg == "-oved" || arg == "--ov-e-device") + { + params.openvino_encode_device = argv[++i]; + } + else if (arg == "-dtw" || arg == "--dtw") + { + params.dtw = argv[++i]; + } + else if (arg == "-ls" || arg == "--log-score") + { + params.log_score = true; + } + else if (arg == "-ng" || arg == "--no-gpu") + { + params.use_gpu = false; + } + else if (arg == "-fa" || arg == "--flash-attn") + { + params.flash_attn = true; + } + else if (arg == "--suppress-regex") + { + params.suppress_regex = argv[++i]; + } + else if (arg == "--grammar") + { + params.grammar = argv[++i]; + } + else if (arg == "--grammar-rule") + { + params.grammar_rule = argv[++i]; + } + else if (arg == "--grammar-penalty") + { + params.grammar_penalty = std::stof(argv[++i]); + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms) +{ + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, "\n"); +} + +struct whisper_print_user_data +{ + const whisper_params *params; + + const std::vector> *pcmf32s; + int progress_prev; +}; + +static std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) +{ + std::string speaker = ""; + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE); + const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) + { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1 * energy1) + { + speaker = "0"; + } + else if (energy1 > 1.1 * energy0) + { + speaker = "1"; + } + else + { + speaker = "?"; + } + + // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str()); + + if (!id_only) + { + speaker.insert(0, "(speaker "); + speaker.append(")"); + } + + return speaker; +} + +static void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data) +{ + int progress_step = ((whisper_print_user_data *)user_data)->params->progress_step; + int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) + { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } +} + +static void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data) +{ + const auto ¶ms = *((whisper_print_user_data *)user_data)->params; + const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s; + + const int n_segments = whisper_full_n_segments(ctx); + + std::string speaker = ""; + + int64_t t0 = 0; + int64_t t1 = 0; + + // print the last n_new segments + const int s0 = n_segments - n_new; + + if (s0 == 0) + { + printf("\n"); + } + + for (int i = s0; i < n_segments; i++) + { + if (!params.no_timestamps || params.diarize) + { + t0 = whisper_full_get_segment_t0(ctx, i); + t1 = whisper_full_get_segment_t1(ctx, i); + } + + if (!params.no_timestamps) + { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + } + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + if (params.print_colors) + { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) + { + if (params.print_special == false) + { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) + { + continue; + } + } + + const char *text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p(ctx, i, j); + + const int col = std::max(0, std::min((int)k_colors.size() - 1, (int)(std::pow(p, 3) * float(k_colors.size())))); + + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); + } + } + else + { + const char *text = whisper_full_get_segment_text(ctx, i); + + printf("%s%s", speaker.c_str(), text); + } + + if (params.tinydiarize) + { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) + { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) + { + printf("\n"); + } + + fflush(stdout); + } +} + +static char *escape_double_quotes_and_backslashes(const char *str) +{ + if (str == NULL) + { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) + { + if (str[i] == '"' || str[i] == '\\') + { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) + { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) + { + if (str[i] == '"' || str[i] == '\\') + { + escaped[pos++] = '\\'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + +// double quote should be escaped by another double quote. (rfc4180) +static char *escape_double_quotes_in_csv(const char *str) +{ + if (str == NULL) + { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) + { + if (str[i] == '"') + { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) + { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) + { + if (str[i] == '"') + { + escaped[pos++] = '"'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + +static void cb_log_disable(enum ggml_log_level, const char *, void *) {} + +int main(int argc, char **argv) +{ + whisper_params params; + + // If the only argument starts with "@", read arguments line-by-line + // from the given file. + std::vector vec_args; + if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') + { + // Save the name of the executable. + vec_args.push_back(argv[0]); + + // Open the response file. + char const *rspfile = argv[1] + sizeof(char); + std::ifstream fin(rspfile); + if (fin.is_open() == false) + { + fprintf(stderr, "error: response file '%s' not found\n", rspfile); + return 1; + } + + // Read the entire response file. + std::string line; + while (std::getline(fin, line)) + { + vec_args.push_back(line); + } + + // Use the contents of the response file as the command-line arguments. + argc = static_cast(vec_args.size()); + argv = static_cast(alloca(argc * sizeof(char *))); + for (int i = 0; i < argc; ++i) + { + argv[i] = const_cast(vec_args[i].c_str()); + } + } + + if (whisper_params_parse(argc, argv, params) == false) + { + whisper_print_usage(argc, argv, params); + return 1; + } + + // remove non-existent files + for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) + { + const auto fname_inp = it->c_str(); + + if (*it != "-" && !is_file_exist(fname_inp)) + { + fprintf(stderr, "error: input file not found '%s'\n", fname_inp); + it = params.fname_inp.erase(it); + continue; + } + + it++; + } + + if (params.fname_inp.empty()) + { + fprintf(stderr, "error: no input files specified\n"); + whisper_print_usage(argc, argv, params); + return 2; + } + + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) + { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + + if (params.diarize && params.tinydiarize) + { + fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + + if (params.no_prints) + { + whisper_log_set(cb_log_disable, NULL); + } + + // whisper init + + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + + if (!params.dtw.empty()) + { + cparams.dtw_token_timestamps = true; + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + + if (params.dtw == "tiny") + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY; + if (params.dtw == "tiny.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + if (params.dtw == "base") + cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE; + if (params.dtw == "base.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; + if (params.dtw == "small") + cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL; + if (params.dtw == "small.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN; + if (params.dtw == "medium") + cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM; + if (params.dtw == "medium.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN; + if (params.dtw == "large.v1") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1; + if (params.dtw == "large.v2") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2; + if (params.dtw == "large.v3") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; + + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { + fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); + return 3; + } + } + + struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams); + + if (ctx == nullptr) + { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return 3; + } + + // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured + whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + + if (!params.grammar.empty()) + { + auto &grammar = params.grammar_parsed; + if (is_file_exist(params.grammar.c_str())) + { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } + else + { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + + // will be empty (default) if there are parse errors + if (grammar.rules.empty()) + { + fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str()); + return 4; + } + else + { + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, grammar); + fprintf(stderr, "\n"); + } + } + + for (int f = 0; f < (int)params.fname_inp.size(); ++f) + { + const auto fname_inp = params.fname_inp[f]; + const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) + { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; + } + + if (!whisper_is_multilingual(ctx)) // TODO: something off here + { + if (params.language != "en" || params.translate) + { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + if (params.detect_language) + { + params.language = "auto"; + } + + if (!params.no_prints) + { + // print system information + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads * params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + + // print some info about the processing + fprintf(stderr, "\n"); + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size()) / WHISPER_SAMPLE_RATE, + params.n_threads, params.n_processors, params.beam_size, params.best_of, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.tinydiarize ? "tdrz = 1, " : "", + params.no_timestamps ? 0 : 1); + + fprintf(stderr, "\n"); + } + + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty()); + wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; + wparams.audio_ctx = params.audio_ctx; + + wparams.debug_mode = params.debug_mode; + + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + + wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); + + wparams.initial_prompt = params.prompt.c_str(); + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + + wparams.no_timestamps = params.no_timestamps; + + whisper_print_user_data user_data = {¶ms, &pcmf32s, 0}; + + const auto &grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (use_grammar) + { + if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) + { + fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str()); + } + else + { + wparams.grammar_rules = grammar_rules.data(); + wparams.n_grammar_rules = grammar_rules.size(); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule); + wparams.grammar_penalty = params.grammar_penalty; + } + } + + // this callback is called on each new segment + if (!wparams.print_realtime) + { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + if (wparams.print_progress) + { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the flag is set to true + + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void *user_data) + { + bool is_aborted = *(bool *)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + + // the callback is called before every computation - if it returns true, the computation is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void *user_data) + { + bool is_aborted = *(bool *)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; + } + + if (whisper_encode_wo_cross_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + + ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx); + // print_ggml_tensor("embd_enc", embd_enc, true); + print_ggml_tensor_shape("embd_enc", embd_enc); + } + + } + + if (!params.no_prints) + { + whisper_print_timings(ctx); + } + whisper_free(ctx); + + return 0; +} diff --git a/examples/nexa-omni-audio/omni-cli.cpp b/examples/nexa-omni-audio/omni-cli.cpp new file mode 100644 index 000000000..5d24b5b5d --- /dev/null +++ b/examples/nexa-omni-audio/omni-cli.cpp @@ -0,0 +1,19 @@ +#include "omni.h" + +int main(int argc, char **argv) +{ + + omni_context_params ctx_params = omni_context_default_params(); + if (!omni_context_params_parse(argc, argv, ctx_params)) + { + return 1; + } + + omni_context *ctx_omni = omni_init_context(ctx_params); + + omni_process_full(ctx_omni, ctx_params); + + omni_free(ctx_omni); + + return 0; +} \ No newline at end of file diff --git a/examples/nexa-omni-audio/omni.cpp b/examples/nexa-omni-audio/omni.cpp new file mode 100644 index 000000000..d2701c2c1 --- /dev/null +++ b/examples/nexa-omni-audio/omni.cpp @@ -0,0 +1,864 @@ +#include "omni.h" +#include "audio-projector.h" +#include "common-nexa.h" + +#include "whisper.h" +#include "llama.h" +#include "common.h" +#include "log.h" +// #include "arg.h" +#include "sampling.h" +#include "llama-impl.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// +// Constants +// + +static const char *AUDIO_TOKEN = "<|AUDIO|>"; + +// +// Whisper +// + +struct whisper_params +{ + int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t audio_ctx = 0; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; + + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool no_prints = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + bool log_score = false; + bool use_gpu = true; + bool flash_attn = false; + + std::string language = "en"; + std::string prompt; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string grammar; + std::string grammar_rule; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + + // A regular expression that matches tokens to suppress + std::string suppress_regex; + + std::string openvino_encode_device = "CPU"; + + std::string dtw = ""; + + std::vector fname_inp = {}; + std::vector fname_out = {}; + + grammar_parser::parse_state grammar_parsed; +}; + +static void whisper_print_usage(int argc, char **argv, const whisper_params ¶ms); + +static char *whisper_param_turn_lowercase(char *in) +{ + int string_len = strlen(in); + for (int i = 0; i < string_len; i++) + { + *(in + i) = tolower((unsigned char)*(in + i)); + } + return in; +} + +static bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms) +{ + for (int i = 1; i < argc; i++) + { + std::string arg = argv[i]; + + if (arg == "-") + { + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') + { + // params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") + { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") + { + params.n_threads = std::stoi(argv[++i]); + } + else if (arg == "-p" || arg == "--processors") + { + params.n_processors = std::stoi(argv[++i]); + } + else if (arg == "-ot" || arg == "--offset-t") + { + params.offset_t_ms = std::stoi(argv[++i]); + } + else if (arg == "-on" || arg == "--offset-n") + { + params.offset_n = std::stoi(argv[++i]); + } + else if (arg == "-d" || arg == "--duration") + { + params.duration_ms = std::stoi(argv[++i]); + } + else if (arg == "-mc" || arg == "--max-context") + { + params.max_context = std::stoi(argv[++i]); + } + else if (arg == "-ml" || arg == "--max-len") + { + params.max_len = std::stoi(argv[++i]); + } + else if (arg == "-bo" || arg == "--best-of") + { + params.best_of = std::stoi(argv[++i]); + } + else if (arg == "-bs" || arg == "--beam-size") + { + params.beam_size = std::stoi(argv[++i]); + } + else if (arg == "-ac" || arg == "--audio-ctx") + { + params.audio_ctx = std::stoi(argv[++i]); + } + else if (arg == "-wt" || arg == "--word-thold") + { + params.word_thold = std::stof(argv[++i]); + } + else if (arg == "-et" || arg == "--entropy-thold") + { + params.entropy_thold = std::stof(argv[++i]); + } + else if (arg == "-lpt" || arg == "--logprob-thold") + { + params.logprob_thold = std::stof(argv[++i]); + } + else if (arg == "-tp" || arg == "--temperature") + { + params.temperature = std::stof(argv[++i]); + } + else if (arg == "-tpi" || arg == "--temperature-inc") + { + params.temperature_inc = std::stof(argv[++i]); + } + else if (arg == "-debug" || arg == "--debug-mode") + { + params.debug_mode = true; + } + else if (arg == "-tr" || arg == "--translate") + { + params.translate = true; + } + else if (arg == "-di" || arg == "--diarize") + { + params.diarize = true; + } + else if (arg == "-tdrz" || arg == "--tinydiarize") + { + params.tinydiarize = true; + } + else if (arg == "-sow" || arg == "--split-on-word") + { + params.split_on_word = true; + } + else if (arg == "-nf" || arg == "--no-fallback") + { + params.no_fallback = true; + } + else if (arg == "-fp" || arg == "--font-path") + { + params.font_path = argv[++i]; + } + else if (arg == "-np" || arg == "--no-prints") + { + params.no_prints = true; + } + else if (arg == "-ps" || arg == "--print-special") + { + params.print_special = true; + } + else if (arg == "-pc" || arg == "--print-colors") + { + params.print_colors = true; + } + else if (arg == "-pp" || arg == "--print-progress") + { + params.print_progress = true; + } + else if (arg == "-nt" || arg == "--no-timestamps") + { + params.no_timestamps = true; + } + else if (arg == "-l" || arg == "--language") + { + params.language = whisper_param_turn_lowercase(argv[++i]); + } + else if (arg == "-dl" || arg == "--detect-language") + { + params.detect_language = true; + } + else if (arg == "--prompt") + { + params.prompt = argv[++i]; + } + else if (arg == "-m" || arg == "--model") + { + params.model = argv[++i]; + } + else if (arg == "-f" || arg == "--file") + { + params.fname_inp.emplace_back(argv[++i]); + } + else if (arg == "-oved" || arg == "--ov-e-device") + { + params.openvino_encode_device = argv[++i]; + } + else if (arg == "-dtw" || arg == "--dtw") + { + params.dtw = argv[++i]; + } + else if (arg == "-ls" || arg == "--log-score") + { + params.log_score = true; + } + else if (arg == "-ng" || arg == "--no-gpu") + { + params.use_gpu = false; + } + else if (arg == "-fa" || arg == "--flash-attn") + { + params.flash_attn = true; + } + else if (arg == "--suppress-regex") + { + params.suppress_regex = argv[++i]; + } + else if (arg == "--grammar") + { + params.grammar = argv[++i]; + } + else if (arg == "--grammar-rule") + { + params.grammar_rule = argv[++i]; + } + else if (arg == "--grammar-penalty") + { + params.grammar_penalty = std::stof(argv[++i]); + } + else if (arg == "--mmproj") + { + continue; + // params.mmproj = argv[++i]; + } + else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") + { + continue; + // params.n_gpu_layers = std::stoi(argv[++i]); + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms) +{ + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, "\n"); +} + +struct whisper_context *whisper_init_context(whisper_params ¶ms) +{ + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + + if (!params.dtw.empty()) + { + cparams.dtw_token_timestamps = true; + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + + if (params.dtw == "tiny") + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY; + if (params.dtw == "tiny.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + if (params.dtw == "base") + cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE; + if (params.dtw == "base.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; + if (params.dtw == "small") + cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL; + if (params.dtw == "small.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN; + if (params.dtw == "medium") + cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM; + if (params.dtw == "medium.en") + cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN; + if (params.dtw == "large.v1") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1; + if (params.dtw == "large.v2") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2; + if (params.dtw == "large.v3") + cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; + + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { + fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); + return NULL; + } + } + + struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams); + if (ctx == nullptr) + { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return NULL; + } + + // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured + whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + + return ctx; +} + +static struct whisper_full_params get_whisper_inference_params_from_whisper_params(whisper_params ¶ms) +{ + + struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty()); + wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; + wparams.audio_ctx = params.audio_ctx; + + wparams.debug_mode = params.debug_mode; + + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + + wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); + + wparams.initial_prompt = params.prompt.c_str(); + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + + wparams.no_timestamps = params.no_timestamps; + + return wparams; +} + +// +// Omni +// + +static void omni_print_usage(int, char **argv) +{ + LOG("\n example usage:\n"); + LOG("\n %s --model --mmproj --file [-p \"describe the audio in detail.\"]\n", argv[0]); + LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); +} + +bool omni_context_params_parse(int argc, char **argv, omni_context_params ¶ms) +{ + for (int i = 1; i < argc; i++) + { + std::string arg = argv[i]; + + if (arg[0] != '-') + { + continue; + } + + if (arg == "-h" || arg == "--help") + { + omni_print_usage(argc, argv); + exit(0); + } + if (arg == "--prompt") + { + params.prompt = argv[++i]; + } + else if (arg == "-m" || arg == "--model") + { + params.model = argv[++i]; + } + else if (arg == "-f" || arg == "--file") + { + params.file = argv[++i]; + } + else if (arg == "--mmproj") + { + params.mmproj = argv[++i]; + } + else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") + { + params.n_gpu_layers = std::stoi(argv[++i]); + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + omni_print_usage(argc, argv); + exit(0); + } + } + + return true; +} + +omni_context_params omni_context_default_params() +{ + omni_context_params params = { + .model = "", + .mmproj = "", + .file = "", + .prompt = "this conversation talks about", + .n_gpu_layers = -1, + }; + return params; +} + +struct omni_params +{ + gpt_params gpt; + whisper_params whisper; +}; + +bool omni_params_parse(int argc, char **argv, omni_params ¶ms) +{ + if (!gpt_params_parse(argc, argv, params.gpt)) + { + return false; + } + + if (!whisper_params_parse(argc, argv, params.whisper)) + { + whisper_print_usage(argc, argv, params.whisper); + return false; + } + + if (params.gpt.model.empty() || params.gpt.mmproj.empty() || params.whisper.fname_inp.empty()) + { + omni_print_usage(argc, argv); + return false; + } + + params.whisper.model = params.gpt.mmproj; + + return true; +} + +static omni_params get_omni_params_from_context_params(omni_context_params ¶ms) +{ + omni_params all_params = { + .gpt = { + .n_gpu_layers = params.n_gpu_layers, + .model = params.model, + .prompt = params.prompt, + }, + .whisper = { + .model = params.mmproj, + .fname_inp = {params.file}, + }, + }; + + if (all_params.gpt.n_threads <= 0) + { + all_params.gpt.n_threads = std::thread::hardware_concurrency(); + } + return all_params; +} + +static bool eval_tokens(struct llama_context *ctx_llama, std::vector tokens, int n_batch, int *n_past) +{ + int N = (int)tokens.size(); + for (int i = 0; i < N; i += n_batch) + { + int n_eval = (int)tokens.size() - i; + if (n_eval > n_batch) + { + n_eval = n_batch; + } + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) + { + LLAMA_LOG_ERROR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + return false; + } + *n_past += n_eval; + } + return true; +} + +static bool eval_id(struct llama_context *ctx_llama, int id, int *n_past) +{ + std::vector tokens; + tokens.push_back(id); + return eval_tokens(ctx_llama, tokens, 1, n_past); +} + +static bool eval_string(struct llama_context *ctx_llama, const char *str, int n_batch, int *n_past, bool add_bos) +{ + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + eval_tokens(ctx_llama, embd_inp, n_batch, n_past); + return true; +} + +static const char * sample(struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_llama, + int * n_past) { + const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); + llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + static std::string ret; + if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + ret = ""; + } else { + ret = llama_token_to_piece(ctx_llama, id); + } + eval_id(ctx_llama, id, n_past); + return ret.c_str(); +} + +static size_t find_audio_token(const std::string &prompt) +{ + return prompt.find(AUDIO_TOKEN); +} + +struct omni_context *omni_init_context(omni_context_params ¶ms) +{ + + omni_params all_params = get_omni_params_from_context_params(params); + + // llama + LLAMA_LOG_INFO("------- llama --------\n"); + + std::string prompt = all_params.gpt.prompt; + if (prompt.empty()) + { + prompt = "this conversation talks about"; + } + + llama_backend_init(); + llama_numa_init(all_params.gpt.numa); + + llama_model_params model_params = llama_model_params_from_gpt_params(all_params.gpt); + + llama_model *model = llama_load_model_from_file(all_params.gpt.model.c_str(), model_params); + if (model == NULL) + { + LLAMA_LOG_ERROR("%s: unable to load model\n", __func__); + return NULL; + } + + llama_context_params ctx_params = llama_context_params_from_gpt_params(all_params.gpt); + ctx_params.n_ctx = all_params.gpt.n_ctx < 2048 ? 2048 : all_params.gpt.n_ctx; // we need a longer context size to process image embeddings + + llama_context *ctx_llama = llama_new_context_with_model(model, ctx_params); + + if (ctx_llama == NULL) + { + LLAMA_LOG_ERROR("%s: failed to create the llama_context\n", __func__); + return NULL; + } + + // whisper + LLAMA_LOG_INFO("------- whisper --------\n"); + + whisper_context *ctx_whisper = whisper_init_context(all_params.whisper); + + // projector + LLAMA_LOG_INFO("------- projector --------\n"); + + audio_projector *projector = new audio_projector(); + if (!projector->load_from_gguf(all_params.whisper.model)) + { + fprintf(stderr, "Failed to load model.\n"); + return NULL; + } + + auto *ctx_omni = (struct omni_context *)malloc(sizeof(omni_context)); + + ctx_omni->ctx_llama = ctx_llama; + ctx_omni->ctx_whisper = ctx_whisper; + ctx_omni->model = model; + ctx_omni->projector = projector; + + LLAMA_LOG_INFO("------- omni context initialized --------\n"); + + return ctx_omni; +} + +void omni_free(struct omni_context *ctx_omni) +{ + if (ctx_omni->ctx_whisper) + { + whisper_free(ctx_omni->ctx_whisper); + ctx_omni->ctx_whisper = NULL; + } + if (ctx_omni->projector) + { + ctx_omni->projector->free(); + } + + llama_free(ctx_omni->ctx_llama); + llama_free_model(ctx_omni->model); + llama_backend_free(); +} + +static bool omni_eval_audio_embed(llama_context *ctx_llama, ggml_tensor *audio_embed, int n_batch, int *n_past) +{ + int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + + int n_audio_embed = audio_embed->ne[1]; + GGML_ASSERT(audio_embed->ne[0] == n_embd); + + size_t audio_embed_size = ggml_nbytes(audio_embed); + float *audio_embed_data = (float *)malloc(audio_embed_size); + ggml_backend_tensor_get(audio_embed, audio_embed_data, 0, audio_embed_size); + + for (int i = 0; i < n_audio_embed; i += n_batch) + { + int n_eval = n_audio_embed - i; + if (n_eval > n_batch) + { + n_eval = n_batch; + } + llama_batch batch = { + /* n_tokens */ int32_t(n_eval), + /* token */ nullptr, + /* embd */ (audio_embed_data + i * n_embd), + /* pos */ nullptr, + /* n_seq_id */ nullptr, + /* seq_id */ nullptr, + /* logits */ nullptr, + /* all_pos_0 */ *n_past, + /* all_pos_1 */ 1, + /* all_seq_id */ 0, + }; + if (llama_decode(ctx_llama, batch)) + { + LLAMA_LOG_ERROR("%s : failed to eval\n", __func__); + return false; + } + *n_past += n_eval; + } + return true; +} + +ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params ¶ms) +{ + auto fname_inp = params.whisper.fname_inp[0]; + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.whisper.diarize)) + { + LLAMA_LOG_ERROR("error: failed to read WAV file '%s'\n", fname_inp.c_str()); + return NULL; + } + + whisper_full_params wparams = get_whisper_inference_params_from_whisper_params(params.whisper); + + if (whisper_encode_wo_cross_parallel(ctx_omni->ctx_whisper, wparams, pcmf32.data(), pcmf32.size(), params.whisper.n_processors) != 0) + { + LLAMA_LOG_ERROR("%s: failed to process audio\n", __func__); + return NULL; + } + + ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx_omni->ctx_whisper); +#ifdef NEXA_DEBUG + print_ggml_tensor_shape("embd_enc", embd_enc); +#endif + + ggml_tensor *embed_proj = audio_projector_inference(*ctx_omni->projector, embd_enc); +#ifdef NEXA_DEBUG + print_ggml_tensor_shape("embed_proj", embed_proj); +#endif + + return embed_proj; +} + +void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt) +{ + int n_past = 0; + + int n_audio_embed = audio_embed->ne[1]; + GGML_ASSERT(params.gpt.n_predict < 0 || params.gpt.n_predict > n_audio_embed); + + const int max_tgt_len = params.gpt.n_predict < 0 ? 256 + n_audio_embed : params.gpt.n_predict; + std::string system_prompt, user_prompt; + size_t audio_pos = find_audio_token(prompt); + if (audio_pos != std::string::npos) + { + system_prompt = prompt.substr(0, audio_pos); + user_prompt = prompt.substr(audio_pos + std::string(AUDIO_TOKEN).length()); + // LLAMA_LOG_INFO("system_prompt: %s\n", system_prompt.c_str()); + // LLAMA_LOG_INFO("user_prompt: %s\n", user_prompt.c_str()); + } + else + { + system_prompt = "user\nAudio 1: <|audio_bos|>"; + user_prompt = "<|audio_eos|>\n" + prompt + "\nmodel\n"; + } + + eval_string(ctx_omni->ctx_llama, system_prompt.c_str(), params.gpt.n_batch, &n_past, true); + omni_eval_audio_embed(ctx_omni->ctx_llama, audio_embed, params.gpt.n_batch, &n_past); + eval_string(ctx_omni->ctx_llama, user_prompt.c_str(), params.gpt.n_batch, &n_past, false); + + // generate the response + + LOG("\n"); + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.gpt.sparams); + if (!ctx_sampling) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + + std::string response = ""; + for (int i = 0; i < max_tgt_len; i++) + { + const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past); + response += tmp; + if (strcmp(tmp, "") == 0) + break; + if (strstr(tmp, "###")) + break; // Yi-VL behavior + printf("%s", tmp); + if (strstr(response.c_str(), "<|im_end|>")) + break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) + if (strstr(response.c_str(), "<|im_start|>")) + break; // Yi-34B llava-1.6 + if (strstr(response.c_str(), "USER:")) + break; // mistral llava-1.6 + + fflush(stdout); + } + + llama_sampling_free(ctx_sampling); + printf("\n"); +} + +void omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms) +{ + omni_params all_params = get_omni_params_from_context_params(params); + + ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params); + omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt); +} \ No newline at end of file diff --git a/examples/nexa-omni-audio/omni.h b/examples/nexa-omni-audio/omni.h new file mode 100644 index 000000000..5cbbd52ed --- /dev/null +++ b/examples/nexa-omni-audio/omni.h @@ -0,0 +1,64 @@ +#pragma once + +#include "whisper.h" +#include "llama.h" +#include "grammar-parser.h" +#include "common.h" +#include "common-nexa.h" + +#include +#include + +#include "audio-projector.h" + +#ifdef OMNI_AUDIO_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef OMNI_AUDIO_BUILD +# define OMNI_AUDIO_API __declspec(dllexport) +# else +# define OMNI_AUDIO_API __declspec(dllimport) +# endif +# else +# define OMNI_AUDIO_API __attribute__ ((visibility ("default"))) +# endif +#else +# define OMNI_AUDIO_API +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct omni_context_params +{ + const char *model; + const char *mmproj; + const char *file; + const char *prompt; + int32_t n_gpu_layers; +}; + +struct omni_context +{ + struct whisper_context *ctx_whisper; + struct audio_projector *projector; + struct llama_context *ctx_llama; + struct llama_model *model; +}; + +OMNI_AUDIO_API bool omni_context_params_parse(int argc, char **argv, omni_context_params ¶ms); + +OMNI_AUDIO_API omni_context_params omni_context_default_params(); + +OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params ¶ms); + +OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni); + +OMNI_AUDIO_API void omni_process_full( + struct omni_context *ctx_omni, + omni_context_params ¶ms +); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/examples/nexa-omni-audio/whisper-mel-cuda.cu b/examples/nexa-omni-audio/whisper-mel-cuda.cu new file mode 100644 index 000000000..c9f94d379 --- /dev/null +++ b/examples/nexa-omni-audio/whisper-mel-cuda.cu @@ -0,0 +1,364 @@ +#define CUB_IGNORE_DEPRECATED_CPP_DIALECT +#include "whisper-mel-cuda.hpp" +#include "whisper.h" + +#include "common.cuh" +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4324) // added padding +#endif + +namespace { + +static const char* cufftGetErrorString(cufftResult_t res) { + switch (res) { + case CUFFT_SUCCESS: return "The cuFFT operation was successful"; + case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle"; + case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory"; + case CUFFT_INVALID_TYPE: return "No longer used"; + case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter"; + case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error"; + case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU"; + case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize"; + case CUFFT_INVALID_SIZE: return "User specified an invalid transform size"; + case CUFFT_UNALIGNED_DATA: return "No longer used"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call"; + case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation"; + case CUFFT_PARSE_ERROR: return "Internal plan database error"; + case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution"; + case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given."; + case CUFFT_LICENSE_ERROR: return "Used in previous versions."; + case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given."; + default: return "Unknown error"; + } +} + +#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString) + +__global__ void k_fill_stft_input( + const float * padded_samples, + const int n_frames, + const float * hann_window, + float * stft_in +) { + auto y = blockIdx.y * blockDim.y + threadIdx.y; + // if (y >= n_frames) return; + auto x = blockIdx.x * blockDim.x + threadIdx.x; + // if (x >= WHISPER_N_FFT) return; + + auto line = padded_samples + y * WHISPER_HOP_LENGTH; + auto outLine = stft_in + y * WHISPER_N_FFT; + + outLine[x] = line[x] * hann_window[x]; +} + +__global__ void k_calc_magnitudes( + const cuComplex * stft_out, + const int n_frames, + float * magnitudes +) { + auto y = blockIdx.y * blockDim.y + threadIdx.y; + // if (y >= n_frames) return; + auto x = blockIdx.x * blockDim.x + threadIdx.x; + // if (x >= WHISPER_N_FFT_HALF) return; + + auto idx = y * WHISPER_N_FFT_HALF + x; + + auto r = stft_out[idx].x; + auto i = stft_out[idx].y; + magnitudes[idx] = r * r + i * i; +} + +__global__ void k_calc_log_mel( + const float * mel_data, + const int n_mel, + const float * max_val, + float * log_mel +) { + auto x = blockIdx.x * blockDim.x + threadIdx.x; + if (x >= n_mel) return; + + float val = mel_data[x]; + + constexpr float e = 1e-10f; + if (val < e) val = e; + + val = log10(val); + + const float max = log10(*max_val) - 8.f; + if (val < max) val = max; + + log_mel[x] = (val + 4) / 4; +} + +static void fill_stft_input( + const float * padded_samples, + int n_frames, + const float * hann_window, + float * stft_in, + cudaStream_t stream +) { + dim3 block(WHISPER_N_FFT, 1); + dim3 grid(1, n_frames); + + k_fill_stft_input<<>>(padded_samples, n_frames, hann_window, stft_in); +} + +static void calc_magnitudes( + const cuComplex * stft_out, + int n_frames, + float * magnitudes, + cudaStream_t stream +) { + dim3 block(WHISPER_N_FFT_HALF, 1); + dim3 grid(1, n_frames); + k_calc_magnitudes<<>>(stft_out, n_frames, magnitudes); +} + +constexpr auto LOG_MEL_PREFIX_SIZE = 256; + +static void calc_log_mel( + const float * mel_data, + int n_mel, + void * tempStorage, + int tempStorageSize, + float * log_mel, + cudaStream_t stream +) { + float * max_val = reinterpret_cast(tempStorage); + void * maxTemp = reinterpret_cast(tempStorage) + LOG_MEL_PREFIX_SIZE; + + size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE); + cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream); + + int block = 256; + int grid = (n_mel + block - 1) / block; + + k_calc_log_mel<<>>(mel_data, n_mel, max_val, log_mel); +} + +class mel_calc_cuda : public whisper_mel_calc { + const int m_n_mel; + + ggml_backend_t m_backend = nullptr; + int m_device = -1; + + cudaStream_t m_stream = nullptr; + cublasHandle_t m_cublas_handle = nullptr; + + float * m_hann_window = nullptr; + + float * m_filters = nullptr; + + // max samples for which we have allocated memory for the temp working areas below (cufft, log_mel) + int m_n_max_samples = 0; + + size_t m_cufft_workspace_size = 0; + void * m_cufft_workspace = nullptr; + + size_t m_log_mel_temp_storage_size = 0; + void * m_log_mel_temp_storage = nullptr; +public: + mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters) + : m_n_mel(filters.n_mel) + , m_backend(backend) + { + ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*) m_backend->context; + m_device = cuda_ctx->device; + + if (ggml_cuda_info().devices[m_device].cc < 600) { + // we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0: + // https://github.com/ggerganov/whisper.cpp/issues/2230 + // to be safe forbid anything below 6.0 + throw std::runtime_error("CUDA compute capability 6.0 or higher is required"); + } + + ggml_cuda_set_device(m_device); + + if (filters.n_fft != WHISPER_N_FFT_HALF) { + throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF"); + } + assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF); + + CUDA_CHECK(cudaStreamCreate(&m_stream)); + CUBLAS_CHECK(cublasCreate(&m_cublas_handle)); + CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); + CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream)); + + // create Hann window + { + auto hw = whisper_mel_calc::hann_window(); + CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + } + + // fill filters + { + auto& f = filters.data; + CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + } + + // preallocate working areas enough for the most common cases (<= 30s) + ensure_working_areas(WHISPER_N_SAMPLES); + } + + ~mel_calc_cuda() { + ggml_cuda_set_device(m_device); + CUDA_CHECK(cudaStreamSynchronize(m_stream)); + CUDA_CHECK(cudaStreamDestroy(m_stream)); + CUDA_CHECK(cudaFree(m_hann_window)); + CUDA_CHECK(cudaFree(m_cufft_workspace)); + CUDA_CHECK(cudaFree(m_filters)); + CUDA_CHECK(cudaFree(m_log_mel_temp_storage)); + } + + void ensure_working_areas(int n_samples) { + if (n_samples <= m_n_max_samples) { + return; + } + + const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT; + const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + + // cufft workspace + { + if (m_cufft_workspace) { + CUDA_CHECK(cudaFree(m_cufft_workspace)); + m_cufft_workspace_size = 0; + m_cufft_workspace = nullptr; + } + CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size)); + CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream)); + } + + // device reduce working area + { + if (m_log_mel_temp_storage) { + CUDA_CHECK(cudaFree(m_log_mel_temp_storage)); + m_log_mel_temp_storage_size = 0; + m_log_mel_temp_storage = nullptr; + } + + const auto max_mels = 160; + + size_t nbytes = 0; + float* temp = nullptr; + cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels); + m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE; + + CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream)); + } + + m_n_max_samples = n_samples; + } + + virtual whisper_mel calculate(whisper_span samples, int /*n_threads*/) override { + ggml_cuda_set_device(m_device); + ensure_working_areas(samples.len); + + const size_t mirror_pad = WHISPER_N_FFT / 2; + const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT; + + // pad + std::vector padded_samples(padded_size); + std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect + std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy + + // fill the rest of the data + // it should canonically be mirrored at the end as well, + // but we just assume the last MEL_FRAME_SIZE/2 samples are zeros + std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f); + + const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + + float * cu_padded_samples = nullptr; + CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + + float * stft_in = nullptr; // contiguous buffer for stft input + CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream)); + + fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream); + + cufftComplex* stft_out; + CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream)); + + cufftHandle plan; + CUFFT_CHECK(cufftCreate(&plan)); + CUFFT_CHECK(cufftSetAutoAllocation(plan, 0)); + { + size_t waSize; + CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize)); + assert(waSize <= m_cufft_workspace_size); + CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace)); + CUFFT_CHECK(cufftSetStream(plan, m_stream)); + } + CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out)); + + const auto n_mag_frames = n_frames - 1; // drop last frame + float * magnitudes; + CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream)); + calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream); + + float * mel_data = nullptr; + CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream)); + + const float fone = 1.0f, fzero = 0.0f; + CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, + int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF, + &fone, + magnitudes, WHISPER_N_FFT_HALF, + m_filters, WHISPER_N_FFT_HALF, + &fzero, + mel_data, int(n_mag_frames))); + + whisper_mel ret; + // Calculate semi-padded sample length to ensure compatibility + int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel); + assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float)); + + float* log_mels = reinterpret_cast(ret.tensor->data); + + calc_log_mel( + mel_data, int(m_n_mel * n_mag_frames), + m_log_mel_temp_storage , int(m_log_mel_temp_storage_size), + log_mels, m_stream); + + CUDA_CHECK(cudaStreamSynchronize(m_stream)); + + // cleanup + CUFFT_CHECK(cufftDestroy(plan)); + CUDA_CHECK(cudaFreeAsync(mel_data, m_stream)); + CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream)); + CUDA_CHECK(cudaFreeAsync(stft_out, m_stream)); + CUDA_CHECK(cudaFreeAsync(stft_in, m_stream)); + CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream)); + + return ret; + } +}; + +} + +whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) { + try { + return new mel_calc_cuda(backend, filters); + } + catch (...) { + // TODO: log error (but for this we would have to expose the log state to be accessible here) + return nullptr; + } +} diff --git a/examples/nexa-omni-audio/whisper-mel-cuda.hpp b/examples/nexa-omni-audio/whisper-mel-cuda.hpp new file mode 100644 index 000000000..2acb6505f --- /dev/null +++ b/examples/nexa-omni-audio/whisper-mel-cuda.hpp @@ -0,0 +1,3 @@ +#include "whisper-mel.hpp" + +whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters); diff --git a/examples/nexa-omni-audio/whisper-mel.hpp b/examples/nexa-omni-audio/whisper-mel.hpp new file mode 100644 index 000000000..f4210b41a --- /dev/null +++ b/examples/nexa-omni-audio/whisper-mel.hpp @@ -0,0 +1,34 @@ +#pragma once +#include "ggml-backend.h" +#include + +struct whisper_mel { + int n_len_org = 0; + + ggml_context * ctx = nullptr; + ggml_tensor * tensor = nullptr; + ggml_backend_buffer_t buffer = nullptr; +}; + +void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel); + +void whisper_mel_free(whisper_mel & mel); + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +template +struct whisper_span { + T * data; + int len; +}; + +struct whisper_mel_calc { + virtual ~whisper_mel_calc(); + virtual whisper_mel calculate(whisper_span samples, int n_threads) = 0; + static whisper_span hann_window(); +}; diff --git a/examples/nexa-omni-audio/whisper.cpp b/examples/nexa-omni-audio/whisper.cpp new file mode 100644 index 000000000..7db2c24ea --- /dev/null +++ b/examples/nexa-omni-audio/whisper.cpp @@ -0,0 +1,10034 @@ +#include "whisper.h" + +#ifdef WHISPER_USE_COREML +#include "coreml/whisper-encoder.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#include "whisper-mel-cuda.hpp" +#endif + +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#ifdef GGML_USE_BLAS +#include "ggml-blas.h" +#endif + +#ifdef WHISPER_USE_OPENVINO +#include "openvino/whisper-openvino-encoder.h" +#endif + +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include "whisper-mel.hpp" + +#include "common-nexa.h" + +#include +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// third-party utilities +// use your favorite implementations +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4267) // possible loss of data +#endif + +#if defined(GGML_BIG_ENDIAN) +#include +#include + +template +static T byteswap(T value) +{ + return std::byteswap(value); +} + +template <> +float byteswap(float value) +{ + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(ggml_tensor *tensor) +{ + T *datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) + { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor *tensor) +{ + switch (tensor->type) + { + case GGML_TYPE_I16: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: + { + byteswap_tensor_data(tensor); + break; + } + default: + { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do \ + { \ + for (auto &datum : f.data) \ + { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do \ + { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) \ + do \ + { \ + } while (0) +#define BYTESWAP_FILTERS(f) \ + do \ + { \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do \ + { \ + } while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +WHISPER_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal(ggml_log_level level, const char *format, ...); +static void whisper_log_callback_default(ggml_log_level level, const char *text, void *user_data); + +#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +// #define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_LOG_DEBUG(...) whisper_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define WHISPER_LOG_DEBUG(...) +#endif + +#define WHISPER_ASSERT(x) \ + do \ + { \ + if (!(x)) \ + { \ + WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// #define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 8 +#define WHISPER_MAX_NODES 4096 + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph *graph, + std::vector &buf, + int n_threads, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + plan.abort_callback = abort_callback; + plan.abort_callback_data = abort_callback_data; + + if (plan.work_size > 0) + { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + return ggml_graph_compute(graph, &plan); +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph *graph, + int n_threads) +{ + + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) + { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + if (ggml_backend_is_cpu(backend)) + { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } +#ifdef GGML_USE_BLAS + if (ggml_backend_is_blas(backend)) + { + ggml_backend_blas_set_n_threads(backend, n_threads); + } +#endif +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(backend)) + { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + } + + bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; + ggml_backend_sched_reset(sched); + return t; +} + +// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" +// the idea is to represent the original matrix multiplication: +// +// Z = X @ Y +// +// with the sum of two matrix multiplications: +// +// Z = (X_0 @ Y_0) + (X_1 @ Y_1) +// +// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad" +// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more +// general-purpose kernels +// +static struct ggml_tensor *ggml_mul_mat_pad(struct ggml_context *ctx, struct ggml_tensor *x, struct ggml_tensor *y, int pad = 32) +{ + // use padding only if dimension 0 is at least 8 times larger than the padding + // else we won't get much benefit from the optimization + const int n_pad_req = 8; + + if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) + { + return ggml_mul_mat(ctx, x, y); + } + + struct ggml_tensor *x_0 = ggml_view_3d(ctx, x, (x->ne[0] / pad) * pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0); + struct ggml_tensor *x_1 = ggml_view_3d(ctx, x, x->ne[0] % pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0] * x_0->nb[0]); + + struct ggml_tensor *y_0 = ggml_view_3d(ctx, y, (y->ne[0] / pad) * pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0); + struct ggml_tensor *y_1 = ggml_view_3d(ctx, y, y->ne[0] % pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0] * y_0->nb[0]); + + return ggml_add(ctx, + ggml_mul_mat(ctx, x_0, y_0), + ggml_mul_mat(ctx, x_1, y_1)); +} + +// TODO: check if other platforms can benefit from this optimization +// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly +#if defined(GGML_USE_METAL) +#define ggml_mul_mat ggml_mul_mat_pad +#endif + +// available whisper models +enum e_model +{ + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map g_model_name = { + {MODEL_UNKNOWN, "unknown"}, + {MODEL_TINY, "tiny"}, + {MODEL_BASE, "base"}, + {MODEL_SMALL, "small"}, + {MODEL_MEDIUM, "medium"}, + {MODEL_LARGE, "large"}, +}; + +static const std::map> g_lang = { + {"en", { + 0, + "english", + }}, + {"zh", { + 1, + "chinese", + }}, + {"de", { + 2, + "german", + }}, + {"es", { + 3, + "spanish", + }}, + {"ru", { + 4, + "russian", + }}, + {"ko", { + 5, + "korean", + }}, + {"fr", { + 6, + "french", + }}, + {"ja", { + 7, + "japanese", + }}, + {"pt", { + 8, + "portuguese", + }}, + {"tr", { + 9, + "turkish", + }}, + {"pl", { + 10, + "polish", + }}, + {"ca", { + 11, + "catalan", + }}, + {"nl", { + 12, + "dutch", + }}, + {"ar", { + 13, + "arabic", + }}, + {"sv", { + 14, + "swedish", + }}, + {"it", { + 15, + "italian", + }}, + {"id", { + 16, + "indonesian", + }}, + {"hi", { + 17, + "hindi", + }}, + {"fi", { + 18, + "finnish", + }}, + {"vi", { + 19, + "vietnamese", + }}, + {"he", { + 20, + "hebrew", + }}, + {"uk", { + 21, + "ukrainian", + }}, + {"el", { + 22, + "greek", + }}, + {"ms", { + 23, + "malay", + }}, + {"cs", { + 24, + "czech", + }}, + {"ro", { + 25, + "romanian", + }}, + {"da", { + 26, + "danish", + }}, + {"hu", { + 27, + "hungarian", + }}, + {"ta", { + 28, + "tamil", + }}, + {"no", { + 29, + "norwegian", + }}, + {"th", { + 30, + "thai", + }}, + {"ur", { + 31, + "urdu", + }}, + {"hr", { + 32, + "croatian", + }}, + {"bg", { + 33, + "bulgarian", + }}, + {"lt", { + 34, + "lithuanian", + }}, + {"la", { + 35, + "latin", + }}, + {"mi", { + 36, + "maori", + }}, + {"ml", { + 37, + "malayalam", + }}, + {"cy", { + 38, + "welsh", + }}, + {"sk", { + 39, + "slovak", + }}, + {"te", { + 40, + "telugu", + }}, + {"fa", { + 41, + "persian", + }}, + {"lv", { + 42, + "latvian", + }}, + {"bn", { + 43, + "bengali", + }}, + {"sr", { + 44, + "serbian", + }}, + {"az", { + 45, + "azerbaijani", + }}, + {"sl", { + 46, + "slovenian", + }}, + {"kn", { + 47, + "kannada", + }}, + {"et", { + 48, + "estonian", + }}, + {"mk", { + 49, + "macedonian", + }}, + {"br", { + 50, + "breton", + }}, + {"eu", { + 51, + "basque", + }}, + {"is", { + 52, + "icelandic", + }}, + {"hy", { + 53, + "armenian", + }}, + {"ne", { + 54, + "nepali", + }}, + {"mn", { + 55, + "mongolian", + }}, + {"bs", { + 56, + "bosnian", + }}, + {"kk", { + 57, + "kazakh", + }}, + {"sq", { + 58, + "albanian", + }}, + {"sw", { + 59, + "swahili", + }}, + {"gl", { + 60, + "galician", + }}, + {"mr", { + 61, + "marathi", + }}, + {"pa", { + 62, + "punjabi", + }}, + {"si", { + 63, + "sinhala", + }}, + {"km", { + 64, + "khmer", + }}, + {"sn", { + 65, + "shona", + }}, + {"yo", { + 66, + "yoruba", + }}, + {"so", { + 67, + "somali", + }}, + {"af", { + 68, + "afrikaans", + }}, + {"oc", { + 69, + "occitan", + }}, + {"ka", { + 70, + "georgian", + }}, + {"be", { + 71, + "belarusian", + }}, + {"tg", { + 72, + "tajik", + }}, + {"sd", { + 73, + "sindhi", + }}, + {"gu", { + 74, + "gujarati", + }}, + {"am", { + 75, + "amharic", + }}, + {"yi", { + 76, + "yiddish", + }}, + {"lo", { + 77, + "lao", + }}, + {"uz", { + 78, + "uzbek", + }}, + {"fo", { + 79, + "faroese", + }}, + {"ht", { + 80, + "haitian creole", + }}, + {"ps", { + 81, + "pashto", + }}, + {"tk", { + 82, + "turkmen", + }}, + {"nn", { + 83, + "nynorsk", + }}, + {"mt", { + 84, + "maltese", + }}, + {"sa", { + 85, + "sanskrit", + }}, + {"lb", { + 86, + "luxembourgish", + }}, + {"my", { + 87, + "myanmar", + }}, + {"bo", { + 88, + "tibetan", + }}, + {"tl", { + 89, + "tagalog", + }}, + {"mg", { + 90, + "malagasy", + }}, + {"as", { + 91, + "assamese", + }}, + {"tt", { + 92, + "tatar", + }}, + {"haw", { + 93, + "hawaiian", + }}, + {"ln", { + 94, + "lingala", + }}, + {"ha", { + 95, + "hausa", + }}, + {"ba", { + 96, + "bashkir", + }}, + {"jw", { + 97, + "javanese", + }}, + {"su", { + 98, + "sundanese", + }}, + {"yue", { + 99, + "cantonese", + }}, +}; + +// [EXPERIMENTAL] Token-level timestamps with DTW +static const whisper_ahead g_aheads_tiny_en[] = {{1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4}}; +static const whisper_ahead g_aheads_tiny[] = {{2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5}}; +static const whisper_ahead g_aheads_base_en[] = {{3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7}}; +static const whisper_ahead g_aheads_base[] = {{3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6}}; +static const whisper_ahead g_aheads_small_en[] = {{6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4}}; +static const whisper_ahead g_aheads_small[] = {{5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5}}; +static const whisper_ahead g_aheads_medium_en[] = {{11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12}}; +static const whisper_ahead g_aheads_medium[] = {{13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4}}; +static const whisper_ahead g_aheads_large_v1[] = {{9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15}}; +static const whisper_ahead g_aheads_large_v2[] = {{10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15}}; +static const whisper_ahead g_aheads_large_v3[] = {{7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6}}; + +static const std::map g_aheads{ + {WHISPER_AHEADS_TINY_EN, {8, g_aheads_tiny_en}}, + {WHISPER_AHEADS_TINY, {6, g_aheads_tiny}}, + {WHISPER_AHEADS_BASE_EN, {5, g_aheads_base_en}}, + {WHISPER_AHEADS_BASE, {8, g_aheads_base}}, + {WHISPER_AHEADS_SMALL_EN, {19, g_aheads_small_en}}, + {WHISPER_AHEADS_SMALL, {10, g_aheads_small}}, + {WHISPER_AHEADS_MEDIUM_EN, {18, g_aheads_medium_en}}, + {WHISPER_AHEADS_MEDIUM, {6, g_aheads_medium}}, + {WHISPER_AHEADS_LARGE_V1, {9, g_aheads_large_v1}}, + {WHISPER_AHEADS_LARGE_V2, {23, g_aheads_large_v2}}, + {WHISPER_AHEADS_LARGE_V3, {10, g_aheads_large_v3}}, +}; + +static std::vector get_alignment_heads_by_layer(const whisper_context_params &cparams, int il, int32_t n_text_layer, int32_t n_head); + +struct whisper_vocab +{ + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map token_to_id; + std::map id_to_token; + + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps + + bool is_multilingual() const + { + return n_vocab >= 51865; + } + + int num_languages() const + { + return n_vocab - 51765 - (is_multilingual() ? 1 : 0); + } +}; + +struct whisper_segment +{ + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; + + bool speaker_turn_next; +}; + +struct whisper_batch +{ + int32_t n_tokens; + + whisper_token *token; + whisper_pos *pos; + int32_t *n_seq_id; // always 1, here for consistency with llama.cpp + whisper_seq_id **seq_id; // null terminated + int8_t *logits; +}; + +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) +{ + whisper_batch batch = { + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + }; + + batch.token = (whisper_token *)malloc(sizeof(whisper_token) * (n_tokens)); + batch.pos = (whisper_pos *)malloc(sizeof(whisper_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *)malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (whisper_seq_id **)malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) + { + batch.seq_id[i] = (whisper_seq_id *)malloc(sizeof(whisper_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *)malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void whisper_batch_free(struct whisper_batch batch) +{ + if (batch.token) + free(batch.token); + if (batch.pos) + free(batch.pos); + if (batch.n_seq_id) + free(batch.n_seq_id); + if (batch.seq_id) + { + for (int i = 0; batch.seq_id[i]; ++i) + { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) + free(batch.logits); +} + +static void whisper_batch_prep_legacy(whisper_batch &batch, const whisper_token *tokens, int n_tokens, int n_past, int seq_id) +{ + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) + { + if (tokens) + { + batch.token[i] = tokens[i]; + } + batch.pos[i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair +{ + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A &a, const B &b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// ggml_backend_sched wrapper for whisper usage +struct whisper_sched +{ + ggml_backend_sched_t sched = nullptr; + + std::vector meta; +}; + +static size_t whisper_sched_size(struct whisper_sched &allocr) +{ + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) + { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static bool whisper_sched_graph_init(struct whisper_sched &allocr, std::vector backends, std::function &&get_graph) +{ + auto &sched = allocr.sched; + auto &meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false); + + meta.resize(ggml_tensor_overhead() * WHISPER_MAX_NODES + ggml_graph_overhead()); + + // since there are dependencies between the different graphs, + // we need to allocate them instead of only reserving to get the correct compute buffer size + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) + { + // failed to allocate the compute buffer + WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams +{ + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t ftype = 1; + float eps = 1e-5f; +}; + +// audio encoding layer +struct whisper_layer_encoder +{ + // encoder.blocks.*.attn_ln + struct ggml_tensor *attn_ln_0_w; + struct ggml_tensor *attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct ggml_tensor *attn_ln_1_w; + struct ggml_tensor *attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct ggml_tensor *attn_q_w; + struct ggml_tensor *attn_q_b; + + // encoder.blocks.*.attn.key + struct ggml_tensor *attn_k_w; + + // encoder.blocks.*.attn.value + struct ggml_tensor *attn_v_w; + struct ggml_tensor *attn_v_b; + + // encoder.blocks.*.mlp_ln + struct ggml_tensor *mlp_ln_w; + struct ggml_tensor *mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct ggml_tensor *mlp_0_w; + struct ggml_tensor *mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct ggml_tensor *mlp_1_w; + struct ggml_tensor *mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder +{ + // decoder.blocks.*.attn_ln + struct ggml_tensor *attn_ln_0_w; + struct ggml_tensor *attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct ggml_tensor *attn_ln_1_w; + struct ggml_tensor *attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct ggml_tensor *attn_q_w; + struct ggml_tensor *attn_q_b; + + // decoder.blocks.*.attn.key + struct ggml_tensor *attn_k_w; + + // decoder.blocks.*.attn.value + struct ggml_tensor *attn_v_w; + struct ggml_tensor *attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct ggml_tensor *cross_attn_ln_0_w; + struct ggml_tensor *cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct ggml_tensor *cross_attn_ln_1_w; + struct ggml_tensor *cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct ggml_tensor *cross_attn_q_w; + struct ggml_tensor *cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct ggml_tensor *cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct ggml_tensor *cross_attn_v_w; + struct ggml_tensor *cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct ggml_tensor *mlp_ln_w; + struct ggml_tensor *mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct ggml_tensor *mlp_0_w; + struct ggml_tensor *mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct ggml_tensor *mlp_1_w; + struct ggml_tensor *mlp_1_b; +}; + +struct whisper_kv_cell +{ + whisper_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const whisper_seq_id &id) const + { + return seq_id.find(id) != seq_id.end(); + } +}; + +struct whisper_kv_cache +{ + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + struct ggml_tensor *k; + struct ggml_tensor *v; + + struct ggml_context *ctx = nullptr; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct whisper_model +{ + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct ggml_tensor *e_pe; + + // encoder.conv1 + struct ggml_tensor *e_conv_1_w; + struct ggml_tensor *e_conv_1_b; + + // encoder.conv2 + struct ggml_tensor *e_conv_2_w; + struct ggml_tensor *e_conv_2_b; + + // encoder.ln_post + struct ggml_tensor *e_ln_w; + struct ggml_tensor *e_ln_b; + + // decoder.positional_embedding + struct ggml_tensor *d_pe; + + // decoder.token_embedding + struct ggml_tensor *d_te; + + // decoder.ln + struct ggml_tensor *d_ln_w; + struct ggml_tensor *d_ln_b; + + std::vector layers_encoder; + std::vector layers_decoder; + + // ggml context that contains all the meta information about the model tensors + struct ggml_context *ctx = nullptr; + + // the model backend data is read-only and can be shared between processors + ggml_backend_buffer_t buffer = nullptr; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_partial_utf8 +{ + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct whisper_grammar +{ + /*const*/ std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_grammar_candidate +{ + whisper_token id; + const uint32_t *code_points; + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_sequence +{ + std::vector tokens; + + // the accumulated transcription in the current iteration (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder +{ + // the currently generated sequence of tokens + whisper_sequence sequence; + + // grammar parse state of generated sequence of tokens + whisper_grammar grammar; + + int i_batch; // the index of the token in the current batch + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 +}; + +// [EXPERIMENTAL] Token-level timestamps with DTW +struct whisper_aheads_masks +{ + std::vector m; // One mask per text layer. + struct ggml_context *ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; +}; + +struct whisper_state +{ + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; + int64_t t_prompt_us = 0; + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + // unified self-attention KV cache for all decoders + whisper_kv_cache kv_self; + + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; + + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + + whisper_mel mel; + whisper_mel_calc *mel_calc = nullptr; + whisper_mel_calc *mel_calc_fallback = nullptr; + + whisper_batch batch; + + whisper_decoder decoders[WHISPER_MAX_DECODERS]; + + std::vector backends; + + // - stores meta info about the intermediate tensors into the `meta` buffers + whisper_sched sched_conv; + whisper_sched sched_encode; + whisper_sched sched_cross; + whisper_sched sched_decode; + + // result of the encoder + struct ggml_tensor *embd_conv = nullptr; + struct ggml_tensor *embd_enc = nullptr; + + // helpers for GPU offloading + std::vector inp_mask; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + + std::vector result_all; + std::vector prompt_past; + + int lang_id = 0; // english by default + + std::string path_model; // populated by whisper_init_from_file_with_params() + +#ifdef WHISPER_USE_COREML + whisper_coreml_context *ctx_coreml = nullptr; +#endif + +#ifdef WHISPER_USE_OPENVINO + whisper_openvino_context *ctx_openvino = nullptr; +#endif + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg = 0; + int64_t t_last = 0; + + whisper_token tid_last; + + std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] Token-level timestamps with DTW + whisper_aheads_masks aheads_masks; + ggml_tensor *aheads_cross_QKs = nullptr; + std::vector aheads_cross_QKs_data; + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx = 0; // 0 - use default +}; + +struct whisper_context +{ + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) + ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16) + + whisper_context_params params; + + whisper_model model; + whisper_vocab vocab; + + whisper_state *state = nullptr; + + std::string path_model; // populated by whisper_init_from_file_with_params() +}; + +struct whisper_global +{ + // We save the log callback globally + ggml_log_callback log_callback = whisper_log_callback_default; + void *log_callback_user_data = nullptr; +}; + +static whisper_global g_state; + +template +static void read_safe(whisper_model_loader *loader, T &dest) +{ + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool whisper_kv_cache_init( + struct whisper_kv_cache &cache, + ggml_backend_t backend, + ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, + int n_ctx) +{ + const int64_t n_mem = n_text_layer * n_ctx; + const int64_t n_elements = n_text_state * n_mem; + + struct ggml_init_params params = { + /*.mem_size =*/2 * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend); + if (!cache.buffer) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__); + return false; + } + + ggml_backend_buffer_clear(cache.buffer, 0); + + return true; +} + +static void whisper_kv_cache_free(struct whisper_kv_cache &cache) +{ + ggml_free(cache.ctx); + ggml_backend_buffer_free(cache.buffer); + cache.ctx = nullptr; +} + +static bool whisper_kv_cache_find_slot( + struct whisper_kv_cache &cache, + const struct whisper_batch &batch) +{ + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) + { + WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) + { + if (cache.head + n_tokens > n_ctx) + { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) + { + if (cache.cells[cache.head + i].pos >= 0) + { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) + { + break; + } + + if (n_tested >= n_ctx) + { + // WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) + { + cache.cells[cache.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) + { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + return true; +} + +// find how many cells are currently in use +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache &cache) +{ + for (uint32_t i = cache.size - 1; i > 0; --i) + { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) + { + return i + 1; + } + } + + return 1; +} + +static void whisper_kv_cache_clear(struct whisper_kv_cache &cache) +{ + for (int32_t i = 0; i < (int32_t)cache.size; ++i) + { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + cache.head = 0; +} + +static void whisper_kv_cache_seq_rm( + struct whisper_kv_cache &cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) +{ + uint32_t new_head = cache.size; + + if (p0 < 0) + p0 = 0; + if (p1 < 0) + p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) + { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) + { + if (seq_id < 0) + { + cache.cells[i].seq_id.clear(); + } + else if (cache.cells[i].has_seq_id(seq_id)) + { + cache.cells[i].seq_id.erase(seq_id); + } + else + { + continue; + } + if (cache.cells[i].seq_id.empty()) + { + cache.cells[i].pos = -1; + if (new_head == cache.size) + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) + cache.head = new_head; +} + +static void whisper_kv_cache_seq_cp( + struct whisper_kv_cache &cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) +{ + if (p0 < 0) + p0 = 0; + if (p1 < 0) + p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) + { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) + { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context &wctx) +{ + if (!wctx.params.flash_attn || !wctx.params.use_gpu) + { + return 1u; + } + +#ifdef GGML_USE_METAL + if (wctx.params.use_gpu) + { + return 32u; + } +#endif + +#ifdef GGML_USE_CUDA + if (wctx.params.use_gpu) + { + return 256u; + } +#endif + + return 1u; +} + +// [EXPERIMENTAL] Token-level timestamps with DTW +static bool aheads_masks_init( + const whisper_context_params &cparams, + const whisper_hparams &hparams, + struct whisper_aheads_masks &aheads_masks, + ggml_backend_t backend) +{ + + const int32_t n_text_layer = hparams.n_text_layer; + const int32_t n_head = hparams.n_text_head; + + // Sanity checks + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { + WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__); + return false; + } + else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) + { + if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) + { + WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer); + return false; + } + } + else + { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) + { + if (aheads.n_heads == 0) + { + WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__); + return false; + } + if (aheads.heads == NULL) + { + WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__); + return false; + } + } + for (size_t i = 0; i < aheads.n_heads; ++i) + { + if (aheads.heads[i].n_text_layer >= n_text_layer) + { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer); + return false; + } + if (aheads.heads[i].n_text_layer < 0) + { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__); + return false; + } + if (aheads.heads[i].n_head >= n_head) + { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head); + return false; + } + if (aheads.heads[i].n_head < 0) + { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__); + return false; + } + } + } + + struct ggml_init_params params = { + /*.mem_size =*/(size_t) static_cast(n_text_layer) * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + + aheads_masks.ctx = ggml_init(params); + + if (!aheads_masks.ctx) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__); + return false; + } + + for (int64_t il = 0; il < n_text_layer; ++il) + { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + if (!aheads.empty()) + { + aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size())); + } + else + { + aheads_masks.m.push_back(nullptr); + } + } + + aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend); + if (!aheads_masks.buffer) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__); + return false; + } + + // Set data on mask tensors + // Since this must be backend agnostic, we write our desired values on mask_data, + // and send it to backend with ggml_backend_tensor_set. + // Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment + // heads. Each row of the mask "marks" one alignment head. E.g. if some text layer + // has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask + // should read: + // 1 0 0 0 0 0 0 0 0 0 + // 0 0 0 0 0 1 0 0 0 0 + // 0 0 0 0 0 0 1 0 0 0 + std::vector mask_data; + for (int64_t il = 0; il < n_text_layer; ++il) + { + if (aheads_masks.m[il] != nullptr) + { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + + size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1]; + size_t data_size_bytes = data_size * sizeof(float); + mask_data.resize(data_size); + + std::fill(mask_data.begin(), mask_data.end(), 0); + for (size_t ih = 0; ih < aheads.size(); ++ih) + { + size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0])); + mask_data[pos] = 1.0f; + } + + ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes); + } + } + + if (aheads_masks.m.empty()) + { + WHISPER_LOG_ERROR("%s: \n", __func__); + return false; + } + + return true; +} + +static void aheads_masks_free(struct whisper_aheads_masks &aheads_masks) +{ + ggml_free(aheads_masks.ctx); + ggml_backend_buffer_free(aheads_masks.buffer); + aheads_masks.ctx = nullptr; +} + +static size_t aheads_masks_nbytes(struct whisper_aheads_masks &aheads_masks) +{ + size_t size = 0; + for (size_t i = 0; i < aheads_masks.m.size(); ++i) + { + if (aheads_masks.m[i] != nullptr) + size += ggml_nbytes(aheads_masks.m[i]); + } + return size; +} + +static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params ¶ms) +{ + ggml_backend_t result = NULL; + +#ifdef GGML_USE_CUDA + if (params.use_gpu) + { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + result = ggml_backend_cuda_init(params.gpu_device); + if (!result) + { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (params.use_gpu) + { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); + result = ggml_backend_metal_init(); + if (!result) + { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } + else if (!ggml_backend_metal_supports_family(result, 7)) + { + WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__); + ggml_backend_free(result); + result = NULL; + } + } +#endif + +#ifdef GGML_USE_SYCL + if (params.use_gpu) + { + WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__); + result = ggml_backend_sycl_init(params.gpu_device); + if (!result) + { + WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_VULKAN + if (params.use_gpu) + { + WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__); + result = ggml_backend_vk_init(params.gpu_device); + if (!result) + { + WHISPER_LOG_ERROR("%s: ggml_backend_vk_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_CANN + if (params.use_gpu) + { + WHISPER_LOG_INFO("%s: using CANN backend\n", __func__); + result = ggml_backend_cann_init(params.gpu_device); + if (!result) + { + WHISPER_LOG_ERROR("%s: ggml_backend_cann_init() failed\n", __func__); + } + } +#endif + + GGML_UNUSED(params); + + return result; +} + +static std::vector whisper_backend_init(const whisper_context_params ¶ms) +{ + std::vector result; + + ggml_backend_t backend_gpu = whisper_backend_init_gpu(params); + + if (backend_gpu) + { + result.push_back(backend_gpu); + } + +#ifdef GGML_USE_BLAS + { + WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__); + ggml_backend_t backend_blas = ggml_backend_blas_init(); + if (!backend_blas) + { + WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__); + } + else + { + result.push_back(backend_blas); + } + } +#endif + + GGML_UNUSED(params); + + result.push_back(ggml_backend_cpu_init()); + + return result; +} + +static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params ¶ms) +{ + ggml_backend_buffer_type_t result = nullptr; + + params.use_gpu || (result = ggml_backend_cpu_buffer_type()); + +#ifdef GGML_USE_CUDA + result || (result = ggml_backend_cuda_buffer_type(params.gpu_device)); +#endif + +#ifdef GGML_USE_METAL + result || (result = ggml_backend_metal_buffer_type()); +#endif + +#ifdef GGML_USE_SYCL + result || (result = ggml_backend_sycl_buffer_type(params.gpu_device)); +#endif + +#ifdef GGML_USE_VULKAN + result || (result = ggml_backend_vk_buffer_type(params.gpu_device)); +#endif + +#ifdef GGML_USE_CANN + result || (result == ggml_backend_cann_buffer_type(params.gpu_device)); +#endif + + result || (result = ggml_backend_cpu_buffer_type()); + + return result; +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(struct whisper_model_loader *loader, whisper_context &wctx) +{ + WHISPER_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto &model = wctx.model; + auto &vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) + { + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + // load hparams + { + auto &hparams = model.hparams; + + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + + // hparams.n_text_layer = 0; + + assert(hparams.n_text_state == hparams.n_audio_state); + + std::string mver = ""; + + if (hparams.n_audio_layer == 4) + { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) + { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) + { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) + { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) + { + model.type = e_model::MODEL_LARGE; + + if (hparams.n_vocab == 51866) + { + mver = " v3"; + } + } + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype)); + if (wctx.wtype == GGML_TYPE_COUNT) + { + WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + return false; + } + + WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype); + WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); + } + + // load mel filters + { + auto &filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + // if (n_vocab != model.hparams.n_vocab) { + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + // } + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) + { + uint32_t len; + read_safe(loader, len); + + if (len > 0) + { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } + else + { + // seems like we have an empty-string token in multi-language models (i = 50256) + // WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) + { + vocab.token_eot++; + vocab.token_sot++; + + // account for variable number of language tokens + const int dt = vocab.num_languages() - 98; + + vocab.token_translate += dt; + vocab.token_transcribe += dt; + vocab.token_solm += dt; + vocab.token_prev += dt; + vocab.token_nosp += dt; + vocab.token_not += dt; + vocab.token_beg += dt; + } + + if (n_vocab < model.hparams.n_vocab) + { + WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) + { + if (i > vocab.token_beg) + { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } + else if (i == vocab.token_eot) + { + word = "[_EOT_]"; + } + else if (i == vocab.token_sot) + { + word = "[_SOT_]"; + } + else if (i == vocab.token_translate) + { + word = "[_TRANSLATE_]"; + } + else if (i == vocab.token_transcribe) + { + word = "[_TRANSCRIBE_]"; + } + else if (i == vocab.token_solm) + { + word = "[_SOLM_]"; + } + else if (i == vocab.token_prev) + { + word = "[_PREV_]"; + } + else if (i == vocab.token_nosp) + { + word = "[_NOSP_]"; + } + else if (i == vocab.token_not) + { + word = "[_NOT_]"; + } + else if (i == vocab.token_beg) + { + word = "[_BEG_]"; + } + else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) + { + word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; + } + else + { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages()); + } + + const ggml_type wtype = wctx.wtype; // ggml-medium.bin: GGML_TYPE_F16 + const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + + // GGML_TYPE_F16 + // e_conv_1_w, e_conv_2_w + // attn_ln_1_w, attn_q_w, attn_k_w, attn_v_w + // mlp_0_w, mlp_1_w + + // create the ggml context + { + const auto &hparams = model.hparams; + + const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; + + const size_t n_tensors = 10 /* input */ + 15 + 15 * n_audio_layer + 24 * n_text_layer; + + struct ggml_init_params params = { + /*.mem_size =*/n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) + { + WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare tensors for the weights + { + auto &ctx = model.ctx; + + const auto &hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) + { + auto &layer = model.layers_encoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4 * n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_audio_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4 * n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) + { + auto &layer = model.layers_decoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4 * n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_text_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4 * n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + // allocate tensors in the backend buffers + model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params)); + if (!model.buffer) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__); + return false; + } + + size_t size_main = ggml_backend_buffer_get_size(model.buffer); + WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6); + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + std::vector read_buf; + + while (true) + { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) + { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = {1, 1, 1, 1}; + for (int i = 0; i < n_dims; ++i) + { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + // skip the decoder weights + // if (name.find("decoder.") != std::string::npos) + // { + // size_t tensor_size = ggml_type_size(ggml_type(ttype)) * nelements; + // loader->seek(loader->context, tensor_size); // Skip tensor data + // WHISPER_LOG_INFO("%s: Skipping tensor: %s with shape [%d, %d, %d]\n", __func__, name.c_str(), ne[0], ne[1], ne[2]); + // continue; + // } + + // WHISPER_LOG_INFO("%s: Loading tensor: %s with shape [%d, %d, %d]\n", __func__, name.c_str(), ne[0], ne[1], ne[2]); + + if (model.tensors.find(name) == model.tensors.end()) + { + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + + if (ggml_nelements(tensor) != nelements) + { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) + { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) + { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); + return false; + } + + // ggml_backend_t backend = wctx.backend; + + // printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); + + if (ggml_backend_buffer_is_host(model.buffer)) + { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } + else + { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + // printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6); + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size / 1e6); + + if (model.n_loaded == 0) + { + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } + else if (model.n_loaded != (int)model.tensors.size()) + { + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +static bool whisper_encode_external(const whisper_state &wstate) +{ + GGML_UNUSED(wstate); + +#ifndef WHISPER_USE_COREML + const bool use_coreml = false; +#else + const bool use_coreml = wstate.ctx_coreml != nullptr; +#endif + +#ifndef WHISPER_USE_OPENVINO + const bool use_openvino = false; +#else + const bool use_openvino = wstate.ctx_openvino != nullptr; +#endif + + return use_coreml || use_openvino; +} + +static struct ggml_cgraph *whisper_build_graph_conv( + whisper_context &wctx, + whisper_state &wstate, + const int mel_offset) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + GGML_UNUSED(n_state); + + const int n_mels = hparams.n_mels; + + struct ggml_init_params params = { + /*.mem_size =*/wstate.sched_conv.meta.size(), + /*.mem_buffer =*/wstate.sched_conv.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph(ctx0); + + GGML_ASSERT(wstate.mel.tensor); + + ggml_tensor *mel_inp = wstate.mel.tensor; + ggml_set_input(mel_inp); + + ggml_tensor *mel; + if (ggml_nelements(mel_inp) > 0) + { + const int n_len = int(mel_inp->ne[0]); + const int out_s = 2 * n_ctx; + const int i0 = std::min(mel_offset, n_len); + const int i1 = std::min(mel_offset + out_s, n_len); + const int mel_s = i1 - i0; + + assert(mel_inp->type == GGML_TYPE_F32); + assert(mel_inp->ne[1] == n_mels); + + ggml_tensor *cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0)); + + if (mel_s < out_s) + { + mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0); + } + else + { + mel = ggml_cont(ctx0, cur); + } + } + else + { + // empty mel - just create a dummy tensor with the correct size + mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels); + } + + ggml_set_name(mel, "mel"); + + struct ggml_tensor *cur = nullptr; + + if (!whisper_encode_external(wstate)) + { + // convolution + gelu + { + cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); + cur = ggml_add(ctx0, cur, model.e_conv_1_b); + + cur = ggml_gelu(ctx0, cur); + + cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); + cur = ggml_add(ctx0, cur, model.e_conv_2_b); + + cur = ggml_gelu(ctx0, cur); + } + + ggml_set_name(cur, "embd_conv"); + wstate.embd_conv = cur; + } + else + { + ggml_build_forward_expand(gf, mel); + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + ggml_set_input(cur); // the external encoder will write into this tensor + + ggml_set_name(cur, "embd_enc"); + wstate.embd_enc = cur; + } + + ggml_set_output(cur); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph *whisper_build_graph_encoder( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_state_head = n_state / n_head; + + auto &kv_pad = wstate.kv_pad; + + // WHISPER_ASSERT(!!kv_pad.ctx); // only used in flash-attn, commented out for now + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + + struct ggml_init_params params = { + /*.mem_size =*/wstate.sched_encode.meta.size(), + /*.mem_buffer =*/wstate.sched_encode.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + + struct ggml_tensor *cur = ggml_view_tensor(ctx0, wstate.embd_conv); + + const float KQscale = 1.0f / sqrtf(float(n_state_head)); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + // static int iter = -1; + // const int n_iter = 1500/n_ctx; + + // iter = (iter + 1) % n_iter; + + // if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + // } + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0] * ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0] * ggml_element_size(model.e_pe) * n_ctx * iter; + + struct ggml_tensor *e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + + // =================================================================== + + // original: + // cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor *inpL = cur; + + for (int il = 0; il < n_layer; ++il) + { + const auto &layer = model.layers_encoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); + + // Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); + + // note: no bias for Key + struct ggml_tensor *Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + // Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); + + struct ggml_tensor *Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b); + + // ------ + + struct ggml_tensor *Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + if (wctx.params.flash_attn) + { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx * n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx * n_state, 0))); + + struct ggml_tensor *K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k) * n_state, + ggml_element_size(kv_pad.k) * n_state_head, + 0); + + struct ggml_tensor *V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v) * n_state, + ggml_element_size(kv_pad.v) * n_state_head, + 0); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } + else + { + struct ggml_tensor *K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor *V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)); + + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor *inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.mlp_ln_w), + layer.mlp_ln_b); + } + +#ifdef WHISPER_USE_FLASH_FF + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_1_b); +#endif + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.e_ln_w), + model.e_ln_b); + } + + ggml_build_forward_expand(gf, cur); + + wstate.embd_enc = cur; + + // ggml_graph_print(gf); + + //////////////////////////////////////////////////////////////////////////// + + // printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); + + ggml_free(ctx0); + + return gf; +} + +// pre-compute cross-attention memory +static struct ggml_cgraph *whisper_build_graph_cross( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + + const int n_state_head = n_state / n_head; + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + + struct ggml_init_params params = { + /*.mem_size =*/wstate.sched_cross.meta.size(), + /*.mem_buffer =*/wstate.sched_cross.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph(ctx0); + + struct ggml_tensor *cur = ggml_view_tensor(ctx0, wstate.embd_enc); + + const float Kscale = pow(float(n_state_head), -0.25); + + for (int il = 0; il < model.hparams.n_text_layer; ++il) + { + auto &layer = model.layers_decoder[il]; + + struct ggml_tensor *Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, Kscale); + + struct ggml_tensor *Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + Vcross, + layer.cross_attn_v_b); + + struct ggml_tensor *k; + struct ggml_tensor *v; + + if (wctx.params.flash_attn) + { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.k) * n_state) * (il * n_ctx_pad)); + + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.v) * n_state) * (il * n_ctx_pad)); + } + else + { + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.k) * n_state) * (il * n_ctx)); + + v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + (n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il * n_ctx) * ggml_element_size(wstate.kv_cross.v) * n_state); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); + } + + // ggml_graph_print(gf); + + ggml_free(ctx0); + + return gf; +} + +// evaluate the encoder with the given state +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - wctx: the model +// - wstate: the state of the encoder +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode_internal( + whisper_context &wctx, + whisper_state &wstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ + const int64_t t_start_us = ggml_time_us(); + + // conv + { + auto &sched = wstate.sched_conv.sched; + + ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + + if (whisper_encode_external(wstate)) + { + ggml_tensor *mel = ggml_graph_get_tensor(gf, "mel"); + assert(mel->ne[1] == wctx.model.hparams.n_mels); + GGML_UNUSED(mel); +#if defined(WHISPER_USE_COREML) + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *)mel->data, (float *)wstate.embd_enc->data); +#elif defined(WHISPER_USE_OPENVINO) + whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc); +#endif + } + } + + // encoder + if (!whisper_encode_external(wstate)) + { + auto &sched = wstate.sched_encode.sched; + + ggml_cgraph *gf = whisper_build_graph_encoder(wctx, wstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + } + + // cross + { + auto &sched = wstate.sched_cross.sched; + + ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + } + + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static struct ggml_cgraph *whisper_build_graph_decoder( + whisper_context &wctx, + whisper_state &wstate, + const whisper_batch &batch, + bool save_alignment_heads_QKs, + bool worst_case) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + auto &kv_self = wstate.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + const int n_ctx = kv_self.size; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int n_state_head = n_state / n_head; + + const int n_tokens = batch.n_tokens; + const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + + const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); + + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + + // WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); + + struct ggml_init_params params = { + /*.mem_size =*/wstate.sched_decode.meta.size(), + /*.mem_buffer =*/wstate.sched_decode.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + + struct ggml_tensor *embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(embd, "embd"); + ggml_set_input(embd); + + struct ggml_tensor *position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(position, "position"); + ggml_set_input(position); + + const float KQscale = pow(float(n_state_head), -0.25); + + struct ggml_tensor *KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_set_input(KQ_mask); + + struct ggml_tensor *KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + + // token encoding + position encoding + struct ggml_tensor *cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor *inpL = cur; + + // [EXPERIMENTAL] Token-level timestamps with DTW + struct ggml_tensor *aheads_cross_QKs = nullptr; + + for (int il = 0; il < n_layer; ++il) + { + const auto &layer = model.layers_decoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.attn_q_b); + + Qcur = ggml_scale(ctx0, Qcur, KQscale); + + // note: no bias for Key + struct ggml_tensor *Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctx0, Kcur, KQscale); + + // store key and value to memory + { + struct ggml_tensor *Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + Vcur, + layer.attn_v_b); + + struct ggml_tensor *k; + struct ggml_tensor *v; + + if (wctx.params.flash_attn) + { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens * n_state, + (ggml_element_size(kv_self.k) * n_state) * (il * n_ctx + kv_head)); + + v = ggml_view_1d(ctx0, kv_self.v, n_tokens * n_state, + (ggml_element_size(kv_self.v) * n_state) * (il * n_ctx + kv_head)); + } + else + { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); + + k = ggml_view_1d(ctx0, kv_self.k, n_tokens * n_state, + (ggml_element_size(kv_self.k) * n_state) * (il * n_ctx + kv_head)); + + v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, + (n_ctx)*ggml_element_size(kv_self.v), + (il * n_ctx) * ggml_element_size(kv_self.v) * n_state + kv_head * ggml_element_size(kv_self.v)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + struct ggml_tensor *Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); + + struct ggml_tensor *K = + ggml_view_3d(ctx0, kv_self.k, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.k) * n_state, + ggml_element_size(kv_self.k) * n_state_head, + ggml_element_size(kv_self.k) * n_state * n_ctx * il); + + if (wctx.params.flash_attn) + { + struct ggml_tensor *V = + ggml_view_3d(ctx0, kv_self.v, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v) * n_state, + ggml_element_size(kv_self.v) * n_state_head, + ggml_element_size(kv_self.v) * n_state * n_ctx * il); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } + else + { + // K * Q + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + + struct ggml_tensor *V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_state_head, n_head, + n_ctx * ggml_element_size(kv_self.v), + n_ctx * ggml_element_size(kv_self.v) * n_state_head, + n_ctx * ggml_element_size(kv_self.v) * n_state * il); + + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.attn_ln_1_b); + } + + // add the input + struct ggml_tensor *inpCA = ggml_add(ctx0, cur, inpL); + + // norm + { + cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.cross_attn_ln_0_w), + layer.cross_attn_ln_0_b); + } + + // cross-attention + { + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.cross_attn_q_b); + + struct ggml_tensor *Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); + + if (wctx.params.flash_attn) + { + struct ggml_tensor *Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k) * n_state, + ggml_element_size(wstate.kv_cross.k) * n_state_head, + ggml_element_size(wstate.kv_cross.k) * n_state * n_audio_ctx_pad * il); + + struct ggml_tensor *Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v) * n_state, + ggml_element_size(wstate.kv_cross.v) * n_state_head, + ggml_element_size(wstate.kv_cross.v) * n_state * n_audio_ctx_pad * il); + + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } + else + { + struct ggml_tensor *Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k) * n_state, + ggml_element_size(wstate.kv_cross.k) * n_state_head, + ggml_element_size(wstate.kv_cross.k) * n_state * n_audio_ctx * il); + + struct ggml_tensor *Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_audio_ctx, n_state_head, n_head, + n_audio_ctx * ggml_element_size(wstate.kv_cross.v), + n_audio_ctx * ggml_element_size(wstate.kv_cross.v) * n_state_head, + n_audio_ctx * ggml_element_size(wstate.kv_cross.v) * n_state * il); + + // ------ + + // K * Q + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, Kcross, Q); + + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) + { + if (wstate.aheads_masks.m[il] != nullptr) + { + struct ggml_tensor *aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) + { + aheads_cross_QKs = aheads_KQs; + } + else + { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2); + } + } + } + + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); + + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.cross_attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpCA); + + struct ggml_tensor *inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.mlp_ln_w), + layer.mlp_ln_b); + } + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_1_b); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + model.d_ln_w), + model.d_ln_b); + } + + // compute logits only for the last token + // comment this line to compute logits for all n_tokens + // might be useful in the future + // cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct ggml_tensor *logits = ggml_mul_mat(ctx0, model.d_te, cur); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) + { + aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs); + aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs); + if (save_alignment_heads_QKs) + { + ggml_build_forward_expand(gf, aheads_cross_QKs); + wstate.aheads_cross_QKs = aheads_cross_QKs; + } + } + + ggml_build_forward_expand(gf, logits); + + ggml_free(ctx0); + + return gf; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode_internal( + whisper_context &wctx, + whisper_state &wstate, + const whisper_batch &batch, + const int n_threads, + bool save_alignment_heads_QKs, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ + const int64_t t_start_us = ggml_time_us(); + + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + const int n_tokens = batch.n_tokens; + + auto &logits_out = wstate.logits; + + struct ggml_tensor *logits; + + // find KV slot for the batch + { + auto &kv_self = wstate.kv_self; + + if (!whisper_kv_cache_find_slot(kv_self, batch)) + { + return false; + } + + const uint32_t pad = whisper_kv_cache_get_padding(wctx); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); + + // kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + // printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); + } + + // decoder + { + auto &sched = wstate.sched_decode.sched; + + ggml_cgraph *gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor *embd = ggml_graph_get_tensor(gf, "embd"); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens * ggml_element_size(embd)); + } + + { + struct ggml_tensor *position = ggml_graph_get_tensor(gf, "position"); + for (int i = 0; i < n_tokens; ++i) + { + const int32_t val = batch.pos[i]; + ggml_backend_tensor_set(position, &val, i * sizeof(int32_t), sizeof(int32_t)); + } + } + + { + struct ggml_tensor *KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); + + auto &kv_self = wstate.kv_self; + + const int32_t n_kv = kv_self.n; + + wstate.inp_mask.resize(ggml_nelements(KQ_mask)); + + float *data = wstate.inp_mask.data(); + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) + { + for (int j = 0; j < n_tokens; ++j) + { + const whisper_pos pos = batch.pos[j]; + const whisper_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) + { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) + { + data[h * (n_kv * n_tokens) + j * n_kv + i] = -INFINITY; + } + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) + { + for (int j = 0; j < n_kv; ++j) + { + data[h * (n_kv * n_tokens) + i * n_kv + j] = -INFINITY; + } + } + } + + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask) * sizeof(float)); + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + } + + logits_out.resize(n_tokens * n_vocab); + for (int i = 0; i < n_tokens; i++) + { + if (batch.logits[i] == 0) + { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab * i), sizeof(float) * (n_vocab * i), sizeof(float) * n_vocab); + } + + if (batch.n_tokens > 1) + { + // printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); + } + + if (batch.n_tokens == 1) + { + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; + } + else if (batch.n_tokens < 16) + { + wstate.t_batchd_us += ggml_time_us() - t_start_us; + wstate.n_batchd += n_tokens; + } + else + { + wstate.t_prompt_us += ggml_time_us() - t_start_us; + wstate.n_prompt += n_tokens; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) +{ + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec); + + return std::string(buf); +} + +#define SIN_COS_N_COUNT WHISPER_N_FFT +namespace +{ + struct whisper_global_cache + { + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + float sin_vals[SIN_COS_N_COUNT]; + float cos_vals[SIN_COS_N_COUNT]; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + float hann_window[WHISPER_N_FFT]; + + whisper_global_cache() + { + fill_sin_cos_table(); + fill_hann_window(sizeof(hann_window) / sizeof(hann_window[0]), true, hann_window); + } + + void fill_sin_cos_table() + { + for (int i = 0; i < SIN_COS_N_COUNT; i++) + { + double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float *output) + { + int offset = -1; + if (periodic) + { + offset = 0; + } + for (int i = 0; i < length; i++) + { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } + } global_cache; +} + +// Mel spectrogram + +void whisper_mel_init(whisper_mel &mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) +{ + // WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel); + mel.n_len_org = n_len_org; + assert(!mel.ctx); + mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); + mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel); + mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend)); + auto alloc = ggml_tallocr_new(mel.buffer); + ggml_tallocr_alloc(&alloc, mel.tensor); +} + +void whisper_mel_free(whisper_mel &mel) +{ + ggml_free(mel.ctx); + ggml_backend_buffer_free(mel.buffer); + + mel.n_len_org = 0; + mel.ctx = nullptr; + mel.tensor = nullptr; + mel.buffer = nullptr; +} + +whisper_mel_calc::~whisper_mel_calc() = default; // export vtable + +whisper_span whisper_mel_calc::hann_window() +{ + return {global_cache.hann_window, WHISPER_N_FFT}; +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float *in, int N, float *out) +{ + const int sin_cos_step = SIN_COS_N_COUNT / N; + + for (int k = 0; k < N; k++) + { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) + { + int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N + re += in[n] * global_cache.cos_vals[idx]; // cos(t) + im -= in[n] * global_cache.sin_vals[idx]; // sin(t) + } + + out[k * 2 + 0] = re; + out[k * 2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float *in, int N, float *out) +{ + if (N == 1) + { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N * 2 == 1) + { + dft(in, N, out); + return; + } + + float *even = in + N; + for (int i = 0; i < half_N; ++i) + { + even[i] = in[2 * i]; + } + float *even_fft = out + 2 * N; + fft(even, half_N, even_fft); + + float *odd = even; + for (int i = 0; i < half_N; ++i) + { + odd[i] = in[2 * i + 1]; + } + float *odd_fft = even_fft + N; + fft(odd, half_N, odd_fft); + + const int sin_cos_step = SIN_COS_N_COUNT / N; + for (int k = 0; k < half_N; k++) + { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = global_cache.cos_vals[idx]; // cos(t) + float im = -global_cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; + + out[2 * k + 0] = even_fft[2 * k + 0] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + half_N) + 0] = even_fft[2 * k + 0] - re * re_odd + im * im_odd; + out[2 * (k + half_N) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } +} + +namespace +{ + + struct whisper_mel_data + { + int n_len; + int n_len_org; + int n_mel; + float *data; + }; + + void log_mel_spectrogram_worker_thread(int ith, const float *hann, const std::vector &samples, + int n_samples, int n_threads, + const whisper_filters &filters, whisper_mel_data &mel) + { + const auto frame_size = WHISPER_N_FFT; + const auto frame_step = WHISPER_HOP_LENGTH; + std::vector fft_in(frame_size * 2, 0.0); + std::vector fft_out(frame_size * 2 * 2 * 2); + int n_fft = filters.n_fft; + int i = ith; + + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + assert(n_fft == 1 + (frame_size / 2)); + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) + { + const int offset = i * frame_step; + + // apply Hann window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) + { + fft_in[j] = hann[j] * samples[offset + j]; + } + // fill the rest with zeros + if (n_samples - offset < frame_size) + { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + } + + // FFT + fft(fft_in.data(), frame_size, fft_out.data()); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fft; j++) + { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) + { + double sum = 0.0; + + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fft - 3; k += 4) + { + sum += + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + } + + // handle n_fft remainder + for (; k < n_fft; k++) + { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + + sum = log10(std::max(sum, 1e-10)); + + mel.data[j * mel.n_len + i] = sum; + } + } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) + { + for (int j = 0; j < mel.n_mel; j++) + { + mel.data[j * mel.n_len + i] = sum; + } + } + } + + struct mel_calc_cpu : public whisper_mel_calc + { + ggml_backend_t m_backend; + const whisper_filters &m_filters; + mel_calc_cpu(ggml_backend_t backend, const whisper_filters &filters) : m_backend(backend), m_filters(filters) {} + + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 + whisper_mel calculate(whisper_span ssamples, int n_threads) override + { + // Hann window + const float *hann = global_cache.hann_window; + + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = WHISPER_N_FFT / 2; + + const int n_samples = int(ssamples.len); + const float *samples = ssamples.data; + + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + + whisper_mel_data mel; + mel.n_mel = m_filters.n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + + std::vector host_mel_data; + + whisper_mel ret; + whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel); + if (ggml_backend_buffer_is_host(ret.buffer)) + { + mel.data = reinterpret_cast(ret.tensor->data); + } + else + { + host_mel_data.resize(mel.n_len * mel.n_mel); + mel.data = host_mel_data.data(); + } + + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) + { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded, + n_samples + stage_2_pad, n_threads, + std::cref(m_filters), std::ref(mel)); + } + + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel); + + for (int iw = 0; iw < n_threads - 1; ++iw) + { + workers[iw].join(); + } + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel * mel.n_len; i++) + { + if (mel.data[i] > mmax) + { + mmax = mel.data[i]; + } + } + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel * mel.n_len; i++) + { + if (mel.data[i] < mmax) + { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0) / 4.0; + } + + if (!host_mel_data.empty()) + { + // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it + ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor)); + } + + return ret; + } + }; +} + +static whisper_mel_calc *whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters &filters) +{ +// TODO: disabled because it relies on ggml internals that are no longer accessible (ggml-backend-impl.h, ggml-cuda/common.cuh, ..) +// #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS) +#if 0 + if (ggml_backend_is_cuda(backend)) { + auto ret = whisper_mel_calc_create_cuda(backend, filters); + if (ret) { + // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run) + const float warmup[256] = { 0 }; + ret->calculate({ warmup, 256 }, 1); + return ret; + } + } +#endif + + // a specialized mel_calc could not be created + // fall back to CPU + return new mel_calc_cpu(backend, filters); +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab &vocab, const std::string &text) +{ + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) + { + for (auto x : m) + { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto &word : words) + { + if (word.empty()) + continue; + + int i = 0; + int n = word.size(); + while (i < n) + { + int j = n; + bool found = false; + while (j > i) + { + auto sub = word.substr(i, j - i); + auto it = vocab.token_to_id.find(sub); + if (it != vocab.token_to_id.end()) + { + tokens.push_back(it->second); + i = j; + found = true; + break; + } + --j; + } + if (!found) + { + WHISPER_LOG_ERROR("unknown token\n"); + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +#ifdef WHISPER_USE_COREML +// replace .bin with -encoder.mlmodelc +static std::string whisper_get_coreml_path_encoder(std::string path_bin) +{ + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) + { + path_bin = path_bin.substr(0, pos); + } + + // match "-qx_x" + pos = path_bin.rfind('-'); + if (pos != std::string::npos) + { + auto sub = path_bin.substr(pos); + if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') + { + path_bin = path_bin.substr(0, pos); + } + } + + path_bin += "-encoder.mlmodelc"; + + return path_bin; +} +#endif + +#ifdef WHISPER_USE_OPENVINO +// replace .bin with-encoder-openvino.xml +static std::string whisper_openvino_get_path_encoder(std::string path_bin) +{ + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) + { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino.xml"; + + return path_bin; +} + +static std::string whisper_openvino_get_path_cache(std::string path_bin) +{ + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) + { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino-cache"; + + return path_bin; +} +#endif + +struct whisper_state *whisper_init_state(whisper_context *ctx) +{ + whisper_state *state = new whisper_state; + + state->backends = whisper_backend_init(ctx->params); + if (state->backends.empty()) + { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + whisper_free_state(state); + return nullptr; + } + + state->mel_calc = whisper_mel_calc_create(state->backends[0], ctx->model.filters); + + // init 60s of random mel data + { + const int n_len = 2 * 100 * WHISPER_CHUNK_SIZE; + const int n_mel = ctx->model.filters.n_mel; + + whisper_mel_free(state->mel); + whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel); + } + + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx + // in theory, there can be a case where this is not enough, but in practice it should always be enough + const int factor = 3; + + if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256) * factor)) + { + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) + { + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) + { + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (ctx->params.dtw_token_timestamps) + { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) + { + WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); + whisper_free_state(state); + return nullptr; + } + const size_t memory_size = aheads_masks_nbytes(state->aheads_masks); + WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size); + } + +#ifdef WHISPER_USE_COREML + const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); + + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); + + state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); + if (!state->ctx_coreml) + { + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); +#ifndef WHISPER_COREML_ALLOW_FALLBACK + whisper_free_state(state); + return nullptr; +#endif + } + else + { + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); + } +#endif + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); + + state->decoders[0].rng = std::mt19937(0); + + // conv allocator + { + bool ok = whisper_sched_graph_init(state->sched_conv, state->backends, + [&]() + { + return whisper_build_graph_conv(*ctx, *state, 0); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6); + } + + // encoder allocator + if (!whisper_encode_external(*state)) + { + bool ok = whisper_sched_graph_init(state->sched_encode, state->backends, + [&]() + { + return whisper_build_graph_encoder(*ctx, *state); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6); + } + + // cross allocator + { + bool ok = whisper_sched_graph_init(state->sched_cross, state->backends, + [&]() + { + return whisper_build_graph_cross(*ctx, *state); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6); + } + + // decoder allocator + { + bool ok = whisper_sched_graph_init(state->sched_decode, state->backends, + [&]() + { + const auto &hparams = ctx->model.hparams; + + // TODO: make sure this is the worst-case scenario + const int n_tokens = hparams.n_text_ctx; + const int n_past = 0; + + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); + + return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +int whisper_ctx_init_openvino_encoder( + struct whisper_context *ctx, + const char *model_path, + const char *device, + const char *cache_dir) +{ +#ifndef WHISPER_USE_OPENVINO + (void)(ctx); + (void)(model_path); + (void)(device); + (void)(cache_dir); + + return 1; +#else + if (!model_path && ctx->path_model.empty()) + { + WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + return 1; + } + + std::string path_encoder; + if (!model_path) + { + // if model_path is not set, attempt to find it in the same directory as ggml-.bin model + path_encoder = whisper_openvino_get_path_encoder(ctx->path_model); + } + else + { + path_encoder = model_path; + } + + std::string path_cache; + if (!cache_dir) + { + // if cache_dir is not set, set it as a dir residing next to ggml-.bin + path_cache = whisper_openvino_get_path_cache(ctx->path_model); + } + else + { + path_cache = cache_dir; + } + + WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); + + ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); + if (!ctx->state->ctx_openvino) + { + WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + return 1; + } + else + { + WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); + } + + return 0; +#endif +} + +struct whisper_context_params whisper_context_default_params() +{ + struct whisper_context_params result = { + /*.use_gpu =*/true, + /*.flash_attn =*/false, + /*.gpu_device =*/0, + + /*.dtw_token_timestamps =*/false, + /*.dtw_aheads_preset =*/WHISPER_AHEADS_NONE, + /*.dtw_n_top =*/-1, + /*.dtw_aheads =*/{ + /*.n_heads =*/0, + /*.heads =*/NULL, + }, + /*.dtw_mem_size =*/1024 * 1024 * 128, + }; + return result; +} + +struct whisper_context *whisper_init_from_file_with_params_no_state(const char *path_model, struct whisper_context_params params) +{ + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) + { + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + whisper_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void *ctx, void *output, size_t read_size) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.seek = [](void *ctx, size_t offset) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->seekg(offset, std::ios::cur); + }; + + loader.eof = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; + return fin->eof(); + }; + + loader.close = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->close(); + }; + + auto ctx = whisper_init_with_params_no_state(&loader, params); + + if (ctx) + { + ctx->path_model = path_model; + } + + return ctx; +} + +struct whisper_context *whisper_init_from_buffer_with_params_no_state(void *buffer, size_t buffer_size, struct whisper_context_params params) +{ + struct buf_context + { + uint8_t *buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = {reinterpret_cast(buffer), buffer_size, 0}; + + WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); + + whisper_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void *ctx, void *output, size_t read_size) + { + buf_context *buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void *ctx) + { + buf_context *buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) {}; + + return whisper_init_with_params_no_state(&loader, params); +} + +struct whisper_context *whisper_init_with_params_no_state(struct whisper_model_loader *loader, struct whisper_context_params params) +{ + ggml_time_init(); + + if (params.flash_attn && params.dtw_token_timestamps) + { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + + whisper_context *ctx = new whisper_context; + ctx->params = params; + + if (!whisper_model_load(loader, *ctx)) + { + loader->close(loader->context); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +struct whisper_context *whisper_init_from_file_with_params(const char *path_model, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_from_file_with_params_no_state(path_model, params); + if (!ctx) + { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) + { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context *whisper_init_from_buffer_with_params(void *buffer, size_t buffer_size, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) + { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) + { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context *whisper_init_with_params(struct whisper_model_loader *loader, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_with_params_no_state(loader, params); + if (!ctx) + { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) + { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context *whisper_init_from_file(const char *path_model) +{ + return whisper_init_from_file_with_params(path_model, whisper_context_default_params()); +} + +struct whisper_context *whisper_init_from_buffer(void *buffer, size_t buffer_size) +{ + return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context *whisper_init(struct whisper_model_loader *loader) +{ + return whisper_init_with_params(loader, whisper_context_default_params()); +} + +struct whisper_context *whisper_init_from_file_no_state(const char *path_model) +{ + return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params()); +} + +struct whisper_context *whisper_init_from_buffer_no_state(void *buffer, size_t buffer_size) +{ + return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context *whisper_init_no_state(struct whisper_model_loader *loader) +{ + return whisper_init_with_params_no_state(loader, whisper_context_default_params()); +} + +void whisper_free_state(struct whisper_state *state) +{ + if (state) + { + whisper_kv_cache_free(state->kv_self); + whisper_kv_cache_free(state->kv_cross); + whisper_kv_cache_free(state->kv_pad); + + whisper_mel_free(state->mel); + + delete state->mel_calc; + state->mel_calc = nullptr; + delete state->mel_calc_fallback; + state->mel_calc_fallback = nullptr; + +#ifdef WHISPER_USE_COREML + if (state->ctx_coreml != nullptr) + { + whisper_coreml_free(state->ctx_coreml); + state->ctx_coreml = nullptr; + } +#endif + +#ifdef WHISPER_USE_OPENVINO + if (state->ctx_openvino != nullptr) + { + whisper_openvino_free(state->ctx_openvino); + state->ctx_openvino = nullptr; + } +#endif + + whisper_batch_free(state->batch); + + ggml_backend_sched_free(state->sched_conv.sched); + ggml_backend_sched_free(state->sched_encode.sched); + ggml_backend_sched_free(state->sched_cross.sched); + ggml_backend_sched_free(state->sched_decode.sched); + + for (auto &backend : state->backends) + { + ggml_backend_free(backend); + } + + // [EXPERIMENTAL] Token-level timestamps with DTW + aheads_masks_free(state->aheads_masks); + + delete state; + } +} + +void whisper_free(struct whisper_context *ctx) +{ + if (ctx) + { + ggml_free(ctx->model.ctx); + + ggml_backend_buffer_free(ctx->model.buffer); + + whisper_free_state(ctx->state); + + delete ctx; + } +} + +void whisper_free_context_params(struct whisper_context_params *params) +{ + if (params) + { + delete params; + } +} + +void whisper_free_params(struct whisper_full_params *params) +{ + if (params) + { + delete params; + } +} + +int whisper_pcm_to_mel_with_state(struct whisper_context *ctx, struct whisper_state *state, const float *samples, int n_samples, int n_threads) +{ + const int64_t t_start_us = ggml_time_us(); + + whisper_mel_free(state->mel); + if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) + { + // calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator + state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads); + } + else + { + // calcuate mel spectrogram for longer audios on the CPU + // 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative + // with our gpu demands + // 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation + // taking longer is not a major concern + if (!state->mel_calc_fallback) + { + state->mel_calc_fallback = new mel_calc_cpu(state->backends[0], ctx->model.filters); + } + state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads); + } + + state->t_mel_us += ggml_time_us() - t_start_us; + + // Dump log_mel_spectrogram + //{ + // auto& mel = state->mel; + // std::ofstream outFile("log_mel_spectrogram.json"); + // outFile << "["; + // for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + // outFile << mel.data[i] << ", "; + // } + // outFile << mel.data[mel.data.size() - 1] << "]"; + // outFile.close(); + //} + return 0; +} + +int whisper_pcm_to_mel(struct whisper_context *ctx, const float *samples, int n_samples, int n_threads) +{ + return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int whisper_set_mel_with_state( + struct whisper_context *ctx, + struct whisper_state *state, + const float *data, + int n_len, + int n_mel) +{ + if (n_mel != ctx->model.filters.n_mel) + { + WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + whisper_mel_free(state->mel); + whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel); + + ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor)); + + return 0; +} + +int whisper_set_mel( + struct whisper_context *ctx, + const float *data, + int n_len, + int n_mel) +{ + return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int whisper_encode_with_state(struct whisper_context *ctx, struct whisper_state *state, int offset, int n_threads) +{ + if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) + { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_encode(struct whisper_context *ctx, int offset, int n_threads) +{ + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) + { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_decode_with_state(struct whisper_context *ctx, struct whisper_state *state, const whisper_token *tokens, int n_tokens, int n_past, int n_threads) +{ + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); + + whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1); + + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) + { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_decode(struct whisper_context *ctx, const whisper_token *tokens, int n_tokens, int n_past, int n_threads) +{ + if (ctx->state == nullptr) + { + WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); + return -1; + } + + return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads); +} + +int whisper_tokenize(struct whisper_context *ctx, const char *text, whisper_token *tokens, int n_max_tokens) +{ + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int)res.size()) + { + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int)res.size(), n_max_tokens); + return -(int)res.size(); + } + + for (int i = 0; i < (int)res.size(); i++) + { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_token_count(struct whisper_context *ctx, const char *text) +{ + return -whisper_tokenize(ctx, text, NULL, 0); +} + +int whisper_lang_max_id(void) +{ + auto max_id = 0; + for (const auto &kv : g_lang) + { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char *lang) +{ + if (!g_lang.count(lang)) + { + for (const auto &kv : g_lang) + { + if (kv.second.second == lang) + { + return kv.second.first; + } + } + + WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang); + return -1; + } + return g_lang.at(lang).first; +} + +const char *whisper_lang_str(int id) +{ + for (const auto &kv : g_lang) + { + if (kv.second.first == id) + { + return kv.first.c_str(); + } + } + + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +const char *whisper_lang_str_full(int id) +{ + for (const auto &kv : g_lang) + { + if (kv.second.first == id) + { + return kv.second.second.c_str(); + } + } + + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect_with_state( + struct whisper_context *ctx, + struct whisper_state *state, + int offset_ms, + int n_threads, + float *lang_probs) +{ + const int seek = offset_ms / 10; + + if (seek < 0) + { + WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= state->mel.n_len_org) + { + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org * 10); + return -2; + } + + // run the encoder + if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) + { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = {whisper_token_sot(ctx)}; + + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) + { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + auto &logits_id = state->decoders[0].logits_id; + logits_id.clear(); + + for (const auto &kv : g_lang) + { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type &a, const pair_type &b) + { return a.first > b.first; }); + } + + // softmax + { + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto &kv : logits_id) + { + kv.first = exp(kv.first - max); + sum += kv.first; + } + + for (auto &kv : logits_id) + { + kv.first /= sum; + } + } + + { + for (const auto &prob : logits_id) + { + if (lang_probs) + { + lang_probs[prob.second] = prob.first; + } + + // printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + } + } + + return logits_id[0].second; +} + +int whisper_lang_auto_detect( + struct whisper_context *ctx, + int offset_ms, + int n_threads, + float *lang_probs) +{ + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); +} + +int whisper_model_n_vocab(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_vocab; +} + +int whisper_model_n_audio_ctx(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_model_n_audio_state(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_audio_state; +} + +int whisper_model_n_audio_head(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_audio_head; +} + +int whisper_model_n_audio_layer(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_audio_layer; +} + +int whisper_model_n_text_ctx(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_text_ctx; +} + +int whisper_model_n_text_state(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_text_state; +} + +int whisper_model_n_text_head(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_text_head; +} + +int whisper_model_n_text_layer(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_text_layer; +} + +int whisper_model_n_mels(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_mels; +} + +int whisper_model_ftype(struct whisper_context *ctx) +{ + return ctx->model.hparams.ftype; +} + +int whisper_model_type(struct whisper_context *ctx) +{ + return ctx->model.type; +} + +const char *whisper_model_type_readable(struct whisper_context *ctx) +{ + switch (ctx->model.type) + { + case e_model::MODEL_TINY: + return "tiny"; + case e_model::MODEL_BASE: + return "base"; + case e_model::MODEL_SMALL: + return "small"; + case e_model::MODEL_MEDIUM: + return "medium"; + case e_model::MODEL_LARGE: + return "large"; + default: + return "unknown"; + } +} + +int whisper_n_len_from_state(struct whisper_state *state) +{ + return state->mel.n_len_org; +} + +int whisper_n_len(struct whisper_context *ctx) +{ + return ctx->state->mel.n_len_org; +} + +int whisper_n_vocab(struct whisper_context *ctx) +{ + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_text_ctx; +} + +int whisper_n_audio_ctx(struct whisper_context *ctx) +{ + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_is_multilingual(struct whisper_context *ctx) +{ + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float *whisper_get_logits(struct whisper_context *ctx) +{ + return ctx->state->logits.data(); +} + +float *whisper_get_logits_from_state(struct whisper_state *state) +{ + return state->logits.data(); +} + +const char *whisper_token_to_str(struct whisper_context *ctx, whisper_token token) +{ + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context *ctx) +{ + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context *ctx) +{ + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_solm(struct whisper_context *ctx) +{ + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_prev(struct whisper_context *ctx) +{ + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_nosp(struct whisper_context *ctx) +{ + return ctx->vocab.token_nosp; +} + +whisper_token whisper_token_not(struct whisper_context *ctx) +{ + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context *ctx) +{ + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context *ctx, int lang_id) +{ + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(struct whisper_context *ctx) +{ + return ctx->vocab.token_translate; +} + +whisper_token whisper_token_transcribe(struct whisper_context *ctx) +{ + return ctx->vocab.token_transcribe; +} + +void whisper_print_timings(struct whisper_context *ctx) +{ + const int64_t t_end_us = ggml_time_us(); + + WHISPER_LOG_INFO("\n"); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) + { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); + const int32_t n_prompt = std::max(1, ctx->state->n_prompt); + + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + } + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us) / 1000.0f); +} + +void whisper_reset_timings(struct whisper_context *ctx) +{ + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) + { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_batchd_us = 0; + ctx->state->t_prompt_us = 0; + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; + ctx->state->n_prompt = 0; + } +} + +static int whisper_has_coreml(void) +{ +#ifdef WHISPER_USE_COREML + return 1; +#else + return 0; +#endif +} + +static int whisper_has_openvino(void) +{ +#ifdef WHISPER_USE_OPENVINO + return 1; +#else + return 0; +#endif +} + +const char *whisper_print_system_info(void) +{ + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | "; + s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; + s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; + s += "CANN = " + std::to_string(ggml_cpu_has_cann()); + return s.c_str(); +} + +////////////////////////////////// +// Grammar - ported from llama.cpp +////////////////////////////////// + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. +static std::pair, whisper_partial_utf8> decode_utf8( + const char *src, + whisper_partial_utf8 partial_start) +{ + static const int lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4}; + const char *pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) + { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) + { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{0, -1}); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) + { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) + { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) + { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{0, n_remain}); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) + { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) + { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), whisper_partial_utf8{value, n_remain}); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element *pos) +{ + switch (pos->type) + { + case WHISPER_GRETYPE_END: + return true; // NOLINT + case WHISPER_GRETYPE_ALT: + return true; // NOLINT + default: + return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair whisper_grammar_match_char( + const whisper_grammar_element *pos, + const uint32_t chr) +{ + + bool found = false; + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT + + do + { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) + { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } + else + { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool whisper_grammar_match_partial_char( + const whisper_grammar_element *pos, + const whisper_partial_utf8 partial_utf8) +{ + + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) + { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) + { + if (n_remain == 2) + { + low = 1 << 11; + } + else if (n_remain == 3) + { + low = 1 << 16; + } + } + + do + { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) + { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) + { + return is_positive_char; + } + pos += 2; + } + else + { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) + { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void whisper_grammar_advance_stack( + const std::vector> &rules, + const std::vector &stack, + std::vector> &new_stacks) +{ + + if (stack.empty()) + { + new_stacks.push_back(stack); + return; + } + + const whisper_grammar_element *pos = stack.back(); + + switch (pos->type) + { + case WHISPER_GRETYPE_RULE_REF: + { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element *subpos = rules[rule_id].data(); + do + { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) + { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) + { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) + { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) + { + // there's another alternate def of this rule to process + subpos++; + } + else + { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `whisper_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> whisper_grammar_accept( + const std::vector> &rules, + const std::vector> &stacks, + const uint32_t chr) +{ + + std::vector> new_stacks; + + for (const auto &stack : stacks) + { + if (stack.empty()) + { + continue; + } + + auto match = whisper_grammar_match_char(stack.back(), chr); + if (match.first) + { + const whisper_grammar_element *pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos)) + { + new_stack.push_back(pos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> &rules, + const std::vector> &stacks, + const std::vector &candidates); + +static std::vector whisper_grammar_reject_candidates_for_stack( + const std::vector> &rules, + const std::vector &stack, + const std::vector &candidates) +{ + + std::vector rejects; + + if (stack.empty()) + { + for (auto tok : candidates) + { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) + { + rejects.push_back(tok); + } + } + return rejects; + } + + const whisper_grammar_element *stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) + { + if (*tok.code_points == 0) + { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) + { + rejects.push_back(tok); + } + } + else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) + { + next_candidates.push_back({tok.id, tok.code_points + 1, tok.partial_utf8}); + } + else + { + rejects.push_back(tok); + } + } + + const auto *stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) + { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + whisper_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) + { + rejects.push_back({tok.id, tok.code_points - 1, tok.partial_utf8}); + } + + return rejects; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> &rules, + const std::vector> &stacks, + const std::vector &candidates) +{ + if (candidates.empty() || stacks.empty()) + { + return std::vector(); + } + + auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) + { + rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +static struct whisper_grammar whisper_grammar_init( + const whisper_grammar_element **rules, + size_t n_rules, + size_t i_start_rule) +{ + const whisper_grammar_element *pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) + { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) + { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[i_start_rule]; + do + { + std::vector stack; + if (!whisper_grammar_is_end_of_sequence(pos)) + { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + whisper_grammar_advance_stack(vec_rules, stack, stacks); + while (!whisper_grammar_is_end_of_sequence(pos)) + { + // scan to end of alternate def + pos++; + } + if (pos->type == WHISPER_GRETYPE_ALT) + { + // there's another alternate def of this rule to process + pos++; + } + else + { + break; + } + } while (true); + + return {std::move(vec_rules), std::move(stacks), {}}; +} + +static void whisper_suppress_invalid_grammar( + whisper_context &ctx, + const whisper_full_params ¶ms, + std::vector &logits, + const whisper_grammar &grammar) +{ + + if (grammar.rules.empty() || grammar.stacks.empty()) + { + return; + } + + // bool allow_eot = false; + // for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + // } + + const whisper_token eot = whisper_token_eot(&ctx); + + std::vector, whisper_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (whisper_token id = 0; id < eot; ++id) + { + const std::string &text = ctx.vocab.id_to_token[id]; + if (!text.empty()) + { + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); + candidates_grammar.push_back({id, candidates_decoded.back().first.data(), candidates_decoded.back().second}); + } + } + + const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + + for (const auto &reject : rejects) + { + logits[reject.id] -= params.grammar_penalty; + } + + // when the grammar allows a continuation, we penalize the end-of-text token + // if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} + // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); +} + +static void whisper_grammar_accept_token(whisper_context &ctx, whisper_grammar &grammar, whisper_token token) +{ + if (grammar.rules.empty() || grammar.stacks.empty()) + { + return; + } + + // fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + + const std::string &text = ctx.vocab.id_to_token[token]; + + if (text.rfind("[_", 0) == 0) + { + // fprintf(stderr, " (skipped)\n"); + return; + } + // fprintf(stderr, "\n"); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto &code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) + { + grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); + } + grammar.partial_utf8 = decoded.second; +} + +////////////// +// END grammar +////////////// + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_context_params *whisper_context_default_params_by_ref(void) +{ + struct whisper_context_params params = whisper_context_default_params(); + + struct whisper_context_params *result = new whisper_context_params(); + *result = params; + return result; +} + +struct whisper_full_params *whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) +{ + struct whisper_full_params params = whisper_full_default_params(strategy); + + struct whisper_full_params *result = new whisper_full_params(); + *result = params; + return result; +} + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) +{ + struct whisper_full_params result = { + /*.strategy =*/strategy, + + /*.n_threads =*/std::min(4, (int32_t)std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/16384, + /*.offset_ms =*/0, + /*.duration_ms =*/0, + + /*.translate =*/false, + /*.no_context =*/true, + /*.no_timestamps =*/false, + /*.single_segment =*/false, + /*.print_special =*/false, + /*.print_progress =*/true, + /*.print_realtime =*/false, + /*.print_timestamps =*/true, + + /*.token_timestamps =*/false, + /*.thold_pt =*/0.01f, + /*.thold_ptsum =*/0.01f, + /*.max_len =*/0, + /*.split_on_word =*/false, + /*.max_tokens =*/0, + + /*.debug_mode =*/false, + /*.audio_ctx =*/0, + + /*.tdrz_enable =*/false, + + /* suppress_regex =*/nullptr, + + /*.initial_prompt =*/nullptr, + /*.prompt_tokens =*/nullptr, + /*.prompt_n_tokens =*/0, + + /*.language =*/"en", + /*.detect_language =*/false, + + /*.suppress_blank =*/true, + /*.suppress_non_speech_tokens =*/false, + + /*.temperature =*/0.0f, + /*.max_initial_ts =*/1.0f, + /*.length_penalty =*/-1.0f, + + /*.temperature_inc =*/0.2f, + /*.entropy_thold =*/2.4f, + /*.logprob_thold =*/-1.0f, + /*.no_speech_thold =*/0.6f, + + /*.greedy =*/{ + /*.best_of =*/-1, + }, + + /*.beam_search =*/{ + /*.beam_size =*/-1, + + /*.patience =*/-1.0f, + }, + + /*.new_segment_callback =*/nullptr, + /*.new_segment_callback_user_data =*/nullptr, + + /*.progress_callback =*/nullptr, + /*.progress_callback_user_data =*/nullptr, + + /*.encoder_begin_callback =*/nullptr, + /*.encoder_begin_callback_user_data =*/nullptr, + + /*.abort_callback =*/nullptr, + /*.abort_callback_user_data =*/nullptr, + + /*.logits_filter_callback =*/nullptr, + /*.logits_filter_callback_user_data =*/nullptr, + + /*.grammar_rules =*/nullptr, + /*.n_grammar_rules =*/0, + /*.i_start_rule =*/0, + /*.grammar_penalty =*/100.0f, + }; + + switch (strategy) + { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/5, + }; + } + break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/5, + + /*.patience =*/-1.0f, + }; + } + break; + } + + return result; +} + +// forward declarations +static std::vector get_signal_energy(const float *signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context &ctx, + struct whisper_state &state, + int i_segment, + float thold_pt, + float thold_ptsum); + +static inline bool should_split_on_word(const char *txt, bool split_on_word) +{ + if (!split_on_word) + return true; + + return txt[0] == ' '; +} + +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context &ctx, struct whisper_state &state, int max_len, bool split_on_word) +{ + auto segment = state.result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int)segment.tokens.size(); i++) + { + const auto &token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) + { + continue; + } + + const auto txt = whisper_token_to_str(&ctx, token.id); + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) + { + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; + + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + + acc = 0; + text = ""; + + segment = state.result_all.back(); + i = -1; + + res++; + } + else + { + acc += cur; + text += txt; + } + } + + state.result_all.back().text = std::move(text); + + return res; +} + +static const std::vector non_speech_tokens = { + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪", "♩", "♪", "♫", "♬", "♭", "♮", "♯"}; + +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +// TODO: optimize +static void whisper_process_logits( + struct whisper_context &ctx, + struct whisper_state &state, + struct whisper_decoder &decoder, + const struct whisper_full_params params, + float temperature) +{ + const auto &vocab = ctx.vocab; + const auto &tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly + auto &probs = decoder.probs; + auto &logits = decoder.logits; + auto &logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), state.logits.data() + decoder.i_batch * n_logits, n_logits * sizeof(float)); + + if (temperature > 0.0f) + { + for (int i = 0; i < n_logits; i++) + { + logits[i] /= temperature; + } + } + + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } + + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) + { + if (is_initial) + { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) + { + for (int i = vocab.token_beg; i < n_logits; ++i) + { + logits[i] = -INFINITY; + } + } + + // suppress sot and nosp tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + + // [TDRZ] when tinydiarize is disabled, suppress solm token + if (params.tdrz_enable == false) + { + logits[vocab.token_solm] = -INFINITY; + } + + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; + + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) + { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + + if (params.logits_filter_callback) + { + params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + } + + // suppress any tokens matching a regular expression + // ref: https://github.com/openai/whisper/discussions/1041 + if (params.suppress_regex != nullptr) + { + std::regex re(params.suppress_regex); + for (std::pair token_id : vocab.token_to_id) + { + if (std::regex_match(token_id.first, re)) + { + logits[token_id.second] = -INFINITY; + } + } + } + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) + { + for (const std::string &token : non_speech_tokens) + { + const std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + // WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) + { + if (penultimate_was_timestamp) + { + for (int i = vocab.token_beg; i < n_logits; ++i) + { + logits[i] = -INFINITY; + } + } + else + { + for (int i = 0; i < vocab.token_eot; ++i) + { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) + { + const float precision = float(WHISPER_CHUNK_SIZE) / ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts / precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) + { + logits[i] = -INFINITY; + } + } + + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) + { + const int tid0 = decoder.seek_delta / 2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) + { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { + logprobs[i] = logits[i] - logsumexp; + } + else + { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) + { + if (logprobs[i] > -INFINITY) + { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) + { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + // WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) + { + for (int i = 0; i < vocab.token_beg; ++i) + { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } + else + { + if (params.n_grammar_rules > 0) + { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { + logprobs[i] = logits[i] - logsumexp; + } + else + { + logprobs[i] = -INFINITY; + } + } + } + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] == -INFINITY) + { + probs[i] = 0.0f; + } + else + { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); + } + + // "And", "and", " And", " and" + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static bool whisper_sequence_tokens_equal(const whisper_sequence &a, const whisper_sequence &b) +{ + if (a.tokens.size() != b.tokens.size()) + { + return false; + } + // sequences are more likely to diverge at the end + for (int i = a.tokens.size() - 1; i >= 0; i--) + { + if (a.tokens[i].id != b.tokens[i].id) + { + return false; + } + } + return true; +} + +static whisper_token_data whisper_sample_token( + whisper_context &ctx, + const whisper_decoder &decoder, + bool best) +{ + whisper_token_data result = { + 0, + 0, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + -1, + -1, + -1, + 0.0f, + }; + + const auto &vocab = ctx.vocab; + + const auto &probs = decoder.probs; + const auto &logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) + { + if (probs[i] == -INFINITY) + { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) + { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts / (sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) + { + for (int i = 0; i < n_logits; ++i) + { + if (result.p < probs[i]) + { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } + else + { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(decoder.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) + { + result.tid = result.id; + result.pt = result.p; + } + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context &ctx, + whisper_decoder &decoder, + int k) +{ + const auto &vocab = ctx.vocab; + + const auto &probs = decoder.probs; + const auto &logits = decoder.logits; + const auto &logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto &logits_id = decoder.logits_id; + + logits_id.resize(n_logits); + for (int i = 0; i < n_logits; ++i) + { + logits_id[i].first = logits[i]; + logits_id[i].second = i; + } + + { + using pair_type = std::remove_reference::type::value_type; + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const pair_type &a, const pair_type &b) + { + return a.first > b.first; + }); + } + + std::vector result; + result.reserve(k); + + whisper_token tid = vocab.token_beg; + + float pt = 0.0; + float ptsum = 0.0; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) + { + if (probs[i] == -INFINITY) + { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) + { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts / (sum_ts + 1e-10); + ptsum = sum_ts; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + for (int i = 0; i < k; ++i) + { + const auto id = dist(decoder.rng); + // printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); + + result.push_back({ + id, + tid, + probs[id], + logprobs[id], + pt, + ptsum, + -1, + -1, + -1, + 0.0f, + }); + + if (result[i].id >= vocab.token_beg) + { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params ¶ms, + whisper_sequence &sequence) +{ + if (sequence.result_len == 0) + { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) + { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result / sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) + { + penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); + } + + sequence.score = result / penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) + { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto &kv : token_counts) + { + const auto p = kv.second / (double)cnt; + entropy -= p * log(p); + + // WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full_with_state( + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ + // clear old results + auto &result_all = state->result_all; + + result_all.clear(); + + if (n_samples > 0) + { + // compute log mel spectrogram + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) + { + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) + { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); + if (lang_id < 0) + { + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); + return -3; + } + state->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + if (params.detect_language) + { + return 0; + } + } + + if (params.token_timestamps) + { + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + if (n_samples > 0) + { + state->energy = get_signal_energy(samples, n_samples, 32); + } + } + + const int seek_start = params.offset_ms / 10; + const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms / 10; + + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + 100) + { + WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start) * 10); + return 0; + } + + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) + { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) + { + temperatures.push_back(t); + } + } + else + { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) + { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } + break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } + break; + }; + + n_decoders = std::max(1, n_decoders); + + if (n_decoders > WHISPER_MAX_DECODERS) + { + WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS); + return -4; + } + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) + { + auto &decoder = state->decoders[j]; + + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize(ctx->vocab.n_vocab); + decoder.logits.resize(ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + decoder.logits_id.reserve(ctx->model.hparams.n_vocab); + + decoder.rng = std::mt19937(0); + } + + // the accumulated text context so far + auto &prompt_past = state->prompt_past; + if (params.no_context) + { + prompt_past.clear(); + } + + // prepare prompt + { + std::vector prompt_tokens; + + // initial prompt + if (!params.prompt_tokens && params.initial_prompt) + { + prompt_tokens.resize(1024); + int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + if (n_needed < 0) + { + prompt_tokens.resize(-n_needed); + n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + } + prompt_tokens.resize(n_needed); + params.prompt_tokens = prompt_tokens.data(); + params.prompt_n_tokens = prompt_tokens.size(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) + { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) + { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) + { + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + state->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector prompt_init = { + whisper_token_sot(ctx), + }; + + if (whisper_is_multilingual(ctx)) + { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) + { + prompt_init.push_back(whisper_token_translate(ctx)); + } + else + { + prompt_init.push_back(whisper_token_transcribe(ctx)); + } + } + + // first release distilled models require the "no_timestamps" token + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866; + if (is_distil && !params.no_timestamps) + { + WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__); + params.no_timestamps = true; + } + } + + if (params.no_timestamps) + { + prompt_init.push_back(whisper_token_not(ctx)); + } + + int seek = seek_start; + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + struct beam_candidate + { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + whisper_grammar grammar; + }; + + std::vector> bc_per_dec(n_decoders); + std::vector beam_candidates; + + // main loop + while (true) + { + if (params.progress_callback) + { + const int progress_cur = (100 * (seek - seek_start)) / (seek_end - seek_start); + + params.progress_callback( + ctx, state, progress_cur, params.progress_callback_user_data); + } + + // if only 1 second left, then stop + if (seek + 100 >= seek_end) + { + break; + } + + if (params.encoder_begin_callback) + { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) + { + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) + { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) + { + prompt_past.clear(); + } + + int best_decoder_id = 0; + + for (int it = 0; it < (int)temperatures.size(); ++it) + { + const float t_cur = temperatures[it]; + + int n_decoders_cur = 1; + + switch (params.strategy) + { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) + { + n_decoders_cur = params.greedy.best_of; + } + } + break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) + { + n_decoders_cur = params.greedy.best_of; + } + else + { + n_decoders_cur = params.beam_search.beam_size; + } + } + break; + }; + + n_decoders_cur = std::max(1, n_decoders_cur); + + WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; + + decoder.seek_delta = 100 * WHISPER_CHUNK_SIZE; + + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + + if (params.grammar_rules != nullptr) + { + decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + } + else + { + decoder.grammar = {}; + } + } + + // init prompt and kv cache for the current iteration + // TODO: do not recompute the prompt if it is the same as previous time + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) + { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx) / 2), int(prompt_past.size())); + + prompt = {whisper_token_prev(ctx)}; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + WHISPER_LOG_DEBUG("\n\n"); + for (int i = 0; i < (int)prompt.size(); i++) + { + WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_LOG_DEBUG("\n\n"); + + whisper_kv_cache_clear(state->kv_self); + + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) + { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + state->decoders[0].i_batch = prompt.size() - 1; + + whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); + + for (int j = 1; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); + + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size() * sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size() * sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size() * sizeof(decoder.logprobs[0])); + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx) / 2 - 4; i < n_max; ++i) + { + const int64_t t_start_sample_us = ggml_time_us(); + + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) + { + for (auto &bc : bc_per_dec) + { + bc.clear(); + } + } + + // sampling + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); + + auto process = [&]() + { + while (true) + { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) + { + break; + } + + auto &decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) + { + continue; + } + + switch (params.strategy) + { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) + { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } + else + { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } + break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto &token : tokens_new) + { + bc_per_dec[j].push_back({ + j, + decoder.seek_delta, + decoder.has_ts, + decoder.sequence, + decoder.grammar, + }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } + break; + }; + } + }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) + { + process(); + } + else + { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) + { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) + { + threads[t].join(); + } + } + } + + beam_candidates.clear(); + for (const auto &bc : bc_per_dec) + { + beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); + + if (!bc.empty()) + { + state->n_sample += 1; + } + } + + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) + { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate &a, const beam_candidate &b) + { + if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) + { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + } + return a.decoder_idx < b.decoder_idx; + }); + + uint32_t cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) + { + continue; + } + + if (cur_c >= beam_candidates.size()) + { + cur_c = 0; + } + + auto &cur = beam_candidates[cur_c++]; + + while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) + { + ++cur_c; + } + + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; + + whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); + + WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) + { + continue; + } + + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + } + } + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) + { + continue; + } + + auto &has_ts = decoder.has_ts; + auto &failed = decoder.failed; + auto &completed = decoder.completed; + auto &seek_delta = decoder.seek_delta; + auto &result_len = decoder.sequence.result_len; + + { + const auto &token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) + { + const int seek_delta_new = 2 * (token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) + { + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new); + failed = true; // TODO: maybe this is not a failure ? + continue; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + + whisper_grammar_accept_token(*ctx, decoder.grammar, token.id); + +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) + { + if (result_len == 0 && !params.no_timestamps) + { + if (seek + seek_delta + 100 >= seek_end) + { + result_len = i + 1; + } + else + { + WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j); + failed = true; + continue; + } + } + + if (params.single_segment || params.no_timestamps) + { + result_len = i + 1; + seek_delta = 100 * WHISPER_CHUNK_SIZE; + } + + WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j); + completed = true; + continue; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) + { + seek_delta = 100 * WHISPER_CHUNK_SIZE; + completed = true; + continue; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2)) + { + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j); + failed = true; + continue; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) + { + continue; + } + + completed_all = false; + } + + if (completed_all) + { + break; + } + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + { + auto &batch = state->batch; + + batch.n_tokens = 0; + + const int n_past = prompt.size() + i; + + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) + { + continue; + } + + // WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + + decoder.i_batch = batch.n_tokens; + + batch.token[batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos[batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id[batch.n_tokens][0] = j; + batch.logits[batch.n_tokens] = 1; + batch.n_tokens++; + } + + assert(batch.n_tokens > 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) + { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -8; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); + + auto process = [&]() + { + while (true) + { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) + { + break; + } + + auto &decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) + { + continue; + } + + whisper_process_logits(*ctx, *state, decoder, params, t_cur); + } + }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) + { + process(); + } + else + { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) + { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) + { + threads[t].join(); + } + } + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; + + if (decoder.failed) + { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); + + WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) + { + WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); + + decoder.failed = true; + state->n_fail_h++; + + continue; + } + + if (best_score < decoder.sequence.score) + { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } + + WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + bool success = true; + + // was the decoding successful for the current temperature? + // do fallback only if: + // - we are not at the last temperature + if (it != (int)temperatures.size() - 1) + { + const auto &decoder = state->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) + { + WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold); + success = false; + state->n_fail_p++; + } + } + + if (success) + { + // for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + // } + + break; + } + + WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + } + + // output results through a user-provided callback + { + const auto &best_decoder = state->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto &tokens_cur = best_decoder.sequence.tokens; + + // [EXPERIMENTAL] Token-level timestamps with DTW + const auto n_segments_before = state->result_all.size(); + + // WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) + { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) + { + prompt_past.push_back(tokens_cur[i].id); + } + + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) + { + int i0 = 0; + auto t0 = seek + 2 * (tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + bool speaker_turn_next = false; + + for (int i = 0; i < (int)tokens_cur.size(); i++) + { + // printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) + { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) + { + speaker_turn_next = true; + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) + { + const auto t1 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); + + if (!text.empty()) + { + const auto tt0 = t0; + const auto tt1 = t1; + + if (params.print_realtime) + { + if (params.print_timestamps) + { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } + else + { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + // printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({tt0, tt1, text, {}, speaker_turn_next}); + for (int j = i0; j <= i; j++) + { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) + { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) + { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int)tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) + { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + speaker_turn_next = false; + } + } + + if (!text.empty()) + { + const auto t1 = seek + seek_delta; + + const auto tt0 = t0; + const auto tt1 = t1; + + if (params.print_realtime) + { + if (params.print_timestamps) + { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } + else + { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({tt0, tt1, text, {}, speaker_turn_next}); + for (int j = i0; j < (int)tokens_cur.size(); j++) + { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) + { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) + { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + } + + // FIXME: will timestamp offsets be correct? + // [EXPERIMENTAL] Token-level timestamps with DTW + { + const auto n_segments = state->result_all.size() - n_segments_before; + if (ctx->params.dtw_token_timestamps && n_segments) + { + const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek); + whisper_exp_compute_token_level_timestamps_dtw( + ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads); + } + } + + // update audio window + seek += seek_delta; + + WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_full( + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ + return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int whisper_full_parallel( + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples, + int n_processors) +{ + if (n_processors == 1) + { + return whisper_full(ctx, params, samples, n_samples); + } + int ret = 0; + + // prepare separate states for each thread + std::vector states; + + const int offset_samples = (WHISPER_SAMPLE_RATE * params.offset_ms) / 1000; + const int n_samples_per_processor = (n_samples - offset_samples) / n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) + { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + const int start_samples = offset_samples + (i + 1) * n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) + { + workers[i].join(); + } + + const int64_t offset_t = (int64_t)params.offset_ms / 10.0; + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) + { + auto &results_i = states[i]->result_all; + + for (auto &result : results_i) + { + // correct the segment timestamp taking into account the offset + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) + { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + // print information about the audio boundaries + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) + { + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int whisper_full_n_segments_from_state(struct whisper_state *state) +{ + return state->result_all.size(); +} + +int whisper_full_n_segments(struct whisper_context *ctx) +{ + return ctx->state->result_all.size(); +} + +int whisper_full_lang_id_from_state(struct whisper_state *state) +{ + return state->lang_id; +} + +int whisper_full_lang_id(struct whisper_context *ctx) +{ + return ctx->state->lang_id; +} + +ggml_tensor *whisper_full_get_embd_conv(struct whisper_context *ctx) +{ + return ctx->state->embd_conv; +} + +ggml_tensor *whisper_full_get_embd_enc(struct whisper_context *ctx) +{ + return ctx->state->embd_enc; +} + +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state *state, int i_segment) +{ + return state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t0(struct whisper_context *ctx, int i_segment) +{ + return ctx->state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state *state, int i_segment) +{ + return state->result_all[i_segment].t1; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context *ctx, int i_segment) +{ + return ctx->state->result_all[i_segment].t1; +} + +bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state *state, int i_segment) +{ + return state->result_all[i_segment].speaker_turn_next; +} + +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context *ctx, int i_segment) +{ + return ctx->state->result_all[i_segment].speaker_turn_next; +} + +const char *whisper_full_get_segment_text_from_state(struct whisper_state *state, int i_segment) +{ + return state->result_all[i_segment].text.c_str(); +} + +const char *whisper_full_get_segment_text(struct whisper_context *ctx, int i_segment) +{ + return ctx->state->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens_from_state(struct whisper_state *state, int i_segment) +{ + return state->result_all[i_segment].tokens.size(); +} + +int whisper_full_n_tokens(struct whisper_context *ctx, int i_segment) +{ + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char *whisper_full_get_token_text_from_state(struct whisper_context *ctx, struct whisper_state *state, int i_segment, int i_token) +{ + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char *whisper_full_get_token_text(struct whisper_context *ctx, int i_segment, int i_token) +{ + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id_from_state(struct whisper_state *state, int i_segment, int i_token) +{ + return state->result_all[i_segment].tokens[i_token].id; +} + +whisper_token whisper_full_get_token_id(struct whisper_context *ctx, int i_segment, int i_token) +{ + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state *state, int i_segment, int i_token) +{ + return state->result_all[i_segment].tokens[i_token]; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context *ctx, int i_segment, int i_token) +{ + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p_from_state(struct whisper_state *state, int i_segment, int i_token) +{ + return state->result_all[i_segment].tokens[i_token].p; +} + +float whisper_full_get_token_p(struct whisper_context *ctx, int i_segment, int i_token) +{ + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) +{ + fputs(whisper_bench_memcpy_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char *whisper_bench_memcpy_str(int n_threads) +{ + static std::string s; + s = ""; + char strbuf[256]; + + ggml_time_init(); + + size_t n = 20; + size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations + + // 1GB array + const size_t size = arr * 1e6; + + double sum = 0.0; + + // heat-up + { + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); + + for (size_t i = 0; i < size; i++) + src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) + { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0) * 1e-6; + + src[rand() % size] = rand() % 256; + } + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double)(n * size) / (tsum * 1e9)); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) + sum += dst[i]; + } + + free(src); + free(dst); + } + + // single-thread + { + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); + + for (size_t i = 0; i < size; i++) + src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) + { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0) * 1e-6; + + src[rand() % size] = rand() % 256; + } + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double)(n * size) / (tsum * 1e9)); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) + sum += dst[i]; + } + + free(src); + free(dst); + } + + // multi-thread + + for (int32_t k = 1; k <= n_threads; k++) + { + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); + + for (size_t i = 0; i < size; i++) + src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + auto helper = [&](int th) + { + const int64_t i0 = (th + 0) * size / k; + const int64_t i1 = (th + 1) * size / k; + + for (size_t i = 0; i < n; i++) + { + memcpy(dst + i0, src + i0, i1 - i0); + + src[i0 + rand() % (i1 - i0)] = rand() % 256; + }; + }; + + const int64_t t0 = ggml_time_us(); + + std::vector threads(k - 1); + for (int32_t th = 0; th < k - 1; ++th) + { + threads[th] = std::thread(helper, th); + } + + helper(k - 1); + + for (int32_t th = 0; th < k - 1; ++th) + { + threads[th].join(); + } + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0) * 1e-6; + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double)(n * size) / (tsum * 1e9), k); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) + sum += dst[i]; + } + + free(src); + free(dst); + } + + snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum); + s += strbuf; + + return s.c_str(); +} + +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) +{ + fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char *whisper_bench_ggml_mul_mat_str(int n_threads) +{ + static std::string s; + s = ""; + char strbuf[256]; + + ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(3llu * N_max * N_max * sizeof(float) + 3 * ggml_tensor_overhead() + ggml_graph_overhead()); + std::vector work; + + // put a bunch of random data in the buffer + for (size_t i = 0; i < buf.size(); i++) + buf[i] = i; + + for (int j = 0; j < (int)sizes.size(); j++) + { + int n_q4_0 = 0; + int n_q4_1 = 0; + int n_q5_0 = 0; + int n_q5_1 = 0; + int n_q8_0 = 0; + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_q4_0 = 0.0; + double s_q4_1 = 0.0; + double s_q5_0 = 0.0; + double s_q5_1 = 0.0; + double s_q8_0 = 0.0; + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 7; ++k) + { + const ggml_type wtype = + k == 0 ? GGML_TYPE_Q4_0 : k == 1 ? GGML_TYPE_Q4_1 + : k == 2 ? GGML_TYPE_Q5_0 + : k == 3 ? GGML_TYPE_Q5_1 + : k == 4 ? GGML_TYPE_Q8_0 + : k == 5 ? GGML_TYPE_F16 + : GGML_TYPE_F32; + + double &s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 + : k == 2 ? s_q5_0 + : k == 3 ? s_q5_1 + : k == 4 ? s_q8_0 + : k == 5 ? s_fp16 + : /*k == 6*/ s_fp32; + int &n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 + : k == 2 ? n_q5_0 + : k == 3 ? n_q5_1 + : k == 4 ? n_q8_0 + : k == 5 ? n_fp16 + : /*k == 6*/ n_fp32; + + struct ggml_init_params gparams = { + /*.mem_size =*/buf.size(), + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/false, + }; + + struct ggml_context *ctx0 = ggml_init(gparams); + + struct ggml_tensor *a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor *b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + + struct ggml_tensor *c = ggml_mul_mat(ctx0, a, b); + + struct ggml_cgraph *gf = ggml_new_graph(ctx0); + + ggml_build_forward_expand(gf, c); + + double tsum = 0.0; + + // heat-up + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); + + for (int i = 0; i < n_max; ++i) + { + const int64_t t0 = ggml_time_us(); + + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0) * 1e-6; + n++; + + if (tsum > 1.0 && n >= 3) + { + break; + } + } + + ggml_free(ctx0); + + s = ((2.0 * N * N * N * n) / tsum) * 1e-9; + } + + // Q4_0 | Q4_1 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1); + s += strbuf; + + // Q5_0 | Q5_1 | Q8_0 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0); + s += strbuf; + + // F16 | F32 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + s += strbuf; + } + + return s.c_str(); +} + +// ================================================================================================= + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) +{ + return std::max(0, std::min((int)n_samples - 1, (int)((t * WHISPER_SAMPLE_RATE) / 100))); +} + +static int64_t sample_to_timestamp(int i_sample) +{ + return (100ll * i_sample) / WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string &text) +{ + float res = 0.0f; + + for (char c : text) + { + if (c == ' ') + { + res += 0.01f; + } + else if (c == ',') + { + res += 2.00f; + } + else if (c == '.') + { + res += 3.00f; + } + else if (c == '!') + { + res += 3.00f; + } + else if (c == '?') + { + res += 3.00f; + } + else if (c >= '0' && c <= '9') + { + res += 3.00f; + } + else + { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float *signal, int n_samples, int n_samples_per_half_window) +{ + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) + { + float sum = 0; + for (int j = -hw; j <= hw; j++) + { + if (i + j >= 0 && i + j < n_samples) + { + sum += fabs(signal[i + j]); + } + } + result[i] = sum / (2 * hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context &ctx, + struct whisper_state &state, + int i_segment, + float thold_pt, + float thold_ptsum) +{ + auto &segment = state.result_all[i_segment]; + auto &tokens = segment.tokens; + + const int n_samples = state.energy.size(); + + if (n_samples == 0) + { + WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) + { + return; + } + + if (n == 1) + { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto &t_beg = state.t_beg; + auto &t_last = state.t_last; + auto &tid_last = state.tid_last; + + for (int j = 0; j < n; ++j) + { + auto &token = tokens[j]; + + if (j == 0) + { + if (token.id == whisper_token_beg(&ctx)) + { + tokens[j].t0 = t0; + tokens[j].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(&ctx); + } + else + { + tokens[j].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2 * (token.tid - whisper_token_beg(&ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) + { + if (j > 0) + { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) + { + while (p1 < n && tokens[p1].t1 < 0) + { + p1++; + } + + if (p1 >= n) + { + p1--; + } + + // printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + + if (p1 > p0) + { + double psum = 0.0; + for (int j = p0; j <= p1; j++) + { + psum += tokens[j].vlen; + } + + // printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) + { + const double ct = tokens[j - 1].t0 + dt * tokens[j - 1].vlen / psum; + + tokens[j - 1].t1 = ct; + tokens[j].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) + { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) + { + if (tokens[j].t1 < 0) + { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) + { + if (tokens[j - 1].t1 > tokens[j].t0) + { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE / 8; + + for (int j = 0; j < n; j++) + { + if (tokens[j].id >= whisper_token_eot(&ctx)) + { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) + { + sum += state.energy[k]; + } + + const float thold = 0.5 * sum / ns; + + { + int k = s0; + if (state.energy[k] > thold && j > 0) + { + while (k > 0 && state.energy[k] > thold) + { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) + { + tokens[j].t0 = tokens[j - 1].t1; + } + else + { + s0 = k; + } + } + else + { + while (state.energy[k] < thold && k < s1) + { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (state.energy[k] > thold) + { + while (k < n_samples - 1 && state.energy[k] > thold) + { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) + { + tokens[j].t1 = tokens[j + 1].t0; + } + else + { + s1 = k; + } + } + else + { + while (state.energy[k] < thold && k > s0) + { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + // for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(&ctx)) { + // continue; + // } + //} +} + +// +// token level timestamps - dtw version +// + +// n_text_layer -> total text layers on model +// n_head -> total heads per text layer on model +static std::vector get_alignment_heads_by_layer(const whisper_context_params &cparams, int il, int n_text_layer, int n_head) +{ + std::vector ret; + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { + return ret; + } + else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) + { + if (il >= n_text_layer - cparams.dtw_n_top) + { + for (int32_t i = 0; i < n_head; ++i) + { + ret.push_back(i); + } + } + } + else + { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + for (size_t i = 0; i < aheads.n_heads; ++i) + { + if (aheads.heads[i].n_text_layer == il) + { + ret.push_back(aheads.heads[i].n_head); + } + } + } + return ret; +} + +// dtw + backtrace to return found path +// based on +// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83 +static ggml_tensor *dtw_and_backtrace(ggml_context *ctx, ggml_tensor *x) +{ + WHISPER_ASSERT(ggml_n_dims(x) == 2); + + int64_t N = x->ne[0]; + int64_t M = x->ne[1]; + struct ggml_tensor *cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1); + struct ggml_tensor *trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1); + + cost = ggml_set_f32(cost, INFINITY); + trace = ggml_set_f32(trace, -1); + ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0); + + // dtw + // supposedly can be optmized by computing diagonals in parallel ? + // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most. + for (int64_t j = 1; j < M + 1; ++j) + { + for (int64_t i = 1; i < N + 1; ++i) + { + float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0); + float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0); + float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0); + + float c; + int32_t t; + if (c0 < c1 && c0 < c2) + { + c = c0; + t = 0; + } + else if (c1 < c0 && c1 < c2) + { + c = c1; + t = 1; + } + else + { + c = c2; + t = 2; + } + + c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c; + ggml_set_f32_nd(cost, i, j, 0, 0, c); + ggml_set_i32_nd(trace, i, j, 0, 0, t); + } + } + + // Backtrace + const int64_t BT_MAX_ROWS = N + M - 1; + struct ggml_tensor *bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2); + // trace[0, :] = 2; + for (int64_t i = 0; i < M + 1; ++i) + ggml_set_i32_nd(trace, 0, i, 0, 0, 2); + // trace[:, 0] = 1; + for (int64_t i = 0; i < N + 1; ++i) + ggml_set_i32_nd(trace, i, 0, 0, 0, 1); + int bt_row_idx = BT_MAX_ROWS - 1; + int64_t i = N; + int64_t j = M; + while (i > 0 || j > 0) + { + ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1); + ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1); + --bt_row_idx; + + int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0); + if (t == 0) + { + --i; + --j; + } + else if (t == 1) + { + --i; + } + else if (t == 2) + { + --j; + } + else + { + WHISPER_ASSERT(0); + } + } + + // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs) + // Clip + transpose + // This might not be entirely necessary for our case, but leaving it for now so output matrix + // is identical to dtw on openAI timing.py + const int64_t result_n_cols = BT_MAX_ROWS - bt_row_idx - 1; + ggml_tensor *r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols); + for (int64_t i = 0; i < 2; ++i) + { + for (int64_t j = 0; j < result_n_cols; ++j) + { + int32_t v = ggml_get_i32_nd(bt, j + bt_row_idx + 1, i, 0, 0); + ggml_set_i32_nd(r, i, j, 0, 0, v); + } + } + + return r; +} + +struct median_filter_user_data +{ + int filter_width; +}; + +static void median_filter(struct ggml_tensor *dst, const struct ggml_tensor *a, int ith, int /*nth*/, void *userdata) +{ + if (ith != 0) + { + return; + } + int filter_width = ((median_filter_user_data *)userdata)->filter_width; + WHISPER_ASSERT(filter_width < a->ne[2]); + WHISPER_ASSERT(filter_width % 2); + WHISPER_ASSERT(ggml_n_dims(a) == 3); + WHISPER_ASSERT(a->type == GGML_TYPE_F32); + + std::vector filter; + filter.reserve(filter_width); + for (int64_t i = 0; i < a->ne[0]; ++i) + { + for (int64_t j = 0; j < a->ne[1]; ++j) + { + for (int64_t k = 0; k < a->ne[2]; ++k) + { + for (int64_t off = -filter_width / 2; off <= filter_width / 2; ++off) + { + // "reflect" padding + int64_t idx = k + off; + if (idx < 0) + { + idx = -idx; + } + else if (idx >= a->ne[2]) + { + idx = 2 * (a->ne[2] - 1) - idx; + } + + filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0)); + } + std::sort(filter.begin(), filter.end()); + const float v = filter[filter.size() / 2]; + ggml_set_f32_nd(dst, i, j, k, 0, v); + filter.clear(); + } + } + } +} + +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads) +{ + const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx; + WHISPER_ASSERT(medfilt_width % 2); + WHISPER_ASSERT(n_frames <= n_audio_ctx * 2); + WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE); + + // FIXME: Allocating mem everytime we call this func + // Our ggml buffer should be pre-allocated somewhere during init and reused + // when we call this function + struct ggml_init_params gparams = { + /*.mem_size =*/ctx->params.dtw_mem_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, + }; + struct ggml_context *gctx = ggml_init(gparams); + + // Build token sequence that will be passed to decoder + // sot + [lang] + text result + eot + std::vector tokens = { + whisper_token_sot(ctx), + }; + if (whisper_is_multilingual(ctx)) + { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + tokens.push_back(whisper_token_lang(ctx, lang_id)); + } + const size_t sot_sequence_length = tokens.size(); + tokens.push_back(whisper_token_not(ctx)); + for (size_t i = i_segment; i < i_segment + n_segments; ++i) + { + auto &segment = state->result_all[i]; + for (auto &t : segment.tokens) + { + // Only text tokens + if (t.id < whisper_token_eot(ctx)) + { + tokens.push_back(t.id); + } + } + } + tokens.push_back(whisper_token_eot(ctx)); + + // Get result tokens, pass then along to decoder to get cross attention QKs + // used in timestamping + // Decoder already returns only alignment head QKs, already concatenated in + // one tensor. + whisper_kv_cache_clear(state->kv_self); + whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0); + whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1); + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) + { + WHISPER_LOG_INFO("DECODER FAILED\n"); + WHISPER_ASSERT(0); + } + WHISPER_ASSERT(state->aheads_cross_QKs != nullptr); + + const auto n_audio_tokens = n_frames / 2; + WHISPER_ASSERT(state->aheads_cross_QKs != NULL); + WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]); + const auto n_tokens = state->aheads_cross_QKs->ne[0]; + const auto n_heads = state->aheads_cross_QKs->ne[2]; + + // Copy data from decoder buffer to a local CPU tensor, discarding unused audio + // tokens (i.e. discarding rows at the end of tensor) + // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32); + WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs)); + ggml_tensor *w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads); + auto &data = state->aheads_cross_QKs_data; + data.resize(n_tokens * n_audio_ctx * n_heads); + ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads); + for (int k = 0; k < n_heads; ++k) + { + for (int j = 0; j < n_audio_tokens; ++j) + { + memcpy( + (char *)w->data + j * w->nb[1] + k * w->nb[2], + data.data() + j * n_tokens + k * n_tokens * n_audio_ctx, + n_tokens * sizeof(float)); + } + } + + // Normalize - in original OpenAI code, this is done over dim=-2. In this case, + // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm + // operates over columns. Afterwards, permute to a shape that facilitates mean + // operation (after median filter) + // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_norm(gctx, w, 1e-9f); + w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0, 3), 0, 2, 1, 3); + + // Pass median filter - this is done over AUDIO_TOKENS dimension. + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Same dims + median_filter_user_data mf_user_data = {medfilt_width}; + w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data); + + // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_mean(gctx, w); + w = ggml_scale(gctx, w, -1.0); + w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]); + + // Remove SOT sequence and EOT + // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS + w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]); + + // Compute + struct ggml_cgraph *gf = ggml_new_graph(gctx); + ggml_build_forward_expand(gf, w); + ggml_graph_compute_with_ctx(gctx, gf, n_threads); + + ggml_tensor *alignment = dtw_and_backtrace(gctx, w); + + // Place timestamps on segments + int32_t last_v = 0; + auto seg_i = state->result_all.begin() + i_segment; + auto tok_i = seg_i->tokens.begin(); + for (int i = 0; i < alignment->ne[1]; ++i) + { + int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0); + if (v != last_v) + { + int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0); + int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio + last_v = v; + + // Skip non-text tokens + while (!(tok_i->id < whisper_token_eot(ctx))) + { + ++tok_i; + if (tok_i == seg_i->tokens.end()) + { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + + tok_i->t_dtw = timestamp; + ++tok_i; + if (tok_i == seg_i->tokens.end()) + { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + } + + // Print DTW timestamps + /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) { + auto & segment = state->result_all[i]; + for (auto &t: segment.tokens) { + const char * tok = whisper_token_to_str(ctx, t.id); + fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100); + } + fprintf(stderr, "\n"); + }*/ + + ggml_free(gctx); +} + +void whisper_log_set(ggml_log_callback log_callback, void *user_data) +{ + g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal(ggml_log_level level, const char *format, ...) +{ + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) + { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } + else + { + char *buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void whisper_log_callback_default(ggml_log_level level, const char *text, void *user_data) +{ + (void)level; + (void)user_data; + fputs(text, stderr); + fflush(stderr); +} + +/* Whisper Encode without cross-attention */ +// ==== NEXA AI specific ==== +static struct ggml_cgraph *omni_whisper_build_graph_encoder( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_state_head = n_state / n_head; + + auto &kv_pad = wstate.kv_pad; + + // WHISPER_ASSERT(!!kv_pad.ctx); // only used in flash-attn, commented out for now + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + + struct ggml_init_params params = { + /*.mem_size =*/wstate.sched_encode.meta.size(), + /*.mem_buffer =*/wstate.sched_encode.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + + struct ggml_tensor *cur = ggml_view_tensor(ctx0, wstate.embd_conv); + + const float KQscale = 1.0f / sqrtf(float(n_state_head)); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + // static int iter = -1; + // const int n_iter = 1500/n_ctx; + + // iter = (iter + 1) % n_iter; + + // if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + // } + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0] * ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0] * ggml_element_size(model.e_pe) * n_ctx * iter; + + struct ggml_tensor *e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + + // =================================================================== + + // original: + // cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor *inpL = cur; + + for (int il = 0; il < n_layer; ++il) + { + const auto &layer = model.layers_encoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); + + // Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); + + // note: no bias for Key + struct ggml_tensor *Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + // Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); + + struct ggml_tensor *Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b); + + // ------ + + struct ggml_tensor *Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + if (wctx.params.flash_attn) + { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx * n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx * n_state, 0))); + + struct ggml_tensor *K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k) * n_state, + ggml_element_size(kv_pad.k) * n_state_head, + 0); + + struct ggml_tensor *V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v) * n_state, + ggml_element_size(kv_pad.v) * n_state_head, + 0); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } + else + { + struct ggml_tensor *K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor *V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)); + + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor *inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.mlp_ln_w), + layer.mlp_ln_b); + } + +#ifdef WHISPER_USE_FLASH_FF + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_1_b); +#endif + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // average pooling + // HACK: without ggml_cpy it will cause segmentation fault in ggml_backend_sched_graph_compute + cur = ggml_cpy(ctx0, + ggml_permute(ctx0, cur, 1, 0, 2, 3), // [ 1024 1500 1 1 ] -> [ 1500 1024 1 1 ] + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state) + ); + cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0); // [ 1500 1024 1 1 ] -> [ 750 1024 1 1 ] + cur = ggml_cpy(ctx0, + ggml_permute(ctx0, cur, 1, 0, 2, 3), // [ 750 1024 1 1 ] -> [ 1024 750 1 1 ] + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx / 2) + ); + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.e_ln_w), + model.e_ln_b); + } + + ggml_build_forward_expand(gf, cur); + + wstate.embd_enc = cur; + + // ggml_graph_print(gf); + + //////////////////////////////////////////////////////////////////////////// + + // printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); + + ggml_free(ctx0); + + return gf; +} + +struct whisper_state *whisper_encoder_init_state(whisper_context *ctx) +{ + whisper_state *state = new whisper_state; + + state->backends = whisper_backend_init(ctx->params); + if (state->backends.empty()) + { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + whisper_free_state(state); + return nullptr; + } + + state->mel_calc = whisper_mel_calc_create(state->backends[0], ctx->model.filters); + + // init 60s of random mel data + { + const int n_len = 2 * 100 * WHISPER_CHUNK_SIZE; + const int n_mel = ctx->model.filters.n_mel; + + whisper_mel_free(state->mel); + whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel); + } + +#ifdef WHISPER_USE_COREML + const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); + + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); + + state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); + if (!state->ctx_coreml) + { + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); +#ifndef WHISPER_COREML_ALLOW_FALLBACK + whisper_free_state(state); + return nullptr; +#endif + } + else + { + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); + } +#endif + + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); + + // conv allocator + { + bool ok = whisper_sched_graph_init(state->sched_conv, state->backends, + [&]() + { + return whisper_build_graph_conv(*ctx, *state, 0); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6); + } + + // encoder allocator + if (!whisper_encode_external(*state)) + { + bool ok = whisper_sched_graph_init(state->sched_encode, state->backends, + [&]() + { + return omni_whisper_build_graph_encoder(*ctx, *state); + }); + + if (!ok) + { + WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6); + } + + return state; +} + +static bool whisper_encoder_load(struct whisper_model_loader *loader, whisper_context &wctx, const char *path_model) +{ + WHISPER_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + // Initialize GGUF context + ggml_context *meta = nullptr; + gguf_context *gguf_ctx = gguf_init_from_file(path_model, {true, &meta}); + + if (!gguf_ctx) + { + WHISPER_LOG_ERROR("%s: failed to initialize GGUF context\n", __func__); + return false; + } + + auto &model = wctx.model; + + // load hparams + { + auto &hparams = model.hparams; + + hparams.n_audio_ctx = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "max_source_positions")); + hparams.n_audio_state = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "d_model")); + hparams.n_audio_head = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "encoder_attention_heads")); + hparams.n_audio_layer = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "encoder_layers")); + hparams.n_mels = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "n_mel")); + + std::string mver = ""; + + if (hparams.n_audio_layer == 4) + { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) + { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) + { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) + { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) + { + model.type = e_model::MODEL_LARGE; + + if (hparams.n_vocab == 51866) + { + mver = " v3"; + } + } + + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + } + + // create the ggml context + const size_t n_tensors = gguf_get_n_tensors(gguf_ctx); + + struct ggml_init_params params = { + /*.mem_size =*/(n_tensors + 3) * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) + { + WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); + return false; + } + + // Open the GGUF file for reading tensor data + std::ifstream fin(path_model, std::ios::binary); + if (!fin) + return fprintf(stderr, "%s: cannot open model file for loading tensors\n", __func__), gguf_free(gguf_ctx), false; + + // Create tensor structures in the GGML context + for (int i = 0; i < n_tensors; ++i) + { + const char *name = gguf_get_tensor_name(gguf_ctx, i); + // WHISPER_LOG_DEBUG("%s: Loading tensor: %s\n", __func__, name); + ggml_tensor *t = ggml_dup_tensor(model.ctx, ggml_get_tensor(meta, name)); + ggml_set_name(t, name); + } + + // allocate tensors in the backend buffers + model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params)); + if (!model.buffer) + { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__); + return false; + } + + // load tensors + { + + size_t total_size = 0; + model.n_loaded = 0; + + for (int i = 0; i < n_tensors; ++i) + { + const char *name = gguf_get_tensor_name(gguf_ctx, i); + ggml_tensor *tensor = ggml_get_tensor(model.ctx, name); + + if (!tensor) + { + WHISPER_LOG_ERROR("%s: failed to get tensor %s\n", __func__, name); + gguf_free(gguf_ctx); + return false; + } + + model.tensors[name] = tensor; + + #ifdef WHISPER_DEBUG + print_ggml_tensor_shape(name, tensor); + #endif + + int num_bytes = ggml_nbytes(tensor); + + // seek to the tensor's data offset in the GGUF file + fin.seekg(gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, i), std::ios::beg); + + if (ggml_backend_buffer_is_host(model.buffer)) + fin.read(reinterpret_cast(tensor->data), num_bytes); + else + { + std::vector read_buf(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, num_bytes); + } + + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size / 1e6); + + if (model.n_loaded == 0) + { + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } + else if (model.n_loaded != (int)model.tensors.size()) + { + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + // load mel filters + { + auto &filters = wctx.model.filters; + + filters.n_mel = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "n_mel")); + filters.n_fft = gguf_get_val_i32(gguf_ctx, gguf_find_key(gguf_ctx, "n_fft")); + + filters.data.resize(filters.n_mel * filters.n_fft); + + ggml_tensor *mel_filters_data = ggml_get_tensor(model.ctx, "mel_filters_data"); + if (ggml_backend_buffer_is_host(model.buffer)) + memcpy(filters.data.data(), mel_filters_data->data, filters.data.size() * sizeof(float)); + else + { + ggml_backend_tensor_get(mel_filters_data, filters.data.data(), 0, ggml_nbytes(mel_filters_data)); + } + BYTESWAP_FILTERS(filters); + } + + // map tensors + { + + const auto &hparams = model.hparams; + + const int n_audio_layer = hparams.n_audio_layer; + + model.layers_encoder.resize(n_audio_layer); + + // encoder + { + model.e_pe = model.tensors["audio_tower.embed_positions.weight"]; + + model.e_conv_1_w = model.tensors["audio_tower.conv1.weight"]; + model.tensors["audio_tower.conv1.bias"] = ggml_reshape_2d(model.ctx, model.tensors["audio_tower.conv1.bias"], 1, hparams.n_audio_state); // [ 1024 ] -> [ 1 1024 ] + model.e_conv_1_b = model.tensors["audio_tower.conv1.bias"]; + + model.e_conv_2_w = model.tensors["audio_tower.conv2.weight"]; + model.tensors["audio_tower.conv2.bias"] = ggml_reshape_2d(model.ctx, model.tensors["audio_tower.conv2.bias"], 1, hparams.n_audio_state); // [ 1024 ] -> [ 1 1024 ] + model.e_conv_2_b = model.tensors["audio_tower.conv2.bias"]; + + model.e_ln_w = model.tensors["audio_tower.layer_norm.weight"]; + model.e_ln_b = model.tensors["audio_tower.layer_norm.bias"]; + + for (int i = 0; i < n_audio_layer; ++i) + { + auto &layer = model.layers_encoder[i]; + + layer.mlp_ln_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".final_layer_norm.weight"]; + layer.mlp_ln_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".final_layer_norm.bias"]; + + layer.mlp_0_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".fc1.weight"]; + layer.mlp_0_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".fc1.bias"]; + + layer.mlp_1_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".fc2.weight"]; + layer.mlp_1_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".fc2.bias"]; + + layer.attn_ln_0_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn_layer_norm.weight"]; + layer.attn_ln_0_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn_layer_norm.bias"]; + + layer.attn_q_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.q_proj.weight"]; + layer.attn_q_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.q_proj.bias"]; + + layer.attn_k_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.k_proj.weight"]; + + layer.attn_v_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.v_proj.weight"]; + layer.attn_v_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.v_proj.bias"]; + + layer.attn_ln_1_w = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.out_proj.weight"]; + layer.attn_ln_1_b = model.tensors["audio_tower.layers." + std::to_string(i) + ".self_attn.out_proj.bias"]; + } + } + } + + size_t size_main = ggml_backend_buffer_get_size(model.buffer); + WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6); + + ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +struct whisper_context *whisper_encoder_init_with_params_no_state(struct whisper_model_loader *loader, struct whisper_context_params params, const char *path_model) +{ + ggml_time_init(); + + if (params.flash_attn && params.dtw_token_timestamps) + { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + + whisper_context *ctx = new whisper_context; + ctx->params = params; + + if (!whisper_encoder_load(loader, *ctx, path_model)) + { + loader->close(loader->context); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +struct whisper_context *whisper_encoder_init_from_file_with_params_no_state(const char *path_model, struct whisper_context_params params) +{ + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) + { + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + whisper_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void *ctx, void *output, size_t read_size) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.seek = [](void *ctx, size_t offset) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->seekg(offset, std::ios::cur); + }; + + loader.eof = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; + return fin->eof(); + }; + + loader.close = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; + fin->close(); + }; + + auto ctx = whisper_encoder_init_with_params_no_state(&loader, params, path_model); + + if (ctx) + { + ctx->path_model = path_model; + } + + return ctx; +} + +struct whisper_context *whisper_encoder_init_from_file_with_params(const char *path_model, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_encoder_init_from_file_with_params_no_state(path_model, params); + if (!ctx) + { + return nullptr; + } + + ctx->state = whisper_encoder_init_state(ctx); + if (!ctx->state) + { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +static bool whisper_encode_wo_cross_internal( + whisper_context &wctx, + whisper_state &wstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ + const int64_t t_start_us = ggml_time_us(); + + // conv + { + auto &sched = wstate.sched_conv.sched; + + ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + + if (whisper_encode_external(wstate)) + { + ggml_tensor *mel = ggml_graph_get_tensor(gf, "mel"); + assert(mel->ne[1] == wctx.model.hparams.n_mels); + GGML_UNUSED(mel); +#if defined(WHISPER_USE_COREML) + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *)mel->data, (float *)wstate.embd_enc->data); +#elif defined(WHISPER_USE_OPENVINO) + whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc); +#endif + } + } + + // encoder + if (!whisper_encode_external(wstate)) + { + auto &sched = wstate.sched_encode.sched; + + ggml_cgraph *gf = omni_whisper_build_graph_encoder(wctx, wstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { + return false; + } + } + + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +int whisper_encode_wo_cross(struct whisper_context *ctx, int offset, int n_threads) +{ + if (!whisper_encode_wo_cross_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) + { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_encode_wo_cross_with_state( + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ + // clear old results + auto &result_all = state->result_all; + + result_all.clear(); + + if (n_samples > 0) + { + // compute log mel spectrogram + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) + { + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) + { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); + if (lang_id < 0) + { + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); + return -3; + } + state->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + if (params.detect_language) + { + return 0; + } + } + + if (params.token_timestamps) + { + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + if (n_samples > 0) + { + state->energy = get_signal_energy(samples, n_samples, 32); + } + } + + const int seek_start = params.offset_ms / 10; + const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms / 10; + + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + 100) + { + WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start) * 10); + return 0; + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) + { + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + state->exp_n_audio_ctx = params.audio_ctx; + + // first release distilled models require the "no_timestamps" token + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866; + if (is_distil && !params.no_timestamps) + { + WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__); + params.no_timestamps = true; + } + } + + int seek = seek_start; + + // main loop + while (true) + { + if (params.progress_callback) + { + const int progress_cur = (100 * (seek - seek_start)) / (seek_end - seek_start); + + params.progress_callback( + ctx, state, progress_cur, params.progress_callback_user_data); + } + + // if only 1 second left, then stop + if (seek + 100 >= seek_end) + { + break; + } + + if (params.encoder_begin_callback) + { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) + { + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode_wo_cross_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) + { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + { + int seek_delta = 100 * WHISPER_CHUNK_SIZE; + // update audio window + seek += seek_delta; + + WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_encode_wo_cross( + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ + return whisper_encode_wo_cross_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int whisper_encode_wo_cross_parallel( + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples, + int n_processors) +{ + if (n_processors == 1) + { + return whisper_encode_wo_cross(ctx, params, samples, n_samples); + } + int ret = 0; + + // prepare separate states for each thread + std::vector states; + + const int offset_samples = (WHISPER_SAMPLE_RATE * params.offset_ms) / 1000; + const int n_samples_per_processor = (n_samples - offset_samples) / n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) + { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + const int start_samples = offset_samples + (i + 1) * n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_encode_wo_cross_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_encode_wo_cross_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) + { + workers[i].join(); + } + + const int64_t offset_t = (int64_t)params.offset_ms / 10.0; + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) + { + auto &results_i = states[i]->result_all; + + for (auto &result : results_i) + { + // correct the segment timestamp taking into account the offset + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) + { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + // print information about the audio boundaries + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) + { + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +bool is_wav_buffer(const std::string buf) { + // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format + // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html + if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") { + return false; + } + + uint32_t chunk_size = *reinterpret_cast(buf.data() + 4); + if (chunk_size + 8 != buf.size()) { + return false; + } + + return true; +} + +bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { + drwav wav; + std::vector wav_data; // used for pipe input from stdin or ffmpeg decoding output + + if (fname == "-") { + { + #ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); + #endif + + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return false; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } + else if (is_wav_buffer(fname)) { + if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from fname buffer\n"); + return false; + } + } + else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { +#if defined(WHISPER_FFMPEG) + if (ffmpeg_decode_audio(fname, wav_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s' \n", fname.c_str()); + return false; + } + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to read wav data as wav \n"); + return false; + } +#else + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; +#endif + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + if (stereo && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + if (wav.sampleRate != COMMON_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000); + drwav_uninit(&wav); + return false; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (stereo) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + + return true; +} diff --git a/examples/nexa-omni-audio/whisper.h b/examples/nexa-omni-audio/whisper.h new file mode 100644 index 000000000..3e3891f78 --- /dev/null +++ b/examples/nexa-omni-audio/whisper.h @@ -0,0 +1,686 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include "ggml.h" + +#include +#include +#include +#include +#include + +#ifdef __GNUC__ +# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define WHISPER_DEPRECATED(func, hint) func +#endif + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1) +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 +#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE) + +#define COMMON_SAMPLE_RATE 16000 // Common sample rate for audio processing (16kHz) + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // whisper_context_params cparams = whisper_context_default_params(); + // + // struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + struct whisper_state; + struct whisper_full_params; + + typedef int32_t whisper_pos; + typedef int32_t whisper_token; + typedef int32_t whisper_seq_id; + + enum whisper_alignment_heads_preset { + WHISPER_AHEADS_NONE, + WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers + WHISPER_AHEADS_CUSTOM, + WHISPER_AHEADS_TINY_EN, + WHISPER_AHEADS_TINY, + WHISPER_AHEADS_BASE_EN, + WHISPER_AHEADS_BASE, + WHISPER_AHEADS_SMALL_EN, + WHISPER_AHEADS_SMALL, + WHISPER_AHEADS_MEDIUM_EN, + WHISPER_AHEADS_MEDIUM, + WHISPER_AHEADS_LARGE_V1, + WHISPER_AHEADS_LARGE_V2, + WHISPER_AHEADS_LARGE_V3, + }; + + typedef struct whisper_ahead { + int n_text_layer; + int n_head; + } whisper_ahead; + + typedef struct whisper_aheads { + size_t n_heads; + const whisper_ahead * heads; + } whisper_aheads; + + struct whisper_context_params { + bool use_gpu; + bool flash_attn; + int gpu_device; // CUDA device + + // [EXPERIMENTAL] Token-level timestamps with DTW + bool dtw_token_timestamps; + enum whisper_alignment_heads_preset dtw_aheads_preset; + + int dtw_n_top; + struct whisper_aheads dtw_aheads; + + size_t dtw_mem_size; // TODO: remove + }; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + // [EXPERIMENTAL] Token-level timestamps with DTW + // do not use if you haven't computed token-level timestamps with dtw + // Roughly corresponds to the moment in audio in which the token was output + int64_t t_dtw; + + float vlen; // voice length of the token + } whisper_token_data; + + typedef struct whisper_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + void (*seek)(void * ctx, size_t offset); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + + // grammar element type + enum whisper_gretype { + // end of rule definition + WHISPER_GRETYPE_END = 0, + + // start of alternate definition for rule + WHISPER_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + WHISPER_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + WHISPER_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + WHISPER_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + WHISPER_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding WHISPER_GRETYPE_CHAR or + // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + WHISPER_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct whisper_grammar_element { + enum whisper_gretype type; + uint32_t value; // Unicode code point or rule ID + } whisper_grammar_element; + + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); + + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), + "use whisper_init_from_file_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader), + "use whisper_init_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model), + "use whisper_init_from_file_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader), + "use whisper_init_with_params_no_state instead" + ); + + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + + // Given a context, enable use of OpenVINO for encode inference. + // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr, + // the path will be generated from the ggml model path that was passed + // in to whisper_init_from_file. For example, if 'path_model' was + // "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be + // assumed to be "/path/to/ggml-base.en-encoder-openvino.xml". + // device: OpenVINO device to run inference on ("CPU", "GPU", etc.) + // cache_dir: Optional cache directory that can speed up init time, especially for + // GPU, by caching compiled 'blobs' there. + // Set to nullptr if not used. + // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1. + WHISPER_API int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir); + + // Frees all allocated memory + WHISPER_API void whisper_free (struct whisper_context * ctx); + WHISPER_API void whisper_free_state(struct whisper_state * state); + WHISPER_API void whisper_free_params(struct whisper_full_params * params); + WHISPER_API void whisper_free_context_params(struct whisper_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + WHISPER_API int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + WHISPER_API int whisper_encode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + // TODO: add support for multiple decoders + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + WHISPER_API int whisper_decode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0) + int whisper_token_count(struct whisper_context * ctx, const char * text); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(void); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str_full(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whisper_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); + + WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx); + WHISPER_API int whisper_model_ftype (struct whisper_context * ctx); + WHISPER_API int whisper_model_type (struct whisper_context * ctx); + + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); + + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); + + // Performance information from the default state. + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + + // Logits filter callback + // Can be used to modify the logits before sampling + // If not NULL, called after applying temperature to logits + typedef void (*whisper_logits_filter_callback)( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + + // Parameters for the whisper_full() function + // If you change the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) + int audio_ctx; // overwrite the audio context size (0 = use default) + + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + + // A regular expression that matches tokens to suppress + const char * suppress_regex; + + // tokens to provide to the whisper decoder as initial prompt + // these are prepended to any existing text context from a previous call + // use whisper_tokenize() to convert text to tokens + // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) + const char * initial_prompt; + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + bool detect_language; + + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + + struct { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + } greedy; + + struct { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf + } beam_search; + + // called for every newly generated text segment + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called on each progress update + whisper_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + + // called by each decoder to filter obtained logits + whisper_logits_filter_callback logits_filter_callback; + void * logits_filter_callback_user_data; + + const whisper_grammar_element ** grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() + WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void); + WHISPER_API struct whisper_context_params whisper_context_default_params (void); + WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); + + // Language id associated with the context's default state + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + + // Language id associated with the provided state + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + + // Get the embedding tensor + WHISPER_API ggml_tensor * whisper_full_get_embd_conv(struct whisper_context * ctx); + WHISPER_API ggml_tensor * whisper_full_get_embd_enc(struct whisper_context * ctx); + + // Get the start and end time of the specified segment + WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + + WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment); + + // Get the text of the specified segment + WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); + + // Get number of tokens in the specified segment + WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); + + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); + WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); + + // Control logging output; default behavior is to print to stderr + + WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); + + /* Whisper Encode without cross-attention */ + + WHISPER_API struct whisper_context * whisper_encoder_init_from_file_with_params(const char * path_model, struct whisper_context_params params); + + WHISPER_API struct whisper_state * whisper_encoder_init_state(struct whisper_context * ctx); + + WHISPER_API int whisper_encode_wo_cross( + struct whisper_context * ctx, + int offset, + int n_threads); + + WHISPER_API int whisper_encode_wo_cross_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + WHISPER_API bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo); + +#ifdef __cplusplus +} +#endif + +#endif