85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from skimage.metrics import structural_similarity as ssim
|
|
from transformers import (
|
|
Blip2ForConditionalGeneration,
|
|
Blip2Processor,
|
|
)
|
|
|
|
|
|
def get_captioning_fn(
|
|
device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl"
|
|
) -> callable:
|
|
if "blip2" in model_name:
|
|
captioning_processor = Blip2Processor.from_pretrained(model_name)
|
|
captioning_model = Blip2ForConditionalGeneration.from_pretrained(
|
|
model_name, torch_dtype=dtype
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Only BLIP-2 models are currently supported"
|
|
)
|
|
captioning_model.to(device)
|
|
|
|
def caption_images(
|
|
images: List[Image.Image],
|
|
prompt: List[str] = None,
|
|
max_new_tokens: int = 32,
|
|
) -> List[str]:
|
|
if prompt is None:
|
|
# Perform VQA
|
|
inputs = captioning_processor(
|
|
images=images, return_tensors="pt"
|
|
).to(device, dtype)
|
|
generated_ids = captioning_model.generate(
|
|
**inputs, max_new_tokens=max_new_tokens
|
|
)
|
|
captions = captioning_processor.batch_decode(
|
|
generated_ids, skip_special_tokens=True
|
|
)
|
|
else:
|
|
# Regular captioning. Prompt is a list of strings, one for each image
|
|
assert len(images) == len(
|
|
prompt
|
|
), "Number of images and prompts must match, got {} and {}".format(
|
|
len(images), len(prompt)
|
|
)
|
|
inputs = captioning_processor(
|
|
images=images, text=prompt, return_tensors="pt"
|
|
).to(device, dtype)
|
|
generated_ids = captioning_model.generate(
|
|
**inputs, max_new_tokens=max_new_tokens
|
|
)
|
|
captions = captioning_processor.batch_decode(
|
|
generated_ids, skip_special_tokens=True
|
|
)
|
|
|
|
return captions
|
|
|
|
return caption_images
|
|
|
|
|
|
def get_image_ssim(imageA, imageB):
|
|
# Determine the size to which we should resize
|
|
new_size = max(imageA.size[0], imageB.size[0]), max(
|
|
imageA.size[1], imageB.size[1]
|
|
)
|
|
|
|
# Resize images
|
|
imageA = imageA.resize(new_size, Image.LANCZOS)
|
|
imageB = imageB.resize(new_size, Image.LANCZOS)
|
|
|
|
# Convert images to grayscale
|
|
grayA = imageA.convert("L")
|
|
grayB = imageB.convert("L")
|
|
|
|
# Convert grayscale images to numpy arrays for SSIM computation
|
|
grayA = np.array(grayA)
|
|
grayB = np.array(grayB)
|
|
|
|
# Compute the Structural Similarity Index (SSIM) between the two images
|
|
score, _ = ssim(grayA, grayB, full=True)
|
|
return score
|