add open LMMs evaluation
This commit is contained in:
parent
8cac24e27d
commit
64b8e24e0a
|
@ -6,26 +6,8 @@ gpt-4o-2024-05-13:
|
|||
model: "gpt-4o-2024-05-13"
|
||||
max_tokens: 512
|
||||
|
||||
gpt-3.5-turbo-0613:
|
||||
import: "./openai-chat.yaml"
|
||||
glm-4v:
|
||||
import: "./finetuned_agent.yaml"
|
||||
parameters:
|
||||
name: "gpt-3.5-turbo-0613"
|
||||
body:
|
||||
model: "gpt-3.5-turbo-0613"
|
||||
max_tokens: 512
|
||||
|
||||
text-davinci-003:
|
||||
import: "./openai-text.yaml"
|
||||
parameters:
|
||||
name: "text-davinci-003"
|
||||
body:
|
||||
model: "text-davinci-003"
|
||||
max_tokens: 512
|
||||
|
||||
text-davinci-002:
|
||||
import: "./openai-text.yaml"
|
||||
parameters:
|
||||
name: "text-davinci-002"
|
||||
body:
|
||||
model: "text-davinci-002"
|
||||
max_tokens: 512
|
||||
name: "glm-4v"
|
||||
url: "http://"
|
||||
|
|
9
configs/agents/finetuned_agent.yaml
Normal file
9
configs/agents/finetuned_agent.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
module: "src.client.agents.HTTPAgent"
|
||||
parameters:
|
||||
prompter:
|
||||
name: "prompt_string"
|
||||
args:
|
||||
suffix: "<|assistant|>\n"
|
||||
agent_format: "<|assistant|>\n{content}\n\n"
|
||||
text_context_limit: 9776
|
||||
return_format: "{response[response]}"
|
|
@ -1,7 +1,7 @@
|
|||
import contextlib
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import math
|
||||
import requests
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
|
||||
|
@ -102,26 +102,78 @@ class Prompter:
|
|||
@staticmethod
|
||||
def prompt_string(
|
||||
prefix: str = "",
|
||||
suffix: str = "AGENT:",
|
||||
system_format: str = "SYSTEM: {content}\n\n###\n\n",
|
||||
user_format: str = "USER: {content}\n\n",
|
||||
agent_format: str = "AGENT: {content}\n\n",
|
||||
suffix: str = "<|agent|>\n",
|
||||
system_format: str = "<|system|>\n{content}\n\n",
|
||||
user_format: str = "<|user|>\n{content}\n\n",
|
||||
agent_format: str = "<|agent|>\n{content}\n\n",
|
||||
prompt_key: str = "prompt",
|
||||
image_key: str = "image",
|
||||
text_context_limit: int = 9776
|
||||
):
|
||||
# todo (YG): current it assumes the content is always a string, but it can be a list for multimodal support
|
||||
def prompter(messages: List[Dict[str, str]]):
|
||||
nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key
|
||||
prompt = prefix
|
||||
for item in messages:
|
||||
def prompter(messages: List[Dict]):
|
||||
nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key, image_key, text_context_limit
|
||||
|
||||
def text_token_estimation(text: str):
|
||||
token_count = 0
|
||||
text = text.replace("\n\n", "\n").replace("```", "`").replace("##", "#")
|
||||
for char in text:
|
||||
if not char.isalnum() and char != " ":
|
||||
token_count += 1
|
||||
words = text.split()
|
||||
for word in words:
|
||||
char_count = 0
|
||||
for char in word:
|
||||
if char.isalnum():
|
||||
char_count += 1
|
||||
token_count += math.ceil(char_count / 6)
|
||||
return token_count
|
||||
|
||||
def item_to_prompt(item: Dict):
|
||||
prompt, images = "", []
|
||||
if item["role"] == "system":
|
||||
prompt += system_format.format(content=item["content"])
|
||||
elif item["role"] == "user":
|
||||
prompt += user_format.format(content=item["content"])
|
||||
if isinstance(item["content"], str):
|
||||
prompt += user_format.format(content=item["content"])
|
||||
elif isinstance(item["content"], list):
|
||||
text_str = ""
|
||||
for content in item["content"]:
|
||||
if content["type"] == "text":
|
||||
text_str += content["text"]
|
||||
else:
|
||||
images.append(content["image_url"]["url"].split("file://")[-1])
|
||||
prompt += user_format.format(content=text_str)
|
||||
else:
|
||||
prompt += agent_format.format(content=item["content"])
|
||||
return prompt, images
|
||||
|
||||
prompt = prefix
|
||||
images = []
|
||||
for item in messages[:5]:
|
||||
item_prompt, item_images = item_to_prompt(item)
|
||||
prompt += item_prompt
|
||||
images += item_images
|
||||
if len(messages) > 5:
|
||||
prompt_tail, images_tail = item_to_prompt(messages[-1])
|
||||
for index in range(len(messages)-2, 4, -2):
|
||||
agent_item = messages[index]
|
||||
agent_prompt, agent_images = item_to_prompt(agent_item)
|
||||
user_item = messages[index-1]
|
||||
user_prompt, user_images = item_to_prompt(user_item)
|
||||
if text_token_estimation(prompt + user_prompt + agent_prompt + prompt_tail) > text_context_limit:
|
||||
prompt_tail = "\n\n** Earlier trajectory has been truncated **\n\n" + prompt_tail
|
||||
break
|
||||
prompt_tail = user_prompt + agent_prompt + prompt_tail
|
||||
images_tail = user_images + agent_images + images_tail
|
||||
prompt += prompt_tail
|
||||
images += images_tail
|
||||
prompt += suffix
|
||||
print(prompt)
|
||||
return {prompt_key: prompt}
|
||||
if len(images) != 1:
|
||||
raise Exception("Only one image is supported")
|
||||
return [
|
||||
{prompt_key: prompt},
|
||||
{image_key: open(images[0], "rb")}
|
||||
]
|
||||
|
||||
return prompter
|
||||
|
||||
|
@ -175,6 +227,7 @@ class HTTPAgent(AgentClient):
|
|||
self,
|
||||
url,
|
||||
proxies=None,
|
||||
data=None,
|
||||
body=None,
|
||||
headers=None,
|
||||
return_format="{response}",
|
||||
|
@ -185,9 +238,11 @@ class HTTPAgent(AgentClient):
|
|||
self.url = url
|
||||
self.proxies = proxies or {}
|
||||
self.headers = headers or {}
|
||||
self.data = data or {}
|
||||
self.body = body or {}
|
||||
self.return_format = return_format
|
||||
self.prompter = Prompter.get_prompter(prompter)
|
||||
self.prompter_type = prompter.get("name", "role_content_dict")
|
||||
if not self.url:
|
||||
raise Exception("Please set 'url' parameter")
|
||||
|
||||
|
@ -195,15 +250,27 @@ class HTTPAgent(AgentClient):
|
|||
return self.prompter(history)
|
||||
|
||||
def inference(self, history: List[dict]) -> str:
|
||||
history = replace_image_url(history, keep_path=False, throw_details=False)
|
||||
if self.prompter_type == "role_content_dict":
|
||||
history = replace_image_url(history, keep_path=False, throw_details=False)
|
||||
else:
|
||||
history = replace_image_url(history, keep_path=True, throw_details=True)
|
||||
for _ in range(5):
|
||||
try:
|
||||
body = self.body.copy()
|
||||
body.update(self._handle_history(history))
|
||||
with no_ssl_verification():
|
||||
resp = requests.post(
|
||||
self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=120
|
||||
)
|
||||
if self.prompter_type == "role_content_dict":
|
||||
body = self.body.copy()
|
||||
body.update(self._handle_history(history))
|
||||
with no_ssl_verification():
|
||||
resp = requests.post(
|
||||
self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=180
|
||||
)
|
||||
else:
|
||||
messages = self._handle_history(history)
|
||||
data = self.data.copy()
|
||||
data.update(messages[0])
|
||||
with no_ssl_verification():
|
||||
resp = requests.post(
|
||||
self.url, data=data, files=messages[1], headers=self.headers, proxies=self.proxies, timeout=180
|
||||
)
|
||||
# print(resp.status_code, resp.text)
|
||||
if resp.status_code != 200:
|
||||
# print(resp.text)
|
||||
|
|
Loading…
Reference in New Issue
Block a user