add open LMMs evaluation
This commit is contained in:
parent
8cac24e27d
commit
64b8e24e0a
|
@ -6,26 +6,8 @@ gpt-4o-2024-05-13:
|
||||||
model: "gpt-4o-2024-05-13"
|
model: "gpt-4o-2024-05-13"
|
||||||
max_tokens: 512
|
max_tokens: 512
|
||||||
|
|
||||||
gpt-3.5-turbo-0613:
|
glm-4v:
|
||||||
import: "./openai-chat.yaml"
|
import: "./finetuned_agent.yaml"
|
||||||
parameters:
|
parameters:
|
||||||
name: "gpt-3.5-turbo-0613"
|
name: "glm-4v"
|
||||||
body:
|
url: "http://"
|
||||||
model: "gpt-3.5-turbo-0613"
|
|
||||||
max_tokens: 512
|
|
||||||
|
|
||||||
text-davinci-003:
|
|
||||||
import: "./openai-text.yaml"
|
|
||||||
parameters:
|
|
||||||
name: "text-davinci-003"
|
|
||||||
body:
|
|
||||||
model: "text-davinci-003"
|
|
||||||
max_tokens: 512
|
|
||||||
|
|
||||||
text-davinci-002:
|
|
||||||
import: "./openai-text.yaml"
|
|
||||||
parameters:
|
|
||||||
name: "text-davinci-002"
|
|
||||||
body:
|
|
||||||
model: "text-davinci-002"
|
|
||||||
max_tokens: 512
|
|
||||||
|
|
9
configs/agents/finetuned_agent.yaml
Normal file
9
configs/agents/finetuned_agent.yaml
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
module: "src.client.agents.HTTPAgent"
|
||||||
|
parameters:
|
||||||
|
prompter:
|
||||||
|
name: "prompt_string"
|
||||||
|
args:
|
||||||
|
suffix: "<|assistant|>\n"
|
||||||
|
agent_format: "<|assistant|>\n{content}\n\n"
|
||||||
|
text_context_limit: 9776
|
||||||
|
return_format: "{response[response]}"
|
|
@ -1,7 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
import math
|
||||||
import requests
|
import requests
|
||||||
from urllib3.exceptions import InsecureRequestWarning
|
from urllib3.exceptions import InsecureRequestWarning
|
||||||
|
|
||||||
|
@ -102,26 +102,78 @@ class Prompter:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prompt_string(
|
def prompt_string(
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
suffix: str = "AGENT:",
|
suffix: str = "<|agent|>\n",
|
||||||
system_format: str = "SYSTEM: {content}\n\n###\n\n",
|
system_format: str = "<|system|>\n{content}\n\n",
|
||||||
user_format: str = "USER: {content}\n\n",
|
user_format: str = "<|user|>\n{content}\n\n",
|
||||||
agent_format: str = "AGENT: {content}\n\n",
|
agent_format: str = "<|agent|>\n{content}\n\n",
|
||||||
prompt_key: str = "prompt",
|
prompt_key: str = "prompt",
|
||||||
|
image_key: str = "image",
|
||||||
|
text_context_limit: int = 9776
|
||||||
):
|
):
|
||||||
# todo (YG): current it assumes the content is always a string, but it can be a list for multimodal support
|
def prompter(messages: List[Dict]):
|
||||||
def prompter(messages: List[Dict[str, str]]):
|
nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key, image_key, text_context_limit
|
||||||
nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key
|
|
||||||
prompt = prefix
|
def text_token_estimation(text: str):
|
||||||
for item in messages:
|
token_count = 0
|
||||||
|
text = text.replace("\n\n", "\n").replace("```", "`").replace("##", "#")
|
||||||
|
for char in text:
|
||||||
|
if not char.isalnum() and char != " ":
|
||||||
|
token_count += 1
|
||||||
|
words = text.split()
|
||||||
|
for word in words:
|
||||||
|
char_count = 0
|
||||||
|
for char in word:
|
||||||
|
if char.isalnum():
|
||||||
|
char_count += 1
|
||||||
|
token_count += math.ceil(char_count / 6)
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
def item_to_prompt(item: Dict):
|
||||||
|
prompt, images = "", []
|
||||||
if item["role"] == "system":
|
if item["role"] == "system":
|
||||||
prompt += system_format.format(content=item["content"])
|
prompt += system_format.format(content=item["content"])
|
||||||
elif item["role"] == "user":
|
elif item["role"] == "user":
|
||||||
prompt += user_format.format(content=item["content"])
|
if isinstance(item["content"], str):
|
||||||
|
prompt += user_format.format(content=item["content"])
|
||||||
|
elif isinstance(item["content"], list):
|
||||||
|
text_str = ""
|
||||||
|
for content in item["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_str += content["text"]
|
||||||
|
else:
|
||||||
|
images.append(content["image_url"]["url"].split("file://")[-1])
|
||||||
|
prompt += user_format.format(content=text_str)
|
||||||
else:
|
else:
|
||||||
prompt += agent_format.format(content=item["content"])
|
prompt += agent_format.format(content=item["content"])
|
||||||
|
return prompt, images
|
||||||
|
|
||||||
|
prompt = prefix
|
||||||
|
images = []
|
||||||
|
for item in messages[:5]:
|
||||||
|
item_prompt, item_images = item_to_prompt(item)
|
||||||
|
prompt += item_prompt
|
||||||
|
images += item_images
|
||||||
|
if len(messages) > 5:
|
||||||
|
prompt_tail, images_tail = item_to_prompt(messages[-1])
|
||||||
|
for index in range(len(messages)-2, 4, -2):
|
||||||
|
agent_item = messages[index]
|
||||||
|
agent_prompt, agent_images = item_to_prompt(agent_item)
|
||||||
|
user_item = messages[index-1]
|
||||||
|
user_prompt, user_images = item_to_prompt(user_item)
|
||||||
|
if text_token_estimation(prompt + user_prompt + agent_prompt + prompt_tail) > text_context_limit:
|
||||||
|
prompt_tail = "\n\n** Earlier trajectory has been truncated **\n\n" + prompt_tail
|
||||||
|
break
|
||||||
|
prompt_tail = user_prompt + agent_prompt + prompt_tail
|
||||||
|
images_tail = user_images + agent_images + images_tail
|
||||||
|
prompt += prompt_tail
|
||||||
|
images += images_tail
|
||||||
prompt += suffix
|
prompt += suffix
|
||||||
print(prompt)
|
if len(images) != 1:
|
||||||
return {prompt_key: prompt}
|
raise Exception("Only one image is supported")
|
||||||
|
return [
|
||||||
|
{prompt_key: prompt},
|
||||||
|
{image_key: open(images[0], "rb")}
|
||||||
|
]
|
||||||
|
|
||||||
return prompter
|
return prompter
|
||||||
|
|
||||||
|
@ -175,6 +227,7 @@ class HTTPAgent(AgentClient):
|
||||||
self,
|
self,
|
||||||
url,
|
url,
|
||||||
proxies=None,
|
proxies=None,
|
||||||
|
data=None,
|
||||||
body=None,
|
body=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
return_format="{response}",
|
return_format="{response}",
|
||||||
|
@ -185,9 +238,11 @@ class HTTPAgent(AgentClient):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.proxies = proxies or {}
|
self.proxies = proxies or {}
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
|
self.data = data or {}
|
||||||
self.body = body or {}
|
self.body = body or {}
|
||||||
self.return_format = return_format
|
self.return_format = return_format
|
||||||
self.prompter = Prompter.get_prompter(prompter)
|
self.prompter = Prompter.get_prompter(prompter)
|
||||||
|
self.prompter_type = prompter.get("name", "role_content_dict")
|
||||||
if not self.url:
|
if not self.url:
|
||||||
raise Exception("Please set 'url' parameter")
|
raise Exception("Please set 'url' parameter")
|
||||||
|
|
||||||
|
@ -195,15 +250,27 @@ class HTTPAgent(AgentClient):
|
||||||
return self.prompter(history)
|
return self.prompter(history)
|
||||||
|
|
||||||
def inference(self, history: List[dict]) -> str:
|
def inference(self, history: List[dict]) -> str:
|
||||||
history = replace_image_url(history, keep_path=False, throw_details=False)
|
if self.prompter_type == "role_content_dict":
|
||||||
|
history = replace_image_url(history, keep_path=False, throw_details=False)
|
||||||
|
else:
|
||||||
|
history = replace_image_url(history, keep_path=True, throw_details=True)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
try:
|
try:
|
||||||
body = self.body.copy()
|
if self.prompter_type == "role_content_dict":
|
||||||
body.update(self._handle_history(history))
|
body = self.body.copy()
|
||||||
with no_ssl_verification():
|
body.update(self._handle_history(history))
|
||||||
resp = requests.post(
|
with no_ssl_verification():
|
||||||
self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=120
|
resp = requests.post(
|
||||||
)
|
self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=180
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages = self._handle_history(history)
|
||||||
|
data = self.data.copy()
|
||||||
|
data.update(messages[0])
|
||||||
|
with no_ssl_verification():
|
||||||
|
resp = requests.post(
|
||||||
|
self.url, data=data, files=messages[1], headers=self.headers, proxies=self.proxies, timeout=180
|
||||||
|
)
|
||||||
# print(resp.status_code, resp.text)
|
# print(resp.status_code, resp.text)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
# print(resp.text)
|
# print(resp.text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user