Merge branch 'master' into concedo_experimental

# Conflicts:
#	.devops/tools.sh
#	.gitignore
#	CMakeLists.txt
#	Makefile
#	README.md
This commit is contained in:
Concedo 2023-12-01 23:46:59 +08:00
commit 4f40c226a0
9 changed files with 224 additions and 65 deletions

View file

@ -11,8 +11,13 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git")
if(NOT IS_DIRECTORY "${GIT_DIR}") if(NOT IS_DIRECTORY "${GIT_DIR}")
file(READ ${GIT_DIR} REAL_GIT_DIR_LINK) file(READ ${GIT_DIR} REAL_GIT_DIR_LINK)
string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK}) string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK})
string(FIND "${REAL_GIT_DIR}" "/" SLASH_POS)
if (SLASH_POS EQUAL 0)
set(GIT_DIR "${REAL_GIT_DIR}")
else()
set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}") set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}")
endif() endif()
endif()
set(GIT_INDEX "${GIT_DIR}/index") set(GIT_INDEX "${GIT_DIR}/index")
else() else()

View file

@ -267,7 +267,7 @@ class Params:
n_ctx = 2048 n_ctx = 2048
return Params( return Params(
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"], n_embd = config["dim"],
n_layer = config["n_layers"], n_layer = config["n_layers"],
n_ctx = n_ctx, n_ctx = n_ctx,

View file

@ -1,4 +1,4 @@
This is a swift clone of `examples/batched`. This is a swift clone of `examples/batched`.
$ `make` $ `make`
$ `./swift MODEL_PATH [PROMPT] [PARALLEL]` $ `./batched_swift MODEL_PATH [PROMPT] [PARALLEL]`

View file

@ -5,7 +5,7 @@ import json
import torch import torch
import numpy as np import numpy as np
from gguf import * from gguf import *
from transformers import CLIPModel, CLIPProcessor from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
TEXT = "clip.text" TEXT = "clip.text"
VISION = "clip.vision" VISION = "clip.vision"
@ -78,11 +78,19 @@ ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images") help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False, ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts") help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip_model_is_vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values") ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values") ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image_mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image_std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
# with proper
args = ap.parse_args() args = ap.parse_args()
@ -96,13 +104,20 @@ if args.use_f32:
# output in the same directory as the model if output_dir is None # output in the same directory as the model if output_dir is None
dir_model = args.model_dir dir_model = args.model_dir
if args.clip_model_is_vision:
vocab = None
tokens = None
else:
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
vocab = json.load(f) vocab = json.load(f)
tokens = [key for key in vocab] tokens = [key for key in vocab]
with open(dir_model + "/config.json", "r", encoding="utf-8") as f: with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
if args.clip_model_is_vision:
v_hparams = config
t_hparams = None
else:
v_hparams = config["vision_config"] v_hparams = config["vision_config"]
t_hparams = config["text_config"] t_hparams = config["text_config"]
@ -117,7 +132,10 @@ ftype = 1
if args.use_f32: if args.use_f32:
ftype = 0 ftype = 0
if args.clip_model_is_vision:
model = CLIPVisionModel.from_pretrained(dir_model)
processor = None
else:
model = CLIPModel.from_pretrained(dir_model) model = CLIPModel.from_pretrained(dir_model)
processor = CLIPProcessor.from_pretrained(dir_model) processor = CLIPProcessor.from_pretrained(dir_model)
@ -128,13 +146,13 @@ has_llava_projector = False
if args.text_only: if args.text_only:
fname_middle = "text-" fname_middle = "text-"
has_vision_encoder = False has_vision_encoder = False
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
elif args.llava_projector is not None: elif args.llava_projector is not None:
fname_middle = "mmproj-" fname_middle = "mmproj-"
has_text_encoder = False has_text_encoder = False
has_llava_projector = True has_llava_projector = True
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
else: else:
fname_middle = "" fname_middle = ""
@ -182,8 +200,12 @@ if has_vision_encoder:
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
image_mean = processor.image_processor.image_mean if args.image_mean is None else args.image_mean if processor is not None:
image_std = processor.image_processor.image_std if args.image_std is None else args.image_std image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
else:
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
image_std = args.image_std if args.image_std is not None else default_image_std
fout.add_array("clip.vision.image_mean", image_mean) fout.add_array("clip.vision.image_mean", image_mean)
fout.add_array("clip.vision.image_std", image_std) fout.add_array("clip.vision.image_std", image_std)

View file

