68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import argparse
|
|
from typing import Any
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
from llms import (
|
|
generate_from_huggingface_completion,
|
|
generate_from_openai_chat_completion,
|
|
generate_from_openai_completion,
|
|
lm_config,
|
|
)
|
|
|
|
APIInput = str | list[Any] | dict[str, Any]
|
|
|
|
model = None
|
|
tokenizer = None
|
|
|
|
def call_llm(
|
|
lm_config: lm_config.LMConfig,
|
|
prompt: APIInput,
|
|
) -> str:
|
|
global model
|
|
global tokenizer
|
|
|
|
response: str
|
|
|
|
if lm_config.provider == "openai":
|
|
if lm_config.mode == "chat":
|
|
assert isinstance(prompt, list)
|
|
response = generate_from_openai_chat_completion(
|
|
messages=prompt,
|
|
model=lm_config.model,
|
|
temperature=lm_config.gen_config["temperature"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
context_length=lm_config.gen_config["context_length"],
|
|
max_tokens=lm_config.gen_config["max_tokens"],
|
|
stop_token=None,
|
|
)
|
|
elif lm_config.mode == "completion":
|
|
assert isinstance(prompt, str)
|
|
response = generate_from_openai_completion(
|
|
prompt=prompt,
|
|
engine=lm_config.model,
|
|
temperature=lm_config.gen_config["temperature"],
|
|
max_tokens=lm_config.gen_config["max_tokens"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
stop_token=lm_config.gen_config["stop_token"],
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"OpenAI models do not support mode {lm_config.mode}"
|
|
)
|
|
elif lm_config.provider == "huggingface":
|
|
assert isinstance(prompt, str)
|
|
response = generate_from_huggingface_completion(
|
|
prompt=prompt,
|
|
model_endpoint=lm_config.gen_config["model_endpoint"],
|
|
temperature=lm_config.gen_config["temperature"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
stop_sequences=lm_config.gen_config["stop_sequences"],
|
|
max_new_tokens=lm_config.gen_config["max_new_tokens"],
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Provider {lm_config.provider} not implemented"
|
|
)
|
|
|
|
return response
|