enable qwen2-audio work E2E
This commit is contained in:
parent
c7b912bdca
commit
f0d1c4fa1c
17 changed files with 13786 additions and 3 deletions
|
@ -53,4 +53,5 @@ else()
|
|||
# add_subdirectory(speculative)
|
||||
# add_subdirectory(tokenize)
|
||||
add_subdirectory(nexa-omni-audio)
|
||||
add_subdirectory(qwen2-audio)
|
||||
endif()
|
||||
|
|
|
@ -41,9 +41,15 @@ if(BUILD_SHARED_LIBS)
|
|||
|
||||
add_library(${OMNI_AUDIO_LIB}_shared SHARED $<TARGET_OBJECTS:${OMNI_AUDIO_LIB}>)
|
||||
target_link_libraries(${OMNI_AUDIO_LIB}_shared PRIVATE ggml_llama common ${WHISPER_LIB})
|
||||
# NEXA AI : must have below two lines to make Neexa SDK export the shared library to the correct location
|
||||
set_target_properties(${OMNI_AUDIO_LIB}_shared PROPERTIES
|
||||
PUBLIC_HEADER omni.h
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME "${OMNI_AUDIO_LIB}"
|
||||
)
|
||||
install(TARGETS ${OMNI_AUDIO_LIB}_shared
|
||||
LIBRARY
|
||||
PUBLIC_HEADER DESTINATION include
|
||||
)
|
||||
|
||||
# Add OMNI_AUDIO_SHARED definition when building the shared library
|
||||
|
|
|
@ -54,9 +54,9 @@ From the root directory of the repo, run commands below:
|
|||
```
|
||||
|
||||
```shell
|
||||
./build/bin/nexa-omni-cli \
|
||||
--model /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/gemma2-2b.gguf \
|
||||
--mmproj /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \
|
||||
./build/bin/nexa-qwen2-cli \
|
||||
--model /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/qwen2-audio/qwen2/Qwen2-7.8B-F16.gguf \
|
||||
--mmproj /home/azureuser/zack/ggml-project-apollo/llama.cpp/examples/audio-lm-python/qwen2-audio-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \
|
||||
--file /home/azureuser/zack/ggml-project-apollo/examples/whisper/samples/jfk.wav \
|
||||
--prompt "this conversation talks about" \
|
||||
--n-gpu-layers 27 # offload all 27 layers of gemma2 model to GPU
|
||||
|
|
54
examples/qwen2-audio/CMakeLists.txt
Normal file
54
examples/qwen2-audio/CMakeLists.txt
Normal file
|
@ -0,0 +1,54 @@
|
|||
# whisper
|
||||
|
||||
# Find the Threads package
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# build nexa-whisper-utils
|
||||
set(WHISPER_LIB nexa-whisper-utils-qwen2)
|
||||
add_library(${WHISPER_LIB} OBJECT
|
||||
whisper.cpp
|
||||
)
|
||||
target_link_libraries(${WHISPER_LIB} PRIVATE ggml_llama common Threads::Threads)
|
||||
|
||||
# add nexa-qwen2-audio-lib library
|
||||
set(QWEN2_AUDIO_LIB nexa-qwen2-audio-lib)
|
||||
add_library(${QWEN2_AUDIO_LIB} OBJECT
|
||||
qwen2.cpp
|
||||
qwen2.h
|
||||
audio-projector.cpp
|
||||
audio-projector.h
|
||||
)
|
||||
target_link_libraries(${QWEN2_AUDIO_LIB} PRIVATE ggml_llama common ${WHISPER_LIB})
|
||||
|
||||
# build the nexa-qwen2-cli
|
||||
add_executable(nexa-qwen2-cli qwen2-cli.cpp)
|
||||
target_link_libraries(nexa-qwen2-cli PRIVATE ggml_llama common Threads::Threads ${WHISPER_LIB} ${QWEN2_AUDIO_LIB})
|
||||
|
||||
# If BUILD_SHARED_LIBS is ON, also build a shared library
|
||||
if(BUILD_SHARED_LIBS)
|
||||
message(STATUS "Building shared libraries")
|
||||
|
||||
set_target_properties(${WHISPER_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(${QWEN2_AUDIO_LIB} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
add_library(${QWEN2_AUDIO_LIB}_shared SHARED $<TARGET_OBJECTS:${QWEN2_AUDIO_LIB}>)
|
||||
target_link_libraries(${QWEN2_AUDIO_LIB}_shared PRIVATE ggml_llama common ${WHISPER_LIB})
|
||||
# NEXA AI : must have below two lines to make Neexa SDK export the shared library to the correct location
|
||||
set_target_properties(${QWEN2_AUDIO_LIB}_shared PROPERTIES
|
||||
PUBLIC_HEADER qwen2.h
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME "${QWEN2_AUDIO_LIB}"
|
||||
)
|
||||
install(TARGETS ${QWEN2_AUDIO_LIB}_shared
|
||||
LIBRARY
|
||||
PUBLIC_HEADER DESTINATION include
|
||||
)
|
||||
|
||||
# Add QWEN2_AUDIO_SHARED definition when building the shared library
|
||||
target_compile_definitions(${QWEN2_AUDIO_LIB}_shared PRIVATE QWEN2_AUDIO_SHARED WHISPER_SHARED)
|
||||
|
||||
# Ensure all symbols are exported on Windows
|
||||
if(MSVC)
|
||||
set_target_properties(${QWEN2_AUDIO_LIB}_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
endif()
|
||||
endif()
|
12
examples/qwen2-audio/README.md
Normal file
12
examples/qwen2-audio/README.md
Normal file
|
@ -0,0 +1,12 @@
|
|||
## Run Qwee2-audio
|
||||
|
||||
From the root directory of the repo, run commands below:
|
||||
|
||||
```shell
|
||||
./build/bin/nexa-qwen2-cli \
|
||||
--model /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/gemma2-2b.gguf \
|
||||
--mmproj /home/azureuser/zack/ggml-project-apollo/llama.cpp.origin/examples/nano-omni-audio/nano-omni-instruct.mel-filters-audio_tower-multi_modal_projector.gguf \
|
||||
--file /home/azureuser/zack/ggml-project-apollo/examples/whisper/samples/jfk.wav \
|
||||
--prompt "this conversation talks about" \
|
||||
--n-gpu-layers 27 # offload all 27 layers of gemma2 model to GPU
|
||||
```
|
37
examples/qwen2-audio/audio-projector.cpp
Normal file
37
examples/qwen2-audio/audio-projector.cpp
Normal file
|
@ -0,0 +1,37 @@
|
|||
#include "audio-projector.h"
|
||||
#include "common-nexa.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector<float> &audio_feature_data)
|
||||
{
|
||||
// Build the computation graph for inference
|
||||
struct ggml_cgraph *gf = model.build_graph();
|
||||
// Allocate the graph tensors
|
||||
ggml_gallocr_alloc_graph(model.compute_alloc, gf);
|
||||
|
||||
// Set the input data
|
||||
struct ggml_tensor *input = ggml_graph_get_tensor(gf, "input");
|
||||
ggml_backend_tensor_set(input, audio_feature_data.data(), 0, audio_feature_data.size() * sizeof(float));
|
||||
|
||||
model.set_n_threads(0);
|
||||
|
||||
// Execute the graph on the backend
|
||||
ggml_backend_graph_compute(model.backend, gf);
|
||||
|
||||
// Return the output tensor (last node in the graph)
|
||||
return ggml_graph_get_tensor(gf, "output");
|
||||
}
|
||||
|
||||
struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor)
|
||||
{
|
||||
// Set the input data
|
||||
std::vector<float> data(ggml_nelements(audio_feature_tensor));
|
||||
ggml_backend_tensor_get(audio_feature_tensor, data.data(), 0, ggml_nbytes(audio_feature_tensor));
|
||||
|
||||
return audio_projector_inference(model, data);
|
||||
}
|
67
examples/qwen2-audio/audio-projector.h
Normal file
67
examples/qwen2-audio/audio-projector.h
Normal file
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common-nexa.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// Audio Projector
|
||||
//
|
||||
|
||||
struct audio_projector : public NexaBaseModel
|
||||
{
|
||||
|
||||
audio_projector() : NexaBaseModel()
|
||||
{
|
||||
this->hparam_names = {
|
||||
"max_source_positions",
|
||||
"d_model",
|
||||
};
|
||||
this->tensor_names = {
|
||||
"multi_modal_projector.linear.weight",
|
||||
"multi_modal_projector.linear.bias",
|
||||
};
|
||||
}
|
||||
|
||||
struct ggml_cgraph *build_graph() override
|
||||
{
|
||||
const int MAX_NODES = 64;
|
||||
size_t buf_size = ggml_tensor_overhead() * MAX_NODES + ggml_graph_overhead_custom(MAX_NODES, false);
|
||||
static std::vector<uint8_t> buf(buf_size);
|
||||
|
||||
// Create temporary GGML context for building the graph
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/buf_size,
|
||||
/*.mem_buffer =*/buf.data(),
|
||||
/*.no_alloc =*/true, // Memory will be allocated later
|
||||
};
|
||||
struct ggml_context *ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph *gf = ggml_new_graph_custom(ctx0, MAX_NODES, false); // Create new graph
|
||||
|
||||
// Create input tensor
|
||||
struct ggml_tensor *input = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,
|
||||
std::get<int32_t>(hparams["d_model"]),
|
||||
std::get<int32_t>(hparams["max_source_positions"]) / 2);
|
||||
ggml_set_name(input, "input");
|
||||
ggml_set_input(input); // Mark tensor as input
|
||||
|
||||
// weight * input + bias
|
||||
struct ggml_tensor *cur = ggml_mul_mat(ctx0, tensors["multi_modal_projector.linear.weight"], input);
|
||||
cur = ggml_add(ctx0, cur, tensors["multi_modal_projector.linear.bias"]);
|
||||
|
||||
// Set the final output
|
||||
ggml_set_name(cur, "output");
|
||||
ggml_set_output(cur);
|
||||
|
||||
ggml_build_forward_expand(gf, cur); // Expand graph with operations
|
||||
|
||||
ggml_free(ctx0); // Free temporary context
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor *audio_projector_inference(audio_projector &model, std::vector<float> &audio_feature_data);
|
||||
|
||||
struct ggml_tensor *audio_projector_inference(audio_projector &model, struct ggml_tensor *audio_feature_tensor);
|
614
examples/qwen2-audio/ggml-cpu-impl.h
Normal file
614
examples/qwen2-audio/ggml-cpu-impl.h
Normal file
|
@ -0,0 +1,614 @@
|
|||
#pragma once
|
||||
|
||||
// GGML CPU internal header
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
|
||||
//#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h> // memcpy
|
||||
#include <math.h> // fabsf
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
#define m512bh(p) p
|
||||
#define m512i(p) p
|
||||
|
||||
#else
|
||||
|
||||
#define m512bh(p) (__m512bh)(p)
|
||||
#define m512i(p) (__m512i)(p)
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Converts brain16 to float32.
|
||||
*
|
||||
* The bfloat16 floating point format has the following structure:
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───┐
|
||||
* 0b0000000000000000 brain16
|
||||
*
|
||||
* Since bf16 has the same number of exponent bits as a 32bit float,
|
||||
* encoding and decoding numbers becomes relatively straightforward.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌──┴───┐┌─┴───────────────────┐
|
||||
* 0b00000000000000000000000000000000 IEEE binary32
|
||||
*
|
||||
* For comparison, the standard fp16 format has fewer exponent bits.
|
||||
*
|
||||
* ┌sign
|
||||
* │
|
||||
* │ ┌exponent
|
||||
* │ │
|
||||
* │ │ ┌mantissa
|
||||
* │ │ │
|
||||
* │┌─┴─┐┌─┴──────┐
|
||||
* 0b0000000000000000 IEEE binary16
|
||||
*
|
||||
* @see IEEE 754-2008
|
||||
*/
|
||||
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.i = (uint32_t)h.bits << 16;
|
||||
return u.f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts float32 to brain16.
|
||||
*
|
||||
* This is binary identical with Google Brain float conversion.
|
||||
* Floats shall round to nearest even, and NANs shall be quiet.
|
||||
* Subnormals aren't flushed to zero, except perhaps when used.
|
||||
* This code should vectorize nicely if using modern compilers.
|
||||
*/
|
||||
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
||||
ggml_bf16_t h;
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} u;
|
||||
u.f = s;
|
||||
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
||||
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
||||
return h;
|
||||
}
|
||||
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
||||
return h;
|
||||
}
|
||||
|
||||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
||||
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
|
||||
#ifndef __FMA__
|
||||
#define __FMA__
|
||||
#endif
|
||||
#ifndef __F16C__
|
||||
#define __F16C__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
||||
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
||||
#ifndef __SSE3__
|
||||
#define __SSE3__
|
||||
#endif
|
||||
#ifndef __SSSE3__
|
||||
#define __SSSE3__
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#include <arm_sve.h>
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
// 16-bit float
|
||||
// on Arm, we use __fp16
|
||||
// on x86, we use uint16_t
|
||||
#if defined(__ARM_NEON)
|
||||
|
||||
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||
//
|
||||
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||
//
|
||||
#include <arm_neon.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
typedef uint16_t ggml_fp16_internal_t;
|
||||
|
||||
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
||||
|
||||
#else
|
||||
|
||||
typedef __fp16 ggml_fp16_internal_t;
|
||||
|
||||
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
||||
|
||||
#endif // _MSC_VER
|
||||
|
||||
#if !defined(__aarch64__)
|
||||
|
||||
// 32-bit ARM compatibility
|
||||
|
||||
// vaddlvq_s16
|
||||
// vpaddq_s16
|
||||
// vpaddq_s32
|
||||
// vaddvq_s32
|
||||
// vaddvq_f32
|
||||
// vmaxvq_f32
|
||||
// vcvtnq_s32_f32
|
||||
// vzip1_u8
|
||||
// vzip2_u8
|
||||
|
||||
inline static int32_t vaddlvq_s16(int16x8_t v) {
|
||||
int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
|
||||
return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
|
||||
}
|
||||
|
||||
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
||||
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
||||
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
||||
return vcombine_s16(a0, b0);
|
||||
}
|
||||
|
||||
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
||||
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
||||
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
||||
return vcombine_s32(a0, b0);
|
||||
}
|
||||
|
||||
inline static int32_t vaddvq_s32(int32x4_t v) {
|
||||
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
||||
}
|
||||
|
||||
inline static float vaddvq_f32(float32x4_t v) {
|
||||
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
||||
}
|
||||
|
||||
inline static float vmaxvq_f32(float32x4_t v) {
|
||||
return
|
||||
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
||||
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
||||
}
|
||||
|
||||
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
||||
int32x4_t res;
|
||||
|
||||
res[0] = roundf(vgetq_lane_f32(v, 0));
|
||||
res[1] = roundf(vgetq_lane_f32(v, 1));
|
||||
res[2] = roundf(vgetq_lane_f32(v, 2));
|
||||
res[3] = roundf(vgetq_lane_f32(v, 3));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
||||
uint8x8_t res;
|
||||
|
||||
res[0] = a[0]; res[1] = b[0];
|
||||
res[2] = a[1]; res[3] = b[1];
|
||||
res[4] = a[2]; res[5] = b[2];
|
||||
res[6] = a[3]; res[7] = b[3];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
||||
uint8x8_t res;
|
||||
|
||||
res[0] = a[4]; res[1] = b[4];
|
||||
res[2] = a[5]; res[3] = b[5];
|
||||
res[4] = a[6]; res[5] = b[6];
|
||||
res[6] = a[7]; res[7] = b[7];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// vld1q_s16_x2
|
||||
// vld1q_u8_x2
|
||||
// vld1q_u8_x4
|
||||
// vld1q_s8_x2
|
||||
// vld1q_s8_x4
|
||||
// TODO: double-check these work correctly
|
||||
|
||||
typedef struct ggml_int16x8x2_t {
|
||||
int16x8_t val[2];
|
||||
} ggml_int16x8x2_t;
|
||||
|
||||
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
||||
ggml_int16x8x2_t res;
|
||||
|
||||
res.val[0] = vld1q_s16(ptr + 0);
|
||||
res.val[1] = vld1q_s16(ptr + 8);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_uint8x16x2_t {
|
||||
uint8x16_t val[2];
|
||||
} ggml_uint8x16x2_t;
|
||||
|
||||
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
||||
ggml_uint8x16x2_t res;
|
||||
|
||||
res.val[0] = vld1q_u8(ptr + 0);
|
||||
res.val[1] = vld1q_u8(ptr + 16);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_uint8x16x4_t {
|
||||
uint8x16_t val[4];
|
||||
} ggml_uint8x16x4_t;
|
||||
|
||||
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
||||
ggml_uint8x16x4_t res;
|
||||
|
||||
res.val[0] = vld1q_u8(ptr + 0);
|
||||
res.val[1] = vld1q_u8(ptr + 16);
|
||||
res.val[2] = vld1q_u8(ptr + 32);
|
||||
res.val[3] = vld1q_u8(ptr + 48);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int8x16x2_t {
|
||||
int8x16_t val[2];
|
||||
} ggml_int8x16x2_t;
|
||||
|
||||
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
||||
ggml_int8x16x2_t res;
|
||||
|
||||
res.val[0] = vld1q_s8(ptr + 0);
|
||||
res.val[1] = vld1q_s8(ptr + 16);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef struct ggml_int8x16x4_t {
|
||||
int8x16_t val[4];
|
||||
} ggml_int8x16x4_t;
|
||||
|
||||
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
||||
ggml_int8x16x4_t res;
|
||||
|
||||
res.val[0] = vld1q_s8(ptr + 0);
|
||||
res.val[1] = vld1q_s8(ptr + 16);
|
||||
res.val[2] = vld1q_s8(ptr + 32);
|
||||
res.val[3] = vld1q_s8(ptr + 48);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// NOTE: not tested
|
||||
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
||||
int8x16_t res;
|
||||
|
||||
res[ 0] = a[b[ 0]];
|
||||
res[ 1] = a[b[ 1]];
|
||||
res[ 2] = a[b[ 2]];
|
||||
res[ 3] = a[b[ 3]];
|
||||
res[ 4] = a[b[ 4]];
|
||||
res[ 5] = a[b[ 5]];
|
||||
res[ 6] = a[b[ 6]];
|
||||
res[ 7] = a[b[ 7]];
|
||||
res[ 8] = a[b[ 8]];
|
||||
res[ 9] = a[b[ 9]];
|
||||
res[10] = a[b[10]];
|
||||
res[11] = a[b[11]];
|
||||
res[12] = a[b[12]];
|
||||
res[13] = a[b[13]];
|
||||
res[14] = a[b[14]];
|
||||
res[15] = a[b[15]];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// NOTE: not tested
|
||||
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
||||
uint8x16_t res;
|
||||
|
||||
res[ 0] = a[b[ 0]];
|
||||
res[ 1] = a[b[ 1]];
|
||||
res[ 2] = a[b[ 2]];
|
||||
res[ 3] = a[b[ 3]];
|
||||
res[ 4] = a[b[ 4]];
|
||||
res[ 5] = a[b[ 5]];
|
||||
res[ 6] = a[b[ 6]];
|
||||
res[ 7] = a[b[ 7]];
|
||||
res[ 8] = a[b[ 8]];
|
||||
res[ 9] = a[b[ 9]];
|
||||
res[10] = a[b[10]];
|
||||
res[11] = a[b[11]];
|
||||
res[12] = a[b[12]];
|
||||
res[13] = a[b[13]];
|
||||
res[14] = a[b[14]];
|
||||
res[15] = a[b[15]];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define ggml_int16x8x2_t int16x8x2_t
|
||||
#define ggml_uint8x16x2_t uint8x16x2_t
|
||||
#define ggml_uint8x16x4_t uint8x16x4_t
|
||||
#define ggml_int8x16x2_t int8x16x2_t
|
||||
#define ggml_int8x16x4_t int8x16x4_t
|
||||
|
||||
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
||||
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
||||
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
||||
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
||||
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
||||
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
||||
#define ggml_vqtbl1q_u8 vqtbl1q_u8
|
||||
|
||||
#endif // !defined(__aarch64__)
|
||||
|
||||
#if !defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
||||
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
|
||||
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
||||
|
||||
#endif // !defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
#endif // defined(__ARM_NEON)
|
||||
|
||||
#if defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||
|
||||
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
|
||||
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
ggml_fp16_internal_t tmp;
|
||||
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
|
||||
return (float)tmp;
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
ggml_fp16_t res;
|
||||
ggml_fp16_internal_t tmp = f;
|
||||
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
|
||||
return res;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#ifdef __wasm_simd128__
|
||||
#include <wasm_simd128.h>
|
||||
#else
|
||||
#ifdef __POWER9_VECTOR__
|
||||
#include <altivec.h>
|
||||
#undef bool
|
||||
#define bool _Bool
|
||||
#else
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <intrin.h>
|
||||
#else
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
|
||||
#if !defined(__riscv)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef __riscv_v_intrinsic
|
||||
#include <riscv_vector.h>
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch64)
|
||||
#if defined(__loongarch_asx)
|
||||
#include <lasxintrin.h>
|
||||
#endif
|
||||
#if defined(__loongarch_sx)
|
||||
#include <lsxintrin.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
|
||||
typedef union {
|
||||
int32_t i;
|
||||
float f;
|
||||
} ft_union;
|
||||
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
static __m256 __lasx_xvreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __F16C__
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
|
||||
#else
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
|
||||
#endif
|
||||
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||
/* the inline asm below is about 12% faster than the lookup method */
|
||||
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||
|
||||
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
register float f;
|
||||
register double d;
|
||||
__asm__(
|
||||
"mtfprd %0,%2\n"
|
||||
"xscvhpdp %0,%0\n"
|
||||
"frsp %1,%0\n" :
|
||||
/* temp */ "=d"(d),
|
||||
/* out */ "=f"(f):
|
||||
/* in */ "r"(h));
|
||||
return f;
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
register double d;
|
||||
register ggml_fp16_t r;
|
||||
__asm__( /* xscvdphp can work on double or single precision */
|
||||
"xscvdphp %0,%2\n"
|
||||
"mffprd %1,%0\n" :
|
||||
/* temp */ "=d"(d),
|
||||
/* out */ "=r"(r):
|
||||
/* in */ "f"(f));
|
||||
return r;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// FP16 <-> FP32
|
||||
// ref: https://github.com/Maratyszcza/FP16
|
||||
|
||||
static inline float fp32_from_bits(uint32_t w) {
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} fp32;
|
||||
fp32.as_bits = w;
|
||||
return fp32.as_value;
|
||||
}
|
||||
|
||||
static inline uint32_t fp32_to_bits(float f) {
|
||||
union {
|
||||
float as_value;
|
||||
uint32_t as_bits;
|
||||
} fp32;
|
||||
fp32.as_value = f;
|
||||
return fp32.as_bits;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
const uint32_t w = (uint32_t) h << 16;
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
const uint32_t two_w = w + w;
|
||||
|
||||
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
||||
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
|
||||
const float exp_scale = 0x1.0p-112f;
|
||||
#else
|
||||
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
|
||||
#endif
|
||||
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
||||
|
||||
const uint32_t magic_mask = UINT32_C(126) << 23;
|
||||
const float magic_bias = 0.5f;
|
||||
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
||||
|
||||
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
||||
const uint32_t result = sign |
|
||||
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
|
||||
const float scale_to_inf = 0x1.0p+112f;
|
||||
const float scale_to_zero = 0x1.0p-110f;
|
||||
#else
|
||||
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
|
||||
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
|
||||
#endif
|
||||
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
||||
|
||||
const uint32_t w = fp32_to_bits(f);
|
||||
const uint32_t shl1_w = w + w;
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
||||
if (bias < UINT32_C(0x71000000)) {
|
||||
bias = UINT32_C(0x71000000);
|
||||
}
|
||||
|
||||
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
||||
const uint32_t bits = fp32_to_bits(base);
|
||||
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
||||
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
||||
const uint32_t nonsign = exp_bits + mantissa_bits;
|
||||
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
|
||||
}
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||
|
||||
#endif // __F16C__
|
||||
|
||||
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
// precomputed f32 table for f16 (256 KB)
|
||||
// defined in ggml.c, initialized in ggml_init()
|
||||
extern float ggml_table_f32_f16[1 << 16];
|
||||
|
||||
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
||||
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
||||
// This is also true for POWER9.
|
||||
#if !defined(GGML_FP16_TO_FP32)
|
||||
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
uint16_t s;
|
||||
memcpy(&s, &f, sizeof(uint16_t));
|
||||
return ggml_table_f32_f16[s];
|
||||
}
|
||||
|
||||
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_FP32_TO_FP16)
|
||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
916
examples/qwen2-audio/main-encode.cpp
Normal file
916
examples/qwen2-audio/main-encode.cpp
Normal file
|
@ -0,0 +1,916 @@
|
|||
#include "common.h"
|
||||
#include "common-nexa.h"
|
||||
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable : 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
// helper function to replace substrings
|
||||
static void replace_all(std::string &s, const std::string &search, const std::string &replace)
|
||||
{
|
||||
for (size_t pos = 0;; pos += replace.length())
|
||||
{
|
||||
pos = s.find(search, pos);
|
||||
if (pos == std::string::npos)
|
||||
break;
|
||||
s.erase(pos, search.length());
|
||||
s.insert(pos, replace);
|
||||
}
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params
|
||||
{
|
||||
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
float grammar_penalty = 100.0f;
|
||||
float temperature = 0.0f;
|
||||
float temperature_inc = 0.2f;
|
||||
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
bool diarize = false;
|
||||
bool tinydiarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool no_prints = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool log_score = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string grammar;
|
||||
std::string grammar_rule;
|
||||
|
||||
// [TDRZ] speaker turn string
|
||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||
|
||||
// A regular expression that matches tokens to suppress
|
||||
std::string suppress_regex;
|
||||
|
||||
std::string openvino_encode_device = "CPU";
|
||||
|
||||
std::string dtw = "";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
};
|
||||
|
||||
static void whisper_print_usage(int argc, char **argv, const whisper_params ¶ms);
|
||||
|
||||
static char *whisper_param_turn_lowercase(char *in)
|
||||
{
|
||||
int string_len = strlen(in);
|
||||
for (int i = 0; i < string_len; i++)
|
||||
{
|
||||
*(in + i) = tolower((unsigned char)*(in + i));
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
static bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms)
|
||||
{
|
||||
for (int i = 1; i < argc; i++)
|
||||
{
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-")
|
||||
{
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-')
|
||||
{
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads")
|
||||
{
|
||||
params.n_threads = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-p" || arg == "--processors")
|
||||
{
|
||||
params.n_processors = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ot" || arg == "--offset-t")
|
||||
{
|
||||
params.offset_t_ms = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-on" || arg == "--offset-n")
|
||||
{
|
||||
params.offset_n = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-d" || arg == "--duration")
|
||||
{
|
||||
params.duration_ms = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-mc" || arg == "--max-context")
|
||||
{
|
||||
params.max_context = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ml" || arg == "--max-len")
|
||||
{
|
||||
params.max_len = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-bo" || arg == "--best-of")
|
||||
{
|
||||
params.best_of = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-bs" || arg == "--beam-size")
|
||||
{
|
||||
params.beam_size = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ac" || arg == "--audio-ctx")
|
||||
{
|
||||
params.audio_ctx = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-wt" || arg == "--word-thold")
|
||||
{
|
||||
params.word_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-et" || arg == "--entropy-thold")
|
||||
{
|
||||
params.entropy_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold")
|
||||
{
|
||||
params.logprob_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-tp" || arg == "--temperature")
|
||||
{
|
||||
params.temperature = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-tpi" || arg == "--temperature-inc")
|
||||
{
|
||||
params.temperature_inc = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-debug" || arg == "--debug-mode")
|
||||
{
|
||||
params.debug_mode = true;
|
||||
}
|
||||
else if (arg == "-tr" || arg == "--translate")
|
||||
{
|
||||
params.translate = true;
|
||||
}
|
||||
else if (arg == "-di" || arg == "--diarize")
|
||||
{
|
||||
params.diarize = true;
|
||||
}
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize")
|
||||
{
|
||||
params.tinydiarize = true;
|
||||
}
|
||||
else if (arg == "-sow" || arg == "--split-on-word")
|
||||
{
|
||||
params.split_on_word = true;
|
||||
}
|
||||
else if (arg == "-nf" || arg == "--no-fallback")
|
||||
{
|
||||
params.no_fallback = true;
|
||||
}
|
||||
else if (arg == "-fp" || arg == "--font-path")
|
||||
{
|
||||
params.font_path = argv[++i];
|
||||
}
|
||||
else if (arg == "-np" || arg == "--no-prints")
|
||||
{
|
||||
params.no_prints = true;
|
||||
}
|
||||
else if (arg == "-ps" || arg == "--print-special")
|
||||
{
|
||||
params.print_special = true;
|
||||
}
|
||||
else if (arg == "-pc" || arg == "--print-colors")
|
||||
{
|
||||
params.print_colors = true;
|
||||
}
|
||||
else if (arg == "-pp" || arg == "--print-progress")
|
||||
{
|
||||
params.print_progress = true;
|
||||
}
|
||||
else if (arg == "-nt" || arg == "--no-timestamps")
|
||||
{
|
||||
params.no_timestamps = true;
|
||||
}
|
||||
else if (arg == "-l" || arg == "--language")
|
||||
{
|
||||
params.language = whisper_param_turn_lowercase(argv[++i]);
|
||||
}
|
||||
else if (arg == "-dl" || arg == "--detect-language")
|
||||
{
|
||||
params.detect_language = true;
|
||||
}
|
||||
else if (arg == "--prompt")
|
||||
{
|
||||
params.prompt = argv[++i];
|
||||
}
|
||||
else if (arg == "-m" || arg == "--model")
|
||||
{
|
||||
params.model = argv[++i];
|
||||
}
|
||||
else if (arg == "-f" || arg == "--file")
|
||||
{
|
||||
params.fname_inp.emplace_back(argv[++i]);
|
||||
}
|
||||
else if (arg == "-oved" || arg == "--ov-e-device")
|
||||
{
|
||||
params.openvino_encode_device = argv[++i];
|
||||
}
|
||||
else if (arg == "-dtw" || arg == "--dtw")
|
||||
{
|
||||
params.dtw = argv[++i];
|
||||
}
|
||||
else if (arg == "-ls" || arg == "--log-score")
|
||||
{
|
||||
params.log_score = true;
|
||||
}
|
||||
else if (arg == "-ng" || arg == "--no-gpu")
|
||||
{
|
||||
params.use_gpu = false;
|
||||
}
|
||||
else if (arg == "-fa" || arg == "--flash-attn")
|
||||
{
|
||||
params.flash_attn = true;
|
||||
}
|
||||
else if (arg == "--suppress-regex")
|
||||
{
|
||||
params.suppress_regex = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar")
|
||||
{
|
||||
params.grammar = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar-rule")
|
||||
{
|
||||
params.grammar_rule = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar-penalty")
|
||||
{
|
||||
params.grammar_penalty = std::stof(argv[++i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms)
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
|
||||
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc);
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
||||
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
struct whisper_print_user_data
|
||||
{
|
||||
const whisper_params *params;
|
||||
|
||||
const std::vector<std::vector<float>> *pcmf32s;
|
||||
int progress_prev;
|
||||
};
|
||||
|
||||
static std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false)
|
||||
{
|
||||
std::string speaker = "";
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++)
|
||||
{
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1 * energy1)
|
||||
{
|
||||
speaker = "0";
|
||||
}
|
||||
else if (energy1 > 1.1 * energy0)
|
||||
{
|
||||
speaker = "1";
|
||||
}
|
||||
else
|
||||
{
|
||||
speaker = "?";
|
||||
}
|
||||
|
||||
// printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
|
||||
if (!id_only)
|
||||
{
|
||||
speaker.insert(0, "(speaker ");
|
||||
speaker.append(")");
|
||||
}
|
||||
|
||||
return speaker;
|
||||
}
|
||||
|
||||
static void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void *user_data)
|
||||
{
|
||||
int progress_step = ((whisper_print_user_data *)user_data)->params->progress_step;
|
||||
int *progress_prev = &(((whisper_print_user_data *)user_data)->progress_prev);
|
||||
if (progress >= *progress_prev + progress_step)
|
||||
{
|
||||
*progress_prev += progress_step;
|
||||
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
|
||||
}
|
||||
}
|
||||
|
||||
static void whisper_print_segment_callback(struct whisper_context *ctx, struct whisper_state * /*state*/, int n_new, void *user_data)
|
||||
{
|
||||
const auto ¶ms = *((whisper_print_user_data *)user_data)->params;
|
||||
const auto &pcmf32s = *((whisper_print_user_data *)user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0 = 0;
|
||||
int64_t t1 = 0;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++)
|
||||
{
|
||||
if (!params.no_timestamps || params.diarize)
|
||||
{
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps)
|
||||
{
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
if (params.print_colors)
|
||||
{
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j)
|
||||
{
|
||||
if (params.print_special == false)
|
||||
{
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char *text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p(ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int)k_colors.size() - 1, (int)(std::pow(p, 3) * float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const char *text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
if (params.tinydiarize)
|
||||
{
|
||||
if (whisper_full_get_segment_speaker_turn_next(ctx, i))
|
||||
{
|
||||
printf("%s", params.tdrz_speaker_turn.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
static char *escape_double_quotes_and_backslashes(const char *str)
|
||||
{
|
||||
if (str == NULL)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t escaped_length = strlen(str) + 1;
|
||||
|
||||
for (size_t i = 0; str[i] != '\0'; i++)
|
||||
{
|
||||
if (str[i] == '"' || str[i] == '\\')
|
||||
{
|
||||
escaped_length++;
|
||||
}
|
||||
}
|
||||
|
||||
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
|
||||
if (escaped == NULL)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; str[i] != '\0'; i++)
|
||||
{
|
||||
if (str[i] == '"' || str[i] == '\\')
|
||||
{
|
||||
escaped[pos++] = '\\';
|
||||
}
|
||||
escaped[pos++] = str[i];
|
||||
}
|
||||
|
||||
// no need to set zero due to calloc() being used prior
|
||||
|
||||
return escaped;
|
||||
}
|
||||
|
||||
// double quote should be escaped by another double quote. (rfc4180)
|
||||
static char *escape_double_quotes_in_csv(const char *str)
|
||||
{
|
||||
if (str == NULL)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t escaped_length = strlen(str) + 1;
|
||||
|
||||
for (size_t i = 0; str[i] != '\0'; i++)
|
||||
{
|
||||
if (str[i] == '"')
|
||||
{
|
||||
escaped_length++;
|
||||
}
|
||||
}
|
||||
|
||||
char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed
|
||||
if (escaped == NULL)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t pos = 0;
|
||||
for (size_t i = 0; str[i] != '\0'; i++)
|
||||
{
|
||||
if (str[i] == '"')
|
||||
{
|
||||
escaped[pos++] = '"';
|
||||
}
|
||||
escaped[pos++] = str[i];
|
||||
}
|
||||
|
||||
// no need to set zero due to calloc() being used prior
|
||||
|
||||
return escaped;
|
||||
}
|
||||
|
||||
static void cb_log_disable(enum ggml_log_level, const char *, void *) {}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
whisper_params params;
|
||||
|
||||
// If the only argument starts with "@", read arguments line-by-line
|
||||
// from the given file.
|
||||
std::vector<std::string> vec_args;
|
||||
if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@')
|
||||
{
|
||||
// Save the name of the executable.
|
||||
vec_args.push_back(argv[0]);
|
||||
|
||||
// Open the response file.
|
||||
char const *rspfile = argv[1] + sizeof(char);
|
||||
std::ifstream fin(rspfile);
|
||||
if (fin.is_open() == false)
|
||||
{
|
||||
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Read the entire response file.
|
||||
std::string line;
|
||||
while (std::getline(fin, line))
|
||||
{
|
||||
vec_args.push_back(line);
|
||||
}
|
||||
|
||||
// Use the contents of the response file as the command-line arguments.
|
||||
argc = static_cast<int>(vec_args.size());
|
||||
argv = static_cast<char **>(alloca(argc * sizeof(char *)));
|
||||
for (int i = 0; i < argc; ++i)
|
||||
{
|
||||
argv[i] = const_cast<char *>(vec_args[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false)
|
||||
{
|
||||
whisper_print_usage(argc, argv, params);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// remove non-existent files
|
||||
for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();)
|
||||
{
|
||||
const auto fname_inp = it->c_str();
|
||||
|
||||
if (*it != "-" && !is_file_exist(fname_inp))
|
||||
{
|
||||
fprintf(stderr, "error: input file not found '%s'\n", fname_inp);
|
||||
it = params.fname_inp.erase(it);
|
||||
continue;
|
||||
}
|
||||
|
||||
it++;
|
||||
}
|
||||
|
||||
if (params.fname_inp.empty())
|
||||
{
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
whisper_print_usage(argc, argv, params);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1)
|
||||
{
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (params.diarize && params.tinydiarize)
|
||||
{
|
||||
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (params.no_prints)
|
||||
{
|
||||
whisper_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
if (!params.dtw.empty())
|
||||
{
|
||||
cparams.dtw_token_timestamps = true;
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
||||
|
||||
if (params.dtw == "tiny")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
|
||||
if (params.dtw == "tiny.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
|
||||
if (params.dtw == "base")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
|
||||
if (params.dtw == "base.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
|
||||
if (params.dtw == "small")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
|
||||
if (params.dtw == "small.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
|
||||
if (params.dtw == "medium")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
|
||||
if (params.dtw == "medium.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
|
||||
if (params.dtw == "large.v1")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
|
||||
if (params.dtw == "large.v2")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
|
||||
if (params.dtw == "large.v3")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
|
||||
|
||||
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE)
|
||||
{
|
||||
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
if (ctx == nullptr)
|
||||
{
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
if (!params.grammar.empty())
|
||||
{
|
||||
auto &grammar = params.grammar_parsed;
|
||||
if (is_file_exist(params.grammar.c_str()))
|
||||
{
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty())
|
||||
{
|
||||
fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
|
||||
return 4;
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int)params.fname_inp.size(); ++f)
|
||||
{
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize))
|
||||
{
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!whisper_is_multilingual(ctx)) // TODO: something off here
|
||||
{
|
||||
if (params.language != "en" || params.translate)
|
||||
{
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
if (params.detect_language)
|
||||
{
|
||||
params.language = "auto";
|
||||
}
|
||||
|
||||
if (!params.no_prints)
|
||||
{
|
||||
// print system information
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads * params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
|
||||
// print some info about the processing
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size()) / WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
|
||||
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
||||
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
|
||||
wparams.temperature = params.temperature;
|
||||
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
|
||||
whisper_print_user_data user_data = {¶ms, &pcmf32s, 0};
|
||||
|
||||
const auto &grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (use_grammar)
|
||||
{
|
||||
if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end())
|
||||
{
|
||||
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (!wparams.print_realtime)
|
||||
{
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
if (wparams.print_progress)
|
||||
{
|
||||
wparams.progress_callback = whisper_print_progress_callback;
|
||||
wparams.progress_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// examples for abort mechanism
|
||||
// in examples below, we do not abort the processing, but we could if the flag is set to true
|
||||
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void *user_data)
|
||||
{
|
||||
bool is_aborted = *(bool *)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
// the callback is called before every computation - if it returns true, the computation is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.abort_callback = [](void *user_data)
|
||||
{
|
||||
bool is_aborted = *(bool *)user_data;
|
||||
return is_aborted;
|
||||
};
|
||||
wparams.abort_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_encode_wo_cross_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
return 10;
|
||||
}
|
||||
|
||||
ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx);
|
||||
// print_ggml_tensor("embd_enc", embd_enc, true);
|
||||
print_ggml_tensor_shape("embd_enc", embd_enc);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (!params.no_prints)
|
||||
{
|
||||
whisper_print_timings(ctx);
|
||||
}
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
19
examples/qwen2-audio/qwen2-cli.cpp
Normal file
19
examples/qwen2-audio/qwen2-cli.cpp
Normal file
|
@ -0,0 +1,19 @@
|
|||
#include "qwen2.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
|
||||
omni_context_params ctx_params = omni_context_default_params();
|
||||
if (!omni_context_params_parse(argc, argv, ctx_params))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
omni_context *ctx_omni = omni_init_context(ctx_params);
|
||||
|
||||
omni_process_full(ctx_omni, ctx_params);
|
||||
|
||||
omni_free(ctx_omni);
|
||||
|
||||
return 0;
|
||||
}
|
872
examples/qwen2-audio/qwen2.cpp
Normal file
872
examples/qwen2-audio/qwen2.cpp
Normal file
|
@ -0,0 +1,872 @@
|
|||
#include "qwen2.h"
|
||||
#include "audio-projector.h"
|
||||
#include "common-nexa.h"
|
||||
|
||||
#include "whisper.h"
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
// #include "arg.h"
|
||||
#include "sampling.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
//
|
||||
// Constants
|
||||
//
|
||||
|
||||
static const char *AUDIO_TOKEN = "<|AUDIO|>";
|
||||
|
||||
//
|
||||
// Whisper
|
||||
//
|
||||
|
||||
struct whisper_params
|
||||
{
|
||||
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
float grammar_penalty = 100.0f;
|
||||
float temperature = 0.0f;
|
||||
float temperature_inc = 0.2f;
|
||||
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
bool diarize = false;
|
||||
bool tinydiarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool no_prints = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool log_score = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string grammar;
|
||||
std::string grammar_rule;
|
||||
|
||||
// [TDRZ] speaker turn string
|
||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||
|
||||
// A regular expression that matches tokens to suppress
|
||||
std::string suppress_regex;
|
||||
|
||||
std::string openvino_encode_device = "CPU";
|
||||
|
||||
std::string dtw = "";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
};
|
||||
|
||||
static void whisper_print_usage(int argc, char **argv, const whisper_params ¶ms);
|
||||
|
||||
static char *whisper_param_turn_lowercase(char *in)
|
||||
{
|
||||
int string_len = strlen(in);
|
||||
for (int i = 0; i < string_len; i++)
|
||||
{
|
||||
*(in + i) = tolower((unsigned char)*(in + i));
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
static bool whisper_params_parse(int argc, char **argv, whisper_params ¶ms)
|
||||
{
|
||||
for (int i = 1; i < argc; i++)
|
||||
{
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-")
|
||||
{
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-')
|
||||
{
|
||||
// params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads")
|
||||
{
|
||||
params.n_threads = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-p" || arg == "--processors")
|
||||
{
|
||||
params.n_processors = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ot" || arg == "--offset-t")
|
||||
{
|
||||
params.offset_t_ms = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-on" || arg == "--offset-n")
|
||||
{
|
||||
params.offset_n = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-d" || arg == "--duration")
|
||||
{
|
||||
params.duration_ms = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-mc" || arg == "--max-context")
|
||||
{
|
||||
params.max_context = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ml" || arg == "--max-len")
|
||||
{
|
||||
params.max_len = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-bo" || arg == "--best-of")
|
||||
{
|
||||
params.best_of = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-bs" || arg == "--beam-size")
|
||||
{
|
||||
params.beam_size = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-ac" || arg == "--audio-ctx")
|
||||
{
|
||||
params.audio_ctx = std::stoi(argv[++i]);
|
||||
}
|
||||
else if (arg == "-wt" || arg == "--word-thold")
|
||||
{
|
||||
params.word_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-et" || arg == "--entropy-thold")
|
||||
{
|
||||
params.entropy_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold")
|
||||
{
|
||||
params.logprob_thold = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-tp" || arg == "--temperature")
|
||||
{
|
||||
params.temperature = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-tpi" || arg == "--temperature-inc")
|
||||
{
|
||||
params.temperature_inc = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "-debug" || arg == "--debug-mode")
|
||||
{
|
||||
params.debug_mode = true;
|
||||
}
|
||||
else if (arg == "-tr" || arg == "--translate")
|
||||
{
|
||||
params.translate = true;
|
||||
}
|
||||
else if (arg == "-di" || arg == "--diarize")
|
||||
{
|
||||
params.diarize = true;
|
||||
}
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize")
|
||||
{
|
||||
params.tinydiarize = true;
|
||||
}
|
||||
else if (arg == "-sow" || arg == "--split-on-word")
|
||||
{
|
||||
params.split_on_word = true;
|
||||
}
|
||||
else if (arg == "-nf" || arg == "--no-fallback")
|
||||
{
|
||||
params.no_fallback = true;
|
||||
}
|
||||
else if (arg == "-fp" || arg == "--font-path")
|
||||
{
|
||||
params.font_path = argv[++i];
|
||||
}
|
||||
else if (arg == "-np" || arg == "--no-prints")
|
||||
{
|
||||
params.no_prints = true;
|
||||
}
|
||||
else if (arg == "-ps" || arg == "--print-special")
|
||||
{
|
||||
params.print_special = true;
|
||||
}
|
||||
else if (arg == "-pc" || arg == "--print-colors")
|
||||
{
|
||||
params.print_colors = true;
|
||||
}
|
||||
else if (arg == "-pp" || arg == "--print-progress")
|
||||
{
|
||||
params.print_progress = true;
|
||||
}
|
||||
else if (arg == "-nt" || arg == "--no-timestamps")
|
||||
{
|
||||
params.no_timestamps = true;
|
||||
}
|
||||
else if (arg == "-l" || arg == "--language")
|
||||
{
|
||||
params.language = whisper_param_turn_lowercase(argv[++i]);
|
||||
}
|
||||
else if (arg == "-dl" || arg == "--detect-language")
|
||||
{
|
||||
params.detect_language = true;
|
||||
}
|
||||
else if (arg == "--prompt")
|
||||
{
|
||||
params.prompt = argv[++i];
|
||||
}
|
||||
else if (arg == "-m" || arg == "--model")
|
||||
{
|
||||
params.model = argv[++i];
|
||||
}
|
||||
else if (arg == "-f" || arg == "--file")
|
||||
{
|
||||
params.fname_inp.emplace_back(argv[++i]);
|
||||
}
|
||||
else if (arg == "-oved" || arg == "--ov-e-device")
|
||||
{
|
||||
params.openvino_encode_device = argv[++i];
|
||||
}
|
||||
else if (arg == "-dtw" || arg == "--dtw")
|
||||
{
|
||||
params.dtw = argv[++i];
|
||||
}
|
||||
else if (arg == "-ls" || arg == "--log-score")
|
||||
{
|
||||
params.log_score = true;
|
||||
}
|
||||
else if (arg == "-ng" || arg == "--no-gpu")
|
||||
{
|
||||
params.use_gpu = false;
|
||||
}
|
||||
else if (arg == "-fa" || arg == "--flash-attn")
|
||||
{
|
||||
params.flash_attn = true;
|
||||
}
|
||||
else if (arg == "--suppress-regex")
|
||||
{
|
||||
params.suppress_regex = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar")
|
||||
{
|
||||
params.grammar = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar-rule")
|
||||
{
|
||||
params.grammar_rule = argv[++i];
|
||||
}
|
||||
else if (arg == "--grammar-penalty")
|
||||
{
|
||||
params.grammar_penalty = std::stof(argv[++i]);
|
||||
}
|
||||
else if (arg == "--mmproj")
|
||||
{
|
||||
continue;
|
||||
// params.mmproj = argv[++i];
|
||||
}
|
||||
else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers")
|
||||
{
|
||||
continue;
|
||||
// params.n_gpu_layers = std::stoi(argv[++i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void whisper_print_usage(int /*argc*/, char **argv, const whisper_params ¶ms)
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
|
||||
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n", params.temperature_inc);
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
||||
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
struct whisper_context *whisper_init_context(whisper_params ¶ms)
|
||||
{
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
if (!params.dtw.empty())
|
||||
{
|
||||
cparams.dtw_token_timestamps = true;
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
||||
|
||||
if (params.dtw == "tiny")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
|
||||
if (params.dtw == "tiny.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
|
||||
if (params.dtw == "base")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
|
||||
if (params.dtw == "base.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
|
||||
if (params.dtw == "small")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
|
||||
if (params.dtw == "small.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
|
||||
if (params.dtw == "medium")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
|
||||
if (params.dtw == "medium.en")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
|
||||
if (params.dtw == "large.v1")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
|
||||
if (params.dtw == "large.v2")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
|
||||
if (params.dtw == "large.v3")
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
|
||||
|
||||
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE)
|
||||
{
|
||||
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
struct whisper_context *ctx = whisper_encoder_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (ctx == nullptr)
|
||||
{
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static struct whisper_full_params get_whisper_inference_params_from_whisper_params(whisper_params ¶ms)
|
||||
{
|
||||
|
||||
struct whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
|
||||
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
||||
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
|
||||
wparams.temperature = params.temperature;
|
||||
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
|
||||
return wparams;
|
||||
}
|
||||
|
||||
//
|
||||
// Omni
|
||||
//
|
||||
|
||||
static void omni_print_usage(int, char **argv)
|
||||
{
|
||||
LOG("\n example usage:\n");
|
||||
LOG("\n %s --model <omni/ggml-model.gguf> --mmproj <whisper/model-f16.gguf> --file <path/to/an/audio.wav> [-p \"describe the audio in detail.\"]\n", argv[0]);
|
||||
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
bool omni_context_params_parse(int argc, char **argv, omni_context_params ¶ms)
|
||||
{
|
||||
for (int i = 1; i < argc; i++)
|
||||
{
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg[0] != '-')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
omni_print_usage(argc, argv);
|
||||
exit(0);
|
||||
}
|
||||
if (arg == "--prompt")
|
||||
{
|
||||
params.prompt = argv[++i];
|
||||
}
|
||||
else if (arg == "-m" || arg == "--model")
|
||||
{
|
||||
params.model = argv[++i];
|
||||
}
|
||||
else if (arg == "-f" || arg == "--file")
|
||||
{
|
||||
params.file = argv[++i];
|
||||
}
|
||||
else if (arg == "--mmproj")
|
||||
{
|
||||
params.mmproj = argv[++i];
|
||||
}
|
||||
else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers")
|
||||
{
|
||||
params.n_gpu_layers = std::stoi(argv[++i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
omni_print_usage(argc, argv);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
omni_context_params omni_context_default_params()
|
||||
{
|
||||
omni_context_params params = {
|
||||
.model = "",
|
||||
.mmproj = "",
|
||||
.file = "",
|
||||
.prompt = "this conversation talks about",
|
||||
.n_gpu_layers = -1,
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
||||
struct omni_params
|
||||
{
|
||||
gpt_params gpt;
|
||||
whisper_params whisper;
|
||||
};
|
||||
|
||||
bool omni_params_parse(int argc, char **argv, omni_params ¶ms)
|
||||
{
|
||||
if (!gpt_params_parse(argc, argv, params.gpt))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!whisper_params_parse(argc, argv, params.whisper))
|
||||
{
|
||||
whisper_print_usage(argc, argv, params.whisper);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.gpt.model.empty() || params.gpt.mmproj.empty() || params.whisper.fname_inp.empty())
|
||||
{
|
||||
omni_print_usage(argc, argv);
|
||||
return false;
|
||||
}
|
||||
|
||||
params.whisper.model = params.gpt.mmproj;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static omni_params get_omni_params_from_context_params(omni_context_params ¶ms)
|
||||
{
|
||||
omni_params all_params = {
|
||||
.gpt = {
|
||||
.n_gpu_layers = params.n_gpu_layers,
|
||||
.model = params.model,
|
||||
.prompt = params.prompt,
|
||||
},
|
||||
.whisper = {
|
||||
.model = params.mmproj,
|
||||
.fname_inp = {params.file},
|
||||
},
|
||||
};
|
||||
|
||||
if (all_params.gpt.n_threads <= 0)
|
||||
{
|
||||
all_params.gpt.n_threads = std::thread::hardware_concurrency();
|
||||
}
|
||||
return all_params;
|
||||
}
|
||||
|
||||
static bool eval_tokens(struct llama_context *ctx_llama, std::vector<llama_token> tokens, int n_batch, int *n_past)
|
||||
{
|
||||
int N = (int)tokens.size();
|
||||
for (int i = 0; i < N; i += n_batch)
|
||||
{
|
||||
int n_eval = (int)tokens.size() - i;
|
||||
if (n_eval > n_batch)
|
||||
{
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0)))
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context *ctx_llama, int id, int *n_past)
|
||||
{
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context *ctx_llama, const char *str, int n_batch, int *n_past, bool add_bos)
|
||||
{
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
|
||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = llama_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static size_t find_audio_token(const std::string &prompt)
|
||||
{
|
||||
return prompt.find(AUDIO_TOKEN);
|
||||
}
|
||||
|
||||
struct omni_context *omni_init_context(omni_context_params ¶ms)
|
||||
{
|
||||
|
||||
omni_params all_params = get_omni_params_from_context_params(params);
|
||||
|
||||
// llama
|
||||
LLAMA_LOG_INFO("------- llama --------\n");
|
||||
|
||||
std::string prompt = all_params.gpt.prompt;
|
||||
if (prompt.empty())
|
||||
{
|
||||
prompt = "this conversation talks about";
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(all_params.gpt.numa);
|
||||
|
||||
llama_model_params model_params = llama_model_params_from_gpt_params(all_params.gpt);
|
||||
|
||||
llama_model *model = llama_load_model_from_file(all_params.gpt.model.c_str(), model_params);
|
||||
if (model == NULL)
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unable to load model\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_params_from_gpt_params(all_params.gpt);
|
||||
ctx_params.n_ctx = all_params.gpt.n_ctx < 2048 ? 2048 : all_params.gpt.n_ctx; // we need a longer context size to process image embeddings
|
||||
|
||||
llama_context *ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL)
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: failed to create the llama_context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// whisper
|
||||
LLAMA_LOG_INFO("------- whisper --------\n");
|
||||
|
||||
whisper_context *ctx_whisper = whisper_init_context(all_params.whisper);
|
||||
|
||||
// projector
|
||||
LLAMA_LOG_INFO("------- projector --------\n");
|
||||
|
||||
audio_projector *projector = new audio_projector();
|
||||
if (!projector->load_from_gguf(all_params.whisper.model))
|
||||
{
|
||||
fprintf(stderr, "Failed to load model.\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto *ctx_omni = (struct omni_context *)malloc(sizeof(omni_context));
|
||||
|
||||
ctx_omni->ctx_llama = ctx_llama;
|
||||
ctx_omni->ctx_whisper = ctx_whisper;
|
||||
ctx_omni->model = model;
|
||||
ctx_omni->projector = projector;
|
||||
|
||||
LLAMA_LOG_INFO("------- qwen2 context initialized --------\n");
|
||||
|
||||
return ctx_omni;
|
||||
}
|
||||
|
||||
void omni_free(struct omni_context *ctx_omni)
|
||||
{
|
||||
if (ctx_omni->ctx_whisper)
|
||||
{
|
||||
whisper_free(ctx_omni->ctx_whisper);
|
||||
ctx_omni->ctx_whisper = NULL;
|
||||
}
|
||||
if (ctx_omni->projector)
|
||||
{
|
||||
ctx_omni->projector->free();
|
||||
}
|
||||
|
||||
llama_free(ctx_omni->ctx_llama);
|
||||
llama_free_model(ctx_omni->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
static bool omni_eval_audio_embed(llama_context *ctx_llama, ggml_tensor *audio_embed, int n_batch, int *n_past)
|
||||
{
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
int n_audio_embed = audio_embed->ne[1];
|
||||
GGML_ASSERT(audio_embed->ne[0] == n_embd);
|
||||
|
||||
size_t audio_embed_size = ggml_nbytes(audio_embed);
|
||||
float *audio_embed_data = (float *)malloc(audio_embed_size);
|
||||
ggml_backend_tensor_get(audio_embed, audio_embed_data, 0, audio_embed_size);
|
||||
|
||||
for (int i = 0; i < n_audio_embed; i += n_batch)
|
||||
{
|
||||
int n_eval = n_audio_embed - i;
|
||||
if (n_eval > n_batch)
|
||||
{
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {
|
||||
/* n_tokens */ int32_t(n_eval),
|
||||
/* token */ nullptr,
|
||||
/* embd */ (audio_embed_data + i * n_embd),
|
||||
/* pos */ nullptr,
|
||||
/* n_seq_id */ nullptr,
|
||||
/* seq_id */ nullptr,
|
||||
/* logits */ nullptr,
|
||||
/* all_pos_0 */ *n_past,
|
||||
/* all_pos_1 */ 1,
|
||||
/* all_seq_id */ 0,
|
||||
};
|
||||
if (llama_decode(ctx_llama, batch))
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params ¶ms)
|
||||
{
|
||||
auto fname_inp = params.whisper.fname_inp[0];
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.whisper.diarize))
|
||||
{
|
||||
LLAMA_LOG_ERROR("error: failed to read WAV file '%s'\n", fname_inp.c_str());
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_full_params wparams = get_whisper_inference_params_from_whisper_params(params.whisper);
|
||||
|
||||
if (whisper_encode_wo_cross_parallel(ctx_omni->ctx_whisper, wparams, pcmf32.data(), pcmf32.size(), params.whisper.n_processors) != 0)
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: failed to process audio\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_tensor *embd_enc = whisper_full_get_embd_enc(ctx_omni->ctx_whisper);
|
||||
#ifdef NEXA_DEBUG
|
||||
print_ggml_tensor_shape("embd_enc", embd_enc);
|
||||
#endif
|
||||
|
||||
ggml_tensor *embed_proj = audio_projector_inference(*ctx_omni->projector, embd_enc);
|
||||
#ifdef NEXA_DEBUG
|
||||
print_ggml_tensor_shape("embed_proj", embed_proj);
|
||||
#endif
|
||||
|
||||
return embed_proj;
|
||||
}
|
||||
|
||||
void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params ¶ms, const std::string &prompt)
|
||||
{
|
||||
int n_past = 0;
|
||||
|
||||
int n_audio_embed = audio_embed->ne[1];
|
||||
GGML_ASSERT(params.gpt.n_predict < 0 || params.gpt.n_predict > n_audio_embed);
|
||||
|
||||
const int max_tgt_len = params.gpt.n_predict < 0 ? 256 + n_audio_embed : params.gpt.n_predict;
|
||||
std::string system_prompt, user_prompt;
|
||||
size_t audio_pos = find_audio_token(prompt);
|
||||
if (audio_pos != std::string::npos)
|
||||
{
|
||||
system_prompt = prompt.substr(0, audio_pos);
|
||||
user_prompt = prompt.substr(audio_pos + std::string(AUDIO_TOKEN).length());
|
||||
// LLAMA_LOG_INFO("system_prompt: %s\n", system_prompt.c_str());
|
||||
// LLAMA_LOG_INFO("user_prompt: %s\n", user_prompt.c_str());
|
||||
}
|
||||
// NEXA AI : major difference with nano-omni is in the prompt handling
|
||||
else
|
||||
// template from : https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct
|
||||
// <|im_start|>system
|
||||
// You are a helpful assistant.<|im_end|>
|
||||
// <|im_start|>user
|
||||
// Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>
|
||||
// What's that sound?<|im_end|>
|
||||
// <|im_start|>assistant
|
||||
{
|
||||
system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>";
|
||||
user_prompt = "<|audio_eos|>\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
|
||||
}
|
||||
|
||||
eval_string(ctx_omni->ctx_llama, system_prompt.c_str(), params.gpt.n_batch, &n_past, true);
|
||||
omni_eval_audio_embed(ctx_omni->ctx_llama, audio_embed, params.gpt.n_batch, &n_past);
|
||||
eval_string(ctx_omni->ctx_llama, user_prompt.c_str(), params.gpt.n_batch, &n_past, false);
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.gpt.sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++)
|
||||
{
|
||||
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0)
|
||||
break;
|
||||
if (strstr(tmp, "###"))
|
||||
break; // Yi-VL behavior
|
||||
printf("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>"))
|
||||
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>"))
|
||||
break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:"))
|
||||
break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
void omni_process_full(struct omni_context *ctx_omni, omni_context_params ¶ms)
|
||||
{
|
||||
omni_params all_params = get_omni_params_from_context_params(params);
|
||||
|
||||
ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
|
||||
omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
|
||||
}
|
64
examples/qwen2-audio/qwen2.h
Normal file
64
examples/qwen2-audio/qwen2.h
Normal file
|
@ -0,0 +1,64 @@
|
|||
#pragma once
|
||||
|
||||
#include "whisper.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "common.h"
|
||||
#include "common-nexa.h"
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "audio-projector.h"
|
||||
|
||||
#ifdef OMNI_AUDIO_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef OMNI_AUDIO_BUILD
|
||||
# define OMNI_AUDIO_API __declspec(dllexport)
|
||||
# else
|
||||
# define OMNI_AUDIO_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define OMNI_AUDIO_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define OMNI_AUDIO_API
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct omni_context_params
|
||||
{
|
||||
const char *model;
|
||||
const char *mmproj;
|
||||
const char *file;
|
||||
const char *prompt;
|
||||
int32_t n_gpu_layers;
|
||||
};
|
||||
|
||||
struct omni_context
|
||||
{
|
||||
struct whisper_context *ctx_whisper;
|
||||
struct audio_projector *projector;
|
||||
struct llama_context *ctx_llama;
|
||||
struct llama_model *model;
|
||||
};
|
||||
|
||||
OMNI_AUDIO_API bool omni_context_params_parse(int argc, char **argv, omni_context_params ¶ms);
|
||||
|
||||
OMNI_AUDIO_API omni_context_params omni_context_default_params();
|
||||
|
||||
OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params ¶ms);
|
||||
|
||||
OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);
|
||||
|
||||
OMNI_AUDIO_API void omni_process_full(
|
||||
struct omni_context *ctx_omni,
|
||||
omni_context_params ¶ms
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
364
examples/qwen2-audio/whisper-mel-cuda.cu
Normal file
364
examples/qwen2-audio/whisper-mel-cuda.cu
Normal file
|
@ -0,0 +1,364 @@
|
|||
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
|
||||
#include "whisper-mel-cuda.hpp"
|
||||
#include "whisper.h"
|
||||
|
||||
#include "common.cuh"
|
||||
#include <ggml-backend.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cufft.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuComplex.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <device_launch_parameters.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4324) // added padding
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
static const char* cufftGetErrorString(cufftResult_t res) {
|
||||
switch (res) {
|
||||
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
|
||||
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
|
||||
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
|
||||
case CUFFT_INVALID_TYPE: return "No longer used";
|
||||
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
|
||||
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
|
||||
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
|
||||
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
|
||||
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
|
||||
case CUFFT_UNALIGNED_DATA: return "No longer used";
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
|
||||
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
|
||||
case CUFFT_PARSE_ERROR: return "Internal plan database error";
|
||||
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
|
||||
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
|
||||
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
|
||||
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
|
||||
default: return "Unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
|
||||
|
||||
__global__ void k_fill_stft_input(
|
||||
const float * padded_samples,
|
||||
const int n_frames,
|
||||
const float * hann_window,
|
||||
float * stft_in
|
||||
) {
|
||||
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
// if (y >= n_frames) return;
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (x >= WHISPER_N_FFT) return;
|
||||
|
||||
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
|
||||
auto outLine = stft_in + y * WHISPER_N_FFT;
|
||||
|
||||
outLine[x] = line[x] * hann_window[x];
|
||||
}
|
||||
|
||||
__global__ void k_calc_magnitudes(
|
||||
const cuComplex * stft_out,
|
||||
const int n_frames,
|
||||
float * magnitudes
|
||||
) {
|
||||
auto y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
// if (y >= n_frames) return;
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (x >= WHISPER_N_FFT_HALF) return;
|
||||
|
||||
auto idx = y * WHISPER_N_FFT_HALF + x;
|
||||
|
||||
auto r = stft_out[idx].x;
|
||||
auto i = stft_out[idx].y;
|
||||
magnitudes[idx] = r * r + i * i;
|
||||
}
|
||||
|
||||
__global__ void k_calc_log_mel(
|
||||
const float * mel_data,
|
||||
const int n_mel,
|
||||
const float * max_val,
|
||||
float * log_mel
|
||||
) {
|
||||
auto x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (x >= n_mel) return;
|
||||
|
||||
float val = mel_data[x];
|
||||
|
||||
constexpr float e = 1e-10f;
|
||||
if (val < e) val = e;
|
||||
|
||||
val = log10(val);
|
||||
|
||||
const float max = log10(*max_val) - 8.f;
|
||||
if (val < max) val = max;
|
||||
|
||||
log_mel[x] = (val + 4) / 4;
|
||||
}
|
||||
|
||||
static void fill_stft_input(
|
||||
const float * padded_samples,
|
||||
int n_frames,
|
||||
const float * hann_window,
|
||||
float * stft_in,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
dim3 block(WHISPER_N_FFT, 1);
|
||||
dim3 grid(1, n_frames);
|
||||
|
||||
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
|
||||
}
|
||||
|
||||
static void calc_magnitudes(
|
||||
const cuComplex * stft_out,
|
||||
int n_frames,
|
||||
float * magnitudes,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
dim3 block(WHISPER_N_FFT_HALF, 1);
|
||||
dim3 grid(1, n_frames);
|
||||
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
|
||||
}
|
||||
|
||||
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
|
||||
|
||||
static void calc_log_mel(
|
||||
const float * mel_data,
|
||||
int n_mel,
|
||||
void * tempStorage,
|
||||
int tempStorageSize,
|
||||
float * log_mel,
|
||||
cudaStream_t stream
|
||||
) {
|
||||
float * max_val = reinterpret_cast<float *>(tempStorage);
|
||||
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
|
||||
|
||||
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
|
||||
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
|
||||
|
||||
int block = 256;
|
||||
int grid = (n_mel + block - 1) / block;
|
||||
|
||||
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
|
||||
}
|
||||
|
||||
class mel_calc_cuda : public whisper_mel_calc {
|
||||
const int m_n_mel;
|
||||
|
||||
ggml_backend_t m_backend = nullptr;
|
||||
int m_device = -1;
|
||||
|
||||
cudaStream_t m_stream = nullptr;
|
||||
cublasHandle_t m_cublas_handle = nullptr;
|
||||
|
||||
float * m_hann_window = nullptr;
|
||||
|
||||
float * m_filters = nullptr;
|
||||
|
||||
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
|
||||
int m_n_max_samples = 0;
|
||||
|
||||
size_t m_cufft_workspace_size = 0;
|
||||
void * m_cufft_workspace = nullptr;
|
||||
|
||||
size_t m_log_mel_temp_storage_size = 0;
|
||||
void * m_log_mel_temp_storage = nullptr;
|
||||
public:
|
||||
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
|
||||
: m_n_mel(filters.n_mel)
|
||||
, m_backend(backend)
|
||||
{
|
||||
ggml_backend_cuda_context* cuda_ctx = (ggml_backend_cuda_context*) m_backend->context;
|
||||
m_device = cuda_ctx->device;
|
||||
|
||||
if (ggml_cuda_info().devices[m_device].cc < 600) {
|
||||
// we've only tesed on 6.0 and higher and we've had reports of crashes on 5.0:
|
||||
// https://github.com/ggerganov/whisper.cpp/issues/2230
|
||||
// to be safe forbid anything below 6.0
|
||||
throw std::runtime_error("CUDA compute capability 6.0 or higher is required");
|
||||
}
|
||||
|
||||
ggml_cuda_set_device(m_device);
|
||||
|
||||
if (filters.n_fft != WHISPER_N_FFT_HALF) {
|
||||
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
|
||||
}
|
||||
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
|
||||
|
||||
CUDA_CHECK(cudaStreamCreate(&m_stream));
|
||||
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
|
||||
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
|
||||
|
||||
// create Hann window
|
||||
{
|
||||
auto hw = whisper_mel_calc::hann_window();
|
||||
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
}
|
||||
|
||||
// fill filters
|
||||
{
|
||||
auto& f = filters.data;
|
||||
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
}
|
||||
|
||||
// preallocate working areas enough for the most common cases (<= 30s)
|
||||
ensure_working_areas(WHISPER_N_SAMPLES);
|
||||
}
|
||||
|
||||
~mel_calc_cuda() {
|
||||
ggml_cuda_set_device(m_device);
|
||||
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
||||
CUDA_CHECK(cudaStreamDestroy(m_stream));
|
||||
CUDA_CHECK(cudaFree(m_hann_window));
|
||||
CUDA_CHECK(cudaFree(m_cufft_workspace));
|
||||
CUDA_CHECK(cudaFree(m_filters));
|
||||
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
|
||||
}
|
||||
|
||||
void ensure_working_areas(int n_samples) {
|
||||
if (n_samples <= m_n_max_samples) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
||||
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
|
||||
// cufft workspace
|
||||
{
|
||||
if (m_cufft_workspace) {
|
||||
CUDA_CHECK(cudaFree(m_cufft_workspace));
|
||||
m_cufft_workspace_size = 0;
|
||||
m_cufft_workspace = nullptr;
|
||||
}
|
||||
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
|
||||
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
|
||||
}
|
||||
|
||||
// device reduce working area
|
||||
{
|
||||
if (m_log_mel_temp_storage) {
|
||||
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
|
||||
m_log_mel_temp_storage_size = 0;
|
||||
m_log_mel_temp_storage = nullptr;
|
||||
}
|
||||
|
||||
const auto max_mels = 160;
|
||||
|
||||
size_t nbytes = 0;
|
||||
float* temp = nullptr;
|
||||
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
|
||||
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
|
||||
|
||||
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
|
||||
}
|
||||
|
||||
m_n_max_samples = n_samples;
|
||||
}
|
||||
|
||||
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
|
||||
ggml_cuda_set_device(m_device);
|
||||
ensure_working_areas(samples.len);
|
||||
|
||||
const size_t mirror_pad = WHISPER_N_FFT / 2;
|
||||
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
|
||||
|
||||
// pad
|
||||
std::vector<float> padded_samples(padded_size);
|
||||
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
|
||||
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
|
||||
|
||||
// fill the rest of the data
|
||||
// it should canonically be mirrored at the end as well,
|
||||
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
|
||||
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
|
||||
|
||||
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
|
||||
float * cu_padded_samples = nullptr;
|
||||
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
|
||||
|
||||
float * stft_in = nullptr; // contiguous buffer for stft input
|
||||
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
|
||||
|
||||
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
|
||||
|
||||
cufftComplex* stft_out;
|
||||
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
|
||||
|
||||
cufftHandle plan;
|
||||
CUFFT_CHECK(cufftCreate(&plan));
|
||||
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
|
||||
{
|
||||
size_t waSize;
|
||||
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
|
||||
assert(waSize <= m_cufft_workspace_size);
|
||||
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
|
||||
CUFFT_CHECK(cufftSetStream(plan, m_stream));
|
||||
}
|
||||
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
|
||||
|
||||
const auto n_mag_frames = n_frames - 1; // drop last frame
|
||||
float * magnitudes;
|
||||
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
|
||||
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
|
||||
|
||||
float * mel_data = nullptr;
|
||||
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
|
||||
|
||||
const float fone = 1.0f, fzero = 0.0f;
|
||||
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
|
||||
&fone,
|
||||
magnitudes, WHISPER_N_FFT_HALF,
|
||||
m_filters, WHISPER_N_FFT_HALF,
|
||||
&fzero,
|
||||
mel_data, int(n_mag_frames)));
|
||||
|
||||
whisper_mel ret;
|
||||
// Calculate semi-padded sample length to ensure compatibility
|
||||
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
|
||||
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
|
||||
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
|
||||
|
||||
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
|
||||
|
||||
calc_log_mel(
|
||||
mel_data, int(m_n_mel * n_mag_frames),
|
||||
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
|
||||
log_mels, m_stream);
|
||||
|
||||
CUDA_CHECK(cudaStreamSynchronize(m_stream));
|
||||
|
||||
// cleanup
|
||||
CUFFT_CHECK(cufftDestroy(plan));
|
||||
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
|
||||
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
|
||||
try {
|
||||
return new mel_calc_cuda(backend, filters);
|
||||
}
|
||||
catch (...) {
|
||||
// TODO: log error (but for this we would have to expose the log state to be accessible here)
|
||||
return nullptr;
|
||||
}
|
||||
}
|
3
examples/qwen2-audio/whisper-mel-cuda.hpp
Normal file
3
examples/qwen2-audio/whisper-mel-cuda.hpp
Normal file
|
@ -0,0 +1,3 @@
|
|||
#include "whisper-mel.hpp"
|
||||
|
||||
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);
|
34
examples/qwen2-audio/whisper-mel.hpp
Normal file
34
examples/qwen2-audio/whisper-mel.hpp
Normal file
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "ggml-backend.h"
|
||||
#include <vector>
|
||||
|
||||
struct whisper_mel {
|
||||
int n_len_org = 0;
|
||||
|
||||
ggml_context * ctx = nullptr;
|
||||
ggml_tensor * tensor = nullptr;
|
||||
ggml_backend_buffer_t buffer = nullptr;
|
||||
};
|
||||
|
||||
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
|
||||
|
||||
void whisper_mel_free(whisper_mel & mel);
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct whisper_span {
|
||||
T * data;
|
||||
int len;
|
||||
};
|
||||
|
||||
struct whisper_mel_calc {
|
||||
virtual ~whisper_mel_calc();
|
||||
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
|
||||
static whisper_span<const float> hann_window();
|
||||
};
|
10034
examples/qwen2-audio/whisper.cpp
Normal file
10034
examples/qwen2-audio/whisper.cpp
Normal file
File diff suppressed because it is too large
Load diff
686
examples/qwen2-audio/whisper.h
Normal file
686
examples/qwen2-audio/whisper.h
Normal file
|
@ -0,0 +1,686 @@
|
|||
#ifndef WHISPER_H
|
||||
#define WHISPER_H
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __GNUC__
|
||||
# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
||||
#elif defined(_MSC_VER)
|
||||
# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
||||
#else
|
||||
# define WHISPER_DEPRECATED(func, hint) func
|
||||
#endif
|
||||
|
||||
#ifdef WHISPER_SHARED
|
||||
# ifdef _WIN32
|
||||
# ifdef WHISPER_BUILD
|
||||
# define WHISPER_API __declspec(dllexport)
|
||||
# else
|
||||
# define WHISPER_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define WHISPER_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define WHISPER_API
|
||||
#endif
|
||||
|
||||
#define WHISPER_SAMPLE_RATE 16000
|
||||
#define WHISPER_N_FFT 400
|
||||
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
|
||||
#define WHISPER_HOP_LENGTH 160
|
||||
#define WHISPER_CHUNK_SIZE 30
|
||||
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
|
||||
|
||||
#define COMMON_SAMPLE_RATE 16000 // Common sample rate for audio processing (16kHz)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//
|
||||
// C interface
|
||||
//
|
||||
// The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
|
||||
// concurrently.
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// #include "whisper.h"
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// whisper_context_params cparams = whisper_context_default_params();
|
||||
//
|
||||
// struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams);
|
||||
//
|
||||
// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
// fprintf(stderr, "failed to process audio\n");
|
||||
// return 7;
|
||||
// }
|
||||
//
|
||||
// const int n_segments = whisper_full_n_segments(ctx);
|
||||
// for (int i = 0; i < n_segments; ++i) {
|
||||
// const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
// printf("%s", text);
|
||||
// }
|
||||
//
|
||||
// whisper_free(ctx);
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// This is a demonstration of the most straightforward usage of the library.
|
||||
// "pcmf32" contains the RAW audio data in 32-bit floating point format.
|
||||
//
|
||||
// The interface also allows for more fine-grained control over the computation, but it requires a deeper
|
||||
// understanding of how the model works.
|
||||
//
|
||||
|
||||
struct whisper_context;
|
||||
struct whisper_state;
|
||||
struct whisper_full_params;
|
||||
|
||||
typedef int32_t whisper_pos;
|
||||
typedef int32_t whisper_token;
|
||||
typedef int32_t whisper_seq_id;
|
||||
|
||||
enum whisper_alignment_heads_preset {
|
||||
WHISPER_AHEADS_NONE,
|
||||
WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
|
||||
WHISPER_AHEADS_CUSTOM,
|
||||
WHISPER_AHEADS_TINY_EN,
|
||||
WHISPER_AHEADS_TINY,
|
||||
WHISPER_AHEADS_BASE_EN,
|
||||
WHISPER_AHEADS_BASE,
|
||||
WHISPER_AHEADS_SMALL_EN,
|
||||
WHISPER_AHEADS_SMALL,
|
||||
WHISPER_AHEADS_MEDIUM_EN,
|
||||
WHISPER_AHEADS_MEDIUM,
|
||||
WHISPER_AHEADS_LARGE_V1,
|
||||
WHISPER_AHEADS_LARGE_V2,
|
||||
WHISPER_AHEADS_LARGE_V3,
|
||||
};
|
||||
|
||||
typedef struct whisper_ahead {
|
||||
int n_text_layer;
|
||||
int n_head;
|
||||
} whisper_ahead;
|
||||
|
||||
typedef struct whisper_aheads {
|
||||
size_t n_heads;
|
||||
const whisper_ahead * heads;
|
||||
} whisper_aheads;
|
||||
|
||||
struct whisper_context_params {
|
||||
bool use_gpu;
|
||||
bool flash_attn;
|
||||
int gpu_device; // CUDA device
|
||||
|
||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||
bool dtw_token_timestamps;
|
||||
enum whisper_alignment_heads_preset dtw_aheads_preset;
|
||||
|
||||
int dtw_n_top;
|
||||
struct whisper_aheads dtw_aheads;
|
||||
|
||||
size_t dtw_mem_size; // TODO: remove
|
||||
};
|
||||
|
||||
typedef struct whisper_token_data {
|
||||
whisper_token id; // token id
|
||||
whisper_token tid; // forced timestamp token id
|
||||
|
||||
float p; // probability of the token
|
||||
float plog; // log probability of the token
|
||||
float pt; // probability of the timestamp token
|
||||
float ptsum; // sum of probabilities of all timestamp tokens
|
||||
|
||||
// token-level timestamp data
|
||||
// do not use if you haven't computed token-level timestamps
|
||||
int64_t t0; // start time of the token
|
||||
int64_t t1; // end time of the token
|
||||
|
||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||
// do not use if you haven't computed token-level timestamps with dtw
|
||||
// Roughly corresponds to the moment in audio in which the token was output
|
||||
int64_t t_dtw;
|
||||
|
||||
float vlen; // voice length of the token
|
||||
} whisper_token_data;
|
||||
|
||||
typedef struct whisper_model_loader {
|
||||
void * context;
|
||||
|
||||
size_t (*read)(void * ctx, void * output, size_t read_size);
|
||||
void (*seek)(void * ctx, size_t offset);
|
||||
bool (*eof)(void * ctx);
|
||||
void (*close)(void * ctx);
|
||||
} whisper_model_loader;
|
||||
|
||||
// grammar element type
|
||||
enum whisper_gretype {
|
||||
// end of rule definition
|
||||
WHISPER_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
WHISPER_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
WHISPER_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
WHISPER_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
WHISPER_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or
|
||||
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
WHISPER_GRETYPE_CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
typedef struct whisper_grammar_element {
|
||||
enum whisper_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} whisper_grammar_element;
|
||||
|
||||
// Various functions for loading a ggml whisper model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
|
||||
// These are the same as the above, but the internal state of the context is not allocated automatically
|
||||
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
|
||||
"use whisper_init_from_file_with_params instead"
|
||||
);
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size),
|
||||
"use whisper_init_from_buffer_with_params instead"
|
||||
);
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader),
|
||||
"use whisper_init_with_params instead"
|
||||
);
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model),
|
||||
"use whisper_init_from_file_with_params_no_state instead"
|
||||
);
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size),
|
||||
"use whisper_init_from_buffer_with_params_no_state instead"
|
||||
);
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader),
|
||||
"use whisper_init_with_params_no_state instead"
|
||||
);
|
||||
|
||||
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
|
||||
|
||||
// Given a context, enable use of OpenVINO for encode inference.
|
||||
// model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
|
||||
// the path will be generated from the ggml model path that was passed
|
||||
// in to whisper_init_from_file. For example, if 'path_model' was
|
||||
// "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
|
||||
// assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
|
||||
// device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
|
||||
// cache_dir: Optional cache directory that can speed up init time, especially for
|
||||
// GPU, by caching compiled 'blobs' there.
|
||||
// Set to nullptr if not used.
|
||||
// Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
|
||||
WHISPER_API int whisper_ctx_init_openvino_encoder(
|
||||
struct whisper_context * ctx,
|
||||
const char * model_path,
|
||||
const char * device,
|
||||
const char * cache_dir);
|
||||
|
||||
// Frees all allocated memory
|
||||
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
||||
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
||||
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
|
||||
WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
// Returns 0 on success
|
||||
WHISPER_API int whisper_pcm_to_mel(
|
||||
struct whisper_context * ctx,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
WHISPER_API int whisper_pcm_to_mel_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
// Returns 0 on success
|
||||
WHISPER_API int whisper_set_mel(
|
||||
struct whisper_context * ctx,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
WHISPER_API int whisper_set_mel_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
// Returns 0 on success
|
||||
WHISPER_API int whisper_encode(
|
||||
struct whisper_context * ctx,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
WHISPER_API int whisper_encode_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
// Returns 0 on success
|
||||
// TODO: add support for multiple decoders
|
||||
WHISPER_API int whisper_decode(
|
||||
struct whisper_context * ctx,
|
||||
const whisper_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
WHISPER_API int whisper_decode_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
const whisper_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
WHISPER_API int whisper_tokenize(
|
||||
struct whisper_context * ctx,
|
||||
const char * text,
|
||||
whisper_token * tokens,
|
||||
int n_max_tokens);
|
||||
|
||||
// Return the number of tokens in the provided text
|
||||
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
|
||||
int whisper_token_count(struct whisper_context * ctx, const char * text);
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
WHISPER_API int whisper_lang_max_id(void);
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
WHISPER_API int whisper_lang_id(const char * lang);
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
||||
WHISPER_API const char * whisper_lang_str(int id);
|
||||
|
||||
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
|
||||
WHISPER_API const char * whisper_lang_str_full(int id);
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
||||
// Returns the top language id or negative on failure
|
||||
// If not null, fills the lang_probs array with the probabilities of all languages
|
||||
// The array must be whisper_lang_max_id() + 1 in size
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
WHISPER_API int whisper_lang_auto_detect(
|
||||
struct whisper_context * ctx,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs);
|
||||
|
||||
WHISPER_API int whisper_lang_auto_detect_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs);
|
||||
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
|
||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
|
||||
|
||||
WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_ftype (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_model_type (struct whisper_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to whisper_decode()
|
||||
// The logits for the last token are stored in the last row
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
|
||||
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
|
||||
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
|
||||
|
||||
|
||||
// Special tokens
|
||||
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
|
||||
|
||||
// Task tokens
|
||||
WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
|
||||
|
||||
// Performance information from the default state.
|
||||
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
|
||||
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
|
||||
|
||||
// Print system information
|
||||
WHISPER_API const char * whisper_print_system_info(void);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Available sampling strategies
|
||||
enum whisper_sampling_strategy {
|
||||
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder
|
||||
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
|
||||
};
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
|
||||
|
||||
// Progress callback
|
||||
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||
|
||||
// Logits filter callback
|
||||
// Can be used to modify the logits before sampling
|
||||
// If not NULL, called after applying temperature to logits
|
||||
typedef void (*whisper_logits_filter_callback)(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
const whisper_token_data * tokens,
|
||||
int n_tokens,
|
||||
float * logits,
|
||||
void * user_data);
|
||||
|
||||
// Parameters for the whisper_full() function
|
||||
// If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
||||
// whisper_full_default_params()
|
||||
struct whisper_full_params {
|
||||
enum whisper_sampling_strategy strategy;
|
||||
|
||||
int n_threads;
|
||||
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
|
||||
int offset_ms; // start offset in ms
|
||||
int duration_ms; // audio duration to process in ms
|
||||
|
||||
bool translate;
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
|
||||
bool print_timestamps; // print timestamps for each text segment when printing realtime
|
||||
|
||||
// [EXPERIMENTAL] token-level timestamps
|
||||
bool token_timestamps; // enable token-level timestamps
|
||||
float thold_pt; // timestamp token probability threshold (~0.01)
|
||||
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
|
||||
int max_len; // max segment length in characters
|
||||
bool split_on_word; // split on word rather than on token (when used with max_len)
|
||||
int max_tokens; // max tokens per segment (0 = no limit)
|
||||
|
||||
// [EXPERIMENTAL] speed-up techniques
|
||||
// note: these can significantly reduce the quality of the output
|
||||
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
|
||||
int audio_ctx; // overwrite the audio context size (0 = use default)
|
||||
|
||||
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
||||
bool tdrz_enable; // enable tinydiarize speaker turn detection
|
||||
|
||||
// A regular expression that matches tokens to suppress
|
||||
const char * suppress_regex;
|
||||
|
||||
// tokens to provide to the whisper decoder as initial prompt
|
||||
// these are prepended to any existing text context from a previous call
|
||||
// use whisper_tokenize() to convert text to tokens
|
||||
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
|
||||
const char * initial_prompt;
|
||||
const whisper_token * prompt_tokens;
|
||||
int prompt_n_tokens;
|
||||
|
||||
// for auto-detection, set to nullptr, "" or "auto"
|
||||
const char * language;
|
||||
bool detect_language;
|
||||
|
||||
// common decoding parameters:
|
||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
|
||||
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
||||
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
|
||||
|
||||
// fallback parameters
|
||||
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
|
||||
float temperature_inc;
|
||||
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
||||
float logprob_thold;
|
||||
float no_speech_thold; // TODO: not implemented
|
||||
|
||||
struct {
|
||||
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
||||
} greedy;
|
||||
|
||||
struct {
|
||||
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
|
||||
|
||||
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
|
||||
} beam_search;
|
||||
|
||||
// called for every newly generated text segment
|
||||
whisper_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
// called on each progress update
|
||||
whisper_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// called each time before the encoder starts
|
||||
whisper_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
|
||||
// called each time before ggml computation starts
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_user_data;
|
||||
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
|
||||
const whisper_grammar_element ** grammar_rules;
|
||||
size_t n_grammar_rules;
|
||||
size_t i_start_rule;
|
||||
float grammar_penalty;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
|
||||
WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void);
|
||||
WHISPER_API struct whisper_context_params whisper_context_default_params (void);
|
||||
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
|
||||
WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy);
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Not thread safe for same context
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
WHISPER_API int whisper_full(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
WHISPER_API int whisper_full_with_state(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_state * state,
|
||||
struct whisper_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||
// Result is stored in the default state of the context
|
||||
// Not thread safe if executed in parallel on the same context.
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
WHISPER_API int whisper_full_parallel(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_full_params params,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_processors);
|
||||
|
||||
// Number of generated text segments
|
||||
// A segment can be a few words, a sentence, or even a paragraph.
|
||||
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
|
||||
|
||||
// Language id associated with the context's default state
|
||||
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
|
||||
|
||||
// Language id associated with the provided state
|
||||
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
|
||||
|
||||
// Get the embedding tensor
|
||||
WHISPER_API ggml_tensor * whisper_full_get_embd_conv(struct whisper_context * ctx);
|
||||
WHISPER_API ggml_tensor * whisper_full_get_embd_enc(struct whisper_context * ctx);
|
||||
|
||||
// Get the start and end time of the specified segment
|
||||
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
// Get whether the next segment is predicted as a speaker turn
|
||||
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
// Get the text of the specified segment
|
||||
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
// Get number of tokens in the specified segment
|
||||
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
// Get the token text of the specified token in the specified segment
|
||||
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
|
||||
|
||||
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get token data for the specified token in the specified segment
|
||||
// This contains probabilities, timestamps, etc.
|
||||
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get the probability of the specified token in the specified segment
|
||||
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Temporary helpers needed for exposing ggml interface
|
||||
|
||||
WHISPER_API int whisper_bench_memcpy (int n_threads);
|
||||
WHISPER_API const char * whisper_bench_memcpy_str (int n_threads);
|
||||
WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads);
|
||||
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
|
||||
|
||||
// Control logging output; default behavior is to print to stderr
|
||||
|
||||
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
/* Whisper Encode without cross-attention */
|
||||
|
||||
WHISPER_API struct whisper_context * whisper_encoder_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
|
||||
|
||||
WHISPER_API struct whisper_state * whisper_encoder_init_state(struct whisper_context * ctx);
|
||||
|
||||
WHISPER_API int whisper_encode_wo_cross(
|
||||
struct whisper_context * ctx,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
WHISPER_API int whisper_encode_wo_cross_parallel(
|
||||
struct whisper_context * ctx,
|
||||
struct whisper_full_params params,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_processors);
|
||||
|
||||
WHISPER_API bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
Loading…
Add table
Add a link
Reference in a new issue