add open LMMs evaluation

This commit is contained in:
mistyreed63849 2024-08-28 09:47:21 +08:00
parent 8cac24e27d
commit 64b8e24e0a
3 changed files with 100 additions and 42 deletions

View File

@ -6,26 +6,8 @@ gpt-4o-2024-05-13:
model: "gpt-4o-2024-05-13" model: "gpt-4o-2024-05-13"
max_tokens: 512 max_tokens: 512
gpt-3.5-turbo-0613: glm-4v:
import: "./openai-chat.yaml" import: "./finetuned_agent.yaml"
parameters: parameters:
name: "gpt-3.5-turbo-0613" name: "glm-4v"
body: url: "http://"
model: "gpt-3.5-turbo-0613"
max_tokens: 512
text-davinci-003:
import: "./openai-text.yaml"
parameters:
name: "text-davinci-003"
body:
model: "text-davinci-003"
max_tokens: 512
text-davinci-002:
import: "./openai-text.yaml"
parameters:
name: "text-davinci-002"
body:
model: "text-davinci-002"
max_tokens: 512

View File

@ -0,0 +1,9 @@
module: "src.client.agents.HTTPAgent"
parameters:
prompter:
name: "prompt_string"
args:
suffix: "<|assistant|>\n"
agent_format: "<|assistant|>\n{content}\n\n"
text_context_limit: 9776
return_format: "{response[response]}"

View File

@ -1,7 +1,7 @@
import contextlib import contextlib
import time import time
import warnings import warnings
import math
import requests import requests
from urllib3.exceptions import InsecureRequestWarning from urllib3.exceptions import InsecureRequestWarning
@ -102,26 +102,78 @@ class Prompter:
@staticmethod @staticmethod
def prompt_string( def prompt_string(
prefix: str = "", prefix: str = "",
suffix: str = "AGENT:", suffix: str = "<|agent|>\n",
system_format: str = "SYSTEM: {content}\n\n###\n\n", system_format: str = "<|system|>\n{content}\n\n",
user_format: str = "USER: {content}\n\n", user_format: str = "<|user|>\n{content}\n\n",
agent_format: str = "AGENT: {content}\n\n", agent_format: str = "<|agent|>\n{content}\n\n",
prompt_key: str = "prompt", prompt_key: str = "prompt",
image_key: str = "image",
text_context_limit: int = 9776
): ):
# todo (YG): current it assumes the content is always a string, but it can be a list for multimodal support def prompter(messages: List[Dict]):
def prompter(messages: List[Dict[str, str]]): nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key, image_key, text_context_limit
nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key
prompt = prefix def text_token_estimation(text: str):
for item in messages: token_count = 0
text = text.replace("\n\n", "\n").replace("```", "`").replace("##", "#")
for char in text:
if not char.isalnum() and char != " ":
token_count += 1
words = text.split()
for word in words:
char_count = 0
for char in word:
if char.isalnum():
char_count += 1
token_count += math.ceil(char_count / 6)
return token_count
def item_to_prompt(item: Dict):
prompt, images = "", []
if item["role"] == "system": if item["role"] == "system":
prompt += system_format.format(content=item["content"]) prompt += system_format.format(content=item["content"])
elif item["role"] == "user": elif item["role"] == "user":
if isinstance(item["content"], str):
prompt += user_format.format(content=item["content"]) prompt += user_format.format(content=item["content"])
elif isinstance(item["content"], list):
text_str = ""
for content in item["content"]:
if content["type"] == "text":
text_str += content["text"]
else:
images.append(content["image_url"]["url"].split("file://")[-1])
prompt += user_format.format(content=text_str)
else: else:
prompt += agent_format.format(content=item["content"]) prompt += agent_format.format(content=item["content"])
return prompt, images
prompt = prefix
images = []
for item in messages[:5]:
item_prompt, item_images = item_to_prompt(item)
prompt += item_prompt
images += item_images
if len(messages) > 5:
prompt_tail, images_tail = item_to_prompt(messages[-1])
for index in range(len(messages)-2, 4, -2):
agent_item = messages[index]
agent_prompt, agent_images = item_to_prompt(agent_item)
user_item = messages[index-1]
user_prompt, user_images = item_to_prompt(user_item)
if text_token_estimation(prompt + user_prompt + agent_prompt + prompt_tail) > text_context_limit:
prompt_tail = "\n\n** Earlier trajectory has been truncated **\n\n" + prompt_tail
break
prompt_tail = user_prompt + agent_prompt + prompt_tail
images_tail = user_images + agent_images + images_tail
prompt += prompt_tail
images += images_tail
prompt += suffix prompt += suffix
print(prompt) if len(images) != 1:
return {prompt_key: prompt} raise Exception("Only one image is supported")
return [
{prompt_key: prompt},
{image_key: open(images[0], "rb")}
]
return prompter return prompter
@ -175,6 +227,7 @@ class HTTPAgent(AgentClient):
self, self,
url, url,
proxies=None, proxies=None,
data=None,
body=None, body=None,
headers=None, headers=None,
return_format="{response}", return_format="{response}",
@ -185,9 +238,11 @@ class HTTPAgent(AgentClient):
self.url = url self.url = url
self.proxies = proxies or {} self.proxies = proxies or {}
self.headers = headers or {} self.headers = headers or {}
self.data = data or {}
self.body = body or {} self.body = body or {}
self.return_format = return_format self.return_format = return_format
self.prompter = Prompter.get_prompter(prompter) self.prompter = Prompter.get_prompter(prompter)
self.prompter_type = prompter.get("name", "role_content_dict")
if not self.url: if not self.url:
raise Exception("Please set 'url' parameter") raise Exception("Please set 'url' parameter")
@ -195,14 +250,26 @@ class HTTPAgent(AgentClient):
return self.prompter(history) return self.prompter(history)
def inference(self, history: List[dict]) -> str: def inference(self, history: List[dict]) -> str:
if self.prompter_type == "role_content_dict":
history = replace_image_url(history, keep_path=False, throw_details=False) history = replace_image_url(history, keep_path=False, throw_details=False)
else:
history = replace_image_url(history, keep_path=True, throw_details=True)
for _ in range(5): for _ in range(5):
try: try:
if self.prompter_type == "role_content_dict":
body = self.body.copy() body = self.body.copy()
body.update(self._handle_history(history)) body.update(self._handle_history(history))
with no_ssl_verification(): with no_ssl_verification():
resp = requests.post( resp = requests.post(
self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=120 self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=180
)
else:
messages = self._handle_history(history)
data = self.data.copy()
data.update(messages[0])
with no_ssl_verification():
resp = requests.post(
self.url, data=data, files=messages[1], headers=self.headers, proxies=self.proxies, timeout=180
) )
# print(resp.status_code, resp.text) # print(resp.status_code, resp.text)
if resp.status_code != 200: if resp.status_code != 200: