agent: add --greedy, --top-p, --top-k options

This commit is contained in:
ochafik 2025-01-19 02:07:06 +00:00
parent c207fdcde6
commit 0401a83b9b

View file

@ -63,7 +63,10 @@ async def main(
system: Optional[str] = None, system: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
cache_prompt: bool = True, cache_prompt: bool = True,
temperature: Optional[int] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
greedy: bool = False,
seed: Optional[int] = None, seed: Optional[int] = None,
interactive: bool = True, interactive: bool = True,
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp', provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
@ -80,6 +83,14 @@ async def main(
api_key = os.environ.get(provider_info['api_key_env']) api_key = os.environ.get(provider_info['api_key_env'])
tool_map, tools = await discover_tools(tool_endpoints or [], verbose) tool_map, tools = await discover_tools(tool_endpoints or [], verbose)
if greedy:
if temperature is None:
temperature = 0.0
if top_k is None:
top_k = 1
if top_p is None:
top_p = 0.0
if think: if think:
tools.append({ tools.append({
@ -129,6 +140,8 @@ async def main(
model=model, model=model,
tools=tools, tools=tools,
temperature=temperature, temperature=temperature,
top_p=top_p,
top_k=top_k,
seed=seed, seed=seed,
) )
if provider == 'llama.cpp': if provider == 'llama.cpp':