integrate unified training function which may use memory allocator
the unified training function also supports arguments whether to use flash attention and/or gradient checkpointing
This commit is contained in:
parent
4ed096c6b0
commit
865c4cd3c1
1 changed files with 34 additions and 9 deletions
|
@ -4391,6 +4391,12 @@ int main(int argc, char ** argv) {
|
||||||
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
||||||
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
||||||
|
|
||||||
|
ggml_allocr * alloc = NULL;
|
||||||
|
if (params.use_alloc) {
|
||||||
|
static const size_t tensor_alignment = 32;
|
||||||
|
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
||||||
std::vector<int> train_samples;
|
std::vector<int> train_samples;
|
||||||
train_samples.push_back(0);
|
train_samples.push_back(0);
|
||||||
|
@ -4437,33 +4443,48 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||||
|
|
||||||
|
ggml_set_no_alloc(ctx0, false);
|
||||||
|
|
||||||
|
// don't use alloc for input tensors, so we can safely fill them with data
|
||||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||||
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||||
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
|
|
||||||
|
ggml_set_no_alloc(ctx0, (alloc != NULL));
|
||||||
|
|
||||||
|
if (alloc) {
|
||||||
|
ggml_allocr_reset(alloc);
|
||||||
|
}
|
||||||
|
|
||||||
opt_cb_data.tokens_input = tokens_input;
|
opt_cb_data.tokens_input = tokens_input;
|
||||||
opt_cb_data.target_logits = target_logits;
|
opt_cb_data.target_logits = target_logits;
|
||||||
opt_cb_data.target_probs = target_probs;
|
opt_cb_data.target_probs = target_probs;
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
|
||||||
|
struct ggml_cgraph * gb_tmp = (params.use_unified || params.use_alloc)
|
||||||
memset(gfbuf->data, 0, ggml_nbytes(gfbuf));
|
? ggml_new_graph(ctx0)
|
||||||
memset(gbbuf->data, 0, ggml_nbytes(gbbuf));
|
: NULL;
|
||||||
|
|
||||||
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
|
||||||
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
|
||||||
|
|
||||||
GGML_ASSERT(n_past == 0);
|
GGML_ASSERT(n_past == 0);
|
||||||
|
|
||||||
struct ggml_tensor * loss = NULL;
|
struct ggml_tensor * loss = NULL;
|
||||||
struct ggml_tensor * logits = NULL;
|
struct ggml_tensor * logits = NULL;
|
||||||
|
|
||||||
if (params.use_checkpointing) {
|
if (params.use_alloc || params.use_unified) {
|
||||||
|
loss = llama_build_train_graphs(
|
||||||
|
&model, alloc, ctx0,
|
||||||
|
gf, gb, gb_tmp,
|
||||||
|
&logits, tokens_input, target_probs,
|
||||||
|
n_tokens, n_batch,
|
||||||
|
params.use_flash,
|
||||||
|
params.use_checkpointing
|
||||||
|
);
|
||||||
|
} else if (params.use_checkpointing) {
|
||||||
loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
&model, ctx0,
|
&model, ctx0,
|
||||||
gf, gb,
|
gf, gb,
|
||||||
|
@ -4641,6 +4662,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (alloc) {
|
||||||
|
ggml_allocr_free(alloc);
|
||||||
|
}
|
||||||
|
|
||||||
delete[] compute_addr;
|
delete[] compute_addr;
|
||||||
delete[] compute_buf_0;
|
delete[] compute_buf_0;
|
||||||
delete[] compute_buf_1;
|
delete[] compute_buf_1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue