webrl/VAB-WebArena-Lite/evaluation_harness/image_utils.py
2025-04-23 17:01:18 +08:00

85 lines
2.6 KiB
Python

from typing import List
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from transformers import (
Blip2ForConditionalGeneration,
Blip2Processor,
)
def get_captioning_fn(
device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl"
) -> callable:
if "blip2" in model_name:
captioning_processor = Blip2Processor.from_pretrained(model_name)
captioning_model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype
)
else:
raise NotImplementedError(
"Only BLIP-2 models are currently supported"
)
captioning_model.to(device)
def caption_images(
images: List[Image.Image],
prompt: List[str] = None,
max_new_tokens: int = 32,
) -> List[str]:
if prompt is None:
# Perform VQA
inputs = captioning_processor(
images=images, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
else:
# Regular captioning. Prompt is a list of strings, one for each image
assert len(images) == len(
prompt
), "Number of images and prompts must match, got {} and {}".format(
len(images), len(prompt)
)
inputs = captioning_processor(
images=images, text=prompt, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
return captions
return caption_images
def get_image_ssim(imageA, imageB):
# Determine the size to which we should resize
new_size = max(imageA.size[0], imageB.size[0]), max(
imageA.size[1], imageB.size[1]
)
# Resize images
imageA = imageA.resize(new_size, Image.LANCZOS)
imageB = imageB.resize(new_size, Image.LANCZOS)
# Convert images to grayscale
grayA = imageA.convert("L")
grayB = imageB.convert("L")
# Convert grayscale images to numpy arrays for SSIM computation
grayA = np.array(grayA)
grayB = np.array(grayB)
# Compute the Structural Similarity Index (SSIM) between the two images
score, _ = ssim(grayA, grayB, full=True)
return score