@ -101,6 +101,12 @@ static void sigint_handler(int signo) {
} }
#endif #endif
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
LOG_TEE("%s", text);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
g_params = &params; g_params = &params;
@ -114,6 +120,7 @@ int main(int argc, char ** argv) {
log_set_target(log_filename_generator("main", "log")); log_set_target(log_filename_generator("main", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv); log_dump_cmdline(argc, argv);
llama_log_set(llama_log_callback_logTee, nullptr);
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
// TODO: Dump params ? // TODO: Dump params ?

View file

@ -11,10 +11,10 @@ app = Flask(__name__)
slot_id = -1 slot_id = -1
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
@ -34,19 +34,19 @@ def is_present(json, key):
#convert chat to prompt #convert chat to prompt
def convert_chat(messages): def convert_chat(messages):
prompt = "" + args.chat_prompt.replace("\\n", "\n")
system_n = args.system_name.replace("\\n", "\n") system_n = args.system_name
user_n = args.user_name.replace("\\n", "\n") user_n = args.user_name
ai_n = args.ai_name.replace("\\n", "\n") ai_n = args.ai_name
stop = args.stop.replace("\\n", "\n") stop = args.stop
prompt = "" + args.chat_prompt + stop
for line in messages: for line in messages:
if (line["role"] == "system"): if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}" prompt += f"{system_n}{line['content']}{stop}"
if (line["role"] == "user"): if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}" prompt += f"{user_n}{line['content']}{stop}"
if (line["role"] == "assistant"): if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}" prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip() prompt += ai_n.rstrip()
@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
} }
] ]
} }
slot_id = data["slot_id"] slot_id = data.get("slot_id")
if (chat): if (chat):
if (start): if (start):
resData["choices"][0]["delta"] = { resData["choices"][0]["delta"] = {
@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
return resData return resData
@app.route('/chat/completions', methods=['POST']) @app.route('/chat/completions', methods=['POST', 'OPTIONS'])
@app.route('/v1/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
def chat_completions(): def chat_completions():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403) return Response(status=403)
if request.method == 'OPTIONS':
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
body = request.get_json() body = request.get_json()
stream = False stream = False
tokenize = False tokenize = False
@ -177,20 +179,22 @@ def chat_completions():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time()) time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
for line in data.iter_lines(): for line in data.iter_lines():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
@app.route('/completions', methods=['POST']) @app.route('/completions', methods=['POST', 'OPTIONS'])
@app.route('/v1/completions', methods=['POST']) @app.route('/v1/completions', methods=['POST', 'OPTIONS'])
def completion(): def completion():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403) return Response(status=403)
if request.method == 'OPTIONS':
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
body = request.get_json() body = request.get_json()
stream = False stream = False
tokenize = False tokenize = False
@ -216,8 +220,8 @@ def completion():
if line: if line:
decoded_line = line.decode('utf-8') decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData)) yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
if __name__ == '__main__': if __name__ == '__main__':
app.run(args.host, port=args.port) app.run(args.host, port=args.port)

View file

@ -156,15 +156,23 @@ struct task_server {
json data; json data;
bool infill_mode = false; bool infill_mode = false;
bool embedding_mode = false; bool embedding_mode = false;
int multitask_id = -1;
}; };
struct task_result { struct task_result {
int id; int id;
int multitask_id = -1;
bool stop; bool stop;
bool error; bool error;
json result_json; json result_json;
}; };
struct task_multi {
int id;
std::set<int> subtasks_remaining{};
std::vector<task_result> results{};
};
// TODO: can become bool if we can't find use of more states // TODO: can become bool if we can't find use of more states
enum slot_state enum slot_state
{ {
@ -407,6 +415,9 @@ struct llama_client_slot
double t_prompt_processing; // ms double t_prompt_processing; // ms
double t_token_generation; // ms double t_token_generation; // ms
// multitasks
int multitask_id = -1;
void reset() { void reset() {
num_prompt_tokens = 0; num_prompt_tokens = 0;
generated_text = ""; generated_text = "";
@ -530,7 +541,8 @@ struct llama_server_context
std::vector<task_server> queue_tasks; std::vector<task_server> queue_tasks;
std::vector<task_result> queue_results; std::vector<task_result> queue_results;
std::mutex mutex_tasks; std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::mutex mutex_results; std::mutex mutex_results;
~llama_server_context() ~llama_server_context()
@ -1113,17 +1125,40 @@ struct llama_server_context
return slot.images.size() > 0; return slot.images.size() > 0;
} }
void send_error(int id, std::string error) void send_error(task_server& task, std::string error)
{ {
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = id; res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false; res.stop = false;
res.error = true; res.error = true;
res.result_json = { { "content", error } }; res.result_json = { { "content", error } };
queue_results.push_back(res); queue_results.push_back(res);
} }
void add_multi_task(int id, std::vector<int>& sub_ids)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_multi multi;
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
}
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
for (auto& multitask : queue_multitasks)
{
if (multitask.id == multitask_id)
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
}
}
}
json get_model_props() json get_model_props()
{ {
return get_formated_generation(slots[0]); return get_formated_generation(slots[0]);
@ -1168,6 +1203,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = false; res.stop = false;
@ -1207,6 +1243,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
@ -1252,6 +1289,12 @@ struct llama_server_context
res.result_json["model"] = slot.oaicompat_model; res.result_json["model"] = slot.oaicompat_model;
} }
// parent multitask, if any, needs to be updated
if (slot.multitask_id != -1)
{
update_multi_task(slot.multitask_id, slot.task_id, res);
}
queue_results.push_back(res); queue_results.push_back(res);
} }
@ -1260,6 +1303,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
@ -1286,9 +1330,9 @@ struct llama_server_context
queue_results.push_back(res); queue_results.push_back(res);
} }
int request_completion(json data, bool infill, bool embedding) int request_completion(json data, bool infill, bool embedding, int multitask_id)
{ {
std::lock_guard<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
task.target_id = 0; task.target_id = 0;
@ -1296,6 +1340,16 @@ struct llama_server_context
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = COMPLETION_TASK; task.type = COMPLETION_TASK;
task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.at("prompt").size() > 1)
{
lock.unlock(); // entering new func scope
return split_multiprompt_task(task);
}
// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task); queue_tasks.push_back(task);
return task.id; return task.id;
} }
@ -1314,8 +1368,17 @@ struct llama_server_context
for (int i = 0; i < (int) queue_results.size(); i++) for (int i = 0; i < (int) queue_results.size(); i++)
{ {
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (queue_results[i].multitask_id == task_id)
{
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
queue_results.erase(queue_results.begin() + i);
continue;
}
if (queue_results[i].id == task_id) if (queue_results[i].id == task_id)
{ {
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i]; task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i); queue_results.erase(queue_results.begin() + i);
return res; return res;
@ -1405,6 +1468,27 @@ struct llama_server_context
queue_tasks.push_back(task); queue_tasks.push_back(task);
} }
int split_multiprompt_task(task_server& multiprompt_task)
{
auto prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
int multitask_id = id_gen++;
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}
// queue up the multitask so we can track its subtask progression
add_multi_task(multitask_id, subtask_ids);
return multitask_id;
}
void process_tasks() void process_tasks()
{ {
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
@ -1420,7 +1504,7 @@ struct llama_server_context
{ {
LOG_TEE("slot unavailable\n"); LOG_TEE("slot unavailable\n");
// send error result // send error result
send_error(task.id, "slot unavailable"); send_error(task, "slot unavailable");
return; return;
} }
@ -1434,11 +1518,12 @@ struct llama_server_context
slot->infill = task.infill_mode; slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data)) if (!launch_slot_with_data(slot, task.data))
{ {
// send error result // send error result
send_error(task.id, "internal_error"); send_error(task, "internal_error");
break; break;
} }
} break; } break;
@ -1454,6 +1539,38 @@ struct llama_server_context
} break; } break;
} }
} }
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end())
{
if (queue_iterator->subtasks_remaining.empty())
{
// all subtasks done == multitask is done
task_result aggregate_result;
aggregate_result.id = queue_iterator->id;
aggregate_result.stop = true;
aggregate_result.error = false;
// collect json results into one json result
std::vector<json> result_jsons;
for (auto& subres : queue_iterator->results)
{
result_jsons.push_back(subres.result_json);
aggregate_result.error = aggregate_result.error && subres.error;
}
aggregate_result.result_json = json{ "results", result_jsons };
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);
queue_iterator = queue_multitasks.erase(queue_iterator);
}
else
{
++queue_iterator;
}
}
} }
bool update_slots() { bool update_slots() {
@ -1845,6 +1962,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
printf(" --log-disable disables logging to a file.\n");
printf("\n"); printf("\n");
} }
@ -2199,6 +2317,11 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.mmproj = argv[i]; params.mmproj = argv[i];
} }
else if (arg == "--log-disable")
{
log_set_target(stdout);
LOG_INFO("logging to file is disabled.", {});
}
else else
{ {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@ -2597,7 +2720,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false); const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2686,7 +2809,7 @@ int main(int argc, char **argv)
{ {
json data = oaicompat_completion_params_parse(json::parse(req.body)); json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false); const int task_id = llama.request_completion(data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
@ -2755,7 +2878,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false); const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2859,7 +2982,7 @@ int main(int argc, char **argv)
{ {
prompt = ""; prompt = "";
} }
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true); const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json"); return res.set_content(result.result_json.dump(), "application/json");
}); });

View file

@ -1,21 +1,19 @@
#include "ggml.h"
#include "ggml-opencl.h" #include "ggml-opencl.h"
#include <array> #include <array>
#include <atomic> #include <atomic>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <limits>
#define CL_TARGET_OPENCL_VERSION 110 #define CL_TARGET_OPENCL_VERSION 110
#include <clblast.h> #include <clblast.h>
#include <clblast_c.h> #include <clblast_c.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "ggml.h"
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif

View file

@ -47,7 +47,6 @@
#endif #endif
#include <windows.h> #include <windows.h>
#include <io.h> #include <io.h>
#include <stdio.h> // for _fseeki64
#endif #endif
#include <algorithm> #include <algorithm>
@ -7290,6 +7289,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
// Replace the data in candidates with the new_candidates data // Replace the data in candidates with the new_candidates data
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
candidates->size = new_candidates.size(); candidates->size = new_candidates.size();
candidates->sorted = false;
if (ctx) { if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;