add openai LMMs evaluation
This commit is contained in:
parent
7e334299f6
commit
8cac24e27d
|
@ -1,6 +1,6 @@
|
|||
module: src.client.agents.HTTPAgent
|
||||
parameters:
|
||||
url: https://one-api.glm.ai/v1/chat/completions # https://api.openai.com/v1/chat/completions
|
||||
url: https://api.openai.com/v1/chat/completions
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer
|
||||
|
|
|
@ -3,22 +3,20 @@ import: definition.yaml
|
|||
concurrency:
|
||||
task:
|
||||
# css-std: 5
|
||||
omnigibson-std: 6
|
||||
# minecraft-std: 6
|
||||
# omnigibson-std: 4
|
||||
minecraft-std: 4
|
||||
agent:
|
||||
# gpt-3.5-turbo-0613: 1 # use this
|
||||
gpt-4o-2024-05-13: 6
|
||||
gpt-4o-2024-05-13: 4
|
||||
|
||||
assignments: # List[Assignment] | Assignment
|
||||
- agent: # "task": List[str] | str , "agent": List[str] | str
|
||||
# - gpt-3.5-turbo-0613
|
||||
- gpt-4o-2024-05-13
|
||||
task:
|
||||
# - css-std
|
||||
- omnigibson-std
|
||||
# - minecraft-std
|
||||
# - omnigibson-std
|
||||
- minecraft-std
|
||||
|
||||
# output: "outputs/{TIMESTAMP}"
|
||||
# output: "outputs/css_test"
|
||||
output: "outputs/omnigibson"
|
||||
# output: "outputs/minecraft"
|
||||
# output: "outputs/omnigibson"
|
||||
output: "outputs/minecraft"
|
||||
|
|
|
@ -3,5 +3,5 @@ definition:
|
|||
|
||||
start:
|
||||
# css-std: 1
|
||||
omnigibson-std: 6
|
||||
# minecraft-std: 6
|
||||
# omnigibson-std: 4
|
||||
minecraft-std: 4
|
||||
|
|
|
@ -31,3 +31,4 @@ minecraft-std:
|
|||
"7": 2
|
||||
data_dir: "data/minecraft"
|
||||
output_dir: "outputs/minecraft"
|
||||
docker_image: "tianjiezhang/vab_minecraft:latest"
|
||||
|
|
|
@ -23,3 +23,4 @@ omnigibson-std:
|
|||
"7": 1
|
||||
data_dir: "data/omnigibson"
|
||||
output_dir: "outputs/omnigibson"
|
||||
docker_image: "tianjiezhang/vab_omnigibson:latest"
|
||||
|
|
11
scripts/minecraft_weights.sh
Normal file
11
scripts/minecraft_weights.sh
Normal file
|
@ -0,0 +1,11 @@
|
|||
mkdir -p data/minecraft/mineclip
|
||||
mkdir -p data/minecraft/steve1
|
||||
mkdir -p data/minecraft/vpt
|
||||
|
||||
wget https://openaipublic.blob.core.windows.net/minecraft-rl/models/2x.model -P data/minecraft/vpt
|
||||
|
||||
curl -L -o data/minecraft/mineclip/attn.pth "https://drive.google.com/uc?export=download&id=1uaZM1ZLBz2dZWcn85rZmjP7LV6Sg5PZW"
|
||||
|
||||
curl -L -o data/minecraft/steve1/steve1.weights "https://drive.google.com/uc?id=1E3fd_-H1rRZqMkUKHfiMhx-ppLLehQPI"
|
||||
|
||||
curl -L -o data/weights/steve1/steve1_prior.pt "https://drive.google.com/uc?id=1OdX5wiybK8jALVfP5_dEo0CWm9BQbDES"
|
72
scripts/omnigibson_download.py
Normal file
72
scripts/omnigibson_download.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def print_user_agreement():
|
||||
print('You are downloading dataset from Behavior-1K and OmniGibson. Here is the user agreement of OmniGibson 0.2.1\n'
|
||||
'\n\nBEHAVIOR DATA BUNDLE END USER LICENSE AGREEMENT\n'
|
||||
'Last revision: December 8, 2022\n'
|
||||
'This License Agreement is for the BEHAVIOR Data Bundle (“Data”). It works with OmniGibson (“Software”) which is a software stack licensed under the MIT License, provided in this repository: https://github.com/StanfordVL/OmniGibson. The license agreements for OmniGibson and the Data are independent. This BEHAVIOR Data Bundle contains artwork and images (“Third Party Content”) from third parties with restrictions on redistribution. It requires measures to protect the Third Party Content which we have taken such as encryption and the inclusion of restrictions on any reverse engineering and use. Recipient is granted the right to use the Data under the following terms and conditions of this License Agreement (“Agreement”):\n\n'
|
||||
'1. Use of the Data is permitted after responding "Yes" to this agreement. A decryption key will be installed automatically.\n'
|
||||
'2. Data may only be used for non-commercial academic research. You may not use a Data for any other purpose.\n'
|
||||
'3. The Data has been encrypted. You are strictly prohibited from extracting any Data from OmniGibson or reverse engineering.\n'
|
||||
'4. You may only use the Data within OmniGibson.\n'
|
||||
'5. You may not redistribute the key or any other Data or elements in whole or part.\n'
|
||||
'6. THE DATA AND SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE DATA OR SOFTWARE OR THE USE OR OTHER DEALINGS IN THE DATA OR SOFTWARE.\n\n')
|
||||
|
||||
|
||||
def print_omniverse_agreement():
|
||||
print("The NVIDIA Omniverse License Agreement (EULA) must be accepted before\n"
|
||||
"Omniverse Kit can start. The license terms for this product can be viewed at\n"
|
||||
"https://docs.omniverse.nvidia.com/app_isaacsim/common/NVIDIA_Omniverse_License_Agreement.html\n")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if not os.path.exists("data/omnigibson"):
|
||||
subprocess.run(["mkdir", "-p", "data/omnigibson"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/activity_definitions"):
|
||||
subprocess.run(["git", "clone", "https://www.modelscope.cn/datasets/VisualAgentBench/VAB-OmniGibson-Data.git", "data/omnigibson"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/GoogleNews-vectors-negative300.bin"):
|
||||
subprocess.run(["wget", "https://huggingface.co/LoganKilpatrick/GoogleNews-vectors-negative300/resolve/main/GoogleNews-vectors-negative300.bin.gz", "-O", "data/omnigibson/GoogleNews-vectors-negative300.bin.gz"])
|
||||
subprocess.run(["gzip", "-d", "data/omnigibson/GoogleNews-vectors-negative300.bin.gz"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/isaac-sim"):
|
||||
subprocess.run(["mkdir", "-p", "data/omnigibson/isaac-sim"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/datasets"):
|
||||
subprocess.run(["mkdir", "-p", "data/omnigibson/datasets"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/datasets/omnigibson.key"):
|
||||
_ = ((()==())+(()==()));__=(((_<<_)<<_)*_);___=('c%'[::(([]!=[])-(()==()))])*(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+(_+(()==())))))%((__+(((_<<_)<<_)+(_<<_))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+((_<<_)*_))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+_))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+(((_<<_)*_)+_))),(__+(((_<<_)<<_)+(()==()))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(_*_)))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(()==()))),(__+(((_<<_)<<_)+((_<<_)*_))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+(_+(()==())))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+_)),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)*_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(_+(()==())))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+_)),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+((_<<_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+(()==()))))))
|
||||
path = ___
|
||||
subprocess.run(["wget", path, "-O", "data/omnigibson/datasets/omnigibson.key"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/datasets/og_dataset"):
|
||||
print("\n")
|
||||
print_user_agreement()
|
||||
while (
|
||||
input(
|
||||
"Do you agree to the above terms for using OmniGibson dataset? [y/n]"
|
||||
)
|
||||
!= "y"
|
||||
):
|
||||
print("You need to agree to the terms for using OmniGibson dataset.")
|
||||
subprocess.run(["wget", "https://storage.googleapis.com/gibson_scenes/og_dataset.tar.gz", "-O", "data/omnigibson/datasets/og_dataset.tar.gz"])
|
||||
subprocess.run(["tar", "-xvf", "data/omnigibson/dataset/og_dataset.tar.gz", "-C", "data/omnigibson/dataset"])
|
||||
|
||||
if not os.path.exists("data/omnigibson/datasets/assets"):
|
||||
subprocess.run(["wget", "https://storage.googleapis.com/gibson_scenes/og_assets.tar.gz", "-O", "data/omnigibson/datasets/og_assets.tar.gz"])
|
||||
subprocess.run(["tar", "-xvf", "data/omnigibson/dataset/og_assets.tar.gz", "-C", "data/omnigibson/dataset"])
|
||||
|
||||
if os.path.exists("data/omnigibson/datasets/og_dataset/scenes/Beechwood_0_int/json/Beechwood_0_int_best.json"):
|
||||
subprocess.run(["rm", "-rf", "data/omnigibson/datasets/og_dataset/scenes/*"])
|
||||
subprocess.run(["mv", "data/omnigibson/vab_omnigibson_scenes/*", "data/omnigibson/datasets/og_dataset/scenes/"])
|
||||
|
||||
print_omniverse_agreement()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -196,7 +196,7 @@ class HTTPAgent(AgentClient):
|
|||
|
||||
def inference(self, history: List[dict]) -> str:
|
||||
history = replace_image_url(history, keep_path=False, throw_details=False)
|
||||
for _ in range(3):
|
||||
for _ in range(5):
|
||||
try:
|
||||
body = self.body.copy()
|
||||
body.update(self._handle_history(history))
|
||||
|
|
|
@ -15,6 +15,7 @@ class Container():
|
|||
data_dir = "data/minecraft"
|
||||
vab_source_dir = ""
|
||||
max_round = 100
|
||||
docker_image = ""
|
||||
|
||||
def use_port(self):
|
||||
time.sleep(random.random())
|
||||
|
@ -83,7 +84,7 @@ class Container():
|
|||
"PYTHONPYCACHEPREFIX": "/tmp"
|
||||
}
|
||||
self.container = self.client.containers.run(
|
||||
"vab-minecraft:latest",
|
||||
self.docker_image,
|
||||
f"xvfb-run -a python main.py --task_name {task} --port {self.port} --max_round {self.max_round}",
|
||||
environment=environment,
|
||||
volumes=volumes,
|
||||
|
@ -101,6 +102,9 @@ class Container():
|
|||
|
||||
async def execute(self, session):
|
||||
while True:
|
||||
self.container.reload()
|
||||
if self.container.status == "exited":
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect(("localhost", self.port))
|
||||
|
@ -207,7 +211,8 @@ class Container():
|
|||
def close(self):
|
||||
try:
|
||||
time.sleep(12)
|
||||
self.container.stop(timeout=36)
|
||||
if self.container.status != "exited":
|
||||
self.container.stop(timeout=24)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import asyncio
|
||||
|
||||
class Minecraft(Task):
|
||||
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, **configs):
|
||||
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, docker_image, **configs):
|
||||
super().__init__(**configs)
|
||||
self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_minecraft_src")
|
||||
Container.available_devices = available_devices
|
||||
|
@ -29,6 +29,7 @@ class Minecraft(Task):
|
|||
os.makedirs(Container.output_dir)
|
||||
Container.max_round = max_round
|
||||
Container.vab_source_dir = self.vab_source_path
|
||||
Container.docker_image = docker_image
|
||||
with open(f"{self.vab_source_path}/jarvis/assets/tasks/formated_tasks_v1.3.json", "r") as f:
|
||||
tasks = json.load(f)
|
||||
self.tasks = [t["task"] for t in tasks]
|
||||
|
@ -40,14 +41,14 @@ class Minecraft(Task):
|
|||
try:
|
||||
container = Container(self.tasks[index])
|
||||
session.clear()
|
||||
session.inject({"role": "user", "content": SYSTEM_PROMPT})
|
||||
session.inject({"role": "system", "content": SYSTEM_PROMPT})
|
||||
while True:
|
||||
result = await container.execute(session)
|
||||
if result.status != SampleStatus.RUNNING:
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR, result={"error": e})
|
||||
finally:
|
||||
try:
|
||||
container.close()
|
||||
|
@ -58,28 +59,29 @@ class Minecraft(Task):
|
|||
average_reward = 0
|
||||
success_count = 0
|
||||
for result in results:
|
||||
if result.result:
|
||||
if isinstance(result.result, Dict) and "success" in result.result:
|
||||
if result.result["success"]:
|
||||
success_count += 1
|
||||
average_reward += 1
|
||||
else:
|
||||
initial_reward = result.result["initial_reward"]
|
||||
final_reward = result.result["final_reward"]
|
||||
reward = (final_reward - initial_reward) / (1.0 - initial_reward)
|
||||
average_reward += max(0, final_reward)
|
||||
return {
|
||||
"total_count": len(results),
|
||||
"success_count": success_count,
|
||||
"success_rate": success_count / len(results),
|
||||
"average_reward": average_reward / len(results)
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
available_ports = [12000, 12001, 12002]
|
||||
available_devices = {"0":6, "1":6, "2":6}
|
||||
available_ports = [11000, 11001, 11002]
|
||||
available_devices = {"9":6, "1":6, "2":6}
|
||||
data_dir = "data/minecraft"
|
||||
output_dir = "outputs/minecraft"
|
||||
max_round = 100
|
||||
task = Minecraft(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, name="Minecraft-std")
|
||||
docker_image = "tianjiezhang/vab_minecraft:latest"
|
||||
task = Minecraft(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, docker_image=docker_image, name="Minecraft-std")
|
||||
print(Container.available_devices)
|
||||
print(Container.available_ports)
|
||||
session = Session()
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from jarvis.steveI.path import OPENAI_CLIP_PATH
|
||||
|
||||
# disable HuggingFace warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -17,7 +16,7 @@ def get_model_full_name(model_name):
|
|||
elif model_name in ["distilbert", "distilbert-base-uncased"]:
|
||||
model_full_name = "distilbert-base-uncased"
|
||||
elif model_name in ["clip", "openai/clip-vit-base-patch16"]:
|
||||
model_full_name = OPENAI_CLIP_PATH
|
||||
model_full_name = "openai/clip-vit-base-patch16"
|
||||
else:
|
||||
model_full_name = model_name
|
||||
return model_full_name
|
||||
|
|
|
@ -4,4 +4,3 @@ VPT_MODEL_PATH = Path(__file__).parent / "weights" / "vpt" / "2x.model"
|
|||
VPT_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1.weights"
|
||||
PRIOR_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1_prior.pt"
|
||||
MINECLIP_WEIGHT_PATH = Path(__file__).parent / "weights" / "mineclip" / "attn.pth"
|
||||
OPENAI_CLIP_PATH = Path(__file__).parent / "weights" / "clip-vit-base-patch16"
|
|
@ -23,6 +23,7 @@ class Container():
|
|||
vab_source_dir = ""
|
||||
modified_omnigibson_src = ""
|
||||
max_round = 100
|
||||
docker_image = ""
|
||||
|
||||
def use_port(self):
|
||||
time.sleep(random.random())
|
||||
|
@ -146,7 +147,7 @@ class Container():
|
|||
"OMNIGIBSON_HEADLESS": "1"
|
||||
}
|
||||
self.container = self.client.containers.run(
|
||||
"vab-omnigibson:latest",
|
||||
self.docker_image,
|
||||
f"python main.py --task {task[0]} --scene {task[1]} --port {self.port} --max_round {self.max_round}",
|
||||
environment=environment,
|
||||
volumes=volumes,
|
||||
|
@ -164,6 +165,9 @@ class Container():
|
|||
|
||||
async def execute(self, session):
|
||||
while True:
|
||||
self.container.reload()
|
||||
if self.container.status == "exited":
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect(("localhost", self.port))
|
||||
|
@ -219,7 +223,7 @@ class Container():
|
|||
]
|
||||
}
|
||||
)
|
||||
print(text_prompt)
|
||||
# print(text_prompt)
|
||||
message = await session.action()
|
||||
print(message)
|
||||
if message.status == AgentOutputStatus.AGENT_CONTEXT_LIMIT:
|
||||
|
@ -252,6 +256,7 @@ class Container():
|
|||
def close(self):
|
||||
try:
|
||||
time.sleep(4)
|
||||
if not self.container.status == "exit":
|
||||
self.container.stop(timeout=24)
|
||||
except:
|
||||
pass
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
import asyncio
|
||||
|
||||
class OmniGibson(Task):
|
||||
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, **configs):
|
||||
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, docker_image, **configs):
|
||||
super().__init__(**configs)
|
||||
self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_omnigibson_src")
|
||||
modified_omnigibson_src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modified_omnigibson_src")
|
||||
|
@ -30,6 +30,7 @@ class OmniGibson(Task):
|
|||
Container.max_round = max_round
|
||||
Container.vab_source_dir = self.vab_source_path
|
||||
Container.modified_omnigibson_src = modified_omnigibson_src
|
||||
Container.docker_image = docker_image
|
||||
with open(f"{self.vab_source_path}/task/tasks.txt", "r") as f:
|
||||
tasks = f.read()
|
||||
self.tasks = eval(tasks)
|
||||
|
@ -41,14 +42,14 @@ class OmniGibson(Task):
|
|||
try:
|
||||
container = Container(self.tasks[index])
|
||||
session.clear()
|
||||
session.inject({"role": "user", "content": SYSTEM_PROMPT})
|
||||
session.inject({"role": "system", "content": SYSTEM_PROMPT})
|
||||
while True:
|
||||
result = await container.execute(session)
|
||||
if result.status != SampleStatus.RUNNING:
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
|
||||
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR, result={"error": e})
|
||||
finally:
|
||||
try:
|
||||
container.close()
|
||||
|
@ -59,7 +60,7 @@ class OmniGibson(Task):
|
|||
average_reward = 0
|
||||
success_count = 0
|
||||
for result in results:
|
||||
if result.result:
|
||||
if isinstance(result.result, Dict) and "success" in result.result:
|
||||
if result.result["success"]:
|
||||
success_count += 1
|
||||
average_reward += 1
|
||||
|
@ -69,6 +70,8 @@ class OmniGibson(Task):
|
|||
reward = (final_reward - initial_reward) / (1.0 - initial_reward)
|
||||
average_reward += max(0, reward)
|
||||
return {
|
||||
"total_count": len(results),
|
||||
"success_count": success_count,
|
||||
"success_rate": success_count / len(results),
|
||||
"average_reward": average_reward / len(results)
|
||||
}
|
||||
|
@ -76,11 +79,12 @@ class OmniGibson(Task):
|
|||
|
||||
async def main():
|
||||
available_ports = [12000, 12001, 12002]
|
||||
available_devices = {"0":1, "1":1, "2":1}
|
||||
available_devices = {"9":1, "1":1, "2":1}
|
||||
data_dir = "data/omnigibson"
|
||||
output_dir = "outputs/omnigibson"
|
||||
max_round = 100
|
||||
task = OmniGibson(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, name="OmniGibson-std")
|
||||
docker_image = "vab_omnigibson:latest"
|
||||
task = OmniGibson(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, docker_image=docker_image, name="OmniGibson-std")
|
||||
print(Container.available_devices)
|
||||
print(Container.available_ports)
|
||||
session = Session()
|
||||
|
|
|
@ -41,6 +41,6 @@ class Assignment(BaseModel):
|
|||
|
||||
|
||||
class ChatHistoryItem(BaseModel):
|
||||
role: Literal["user", "agent"]
|
||||
role: Literal["user", "agent", "system"]
|
||||
# content: str
|
||||
content: Union[str, List[Dict]]
|
||||
|
|
Loading…
Reference in New Issue
Block a user