diff --git a/examples/agent/run.py b/examples/agent/run.py index 3330f1b7a..bc47a8756 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -63,7 +63,10 @@ async def main( system: Optional[str] = None, verbose: bool = False, cache_prompt: bool = True, - temperature: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + greedy: bool = False, seed: Optional[int] = None, interactive: bool = True, provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp', @@ -80,6 +83,14 @@ async def main( api_key = os.environ.get(provider_info['api_key_env']) tool_map, tools = await discover_tools(tool_endpoints or [], verbose) + + if greedy: + if temperature is None: + temperature = 0.0 + if top_k is None: + top_k = 1 + if top_p is None: + top_p = 0.0 if think: tools.append({ @@ -129,6 +140,8 @@ async def main( model=model, tools=tools, temperature=temperature, + top_p=top_p, + top_k=top_k, seed=seed, ) if provider == 'llama.cpp':