remove unused compute buffer 3
This commit is contained in:
parent
6e3f95bf06
commit
e05e4414ac
1 changed files with 27 additions and 47 deletions
|
@ -1932,11 +1932,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
void * compute_buf_0,
|
void * compute_buf_0,
|
||||||
void * compute_buf_1,
|
void * compute_buf_1,
|
||||||
void * compute_buf_2,
|
void * compute_buf_2,
|
||||||
void * compute_buf_3,
|
|
||||||
size_t size_buf_0,
|
size_t size_buf_0,
|
||||||
size_t size_buf_1,
|
size_t size_buf_1,
|
||||||
size_t size_buf_2,
|
size_t size_buf_2,
|
||||||
size_t size_buf_3,
|
|
||||||
const int n_tokens,
|
const int n_tokens,
|
||||||
const int n_batch) {
|
const int n_batch) {
|
||||||
|
|
||||||
|
@ -1966,16 +1964,14 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
bool track_max_mem = true;
|
bool track_max_mem = true;
|
||||||
|
|
||||||
int last_buf = -1;
|
int last_buf = -1;
|
||||||
size_t buf_offs[4] = { 0, 0, 0, 0 };
|
size_t buf_offs[3] = { 0, 0, 0 };
|
||||||
size_t buf_size[4] = { size_buf_0,
|
size_t buf_size[3] = { size_buf_0,
|
||||||
size_buf_1,
|
size_buf_1,
|
||||||
size_buf_2,
|
size_buf_2 };
|
||||||
size_buf_3 };
|
void * buf_data[3] = { compute_buf_0,
|
||||||
void * buf_data[4] = { compute_buf_0,
|
|
||||||
compute_buf_1,
|
compute_buf_1,
|
||||||
compute_buf_2,
|
compute_buf_2 };
|
||||||
compute_buf_3 };
|
size_t buf_maxs[3] = { 0, 0, 0 };
|
||||||
size_t buf_maxs[4] = { 0, 0, 0, 0 };
|
|
||||||
|
|
||||||
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
|
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
|
||||||
size_t last_offs = 0;
|
size_t last_offs = 0;
|
||||||
|
@ -2083,7 +2079,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
clr_buf(1);
|
clr_buf(1);
|
||||||
clr_buf(2);
|
clr_buf(2);
|
||||||
clr_buf(3);
|
|
||||||
|
|
||||||
use_buf(-1);
|
use_buf(-1);
|
||||||
|
|
||||||
|
@ -2112,22 +2107,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
|
|
||||||
// example for 16 layers:
|
// example for 16 layers:
|
||||||
// inp ~ implicit zeroth checkpoint == input
|
// inp ~ implicit zeroth checkpoint == input
|
||||||
// L00 f 4b
|
// L00 f 4b [
|
||||||
// L01 f 4b
|
// L01 f 4b 4th second forward pass
|
||||||
// L02 f 4b
|
// L02 f 4b
|
||||||
// L03 fc4b first checkpoint
|
// L03 fc4b ] first checkpoint
|
||||||
// L04 f 3b
|
// L04 f 3b [
|
||||||
// L05 f 3b
|
// L05 f 3b 3rd second forward pass
|
||||||
// L06 f 3b
|
// L06 f 3b
|
||||||
// L07 fc3b second checkpoint
|
// L07 fc3b ] second checkpoint
|
||||||
// L08 f 2b
|
// L08 f 2b [
|
||||||
// L09 f 2b
|
// L09 f 2b 2nd second forward pass
|
||||||
// L10 f 2b
|
// L10 f 2b
|
||||||
// L11 fc2b third checkpoint
|
// L11 fc2b ] third checkpoint
|
||||||
// L12 f 1b
|
// L12 f 1b [
|
||||||
// L13 f 1b
|
// L13 f 1b 1st second forward pass
|
||||||
// L14 f 1b
|
// L14 f 1b
|
||||||
// L15 f 1b
|
// L15 f 1b ]
|
||||||
|
|
||||||
// need to remember these for the backward pass
|
// need to remember these for the backward pass
|
||||||
std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
|
std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
|
||||||
|
@ -2162,7 +2157,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
|
|
||||||
struct ggml_tensor * cur = t01;
|
struct ggml_tensor * cur = t01;
|
||||||
|
|
||||||
|
|
||||||
int chk_idx = 0;
|
int chk_idx = 0;
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct my_llama_layer & layer = model->layers[il];
|
struct my_llama_layer & layer = model->layers[il];
|
||||||
|
@ -2455,13 +2449,11 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
clr_buf(1);
|
clr_buf(1);
|
||||||
clr_buf(2);
|
clr_buf(2);
|
||||||
clr_buf(3);
|
|
||||||
|
|
||||||
if (track_max_mem) {
|
if (track_max_mem) {
|
||||||
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
||||||
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
||||||
printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]);
|
printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]);
|
||||||
printf("%s: max size compute buf3: %zu\n", __func__, buf_maxs[3]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// now that all grads are created, set the graph leafs and grads
|
// now that all grads are created, set the graph leafs and grads
|
||||||
|
@ -3434,7 +3426,6 @@ struct train_params get_default_train_params() {
|
||||||
params.mem_compute0_gb = 8;
|
params.mem_compute0_gb = 8;
|
||||||
params.mem_compute1_gb = 1;
|
params.mem_compute1_gb = 1;
|
||||||
params.mem_compute2_gb = 2;
|
params.mem_compute2_gb = 2;
|
||||||
params.mem_compute3_gb = 1;
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3486,7 +3477,6 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||||
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
|
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
|
||||||
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
|
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
|
||||||
fprintf(stderr, " --mem-compute3 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute3_gb);
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3724,12 +3714,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->mem_compute2_gb = std::stoi(argv[i]);
|
params->mem_compute2_gb = std::stoi(argv[i]);
|
||||||
} else if (arg == "--mem-compute3") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params->mem_compute3_gb = std::stoi(argv[i]);
|
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
exit(0);
|
exit(0);
|
||||||
|
@ -3892,11 +3876,9 @@ int main(int argc, char ** argv) {
|
||||||
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
||||||
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
|
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
|
||||||
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
|
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
|
||||||
size_t size_buf_3 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute3_gb);
|
|
||||||
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
||||||
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
||||||
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
||||||
uint8_t * compute_buf_3 = new uint8_t[size_buf_3];
|
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
||||||
std::vector<int> train_samples;
|
std::vector<int> train_samples;
|
||||||
|
@ -3924,9 +3906,9 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params cparams = {
|
struct ggml_init_params cparams = {
|
||||||
/*.mem_size =*/ compute_size,
|
compute_size, // mem_size
|
||||||
/*.mem_buffer =*/ compute_addr,
|
compute_addr, // mem_buffer
|
||||||
/*.no_alloc =*/ false,
|
false, // no_alloc
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||||
|
|
||||||
|
@ -3960,8 +3942,8 @@ int main(int argc, char ** argv) {
|
||||||
&model, ctx0,
|
&model, ctx0,
|
||||||
gf, gb,
|
gf, gb,
|
||||||
&logits, tokens_input, target_probs,
|
&logits, tokens_input, target_probs,
|
||||||
compute_buf_0, compute_buf_1, compute_buf_2, compute_buf_3,
|
compute_buf_0, compute_buf_1, compute_buf_2,
|
||||||
size_buf_0, size_buf_1, size_buf_2, size_buf_3,
|
size_buf_0, size_buf_1, size_buf_2,
|
||||||
n_tokens, n_batch);
|
n_tokens, n_batch);
|
||||||
} else if (params.use_scratch) {
|
} else if (params.use_scratch) {
|
||||||
loss = forward_batch_wo_cache_flash_attn_train(
|
loss = forward_batch_wo_cache_flash_attn_train(
|
||||||
|
@ -4082,9 +4064,9 @@ int main(int argc, char ** argv) {
|
||||||
printf("---\n");
|
printf("---\n");
|
||||||
for (int i=0; i<n_gen; ++i) {
|
for (int i=0; i<n_gen; ++i) {
|
||||||
struct ggml_init_params cparams = {
|
struct ggml_init_params cparams = {
|
||||||
/*.mem_size =*/ compute_size,
|
compute_size, // .mem_size
|
||||||
/*.mem_buffer =*/ compute_addr,
|
compute_addr, // .mem_buffer
|
||||||
/*.no_alloc =*/ false,
|
false, // .no_alloc
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||||
|
|
||||||
|
@ -4120,10 +4102,8 @@ int main(int argc, char ** argv) {
|
||||||
delete[] compute_addr;
|
delete[] compute_addr;
|
||||||
delete[] compute_buf_0;
|
delete[] compute_buf_0;
|
||||||
delete[] compute_buf_1;
|
delete[] compute_buf_1;
|
||||||
|
ggml_free(model.ctx);
|
||||||
llama_free(lctx);
|
llama_free(lctx);
|
||||||
llama_free_model(lmodel);
|
llama_free_model(lmodel);
|
||||||
ggml_free(model.ctx);
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue