This commit is contained in:
Jared Van Bortel 2024-01-10 11:29:04 -05:00
parent 3773e1afe7
commit 1eb8804c18
18 changed files with 2183 additions and 2081 deletions

30
ggml.c
View file

@ -2336,6 +2336,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}
void ggml_free(struct ggml_context * ctx) {
if (ctx == NULL) {
return;
}
// make this function thread safe
ggml_critical_section_start();
@ -4351,6 +4355,23 @@ struct ggml_tensor * ggml_cpy_inplace(
return ggml_cpy_impl(ctx, a, b, true);
}
struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_type type) {
bool is_node = false;
struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
ggml_format_name(result, "%s (copy)", a->name);
result->op = GGML_OP_CPY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = result;
return result;
}
// ggml_cont
static struct ggml_tensor * ggml_cont_impl(
@ -14851,7 +14872,7 @@ size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tenso
return i;
}
static struct ggml_hash_set ggml_hash_set_new(size_t size) {
struct ggml_hash_set ggml_hash_set_new(size_t size) {
size = ggml_hash_size(size);
struct ggml_hash_set result;
result.size = size;
@ -16600,7 +16621,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
return GGML_EXIT_SUCCESS;
}
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
@ -16662,14 +16683,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_MUL_MAT_ID:
{
cur = 0;
const struct ggml_tensor * src0 = node->src[2];
const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = ggml_get_op_params_i32(node, 1);
cur = GGML_PAD(cur, sizeof(int64_t)); // align
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
} break;