webrl/VAB-WebArena-Lite/new/lm_config.py
2024-10-20 00:10:34 +08:00

58 lines
2.1 KiB
Python

"""Config for language models."""
from __future__ import annotations
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class LMConfig:
"""A config for a language model.
Attributes:
provider: The name of the API provider.
model: The name of the model.
model_cls: The Python class corresponding to the model, mostly for
Hugging Face transformers.
tokenizer_cls: The Python class corresponding to the tokenizer, mostly
for Hugging Face transformers.
mode: The mode of the API calls, e.g., "chat" or "generation".
"""
provider: str
model: str
model_cls: type | None = None
tokenizer_cls: type | None = None
mode: str | None = None
gen_config: dict[str, Any] = dataclasses.field(default_factory=dict)
def construct_llm_config(args: argparse.Namespace) -> LMConfig:
llm_config = LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider in ["openai", "google", "api", "finetune"]:
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
llm_config.gen_config["max_retry"] = args.max_retry
elif args.provider == "huggingface":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["max_new_tokens"] = args.max_tokens
llm_config.gen_config["stop_sequences"] = (
[args.stop_token] if args.stop_token else None
)
llm_config.gen_config["max_obs_length"] = args.max_obs_length
llm_config.gen_config["model_endpoint"] = args.model_endpoint
llm_config.gen_config["max_retry"] = args.max_retry
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config