fix tests and examples

This commit is contained in:
slaren 2024-11-11 23:44:27 +01:00
parent 4428593487
commit 8768c7c45a
7 changed files with 37 additions and 53 deletions

View file

@ -774,13 +774,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
struct test {
static const std::string build_commit;
static const int build_number;
static const bool cuda;
static const bool vulkan;
static const bool kompute;
static const bool metal;
static const bool sycl;
static const bool gpu_blas;
static const bool blas;
static const std::string cpu_info;
static const std::string gpu_info;
std::string model_filename;
@ -793,7 +786,6 @@ struct test {
std::string cpu_mask;
bool cpu_strict;
int poll;
bool has_rpc;
ggml_type type_k;
ggml_type type_v;
int n_gpu_layers;
@ -822,7 +814,6 @@ struct test {
cpu_mask = inst.cpu_mask;
cpu_strict = inst.cpu_strict;
poll = inst.poll;
has_rpc = !inst.rpc_servers.empty();
type_k = inst.type_k;
type_v = inst.type_v;
n_gpu_layers = inst.n_gpu_layers;
@ -881,7 +872,6 @@ struct test {
static const std::vector<std::string> & get_fields() {
static const std::vector<std::string> fields = {
"build_commit", "build_number",
"cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_ubatch",
@ -908,8 +898,7 @@ struct test {
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
if (field == "f16_kv" || field == "no_kv_offload" ||
field == "cpu_strict" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
return BOOL;
@ -938,8 +927,6 @@ struct test {
}
std::vector<std::string> values = {
build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan),
std::to_string(metal), std::to_string(sycl), std::to_string(has_rpc), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_ubatch),
@ -967,13 +954,6 @@ struct test {
const std::string test::build_commit = LLAMA_COMMIT;
const int test::build_number = LLAMA_BUILD_NUMBER;
const bool test::cuda = !!ggml_cpu_has_cuda();
const bool test::vulkan = !!ggml_cpu_has_vulkan();
const bool test::kompute = !!ggml_cpu_has_kompute();
const bool test::metal = !!ggml_cpu_has_metal();
const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
const bool test::blas = !!ggml_cpu_has_blas();
const bool test::sycl = !!ggml_cpu_has_sycl();
const std::string test::cpu_info = get_cpu_info();
const std::string test::gpu_info = get_gpu_info();
@ -1268,9 +1248,6 @@ struct markdown_printer : public printer {
value = buf;
} else if (field == "backend") {
value = test::get_backend();
if (t.has_rpc) {
value += "+RPC";
}
} else if (field == "test") {
if (t.n_prompt > 0 && t.n_gen == 0) {
snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);

View file

@ -142,7 +142,7 @@ static bool tensor_is_contiguous(const struct ggml_tensor * tensor) {
}
static void test_roundtrip_on_chunk(
const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, bool use_reference,
const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats
) {
if (layer->type == GGML_TYPE_F16) {
@ -156,7 +156,7 @@ static void test_roundtrip_on_chunk(
if (use_reference) {
qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size);
} else {
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
qfns_cpu.from_float(input_scratch, quantized_scratch, chunk_size);
}
qfns.to_float(quantized_scratch, output_scratch, chunk_size);
@ -166,7 +166,7 @@ static void test_roundtrip_on_chunk(
// Run quantization function for a single layer and update error stats
static void test_roundtrip_on_layer(
std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, bool use_reference,
std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference,
const ggml_tensor * layer, std::vector<float> & input_scratch, std::vector<char> & quantized_scratch,
std::vector<float> & output_scratch, error_stats & total_error, int max_thread = 0
) {
@ -187,13 +187,13 @@ static void test_roundtrip_on_layer(
int num_chunks = (nelements + chunk_size - 1)/chunk_size;
if (num_chunks < 2 || max_thread < 2) {
test_roundtrip_on_chunk(layer, 0, nelements, qfns, use_reference, input_scratch_ptr, quantized_scratch.data(),
test_roundtrip_on_chunk(layer, 0, nelements, qfns, qfns_cpu, use_reference, input_scratch_ptr, quantized_scratch.data(),
output_scratch.data(), print_layer_stats ? layer_error : total_error);
} else {
auto & stats = print_layer_stats ? layer_error : total_error;
std::mutex mutex;
uint64_t counter = 0;
auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr,
auto compute = [&mutex, &counter, &stats, &qfns, &qfns_cpu, nelements, layer, use_reference, input_scratch_ptr,
&quantized_scratch, &output_scratch, chunk_size] () {
error_stats local_stats {};
while (true) {
@ -205,7 +205,7 @@ static void test_roundtrip_on_layer(
}
lock.unlock();
uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset;
test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset,
test_roundtrip_on_chunk(layer, offset, chunk, qfns, qfns_cpu, use_reference, input_scratch_ptr + offset,
quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats);
}
};
@ -371,8 +371,9 @@ int main(int argc, char ** argv) {
if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
continue;
}
const auto * qfns = ggml_get_type_traits(type);
if (qfns->from_float && qfns->to_float) {
const auto * qfns = ggml_get_type_traits(type);
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
if (qfns_cpu->from_float && qfns->to_float) {
if (params.verbose) {
printf("testing %s ...\n", ggml_type_name(type));
}
@ -393,7 +394,7 @@ int main(int argc, char ** argv) {
test_roundtrip_on_layer(
layer_name,
params.per_layer_stats,
*qfns,
*qfns, *qfns_cpu,
params.reference,
kv_tensor.second,
input_scratch,

View file

@ -19,6 +19,8 @@ if (BLAS_FOUND)
target_include_directories(ggml-blas PRIVATE . ..)
if (${GGML_BLAS_VENDOR} MATCHES "Apple")
add_compile_definitions(ACCELERATE_NEW_LAPACK)
add_compile_definitions(ACCELERATE_LAPACK_ILP64)
add_compile_definitions(GGML_BLAS_USE_ACCELERATE)
elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
# BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.

View file

@ -4931,6 +4931,8 @@ void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y,
quantize_iq2_s(x, y, 1, k, NULL);
}
// =============================== data validation
static bool validate_float(float f, size_t i) {
if (isinf(f)) {
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);

View file

@ -263,9 +263,9 @@ int main(int argc, char** argv) {
// Note, we do not include this in the timing as in practical application
// we already have the quantized model weights.
if (useQ4_1) {
funcs->from_float(x1.data(), q41.data(), kVecSize);
funcs_cpu->from_float(x1.data(), q41.data(), kVecSize);
} else {
funcs->from_float(x1.data(), q40.data(), kVecSize);
funcs_cpu->from_float(x1.data(), q40.data(), kVecSize);
}
// Now measure time the dot product needs using the "scalar" version above
@ -284,7 +284,7 @@ int main(int argc, char** argv) {
dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
}
else {
const auto * vdot = ggml_get_type_traits(funcs_cpu->vec_dot_type);
const auto * vdot = ggml_get_type_traits_cpu(funcs_cpu->vec_dot_type);
vdot->from_float(y1.data(), q8.data(), kVecSize);
if (useQ4_1) funcs_cpu->vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1);
else funcs_cpu->vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1);

View file

@ -45,22 +45,23 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
}
// Total quantization error on test data
static float total_quantization_error(const ggml_type_traits * qfns, size_t test_size, const float * test_data) {
static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
std::vector<uint8_t> tmp_q(2*test_size);
std::vector<float> tmp_out(test_size);
qfns->from_float(test_data, tmp_q.data(), test_size);
qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
return array_rmse(test_data, tmp_out.data(), test_size);
}
// Total quantization error on test data
static float reference_quantization_error(const ggml_type_traits * qfns, size_t test_size, const float * test_data) {
static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
std::vector<uint8_t> tmp_q(2*test_size);
std::vector<float> tmp_out(test_size);
std::vector<float> tmp_out_ref(test_size);
qfns->from_float(test_data, tmp_q.data(), test_size);
// FIXME: why is done twice?
qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
qfns->from_float_ref(test_data, tmp_q.data(), test_size);
@ -84,9 +85,9 @@ static float dot_product_error(
std::vector<uint8_t> tmp_q1(2*test_size);
std::vector<uint8_t> tmp_q2(2*test_size);
const auto * vdot = ggml_get_type_traits(qfns_cpu->vec_dot_type);
const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
qfns->from_float(test_data1, tmp_q1.data(), test_size);
qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
vdot->from_float(test_data2, tmp_q2.data(), test_size);
float result = INFINITY;
@ -145,8 +146,8 @@ int main(int argc, char * argv[]) {
printf("Testing %s\n", ggml_type_name((ggml_type) i));
ggml_quantize_init(ei);
if (qfns->from_float && qfns->to_float) {
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
if (qfns_cpu->from_float && qfns->to_float) {
const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
const float max_quantization_error =
type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
@ -161,7 +162,7 @@ int main(int argc, char * argv[]) {
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
}
const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
num_failed += failed;
if (failed || verbose) {

View file

@ -123,9 +123,10 @@ static void usage(char * argv[]) {
printf(" --type TYPE set test type as");
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
ggml_type type = (ggml_type) i;
const auto * qfns = ggml_get_type_traits(type);
const auto * qfns = ggml_get_type_traits(type);
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
if (ggml_type_name(type) != NULL) {
if (qfns->from_float && qfns->to_float) {
if (qfns_cpu->from_float && qfns->to_float) {
printf(" %s", ggml_type_name(type));
}
}
@ -277,7 +278,7 @@ int main(int argc, char * argv[]) {
continue;
}
if (qfns->from_float && qfns->to_float) {
if (qfns_cpu->from_float && qfns->to_float) {
printf("%s\n", ggml_type_name(type));
ggml_quantize_init(type);
@ -301,7 +302,7 @@ int main(int argc, char * argv[]) {
for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float {
qfns->from_float(test_data1, test_q1, size);
qfns_cpu->from_float(test_data1, test_q1, size);
return test_q1[0];
};
size_t quantized_size = ggml_row_size(type, size);
@ -312,7 +313,7 @@ int main(int argc, char * argv[]) {
if (params.op_dequantize_row_q) {
printf(" dequantize_row_q\n");
qfns->from_float(test_data1, test_q1, largest);
qfns_cpu->from_float(test_data1, test_q1, largest);
for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float {
@ -330,7 +331,7 @@ int main(int argc, char * argv[]) {
for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float {
const auto * vdot = ggml_get_type_traits(qfns_cpu->vec_dot_type);
const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
vdot->from_float(test_data1, test_q1, size);
return test_q1[0];
};
@ -342,8 +343,8 @@ int main(int argc, char * argv[]) {
if (params.op_vec_dot_q) {
printf(" vec_dot_q\n");
qfns->from_float(test_data1, test_q1, largest);
qfns->from_float(test_data2, test_q2, largest);
qfns_cpu->from_float(test_data1, test_q1, largest);
qfns_cpu->from_float(test_data2, test_q2, largest);
for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float {