AgentOccam/llms/utils.py
2025-01-22 11:32:35 -08:00

68 lines
2.2 KiB
Python

import argparse
from typing import Any
from transformers import AutoTokenizer, AutoModel
from llms import (
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
lm_config,
)
APIInput = str | list[Any] | dict[str, Any]
model = None
tokenizer = None
def call_llm(
lm_config: lm_config.LMConfig,
prompt: APIInput,
) -> str:
global model
global tokenizer
response: str
if lm_config.provider == "openai":
if lm_config.mode == "chat":
assert isinstance(prompt, list)
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
assert isinstance(prompt, str)
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
)
elif lm_config.provider == "huggingface":
assert isinstance(prompt, str)
response = generate_from_huggingface_completion(
prompt=prompt,
model_endpoint=lm_config.gen_config["model_endpoint"],
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
stop_sequences=lm_config.gen_config["stop_sequences"],
max_new_tokens=lm_config.gen_config["max_new_tokens"],
)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)
return response