Merge remote-tracking branch 'origin/master' into jinja
This commit is contained in:
commit
40db78963b
66 changed files with 2877 additions and 1247 deletions
|
@ -41,7 +41,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 2b. Test the sharded model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
@ -51,7 +51,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 3b. Test the merged model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
@ -61,7 +61,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 4b. Test the sharded model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
@ -71,7 +71,7 @@ echo
|
|||
#echo
|
||||
|
||||
# 5b. Test the merged model is loading properly
|
||||
#$MAIN --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
|
||||
#$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
|
||||
#echo PASS
|
||||
#echo
|
||||
|
||||
|
@ -81,7 +81,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 6b. Test the sharded model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
|
|
@ -683,7 +683,7 @@ struct cmd_params_instance {
|
|||
bool cpu_strict;
|
||||
int poll;
|
||||
int n_gpu_layers;
|
||||
std::string rpc_servers;
|
||||
std::string rpc_servers_str;
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
|
@ -696,8 +696,37 @@ struct cmd_params_instance {
|
|||
llama_model_params mparams = llama_model_default_params();
|
||||
|
||||
mparams.n_gpu_layers = n_gpu_layers;
|
||||
if (!rpc_servers.empty()) {
|
||||
mparams.rpc_servers = rpc_servers.c_str();
|
||||
if (!rpc_servers_str.empty()) {
|
||||
auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
|
||||
|
||||
// add RPC devices
|
||||
if (!rpc_servers.empty()) {
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
static std::vector<ggml_backend_dev_t> devices;
|
||||
devices.clear();
|
||||
for (const std::string & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
} else {
|
||||
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
devices.push_back(nullptr);
|
||||
mparams.devices = devices.data();
|
||||
}
|
||||
}
|
||||
mparams.split_mode = split_mode;
|
||||
mparams.main_gpu = main_gpu;
|
||||
|
@ -708,7 +737,7 @@ struct cmd_params_instance {
|
|||
}
|
||||
|
||||
bool equal_mparams(const cmd_params_instance & other) const {
|
||||
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers &&
|
||||
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
|
||||
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
|
||||
tensor_split == other.tensor_split;
|
||||
}
|
||||
|
|
|
@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||
jlong context_pointer,
|
||||
jlong batch_pointer,
|
||||
jstring jtext,
|
||||
jboolean format_chat,
|
||||
jint n_len
|
||||
) {
|
||||
|
||||
|
@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||
|
||||
const auto tokens_list = common_tokenize(context, text, 1);
|
||||
bool parse_special = (format_chat == JNI_TRUE);
|
||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
||||
|
||||
auto n_ctx = llama_n_ctx(context);
|
||||
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
|
||||
|
@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||
}
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
LOGi("%s", common_token_to_piece(context, id).c_str());
|
||||
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
||||
}
|
||||
|
||||
common_batch_clear(*batch);
|
||||
|
|
|
@ -65,6 +65,7 @@ class LLamaAndroid {
|
|||
context: Long,
|
||||
batch: Long,
|
||||
text: String,
|
||||
formatChat: Boolean,
|
||||
nLen: Int
|
||||
): Int
|
||||
|
||||
|
@ -115,10 +116,10 @@ class LLamaAndroid {
|
|||
}
|
||||
}
|
||||
|
||||
fun send(message: String): Flow<String> = flow {
|
||||
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
|
||||
when (val state = threadLocalState.get()) {
|
||||
is State.Loaded -> {
|
||||
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
|
||||
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
|
||||
while (ncur.value <= nlen) {
|
||||
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
|
||||
if (str == null) {
|
||||
|
|
|
@ -47,7 +47,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 3a. Test the requanted model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
@ -57,7 +57,7 @@ echo PASS
|
|||
echo
|
||||
|
||||
# 4b. Test the requanted model is loading properly
|
||||
$MAIN --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32
|
||||
$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32
|
||||
echo PASS
|
||||
echo
|
||||
|
||||
|
|
|
@ -78,3 +78,40 @@ play the audio:
|
|||
$ aplay output.wav
|
||||
```
|
||||
|
||||
### Running the example with llama-server
|
||||
Running this example with `llama-server` is also possible and requires two
|
||||
server instances to be started. One will serve the LLM model and the other
|
||||
will serve the voice decoder model.
|
||||
|
||||
The LLM model server can be started with the following command:
|
||||
```console
|
||||
$ ./build/bin/llama-server -m ./models/outetts-0.2-0.5B-q8_0.gguf --port 8020
|
||||
```
|
||||
|
||||
And the voice decoder model server can be started using:
|
||||
```console
|
||||
./build/bin/llama-server -m ./models/wavtokenizer-large-75-f16.gguf --port 8021 --embeddings --pooling none
|
||||
```
|
||||
|
||||
Then we can run [tts-outetts.py](tts-outetts.py) to generate the audio.
|
||||
|
||||
First create a virtual environment for python and install the required
|
||||
dependencies (this in only required to be done once):
|
||||
```console
|
||||
$ python3 -m venv venv
|
||||
$ source venv/bin/activate
|
||||
(venv) pip install requests numpy
|
||||
```
|
||||
|
||||
And then run the python script using:
|
||||
```conole
|
||||
(venv) python ./examples/tts/tts-outetts.py http://localhost:8020 http://localhost:8021 "Hello world"
|
||||
spectrogram generated: n_codes: 90, n_embd: 1282
|
||||
converting to audio ...
|
||||
audio generated: 28800 samples
|
||||
audio written to file "output.wav"
|
||||
```
|
||||
And to play the audio we can again use aplay or any other media player:
|
||||
```console
|
||||
$ aplay output.wav
|
||||
```
|
||||
|
|
|
@ -3,6 +3,121 @@ import sys
|
|||
#import struct
|
||||
import requests
|
||||
import re
|
||||
import struct
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def fill_hann_window(size, periodic=True):
|
||||
if periodic:
|
||||
return np.hanning(size + 1)[:-1]
|
||||
return np.hanning(size)
|
||||
|
||||
|
||||
def irfft(n_fft, complex_input):
|
||||
return np.fft.irfft(complex_input, n=n_fft)
|
||||
|
||||
|
||||
def fold(buffer, n_out, n_win, n_hop, n_pad):
|
||||
result = np.zeros(n_out)
|
||||
n_frames = len(buffer) // n_win
|
||||
|
||||
for i in range(n_frames):
|
||||
start = i * n_hop
|
||||
end = start + n_win
|
||||
result[start:end] += buffer[i * n_win:(i + 1) * n_win]
|
||||
|
||||
return result[n_pad:-n_pad] if n_pad > 0 else result
|
||||
|
||||
|
||||
def process_frame(args):
|
||||
l, n_fft, ST, hann = args
|
||||
frame = irfft(n_fft, ST[l])
|
||||
frame = frame * hann
|
||||
hann2 = hann * hann
|
||||
return frame, hann2
|
||||
|
||||
|
||||
def embd_to_audio(embd, n_codes, n_embd, n_thread=4):
|
||||
embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd)
|
||||
|
||||
n_fft = 1280
|
||||
n_hop = 320
|
||||
n_win = 1280
|
||||
n_pad = (n_win - n_hop) // 2
|
||||
n_out = (n_codes - 1) * n_hop + n_win
|
||||
|
||||
hann = fill_hann_window(n_fft, True)
|
||||
|
||||
E = np.zeros((n_embd, n_codes), dtype=np.float32)
|
||||
for l in range(n_codes):
|
||||
for k in range(n_embd):
|
||||
E[k, l] = embd[l, k]
|
||||
|
||||
half_embd = n_embd // 2
|
||||
S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64)
|
||||
|
||||
for k in range(half_embd):
|
||||
for l in range(n_codes):
|
||||
mag = E[k, l]
|
||||
phi = E[k + half_embd, l]
|
||||
|
||||
mag = np.clip(np.exp(mag), 0, 1e2)
|
||||
S[l, k] = mag * np.exp(1j * phi)
|
||||
|
||||
res = np.zeros(n_codes * n_fft)
|
||||
hann2_buffer = np.zeros(n_codes * n_fft)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=n_thread) as executor:
|
||||
args = [(l, n_fft, S, hann) for l in range(n_codes)]
|
||||
results = list(executor.map(process_frame, args))
|
||||
|
||||
for l, (frame, hann2) in enumerate(results):
|
||||
res[l*n_fft:(l+1)*n_fft] = frame
|
||||
hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2
|
||||
|
||||
audio = fold(res, n_out, n_win, n_hop, n_pad)
|
||||
env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad)
|
||||
|
||||
mask = env > 1e-10
|
||||
audio[mask] /= env[mask]
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def save_wav(filename, audio_data, sample_rate):
|
||||
num_channels = 1
|
||||
bits_per_sample = 16
|
||||
bytes_per_sample = bits_per_sample // 8
|
||||
data_size = len(audio_data) * bytes_per_sample
|
||||
byte_rate = sample_rate * num_channels * bytes_per_sample
|
||||
block_align = num_channels * bytes_per_sample
|
||||
chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes
|
||||
|
||||
header = struct.pack(
|
||||
'<4sI4s4sIHHIIHH4sI',
|
||||
b'RIFF',
|
||||
chunk_size,
|
||||
b'WAVE',
|
||||
b'fmt ',
|
||||
16, # fmt chunk size
|
||||
1, # audio format (PCM)
|
||||
num_channels,
|
||||
sample_rate,
|
||||
byte_rate,
|
||||
block_align,
|
||||
bits_per_sample,
|
||||
b'data',
|
||||
data_size
|
||||
)
|
||||
|
||||
audio_data = np.clip(audio_data * 32767, -32768, 32767)
|
||||
pcm_data = audio_data.astype(np.int16)
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(header)
|
||||
f.write(pcm_data.tobytes())
|
||||
|
||||
|
||||
def process_text(text: str):
|
||||
text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
|
||||
|
@ -170,6 +285,15 @@ n_embd = len(embd[0])
|
|||
print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
|
||||
|
||||
# post-process the spectrogram to convert to audio
|
||||
# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
|
||||
print('converting to audio ...')
|
||||
print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python')
|
||||
audio = embd_to_audio(embd, n_codes, n_embd)
|
||||
print('audio generated: %d samples' % len(audio))
|
||||
|
||||
filename = "output.wav"
|
||||
sample_rate = 24000 # sampling rate
|
||||
|
||||
# zero out first 0.25 seconds
|
||||
audio[:24000 // 4] = 0.0
|
||||
|
||||
save_wav(filename, audio, sample_rate)
|
||||
print('audio written to file "%s"' % filename)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue