654 lines
24 KiB
Python
654 lines
24 KiB
Python
"""base class for evaluation"""
|
|
# answer string match
|
|
import importlib
|
|
import json
|
|
import re
|
|
import time
|
|
import urllib
|
|
from pathlib import Path
|
|
from typing import Any, Optional, Tuple, Union
|
|
from urllib.parse import urljoin
|
|
|
|
import evaluate # type: ignore[import]
|
|
import requests
|
|
from beartype import beartype
|
|
from beartype.door import is_bearable
|
|
from nltk.tokenize import word_tokenize # type: ignore
|
|
from PIL import Image
|
|
from playwright.sync_api import CDPSession, Page
|
|
|
|
from browser_env.actions import Action
|
|
from browser_env.utils import StateInfo
|
|
from evaluation_harness import image_utils
|
|
from evaluation_harness.helper_functions import (
|
|
PseudoPage,
|
|
get_query_text,
|
|
get_query_text_lowercase,
|
|
gitlab_get_project_memeber_role,
|
|
llm_fuzzy_match,
|
|
llm_ua_match,
|
|
reddit_get_latest_comment_content_by_username,
|
|
reddit_get_latest_comment_obj_by_username,
|
|
reddit_get_parent_comment_username_of_latest_comment_by_username,
|
|
reddit_get_post_url,
|
|
shopping_get_latest_order_url,
|
|
shopping_get_num_reviews,
|
|
shopping_get_order_product_name_list,
|
|
shopping_get_order_product_option,
|
|
shopping_get_order_product_quantity,
|
|
shopping_get_product_attributes,
|
|
shopping_get_product_price,
|
|
shopping_get_rating_as_percentage,
|
|
shopping_get_sku_latest_review_author,
|
|
shopping_get_sku_latest_review_rating,
|
|
shopping_get_sku_latest_review_text,
|
|
)
|
|
|
|
Trajectory = list[Union[Action, StateInfo]]
|
|
|
|
|
|
@beartype
|
|
class Evaluator(object):
|
|
def __init__(self, eval_tag: str = "") -> None:
|
|
self.eval_tag = eval_tag
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage
|
|
) -> float:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_last_action(trajectory: Trajectory) -> Action:
|
|
try:
|
|
is_bearable(trajectory[-1], Action)
|
|
last_action = trajectory[-1]
|
|
except Exception:
|
|
raise ValueError(
|
|
"The last element of trajectory should be an action, add a fake stop action if needed"
|
|
)
|
|
|
|
return last_action # type: ignore[return-value]
|
|
|
|
@staticmethod
|
|
def get_last_state(trajectory: Trajectory) -> StateInfo:
|
|
try:
|
|
is_bearable(trajectory[-2], StateInfo)
|
|
last_state = trajectory[-2]
|
|
except Exception:
|
|
raise ValueError(
|
|
"The second last element of trajectory should be a state, add a fake stop action if needed"
|
|
)
|
|
|
|
return last_state # type: ignore[return-value]
|
|
|
|
|
|
@beartype
|
|
class NumericEvaluator(Evaluator):
|
|
"""Check if the numerical relationship is correct"""
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def str_2_int(s: str) -> Optional[int]:
|
|
try:
|
|
s = s.strip()
|
|
if "," in s:
|
|
s = s.replace(",", "")
|
|
|
|
return int(s)
|
|
except ValueError:
|
|
# Return None if the string cannot be converted to int
|
|
print(f"[NumericEvaluator error]: Cannot convert {s} to int")
|
|
return None
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def compare_inequality(
|
|
value: Union[int, float], inequality: str, tol: float = 1e-8
|
|
) -> bool:
|
|
"""
|
|
Compare a value (int or float) against an inequality string.
|
|
|
|
Args:
|
|
- value (int/float): The value to be compared.
|
|
- inequality (str): Inequality in the form of "< 700", ">= 300", etc.
|
|
- tol (float): Tolerance for floating point comparisons.
|
|
|
|
Returns:
|
|
- bool: True if the value satisfies the inequality, False otherwise.
|
|
"""
|
|
# Extract the operator and the number from the inequality string
|
|
ops = {
|
|
"<=": lambda x, y: x <= y + tol,
|
|
">=": lambda x, y: x >= y - tol,
|
|
"==": lambda x, y: abs(x - y) <= tol,
|
|
"<": lambda x, y: x < y + tol,
|
|
">": lambda x, y: x > y - tol,
|
|
}
|
|
|
|
for op, func in ops.items():
|
|
if op in inequality:
|
|
_, num = inequality.split(op)
|
|
return func(value, float(num.strip()))
|
|
|
|
raise ValueError(f"Invalid inequality string: {inequality}")
|
|
|
|
|
|
@beartype
|
|
class StringEvaluator(Evaluator):
|
|
"""Check whether the answer is correct with:
|
|
exact match: the answer is exactly the same as the reference answer
|
|
must include: each phrase in the reference answer must be included in the answer
|
|
fuzzy match: the answer is similar to the reference answer, using LLM judge
|
|
"""
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def clean_answer(answer: str) -> str:
|
|
if answer.startswith("'") and answer.endswith("'"):
|
|
answer = answer[1:-1]
|
|
elif answer.startswith('"') and answer.endswith('"'):
|
|
answer = answer[1:-1]
|
|
return answer.lower()
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def exact_match(ref: str, pred: Union[str, int]) -> float:
|
|
if isinstance(pred, int):
|
|
pred = str(pred)
|
|
return float(
|
|
StringEvaluator.clean_answer(pred)
|
|
== StringEvaluator.clean_answer(ref)
|
|
)
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def must_include(ref: str, pred: str) -> float:
|
|
clean_ref = StringEvaluator.clean_answer(ref)
|
|
clean_pred = StringEvaluator.clean_answer(pred)
|
|
# tokenize the answer if the ref is a single word
|
|
# prevent false positive (e.g, 0)
|
|
if len(word_tokenize(clean_ref)) == 1:
|
|
tok_pred = word_tokenize(clean_pred)
|
|
for token in tok_pred:
|
|
if '/' in token:
|
|
sub_tokens = token.split('/')
|
|
tok_pred.extend(sub_tokens)
|
|
return float(clean_ref in tok_pred)
|
|
else:
|
|
return float(clean_ref in clean_pred)
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def must_exclude(ref: str, pred: str) -> float:
|
|
"""Returns 1 if pred is not in ref, and 0 otherwise"""
|
|
clean_ref = StringEvaluator.clean_answer(ref)
|
|
clean_pred = StringEvaluator.clean_answer(pred)
|
|
# tokenize the answer if the ref is a single word
|
|
# prevent false positive (e.g, 0)
|
|
if len(word_tokenize(clean_ref)) == 1:
|
|
tok_pred = word_tokenize(clean_pred)
|
|
return float(clean_ref not in tok_pred)
|
|
else:
|
|
return float(clean_ref not in clean_pred)
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
|
|
return llm_fuzzy_match(pred, ref, intent)
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def ua_match(ref: str, pred: str, intent: str) -> float:
|
|
return llm_ua_match(pred, ref, intent)
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage | None = None
|
|
) -> float:
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
last_action = self.get_last_action(trajectory)
|
|
pred = self.clean_answer(last_action["answer"])
|
|
|
|
score = 1.0
|
|
for approach, value in configs["eval"]["reference_answers"].items():
|
|
match approach:
|
|
case "exact_match":
|
|
score *= self.exact_match(ref=value, pred=pred)
|
|
case "required_values":
|
|
required_values = value
|
|
assert isinstance(required_values, list)
|
|
pred = NumericEvaluator.str_2_int(pred)
|
|
if pred is None:
|
|
score = 0.0
|
|
else:
|
|
for v in required_values:
|
|
value_or = v.split(" |OR| ")
|
|
score *= any(
|
|
[
|
|
NumericEvaluator.compare_inequality(
|
|
pred, value
|
|
)
|
|
for value in value_or
|
|
]
|
|
)
|
|
case "must_include":
|
|
assert isinstance(value, list)
|
|
for must_value in value:
|
|
value_or = must_value.split(" |OR| ")
|
|
score *= any([self.must_include(ref=v, pred=pred) for v in value_or])
|
|
case "must_exclude":
|
|
assert isinstance(value, list)
|
|
for must_excl_value in value:
|
|
score *= self.must_exclude(
|
|
ref=must_excl_value, pred=pred
|
|
)
|
|
case "one_of":
|
|
assert isinstance(value, list)
|
|
found = False
|
|
for one_of_value in value:
|
|
one_of_value = self.clean_answer(one_of_value)
|
|
if one_of_value in pred:
|
|
found = True
|
|
break
|
|
score = score * found
|
|
case "fuzzy_match":
|
|
intent = configs["intent"]
|
|
if value == "N/A":
|
|
# if the instruction only asks the model to generate N/A when encountering an unachievable task
|
|
# without more concrete reasons
|
|
score *= self.exact_match(ref=value, pred=pred)
|
|
# if the instruction also asks the model to generate the reason why the task is unachievable
|
|
# this should be the default as it will prevent false positive N/A`
|
|
if score != 1:
|
|
score = 1.0 * self.ua_match(
|
|
intent=configs["intent"],
|
|
ref=configs["eval"]["string_note"],
|
|
pred=pred,
|
|
)
|
|
else:
|
|
assert isinstance(value, list)
|
|
reference = ', '.join(value)
|
|
score *= self.fuzzy_match(
|
|
ref=reference, pred=pred, intent=intent
|
|
)
|
|
return score
|
|
|
|
|
|
@beartype
|
|
class StringSoftEvaluator(Evaluator):
|
|
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage | None = None
|
|
) -> float:
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
last_action = self.get_last_action(trajectory)
|
|
pred = last_action["answer"]
|
|
ref = configs["eval"]["reference_answers"]
|
|
# rouge
|
|
m = evaluate.load("rouge")
|
|
rouge = m.compute(predictions=[pred], references=[ref])
|
|
return float(rouge["rouge1"])
|
|
|
|
|
|
@beartype
|
|
class URLExactEvaluator(Evaluator):
|
|
"""Check whether the URL is exactly the same as of the reference URLs"""
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage
|
|
) -> float:
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
def clean_url(url: str) -> str:
|
|
url = str(url)
|
|
# Replace http://localhost with http://127.0.0.1 to keep things consistent across evals.
|
|
url = url.replace("localhost", "127.0.0.1")
|
|
if url.endswith("/"):
|
|
url = url[:-1]
|
|
return url
|
|
|
|
pred = clean_url(page.url)
|
|
ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
|
|
ref_urls = [clean_url(url) for url in ref_urls]
|
|
matching_rule = configs["eval"].get("url_note", "EXACT")
|
|
if matching_rule == "EXACT":
|
|
if pred in ref_urls:
|
|
return 1.0
|
|
else:
|
|
return 0.0
|
|
elif matching_rule == "GOLD in PRED":
|
|
if any([ref in pred for ref in ref_urls]):
|
|
return 1.0
|
|
else:
|
|
return 0.0
|
|
else:
|
|
raise ValueError(f"Unknown matching rule: {matching_rule}")
|
|
|
|
|
|
@beartype
|
|
class HTMLContentExactEvaluator(Evaluator):
|
|
"""Check whether the contents appear in the page"""
|
|
|
|
@staticmethod
|
|
@beartype
|
|
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
|
|
return llm_fuzzy_match(pred, ref, intent)
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage
|
|
) -> float:
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
targets = configs["eval"]["program_html"]
|
|
|
|
score = 1.0
|
|
for target in targets:
|
|
target_url: str = target["url"] # which url to check
|
|
if target_url.startswith("func"):
|
|
func = target_url.split("func:")[1]
|
|
func = func.replace("__last_url__", page.url)
|
|
target_url = eval(func)
|
|
|
|
locator: str = target["locator"] # js element locator
|
|
|
|
# navigate to that url
|
|
if target_url != "last":
|
|
page.goto(target_url)
|
|
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
|
|
|
|
# empty, use the full page
|
|
if not locator.strip():
|
|
selected_element = page.content()
|
|
# use JS to select the element
|
|
elif locator.startswith("document.") or locator.startswith(
|
|
"[...document."
|
|
):
|
|
if "prep_actions" in target:
|
|
try:
|
|
for prep_action in target["prep_actions"]:
|
|
page.evaluate(f"() => {prep_action}")
|
|
except Exception:
|
|
pass
|
|
try:
|
|
selected_element = str(page.evaluate(f"() => {locator}"))
|
|
if not selected_element:
|
|
selected_element = ""
|
|
except Exception:
|
|
# the page is wrong, return empty
|
|
selected_element = ""
|
|
elif locator.startswith("lambda:"):
|
|
try:
|
|
locator = locator.lstrip("lambda:")
|
|
selected_element = page.evaluate(locator)
|
|
if not selected_element:
|
|
selected_element = None
|
|
except Exception:
|
|
# the page is wrong, return empty
|
|
selected_element = None
|
|
# run program to call API
|
|
elif locator.startswith("func:"): # a helper function
|
|
func = locator.split("func:")[1]
|
|
func = func.replace("__page__", "page")
|
|
selected_element = eval(func)
|
|
else:
|
|
raise ValueError(f"Unknown locator: {locator}")
|
|
|
|
# If the selected element is None, then the page is wrong
|
|
if selected_element is None:
|
|
score = 0.0
|
|
break
|
|
|
|
if "exact_match" in target["required_contents"]:
|
|
required_contents = target["required_contents"]["exact_match"]
|
|
score *= StringEvaluator.exact_match(
|
|
ref=required_contents, pred=selected_element
|
|
)
|
|
elif "must_include" in target["required_contents"]:
|
|
required_contents = target["required_contents"]["must_include"]
|
|
assert isinstance(required_contents, list)
|
|
for content in required_contents:
|
|
content_or = content.split(" |OR| ")
|
|
score *= any(
|
|
[
|
|
StringEvaluator.must_include(
|
|
ref=content, pred=selected_element
|
|
)
|
|
for content in content_or
|
|
]
|
|
)
|
|
elif "must_exclude" in target["required_contents"]:
|
|
required_contents = target["required_contents"]["must_exclude"]
|
|
assert isinstance(required_contents, list)
|
|
for content in required_contents:
|
|
assert " |OR| " not in content
|
|
score *= StringEvaluator.must_exclude(
|
|
content, pred=selected_element
|
|
)
|
|
elif "required_values" in target["required_contents"]:
|
|
required_values = target["required_contents"][
|
|
"required_values"
|
|
]
|
|
assert isinstance(required_values, list)
|
|
if isinstance(selected_element, str):
|
|
selected_element = NumericEvaluator.str_2_int(
|
|
selected_element
|
|
)
|
|
if selected_element is None:
|
|
score = 0.0
|
|
else:
|
|
for value in required_values:
|
|
value_or = value.split(" |OR| ")
|
|
score *= any(
|
|
[
|
|
NumericEvaluator.compare_inequality(
|
|
selected_element, value
|
|
)
|
|
for value in value_or
|
|
]
|
|
)
|
|
elif "fuzzy_match" in target["required_contents"]:
|
|
required_contents = target["required_contents"]["fuzzy_match"]
|
|
intent = configs["intent"]
|
|
|
|
assert isinstance(required_contents, list)
|
|
reference = ', '.join(required_contents)
|
|
score *= self.fuzzy_match(
|
|
ref=reference, pred=selected_element, intent=intent
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown required_contents: {target['required_contents'].keys()}"
|
|
)
|
|
|
|
return score
|
|
|
|
|
|
@beartype
|
|
class PageImageEvaluator(Evaluator):
|
|
"""Check whether the answer is correct by querying a vision model."""
|
|
|
|
def __init__(self, captioning_fn):
|
|
self.captioning_fn = captioning_fn
|
|
# Default to 0.8 as the threshold for similarity to account for compression, resizing, etc
|
|
# This might be too generous but we bias towards minimizing false negatives.
|
|
self.ssim_threshold = 0.8
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage | None = None
|
|
) -> float:
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
for query in configs["eval"]["page_image_query"]:
|
|
locator: str = query["eval_image_class"]
|
|
target_url: str = query["eval_image_url"]
|
|
if target_url.startswith("func"):
|
|
func = target_url.split("func:")[1]
|
|
func = func.replace("__last_url__", page.url)
|
|
target_url = eval(func)
|
|
|
|
# navigate to that url
|
|
if target_url != "last":
|
|
page.goto(target_url)
|
|
time.sleep(3) # TODO(jykoh): fix this hard-coded sleep
|
|
|
|
# empty, use the full page
|
|
if not locator.strip():
|
|
images = page.get_by_role("img").all()
|
|
# use JS to select the element
|
|
elif locator.startswith("."):
|
|
# Get all img children under the locator
|
|
elements = page.query_selector_all(locator)
|
|
images = []
|
|
for element in elements:
|
|
is_img = element.evaluate(
|
|
'element => element.tagName === "IMG"'
|
|
)
|
|
if is_img:
|
|
images.append(element)
|
|
else:
|
|
images.extend(element.query_selector_all("img"))
|
|
else:
|
|
raise ValueError(f"Unknown locator: {locator}")
|
|
|
|
if images == []:
|
|
return 0.0
|
|
|
|
all_image_pixels = []
|
|
for image in images:
|
|
try:
|
|
# Get image from URL.
|
|
image_url = image.get_attribute("src")
|
|
if not image_url.startswith(
|
|
("http://", "https://", "www.")
|
|
):
|
|
image_url = urljoin(page.url, image_url)
|
|
image = Image.open(
|
|
requests.get(image_url, stream=True).raw
|
|
)
|
|
all_image_pixels.append(image)
|
|
except Exception as e:
|
|
print("[WARNING]: ", e)
|
|
|
|
score = 1.0
|
|
if all_image_pixels == []:
|
|
return 0.0
|
|
else:
|
|
# Run the VQA eval on the image elements.
|
|
eval_vqas = query.get("eval_vqa", [])
|
|
assert (
|
|
len(eval_vqas) > 0 or "eval_fuzzy_image_match" in query
|
|
), "eval_vqa must have at least 2 questions or eval_fuzzy_image_match must be True"
|
|
for qa in eval_vqas:
|
|
question, answer = qa["question"], qa["answer"]
|
|
prompt = f"Q: {question} A:"
|
|
pred_ans = self.captioning_fn(
|
|
all_image_pixels, [prompt] * len(all_image_pixels)
|
|
)
|
|
score *= float(
|
|
any(
|
|
[answer.lower() in ans.lower() for ans in pred_ans]
|
|
)
|
|
)
|
|
|
|
if "eval_fuzzy_image_match" in query:
|
|
ssim_threshold = query.get(
|
|
"ssim_threshold", self.ssim_threshold
|
|
)
|
|
exact_match_imgs = query["eval_fuzzy_image_match"].split(
|
|
" |OR| "
|
|
)
|
|
all_exact_match_pixels = []
|
|
|
|
for exact_match_img in exact_match_imgs:
|
|
if exact_match_img.startswith("http"):
|
|
exact_match_pixels = Image.open(
|
|
requests.get(exact_match_img, stream=True).raw
|
|
)
|
|
else:
|
|
exact_match_pixels = Image.open(exact_match_img)
|
|
all_exact_match_pixels.append(exact_match_pixels)
|
|
|
|
# Check if any of the images on the page match
|
|
found_exact_match = False
|
|
for exact_match_pixels in all_exact_match_pixels:
|
|
for image_pixels in all_image_pixels:
|
|
ssim = image_utils.get_image_ssim(
|
|
image_pixels, exact_match_pixels
|
|
)
|
|
if ssim > ssim_threshold:
|
|
found_exact_match = True
|
|
break
|
|
score *= float(found_exact_match)
|
|
|
|
return score
|
|
|
|
|
|
class EvaluatorComb:
|
|
def __init__(self, evaluators: list[Evaluator]) -> None:
|
|
self.evaluators = evaluators
|
|
|
|
def __call__(
|
|
self,
|
|
trajectory: Trajectory,
|
|
config_file: Path | str,
|
|
page: Page | PseudoPage
|
|
) -> float:
|
|
|
|
score = 1.0
|
|
for evaluator in self.evaluators:
|
|
cur_score = evaluator(trajectory, config_file, page)
|
|
score *= cur_score
|
|
|
|
return score
|
|
|
|
|
|
@beartype
|
|
def evaluator_router(
|
|
config_file: Path | str, captioning_fn=None
|
|
) -> EvaluatorComb:
|
|
"""Router to get the evaluator class"""
|
|
with open(config_file, "r") as f:
|
|
configs = json.load(f)
|
|
|
|
eval_types = configs["eval"]["eval_types"]
|
|
evaluators: list[Evaluator | EvaluatorPartial] = []
|
|
for eval_type in eval_types:
|
|
match eval_type:
|
|
case "string_match":
|
|
evaluators.append(StringEvaluator())
|
|
case "url_match":
|
|
evaluators.append(URLExactEvaluator())
|
|
case "program_html":
|
|
evaluators.append(HTMLContentExactEvaluator())
|
|
case "page_image_query":
|
|
evaluators.append(PageImageEvaluator(captioning_fn))
|
|
case _:
|
|
raise ValueError(f"eval_type {eval_type} is not supported")
|
|
|
|
return EvaluatorComb(evaluators)
|