Add an f16 case to ggml_add_cast_impl and llama_build_lora_finetune_graphs

This commit is contained in:
Andrew Godfrey 2023-10-23 18:31:06 -07:00
parent 19097c97a8
commit 7cbf5b282c
2 changed files with 2 additions and 2 deletions

View file

@ -652,7 +652,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
if (ggml_is_quantized(a->type)) { if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32); return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
} else if (a->type == GGML_TYPE_F32) { } else if (a->type == GGML_TYPE_F32) {
return ggml_add(ctx, a, b); return ggml_add(ctx, a, b);

2
ggml.c
View file

@ -5636,7 +5636,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
// TODO: support less-strict constraint // TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a)); // GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a));
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
bool is_node = false; bool is_node = false;