88 lines
3.2 KiB
Python
88 lines
3.2 KiB
Python
import argparse
|
|
from typing import Any
|
|
|
|
try:
|
|
from vertexai.preview.generative_models import Image
|
|
from llms import generate_from_gemini_completion
|
|
except:
|
|
print('Google Cloud not set up, skipping import of vertexai.preview.generative_models.Image and llms.generate_from_gemini_completion')
|
|
|
|
from llms import (
|
|
generate_from_huggingface_completion,
|
|
generate_from_openai_chat_completion,
|
|
generate_from_openai_completion,
|
|
generate_with_api,
|
|
lm_config,
|
|
)
|
|
|
|
APIInput = str | list[Any] | dict[str, Any]
|
|
|
|
|
|
def call_llm(
|
|
lm_config: lm_config.LMConfig,
|
|
prompt: APIInput,
|
|
) -> str:
|
|
response: str
|
|
if lm_config.provider == "openai":
|
|
if lm_config.mode == "chat":
|
|
assert isinstance(prompt, list)
|
|
response = generate_from_openai_chat_completion(
|
|
messages=prompt,
|
|
model=lm_config.model,
|
|
temperature=lm_config.gen_config["temperature"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
context_length=lm_config.gen_config["context_length"],
|
|
max_tokens=lm_config.gen_config["max_tokens"],
|
|
stop_token=None,
|
|
)
|
|
elif lm_config.mode == "completion":
|
|
assert isinstance(prompt, str)
|
|
response = generate_from_openai_completion(
|
|
prompt=prompt,
|
|
engine=lm_config.model,
|
|
temperature=lm_config.gen_config["temperature"],
|
|
max_tokens=lm_config.gen_config["max_tokens"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
stop_token=lm_config.gen_config["stop_token"],
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"OpenAI models do not support mode {lm_config.mode}"
|
|
)
|
|
elif lm_config.provider == "huggingface":
|
|
assert isinstance(prompt, str)
|
|
response = generate_from_huggingface_completion(
|
|
prompt=prompt,
|
|
model_endpoint=lm_config.gen_config["model_endpoint"],
|
|
temperature=lm_config.gen_config["temperature"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
stop_sequences=lm_config.gen_config["stop_sequences"],
|
|
max_new_tokens=lm_config.gen_config["max_new_tokens"],
|
|
)
|
|
elif lm_config.provider == "google":
|
|
assert isinstance(prompt, list)
|
|
assert all(
|
|
[isinstance(p, str) or isinstance(p, Image) for p in prompt]
|
|
)
|
|
response = generate_from_gemini_completion(
|
|
prompt=prompt,
|
|
engine=lm_config.model,
|
|
temperature=lm_config.gen_config["temperature"],
|
|
max_tokens=lm_config.gen_config["max_tokens"],
|
|
top_p=lm_config.gen_config["top_p"],
|
|
)
|
|
elif lm_config.provider in ["api", "finetune"]:
|
|
args = {
|
|
"temperature": lm_config.gen_config["temperature"], # openai, gemini, claude
|
|
"max_tokens": lm_config.gen_config["max_tokens"], # openai, gemini, claude
|
|
"top_k": lm_config.gen_config["top_p"], # qwen
|
|
}
|
|
response = generate_with_api(prompt, lm_config.model, args)
|
|
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Provider {lm_config.provider} not implemented"
|
|
)
|
|
|
|
return response
|