integrate unified training function which may use memory allocator

the unified training function also supports arguments whether to use flash attention and/or gradient checkpointing
This commit is contained in:
xaedes 2023-08-14 18:12:58 +02:00
parent 4ed096c6b0
commit 865c4cd3c1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -4391,6 +4391,12 @@ int main(int argc, char ** argv) {
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
ggml_allocr * alloc = NULL;
if (params.use_alloc) {
static const size_t tensor_alignment = 32;
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
}
GGML_ASSERT(n_tokens < (int) train_tokens.size());
std::vector<int> train_samples;
train_samples.push_back(0);
@ -4437,33 +4443,48 @@ int main(int argc, char ** argv) {
};
struct ggml_context * ctx0 = ggml_init(cparams);
ggml_set_no_alloc(ctx0, false);
// don't use alloc for input tensors, so we can safely fill them with data
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
ggml_set_no_alloc(ctx0, (alloc != NULL));
if (alloc) {
ggml_allocr_reset(alloc);
}
opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_logits = target_logits;
opt_cb_data.target_probs = target_probs;
int n_past = 0;
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
memset(gfbuf->data, 0, ggml_nbytes(gfbuf));
memset(gbbuf->data, 0, ggml_nbytes(gbbuf));
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
struct ggml_cgraph * gb_tmp = (params.use_unified || params.use_alloc)
? ggml_new_graph(ctx0)
: NULL;
GGML_ASSERT(n_past == 0);
struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits = NULL;
if (params.use_checkpointing) {
if (params.use_alloc || params.use_unified) {
loss = llama_build_train_graphs(
&model, alloc, ctx0,
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.use_flash,
params.use_checkpointing
);
} else if (params.use_checkpointing) {
loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
&model, ctx0,
gf, gb,
@ -4641,6 +4662,10 @@ int main(int argc, char ** argv) {
}
}
if (alloc) {
ggml_allocr_free(alloc);
}
delete[] compute_addr;
delete[] compute_buf_0;
delete[] compute_buf_1;